tesseract 4.1.1
Loading...
Searching...
No Matches
network.h
Go to the documentation of this file.
1
2// File: network.h
3// Description: Base class for neural network implementations.
4// Author: Ray Smith
5//
6// (C) Copyright 2013, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18#ifndef TESSERACT_LSTM_NETWORK_H_
19#define TESSERACT_LSTM_NETWORK_H_
20
21#include <cstdio>
22#include <cmath>
23
24#include "genericvector.h"
25#include "helpers.h"
26#include "matrix.h"
27#include "networkio.h"
28#include "serialis.h"
29#include "static_shape.h"
30#include "strngs.h" // for STRING
31#include "tprintf.h"
32
33struct Pix;
34class ScrollView;
35class TBOX;
36
37namespace tesseract {
38
39class ImageData;
40class NetworkScratch;
41
42// Enum to store the run-time type of a Network. Keep in sync with kTypeNames.
44 NT_NONE, // The naked base class.
45 NT_INPUT, // Inputs from an image.
46 // Plumbing networks combine other networks or rearrange the inputs.
47 NT_CONVOLVE, // Duplicates inputs in a sliding window neighborhood.
48 NT_MAXPOOL, // Chooses the max result from a rectangle.
49 NT_PARALLEL, // Runs networks in parallel.
50 NT_REPLICATED, // Runs identical networks in parallel.
51 NT_PAR_RL_LSTM, // Runs LTR and RTL LSTMs in parallel.
52 NT_PAR_UD_LSTM, // Runs Up and Down LSTMs in parallel.
53 NT_PAR_2D_LSTM, // Runs 4 LSTMs in parallel.
54 NT_SERIES, // Executes a sequence of layers.
55 NT_RECONFIG, // Scales the time/y size but makes the output deeper.
56 NT_XREVERSED, // Reverses the x direction of the inputs/outputs.
57 NT_YREVERSED, // Reverses the y-direction of the inputs/outputs.
58 NT_XYTRANSPOSE, // Transposes x and y (for just a single op).
59 // Functional networks actually calculate stuff.
60 NT_LSTM, // Long-Short-Term-Memory block.
61 NT_LSTM_SUMMARY, // LSTM that only keeps its last output.
62 NT_LOGISTIC, // Fully connected logistic nonlinearity.
63 NT_POSCLIP, // Fully connected rect lin version of logistic.
64 NT_SYMCLIP, // Fully connected rect lin version of tanh.
65 NT_TANH, // Fully connected with tanh nonlinearity.
66 NT_RELU, // Fully connected with rectifier nonlinearity.
67 NT_LINEAR, // Fully connected with no nonlinearity.
68 NT_SOFTMAX, // Softmax uses exponential normalization, with CTC.
69 NT_SOFTMAX_NO_CTC, // Softmax uses exponential normalization, no CTC.
70 // The SOFTMAX LSTMs both have an extra softmax layer on top, but inside, with
71 // the outputs fed back to the input of the LSTM at the next timestep.
72 // The ENCODED version binary encodes the softmax outputs, providing log2 of
73 // the number of outputs as additional inputs, and the other version just
74 // provides all the softmax outputs as additional inputs.
75 NT_LSTM_SOFTMAX, // 1-d LSTM with built-in fully connected softmax.
76 NT_LSTM_SOFTMAX_ENCODED, // 1-d LSTM with built-in binary encoded softmax.
77 // A TensorFlow graph encapsulated as a Tesseract network.
79
80 NT_COUNT // Array size.
81};
82
83// Enum of Network behavior flags. Can in theory be set for each individual
84// network element.
86 // Network forward/backprop behavior.
87 NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer.
88 NF_ADAM = 128, // Weight-specific learning rate.
89};
90
91// State of training and desired state used in SetEnableTraining.
93 // Valid states of training_.
94 TS_DISABLED, // Disabled permanently.
95 TS_ENABLED, // Enabled for backprop and to write a training dump.
96 // Re-enable from ANY disabled state.
97 TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump.
98 // Valid only for SetEnableTraining.
99 TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED.
100};
101
102// Base class for network types. Not quite an abstract base class, but almost.
103// Most of the time no isolated Network exists, except prior to
104// deserialization.
105class Network {
106 public:
107 Network();
108 Network(NetworkType type, const STRING& name, int ni, int no);
109 virtual ~Network() = default;
110
111 // Accessors.
113 return type_;
114 }
115 bool IsTraining() const { return training_ == TS_ENABLED; }
116 bool needs_to_backprop() const {
117 return needs_to_backprop_;
118 }
119 int num_weights() const { return num_weights_; }
120 int NumInputs() const {
121 return ni_;
122 }
123 int NumOutputs() const {
124 return no_;
125 }
126 // Returns the required shape input to the network.
127 virtual StaticShape InputShape() const {
128 StaticShape result;
129 return result;
130 }
131 // Returns the shape output from the network given an input shape (which may
132 // be partially unknown ie zero).
133 virtual StaticShape OutputShape(const StaticShape& input_shape) const {
134 StaticShape result(input_shape);
135 result.set_depth(no_);
136 return result;
137 }
138 const STRING& name() const {
139 return name_;
140 }
141 virtual STRING spec() const {
142 return "?";
143 }
144 bool TestFlag(NetworkFlags flag) const {
145 return (network_flags_ & flag) != 0;
146 }
147
148 // Initialization and administrative functions that are mostly provided
149 // by Plumbing.
150 // Returns true if the given type is derived from Plumbing, and thus contains
151 // multiple sub-networks that can have their own learning rate.
152 virtual bool IsPlumbingType() const { return false; }
153
154 // Suspends/Enables/Permanently disables training by setting the training_
155 // flag. Serialize and DeSerialize only operate on the run-time data if state
156 // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
157 // temporarily disable layers in state TS_ENABLED, allowing a trainer to
158 // serialize as if it were a recognizer.
159 // TS_RE_ENABLE will re-enable layers that were previously in any disabled
160 // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
161 // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
162 // recognizer can be converted back to a trainer.
163 virtual void SetEnableTraining(TrainingState state);
164
165 // Sets flags that control the action of the network. See NetworkFlags enum
166 // for bit values.
167 virtual void SetNetworkFlags(uint32_t flags);
168
169 // Sets up the network for training. Initializes weights using weights of
170 // scale `range` picked according to the random number generator `randomizer`.
171 // Note that randomizer is a borrowed pointer that should outlive the network
172 // and should not be deleted by any of the networks.
173 // Returns the number of weights initialized.
174 virtual int InitWeights(float range, TRand* randomizer);
175 // Changes the number of outputs to the outside world to the size of the given
176 // code_map. Recursively searches the entire network for Softmax layers that
177 // have exactly old_no outputs, and operates only on those, leaving all others
178 // unchanged. This enables networks with multiple output layers to get all
179 // their softmaxes updated, but if an internal layer, uses one of those
180 // softmaxes for input, then the inputs will effectively be scrambled.
181 // TODO(rays) Fix this before any such network is implemented.
182 // The softmaxes are resized by copying the old weight matrix entries for each
183 // output from code_map[output] where non-negative, and uses the mean (over
184 // all outputs) of the existing weights for all outputs with negative code_map
185 // entries. Returns the new number of weights.
186 virtual int RemapOutputs(int old_no, const std::vector<int>& code_map) {
187 return 0;
188 }
189
190 // Converts a float network to an int network.
191 virtual void ConvertToInt() {}
192
193 // Provides a pointer to a TRand for any networks that care to use it.
194 // Note that randomizer is a borrowed pointer that should outlive the network
195 // and should not be deleted by any of the networks.
196 virtual void SetRandomizer(TRand* randomizer);
197
198 // Sets needs_to_backprop_ to needs_backprop and returns true if
199 // needs_backprop || any weights in this network so the next layer forward
200 // can be told to produce backprop for this layer if needed.
201 virtual bool SetupNeedsBackprop(bool needs_backprop);
202
203 // Returns the most recent reduction factor that the network applied to the
204 // time sequence. Assumes that any 2-d is already eliminated. Used for
205 // scaling bounding boxes of truth data and calculating result bounding boxes.
206 // WARNING: if GlobalMinimax is used to vary the scale, this will return
207 // the last used scale factor. Call it before any forward, and it will return
208 // the minimum scale factor of the paths through the GlobalMinimax.
209 virtual int XScaleFactor() const {
210 return 1;
211 }
212
213 // Provides the (minimum) x scale factor to the network (of interest only to
214 // input units) so they can determine how to scale bounding boxes.
215 virtual void CacheXScaleFactor(int factor) {}
216
217 // Provides debug output on the weights.
218 virtual void DebugWeights() = 0;
219
220 // Writes to the given file. Returns false in case of error.
221 // Should be overridden by subclasses, but called by their Serialize.
222 virtual bool Serialize(TFile* fp) const;
223 // Reads from the given file. Returns false in case of error.
224 // Should be overridden by subclasses, but NOT called by their DeSerialize.
225 virtual bool DeSerialize(TFile* fp) = 0;
226
227 public:
228 // Updates the weights using the given learning rate, momentum and adam_beta.
229 // num_samples is used in the adam computation iff use_adam_ is true.
230 virtual void Update(float learning_rate, float momentum, float adam_beta,
231 int num_samples) {}
232 // Sums the products of weight updates in *this and other, splitting into
233 // positive (same direction) in *same and negative (different direction) in
234 // *changed.
235 virtual void CountAlternators(const Network& other, double* same,
236 double* changed) const {}
237
238 // Reads from the given file. Returns nullptr in case of error.
239 // Determines the type of the serialized class and calls its DeSerialize
240 // on a new object of the appropriate type, which is returned.
241 static Network* CreateFromFile(TFile* fp);
242
243 // Runs forward propagation of activations on the input line.
244 // Note that input and output are both 2-d arrays.
245 // The 1st index is the time element. In a 1-d network, it might be the pixel
246 // position on the textline. In a 2-d network, the linearization is defined
247 // by the stride_map. (See networkio.h).
248 // The 2nd index of input is the network inputs/outputs, and the dimension
249 // of the input must match NumInputs() of this network.
250 // The output array will be resized as needed so that its 1st dimension is
251 // always equal to the number of output values, and its second dimension is
252 // always NumOutputs(). Note that all this detail is encapsulated away inside
253 // NetworkIO, as are the internals of the scratch memory space used by the
254 // network. See networkscratch.h for that.
255 // If input_transpose is not nullptr, then it contains the transpose of input,
256 // and the caller guarantees that it will still be valid on the next call to
257 // backward. The callee is therefore at liberty to save the pointer and
258 // reference it on a call to backward. This is a bit ugly, but it makes it
259 // possible for a replicating parallel to calculate the input transpose once
260 // instead of all the replicated networks having to do it.
261 virtual void Forward(bool debug, const NetworkIO& input,
262 const TransposedArray* input_transpose,
263 NetworkScratch* scratch, NetworkIO* output) = 0;
264
265 // Runs backward propagation of errors on fwdX_deltas.
266 // Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward.
267 // Returns false if back_deltas was not set, due to there being no point in
268 // propagating further backwards. Thus most complete networks will always
269 // return false from Backward!
270 virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
271 NetworkScratch* scratch,
272 NetworkIO* back_deltas) = 0;
273
274 // === Debug image display methods. ===
275 // Displays the image of the matrix to the forward window.
276 void DisplayForward(const NetworkIO& matrix);
277 // Displays the image of the matrix to the backward window.
278 void DisplayBackward(const NetworkIO& matrix);
279
280 // Creates the window if needed, otherwise clears it.
281 static void ClearWindow(bool tess_coords, const char* window_name,
282 int width, int height, ScrollView** window);
283
284 // Displays the pix in the given window. and returns the height of the pix.
285 // The pix is pixDestroyed.
286 static int DisplayImage(Pix* pix, ScrollView* window);
287
288 protected:
289 // Returns a random number in [-range, range].
290 double Random(double range);
291
292 protected:
293 NetworkType type_; // Type of the derived network class.
294 TrainingState training_; // Are we currently training?
295 bool needs_to_backprop_; // This network needs to output back_deltas.
296 int32_t network_flags_; // Behavior control flags in NetworkFlags.
297 int32_t ni_; // Number of input values.
298 int32_t no_; // Number of output values.
299 int32_t num_weights_; // Number of weights in this and sub-network.
300 STRING name_; // A unique name for this layer.
301
302 // NOT-serialized debug data.
303 ScrollView* forward_win_; // Recognition debug display window.
304 ScrollView* backward_win_; // Training debug display window.
305 TRand* randomizer_; // Random number generator.
306};
307
308} // namespace tesseract.
309
310#endif // TESSERACT_LSTM_NETWORK_H_
TrainingState
Definition: network.h:92
@ TS_TEMP_DISABLE
Definition: network.h:97
@ TS_ENABLED
Definition: network.h:95
@ TS_DISABLED
Definition: network.h:94
@ TS_RE_ENABLE
Definition: network.h:99
NetworkType
Definition: network.h:43
@ NT_LINEAR
Definition: network.h:67
@ NT_MAXPOOL
Definition: network.h:48
@ NT_RELU
Definition: network.h:66
@ NT_XREVERSED
Definition: network.h:56
@ NT_LSTM
Definition: network.h:60
@ NT_CONVOLVE
Definition: network.h:47
@ NT_SOFTMAX
Definition: network.h:68
@ NT_NONE
Definition: network.h:44
@ NT_LOGISTIC
Definition: network.h:62
@ NT_PAR_UD_LSTM
Definition: network.h:52
@ NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:76
@ NT_PARALLEL
Definition: network.h:49
@ NT_SYMCLIP
Definition: network.h:64
@ NT_PAR_2D_LSTM
Definition: network.h:53
@ NT_LSTM_SUMMARY
Definition: network.h:61
@ NT_YREVERSED
Definition: network.h:57
@ NT_RECONFIG
Definition: network.h:55
@ NT_INPUT
Definition: network.h:45
@ NT_TENSORFLOW
Definition: network.h:78
@ NT_POSCLIP
Definition: network.h:63
@ NT_LSTM_SOFTMAX
Definition: network.h:75
@ NT_XYTRANSPOSE
Definition: network.h:58
@ NT_SERIES
Definition: network.h:54
@ NT_SOFTMAX_NO_CTC
Definition: network.h:69
@ NT_TANH
Definition: network.h:65
@ NT_PAR_RL_LSTM
Definition: network.h:51
@ NT_COUNT
Definition: network.h:80
@ NT_REPLICATED
Definition: network.h:50
NetworkFlags
Definition: network.h:85
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:87
@ NF_ADAM
Definition: network.h:88
Definition: rect.h:34
Definition: strngs.h:45
int32_t network_flags_
Definition: network.h:296
NetworkType type_
Definition: network.h:293
virtual int RemapOutputs(int old_no, const std::vector< int > &code_map)
Definition: network.h:186
virtual int XScaleFactor() const
Definition: network.h:209
int NumOutputs() const
Definition: network.h:123
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)=0
bool needs_to_backprop_
Definition: network.h:295
int num_weights() const
Definition: network.h:119
virtual bool SetupNeedsBackprop(bool needs_backprop)
Definition: network.cpp:145
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:312
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:288
double Random(double range)
Definition: network.cpp:281
virtual bool DeSerialize(TFile *fp)=0
virtual bool IsPlumbingType() const
Definition: network.h:152
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:299
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:110
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
bool needs_to_backprop() const
Definition: network.h:116
ScrollView * forward_win_
Definition: network.h:303
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:187
bool IsTraining() const
Definition: network.h:115
virtual void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
Definition: network.h:230
static int DisplayImage(Pix *pix, ScrollView *window)
Definition: network.cpp:335
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
virtual void CacheXScaleFactor(int factor)
Definition: network.h:215
ScrollView * backward_win_
Definition: network.h:304
virtual void DebugWeights()=0
virtual STRING spec() const
Definition: network.h:141
const STRING & name() const
Definition: network.h:138
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:133
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
virtual void CountAlternators(const Network &other, double *same, double *changed) const
Definition: network.h:235
int NumInputs() const
Definition: network.h:120
int32_t num_weights_
Definition: network.h:299
virtual int InitWeights(float range, TRand *randomizer)
Definition: network.cpp:130
TrainingState training_
Definition: network.h:294
virtual ~Network()=default
virtual void SetNetworkFlags(uint32_t flags)
Definition: network.cpp:124
NetworkType type() const
Definition: network.h:112
TRand * randomizer_
Definition: network.h:305
virtual void ConvertToInt()
Definition: network.h:191
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
virtual StaticShape InputShape() const
Definition: network.h:127
void set_depth(int value)
Definition: static_shape.h:49