tesseract 4.1.1
Loading...
Searching...
No Matches
network.cpp
Go to the documentation of this file.
1
2// File: network.cpp
3// Description: Base class for neural network implementations.
4// Author: Ray Smith
5//
6// (C) Copyright 2013, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18// Include automatically generated configuration file if running autoconf.
19#ifdef HAVE_CONFIG_H
20#include "config_auto.h"
21#endif
22
23#include "network.h"
24
25#include <cstdlib>
26
27// This base class needs to know about all its sub-classes because of the
28// factory deserializing method: CreateFromFile.
29#include "allheaders.h"
30#include "convolve.h"
31#include "fullyconnected.h"
32#include "input.h"
33#include "lstm.h"
34#include "maxpool.h"
35#include "parallel.h"
36#include "reconfig.h"
37#include "reversed.h"
38#include "scrollview.h"
39#include "series.h"
40#include "statistc.h"
41#ifdef INCLUDE_TENSORFLOW
42#include "tfnetwork.h"
43#endif
44#include "tprintf.h"
45
46namespace tesseract {
47
48// Min and max window sizes.
49const int kMinWinSize = 500;
50const int kMaxWinSize = 2000;
51// Window frame sizes need adding on to make the content fit.
52const int kXWinFrameSize = 30;
53const int kYWinFrameSize = 80;
54
55// String names corresponding to the NetworkType enum.
56// Keep in sync with NetworkType.
57// Names used in Serialization to allow re-ordering/addition/deletion of
58// layer types in NetworkType without invalidating existing network files.
59static char const* const kTypeNames[NT_COUNT] = {
60 "Invalid", "Input",
61 "Convolve", "Maxpool",
62 "Parallel", "Replicated",
63 "ParBidiLSTM", "DepParUDLSTM",
64 "Par2dLSTM", "Series",
65 "Reconfig", "RTLReversed",
66 "TTBReversed", "XYTranspose",
67 "LSTM", "SummLSTM",
68 "Logistic", "LinLogistic",
69 "LinTanh", "Tanh",
70 "Relu", "Linear",
71 "Softmax", "SoftmaxNoCTC",
72 "LSTMSoftmax", "LSTMBinarySoftmax",
73 "TensorFlow",
74};
75
77 : type_(NT_NONE),
78 training_(TS_ENABLED),
79 needs_to_backprop_(true),
80 network_flags_(0),
81 ni_(0),
82 no_(0),
83 num_weights_(0),
84 forward_win_(nullptr),
85 backward_win_(nullptr),
86 randomizer_(nullptr) {}
87Network::Network(NetworkType type, const STRING& name, int ni, int no)
88 : type_(type),
89 training_(TS_ENABLED),
90 needs_to_backprop_(true),
91 network_flags_(0),
92 ni_(ni),
93 no_(no),
94 num_weights_(0),
95 name_(name),
96 forward_win_(nullptr),
97 backward_win_(nullptr),
98 randomizer_(nullptr) {}
99
100
101// Suspends/Enables/Permanently disables training by setting the training_
102// flag. Serialize and DeSerialize only operate on the run-time data if state
103// is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
104// temporarily disable layers in state TS_ENABLED, allowing a trainer to
105// serialize as if it were a recognizer.
106// TS_RE_ENABLE will re-enable layers that were previously in any disabled
107// state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
108// TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
109// recognizer can be converted back to a trainer.
111 if (state == TS_RE_ENABLE) {
112 // Enable only from temp disabled.
114 } else if (state == TS_TEMP_DISABLE) {
115 // Temp disable only from enabled.
116 if (training_ == TS_ENABLED) training_ = state;
117 } else {
118 training_ = state;
119 }
120}
121
122// Sets flags that control the action of the network. See NetworkFlags enum
123// for bit values.
124void Network::SetNetworkFlags(uint32_t flags) {
125 network_flags_ = flags;
126}
127
128// Sets up the network for training. Initializes weights using weights of
129// scale `range` picked according to the random number generator `randomizer`.
130int Network::InitWeights(float range, TRand* randomizer) {
131 randomizer_ = randomizer;
132 return 0;
133}
134
135// Provides a pointer to a TRand for any networks that care to use it.
136// Note that randomizer is a borrowed pointer that should outlive the network
137// and should not be deleted by any of the networks.
138void Network::SetRandomizer(TRand* randomizer) {
139 randomizer_ = randomizer;
140}
141
142// Sets needs_to_backprop_ to needs_backprop and returns true if
143// needs_backprop || any weights in this network so the next layer forward
144// can be told to produce backprop for this layer if needed.
145bool Network::SetupNeedsBackprop(bool needs_backprop) {
146 needs_to_backprop_ = needs_backprop;
147 return needs_backprop || num_weights_ > 0;
148}
149
150// Writes to the given file. Returns false in case of error.
151bool Network::Serialize(TFile* fp) const {
152 int8_t data = NT_NONE;
153 if (!fp->Serialize(&data)) return false;
154 STRING type_name = kTypeNames[type_];
155 if (!type_name.Serialize(fp)) return false;
156 data = training_;
157 if (!fp->Serialize(&data)) return false;
158 data = needs_to_backprop_;
159 if (!fp->Serialize(&data)) return false;
160 if (!fp->Serialize(&network_flags_)) return false;
161 if (!fp->Serialize(&ni_)) return false;
162 if (!fp->Serialize(&no_)) return false;
163 if (!fp->Serialize(&num_weights_)) return false;
164 if (!name_.Serialize(fp)) return false;
165 return true;
166}
167
168static NetworkType getNetworkType(TFile* fp) {
169 int8_t data;
170 if (!fp->DeSerialize(&data)) return NT_NONE;
171 if (data == NT_NONE) {
172 STRING type_name;
173 if (!type_name.DeSerialize(fp)) return NT_NONE;
174 for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
175 }
176 if (data == NT_COUNT) {
177 tprintf("Invalid network layer type:%s\n", type_name.string());
178 return NT_NONE;
179 }
180 }
181 return static_cast<NetworkType>(data);
182}
183
184// Reads from the given file. Returns nullptr in case of error.
185// Determines the type of the serialized class and calls its DeSerialize
186// on a new object of the appropriate type, which is returned.
188 NetworkType type; // Type of the derived network class.
189 TrainingState training; // Are we currently training?
190 bool needs_to_backprop; // This network needs to output back_deltas.
191 int32_t network_flags; // Behavior control flags in NetworkFlags.
192 int32_t ni; // Number of input values.
193 int32_t no; // Number of output values.
194 int32_t num_weights; // Number of weights in this and sub-network.
195 STRING name; // A unique name for this layer.
196 int8_t data;
197 Network* network = nullptr;
198 type = getNetworkType(fp);
199 if (!fp->DeSerialize(&data)) return nullptr;
200 training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
201 if (!fp->DeSerialize(&data)) return nullptr;
202 needs_to_backprop = data != 0;
203 if (!fp->DeSerialize(&network_flags)) return nullptr;
204 if (!fp->DeSerialize(&ni)) return nullptr;
205 if (!fp->DeSerialize(&no)) return nullptr;
206 if (!fp->DeSerialize(&num_weights)) return nullptr;
207 if (!name.DeSerialize(fp)) return nullptr;
208
209 switch (type) {
210 case NT_CONVOLVE:
211 network = new Convolve(name, ni, 0, 0);
212 break;
213 case NT_INPUT:
214 network = new Input(name, ni, no);
215 break;
216 case NT_LSTM:
217 case NT_LSTM_SOFTMAX:
219 case NT_LSTM_SUMMARY:
220 network =
221 new LSTM(name, ni, no, no, false, type);
222 break;
223 case NT_MAXPOOL:
224 network = new Maxpool(name, ni, 0, 0);
225 break;
226 // All variants of Parallel.
227 case NT_PARALLEL:
228 case NT_REPLICATED:
229 case NT_PAR_RL_LSTM:
230 case NT_PAR_UD_LSTM:
231 case NT_PAR_2D_LSTM:
232 network = new Parallel(name, type);
233 break;
234 case NT_RECONFIG:
235 network = new Reconfig(name, ni, 0, 0);
236 break;
237 // All variants of reversed.
238 case NT_XREVERSED:
239 case NT_YREVERSED:
240 case NT_XYTRANSPOSE:
241 network = new Reversed(name, type);
242 break;
243 case NT_SERIES:
244 network = new Series(name);
245 break;
246 case NT_TENSORFLOW:
247#ifdef INCLUDE_TENSORFLOW
248 network = new TFNetwork(name);
249#else
250 tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
251#endif
252 break;
253 // All variants of FullyConnected.
254 case NT_SOFTMAX:
256 case NT_RELU:
257 case NT_TANH:
258 case NT_LINEAR:
259 case NT_LOGISTIC:
260 case NT_POSCLIP:
261 case NT_SYMCLIP:
262 network = new FullyConnected(name, ni, no, type);
263 break;
264 default:
265 break;
266 }
267 if (network) {
268 network->training_ = training;
270 network->network_flags_ = network_flags;
271 network->num_weights_ = num_weights;
272 if (!network->DeSerialize(fp)) {
273 delete network;
274 network = nullptr;
275 }
276 }
277 return network;
278}
279
280// Returns a random number in [-range, range].
281double Network::Random(double range) {
282 ASSERT_HOST(randomizer_ != nullptr);
283 return randomizer_->SignedRand(range);
284}
285
286// === Debug image display methods. ===
287// Displays the image of the matrix to the forward window.
289#ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
290 Pix* image = matrix.ToPix();
291 ClearWindow(false, name_.string(), pixGetWidth(image),
292 pixGetHeight(image), &forward_win_);
295#endif // GRAPHICS_DISABLED
296}
297
298// Displays the image of the matrix to the backward window.
300#ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
301 Pix* image = matrix.ToPix();
302 STRING window_name = name_ + "-back";
303 ClearWindow(false, window_name.string(), pixGetWidth(image),
304 pixGetHeight(image), &backward_win_);
307#endif // GRAPHICS_DISABLED
308}
309
310#ifndef GRAPHICS_DISABLED
311// Creates the window if needed, otherwise clears it.
312void Network::ClearWindow(bool tess_coords, const char* window_name,
313 int width, int height, ScrollView** window) {
314 if (*window == nullptr) {
315 int min_size = std::min(width, height);
316 if (min_size < kMinWinSize) {
317 if (min_size < 1) min_size = 1;
318 width = width * kMinWinSize / min_size;
319 height = height * kMinWinSize / min_size;
320 }
321 width += kXWinFrameSize;
322 height += kYWinFrameSize;
323 if (width > kMaxWinSize) width = kMaxWinSize;
324 if (height > kMaxWinSize) height = kMaxWinSize;
325 *window = new ScrollView(window_name, 80, 100, width, height, width, height,
326 tess_coords);
327 tprintf("Created window %s of size %d, %d\n", window_name, width, height);
328 } else {
329 (*window)->Clear();
330 }
331}
332
333// Displays the pix in the given window. and returns the height of the pix.
334// The pix is pixDestroyed.
335int Network::DisplayImage(Pix* pix, ScrollView* window) {
336 int height = pixGetHeight(pix);
337 window->Image(pix, 0, 0);
338 pixDestroy(&pix);
339 return height;
340}
341#endif // GRAPHICS_DISABLED
342
343} // namespace tesseract.
#define ASSERT_HOST(x)
Definition: errcode.h:88
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35
const int kXWinFrameSize
Definition: network.cpp:52
const int kYWinFrameSize
Definition: network.cpp:53
const int kMinWinSize
Definition: network.cpp:49
TrainingState
Definition: network.h:92
@ TS_TEMP_DISABLE
Definition: network.h:97
@ TS_ENABLED
Definition: network.h:95
@ TS_DISABLED
Definition: network.h:94
@ TS_RE_ENABLE
Definition: network.h:99
NetworkType
Definition: network.h:43
@ NT_LINEAR
Definition: network.h:67
@ NT_MAXPOOL
Definition: network.h:48
@ NT_RELU
Definition: network.h:66
@ NT_XREVERSED
Definition: network.h:56
@ NT_LSTM
Definition: network.h:60
@ NT_CONVOLVE
Definition: network.h:47
@ NT_SOFTMAX
Definition: network.h:68
@ NT_NONE
Definition: network.h:44
@ NT_LOGISTIC
Definition: network.h:62
@ NT_PAR_UD_LSTM
Definition: network.h:52
@ NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:76
@ NT_PARALLEL
Definition: network.h:49
@ NT_SYMCLIP
Definition: network.h:64
@ NT_PAR_2D_LSTM
Definition: network.h:53
@ NT_LSTM_SUMMARY
Definition: network.h:61
@ NT_YREVERSED
Definition: network.h:57
@ NT_RECONFIG
Definition: network.h:55
@ NT_INPUT
Definition: network.h:45
@ NT_TENSORFLOW
Definition: network.h:78
@ NT_POSCLIP
Definition: network.h:63
@ NT_LSTM_SOFTMAX
Definition: network.h:75
@ NT_XYTRANSPOSE
Definition: network.h:58
@ NT_SERIES
Definition: network.h:54
@ NT_SOFTMAX_NO_CTC
Definition: network.h:69
@ NT_TANH
Definition: network.h:65
@ NT_PAR_RL_LSTM
Definition: network.h:51
@ NT_COUNT
Definition: network.h:80
@ NT_REPLICATED
Definition: network.h:50
const int kMaxWinSize
Definition: network.cpp:50
double SignedRand(double range)
Definition: helpers.h:55
bool Serialize(const char *data, size_t count=1)
Definition: serialis.cpp:148
bool DeSerialize(char *data, size_t count=1)
Definition: serialis.cpp:104
Definition: strngs.h:45
bool Serialize(FILE *fp) const
Definition: strngs.cpp:146
const char * string() const
Definition: strngs.cpp:194
bool DeSerialize(bool swap, FILE *fp)
Definition: strngs.cpp:159
int32_t network_flags_
Definition: network.h:296
NetworkType type_
Definition: network.h:293
bool needs_to_backprop_
Definition: network.h:295
int num_weights() const
Definition: network.h:119
virtual bool SetupNeedsBackprop(bool needs_backprop)
Definition: network.cpp:145
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:312
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:288
double Random(double range)
Definition: network.cpp:281
virtual bool DeSerialize(TFile *fp)=0
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:299
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:110
bool needs_to_backprop() const
Definition: network.h:116
ScrollView * forward_win_
Definition: network.h:303
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:187
static int DisplayImage(Pix *pix, ScrollView *window)
Definition: network.cpp:335
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
ScrollView * backward_win_
Definition: network.h:304
const STRING & name() const
Definition: network.h:138
int32_t num_weights_
Definition: network.h:299
virtual int InitWeights(float range, TRand *randomizer)
Definition: network.cpp:130
TrainingState training_
Definition: network.h:294
virtual void SetNetworkFlags(uint32_t flags)
Definition: network.cpp:124
NetworkType type() const
Definition: network.h:112
TRand * randomizer_
Definition: network.h:305
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
Pix * ToPix() const
Definition: networkio.cpp:286
static void Update()
Definition: scrollview.cpp:709
void Image(struct Pix *image, int x_pos, int y_pos)
Definition: scrollview.cpp:765