tesseract 4.1.1
Loading...
Searching...
No Matches
networkbuilder.cpp
Go to the documentation of this file.
1
2// File: networkbuilder.cpp
3// Description: Class to parse the network description language and
4// build a corresponding network.
5// Author: Ray Smith
6// Created: Wed Jul 16 18:35:38 PST 2014
7//
8// (C) Copyright 2014, Google Inc.
9// Licensed under the Apache License, Version 2.0 (the "License");
10// you may not use this file except in compliance with the License.
11// You may obtain a copy of the License at
12// http://www.apache.org/licenses/LICENSE-2.0
13// Unless required by applicable law or agreed to in writing, software
14// distributed under the License is distributed on an "AS IS" BASIS,
15// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16// See the License for the specific language governing permissions and
17// limitations under the License.
19
20#include "networkbuilder.h"
21#include "convolve.h"
22#include "fullyconnected.h"
23#include "input.h"
24#include "lstm.h"
25#include "maxpool.h"
26#include "network.h"
27#include "parallel.h"
28#include "reconfig.h"
29#include "reversed.h"
30#include "series.h"
31#include "unicharset.h"
32
33namespace tesseract {
34
35// Builds a network with a network_spec in the network description
36// language, to recognize a character set of num_outputs size.
37// If append_index is non-negative, then *network must be non-null and the
38// given network_spec will be appended to *network AFTER append_index, with
39// the top of the input *network discarded.
40// Note that network_spec is call by value to allow a non-const char* pointer
41// into the string for BuildFromString.
42// net_flags control network behavior according to the NetworkFlags enum.
43// The resulting network is returned via **network.
44// Returns false if something failed.
45bool NetworkBuilder::InitNetwork(int num_outputs, STRING network_spec,
46 int append_index, int net_flags,
47 float weight_range, TRand* randomizer,
48 Network** network) {
49 NetworkBuilder builder(num_outputs);
50 Series* bottom_series = nullptr;
51 StaticShape input_shape;
52 if (append_index >= 0) {
53 // Split the current network after the given append_index.
54 ASSERT_HOST(*network != nullptr && (*network)->type() == NT_SERIES);
55 auto* series = static_cast<Series*>(*network);
56 Series* top_series = nullptr;
57 series->SplitAt(append_index, &bottom_series, &top_series);
58 if (bottom_series == nullptr || top_series == nullptr) {
59 tprintf("Yikes! Splitting current network failed!!\n");
60 return false;
61 }
62 input_shape = bottom_series->OutputShape(input_shape);
63 delete top_series;
64 }
65 char* str_ptr = &network_spec[0];
66 *network = builder.BuildFromString(input_shape, &str_ptr);
67 if (*network == nullptr) return false;
68 (*network)->SetNetworkFlags(net_flags);
69 (*network)->InitWeights(weight_range, randomizer);
70 (*network)->SetupNeedsBackprop(false);
71 if (bottom_series != nullptr) {
72 bottom_series->AppendSeries(*network);
73 *network = bottom_series;
74 }
75 (*network)->CacheXScaleFactor((*network)->XScaleFactor());
76 return true;
77}
78
79// Helper skips whitespace.
80static void SkipWhitespace(char** str) {
81 while (**str == ' ' || **str == '\t' || **str == '\n') ++*str;
82}
83
84// Parses the given string and returns a network according to the network
85// description language in networkbuilder.h
87 char** str) {
88 SkipWhitespace(str);
89 char code_ch = **str;
90 if (code_ch == '[') {
91 return ParseSeries(input_shape, nullptr, str);
92 }
93 if (input_shape.depth() == 0) {
94 // There must be an input at this point.
95 return ParseInput(str);
96 }
97 switch (code_ch) {
98 case '(':
99 return ParseParallel(input_shape, str);
100 case 'R':
101 return ParseR(input_shape, str);
102 case 'S':
103 return ParseS(input_shape, str);
104 case 'C':
105 return ParseC(input_shape, str);
106 case 'M':
107 return ParseM(input_shape, str);
108 case 'L':
109 return ParseLSTM(input_shape, str);
110 case 'F':
111 return ParseFullyConnected(input_shape, str);
112 case 'O':
113 return ParseOutput(input_shape, str);
114 default:
115 tprintf("Invalid network spec:%s\n", *str);
116 return nullptr;
117 }
118 return nullptr;
119}
120
121// Parses an input specification and returns the result, which may include a
122// series.
123Network* NetworkBuilder::ParseInput(char** str) {
124 // There must be an input at this point.
125 int length = 0;
126 int batch, height, width, depth;
127 int num_converted =
128 sscanf(*str, "%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length);
129 StaticShape shape;
130 shape.SetShape(batch, height, width, depth);
131 // num_converted may or may not include the length.
132 if (num_converted != 4 && num_converted != 5) {
133 tprintf("Must specify an input layer as the first layer, not %s!!\n", *str);
134 return nullptr;
135 }
136 *str += length;
137 Input* input = new Input("Input", shape);
138 // We want to allow [<input>rest of net... or <input>[rest of net... so we
139 // have to check explicitly for '[' here.
140 SkipWhitespace(str);
141 if (**str == '[') return ParseSeries(shape, input, str);
142 return input;
143}
144
145// Parses a sequential series of networks, defined by [<net><net>...].
146Network* NetworkBuilder::ParseSeries(const StaticShape& input_shape,
147 Input* input_layer, char** str) {
148 StaticShape shape = input_shape;
149 Series* series = new Series("Series");
150 ++*str;
151 if (input_layer != nullptr) {
152 series->AddToStack(input_layer);
153 shape = input_layer->OutputShape(shape);
154 }
155 Network* network = nullptr;
156 while (**str != '\0' && **str != ']' &&
157 (network = BuildFromString(shape, str)) != nullptr) {
158 shape = network->OutputShape(shape);
159 series->AddToStack(network);
160 }
161 if (**str != ']') {
162 tprintf("Missing ] at end of [Series]!\n");
163 delete series;
164 return nullptr;
165 }
166 ++*str;
167 return series;
168}
169
170// Parses a parallel set of networks, defined by (<net><net>...).
171Network* NetworkBuilder::ParseParallel(const StaticShape& input_shape,
172 char** str) {
173 Parallel* parallel = new Parallel("Parallel", NT_PARALLEL);
174 ++*str;
175 Network* network = nullptr;
176 while (**str != '\0' && **str != ')' &&
177 (network = BuildFromString(input_shape, str)) != nullptr) {
178 parallel->AddToStack(network);
179 }
180 if (**str != ')') {
181 tprintf("Missing ) at end of (Parallel)!\n");
182 delete parallel;
183 return nullptr;
184 }
185 ++*str;
186 return parallel;
187}
188
189// Parses a network that begins with 'R'.
190Network* NetworkBuilder::ParseR(const StaticShape& input_shape, char** str) {
191 char dir = (*str)[1];
192 if (dir == 'x' || dir == 'y') {
193 STRING name = "Reverse";
194 name += dir;
195 *str += 2;
196 Network* network = BuildFromString(input_shape, str);
197 if (network == nullptr) return nullptr;
198 auto* rev =
199 new Reversed(name, dir == 'y' ? NT_YREVERSED : NT_XREVERSED);
200 rev->SetNetwork(network);
201 return rev;
202 }
203 int replicas = strtol(*str + 1, str, 10);
204 if (replicas <= 0) {
205 tprintf("Invalid R spec!:%s\n", *str);
206 return nullptr;
207 }
208 Parallel* parallel = new Parallel("Replicated", NT_REPLICATED);
209 char* str_copy = *str;
210 for (int i = 0; i < replicas; ++i) {
211 str_copy = *str;
212 Network* network = BuildFromString(input_shape, &str_copy);
213 if (network == nullptr) {
214 tprintf("Invalid replicated network!\n");
215 delete parallel;
216 return nullptr;
217 }
218 parallel->AddToStack(network);
219 }
220 *str = str_copy;
221 return parallel;
222}
223
224// Parses a network that begins with 'S'.
225Network* NetworkBuilder::ParseS(const StaticShape& input_shape, char** str) {
226 int y = strtol(*str + 1, str, 10);
227 if (**str == ',') {
228 int x = strtol(*str + 1, str, 10);
229 if (y <= 0 || x <= 0) {
230 tprintf("Invalid S spec!:%s\n", *str);
231 return nullptr;
232 }
233 return new Reconfig("Reconfig", input_shape.depth(), x, y);
234 } else if (**str == '(') {
235 // TODO(rays) Add Generic reshape.
236 tprintf("Generic reshape not yet implemented!!\n");
237 return nullptr;
238 }
239 tprintf("Invalid S spec!:%s\n", *str);
240 return nullptr;
241}
242
243// Helper returns the fully-connected type for the character code.
244static NetworkType NonLinearity(char func) {
245 switch (func) {
246 case 's':
247 return NT_LOGISTIC;
248 case 't':
249 return NT_TANH;
250 case 'r':
251 return NT_RELU;
252 case 'l':
253 return NT_LINEAR;
254 case 'm':
255 return NT_SOFTMAX;
256 case 'p':
257 return NT_POSCLIP;
258 case 'n':
259 return NT_SYMCLIP;
260 default:
261 return NT_NONE;
262 }
263}
264
265// Parses a network that begins with 'C'.
266Network* NetworkBuilder::ParseC(const StaticShape& input_shape, char** str) {
267 NetworkType type = NonLinearity((*str)[1]);
268 if (type == NT_NONE) {
269 tprintf("Invalid nonlinearity on C-spec!: %s\n", *str);
270 return nullptr;
271 }
272 int y = 0, x = 0, d = 0;
273 if ((y = strtol(*str + 2, str, 10)) <= 0 || **str != ',' ||
274 (x = strtol(*str + 1, str, 10)) <= 0 || **str != ',' ||
275 (d = strtol(*str + 1, str, 10)) <= 0) {
276 tprintf("Invalid C spec!:%s\n", *str);
277 return nullptr;
278 }
279 if (x == 1 && y == 1) {
280 // No actual convolution. Just a FullyConnected on the current depth, to
281 // be slid over all batch,y,x.
282 return new FullyConnected("Conv1x1", input_shape.depth(), d, type);
283 }
284 Series* series = new Series("ConvSeries");
285 Convolve* convolve =
286 new Convolve("Convolve", input_shape.depth(), x / 2, y / 2);
287 series->AddToStack(convolve);
288 StaticShape fc_input = convolve->OutputShape(input_shape);
289 series->AddToStack(new FullyConnected("ConvNL", fc_input.depth(), d, type));
290 return series;
291}
292
293// Parses a network that begins with 'M'.
294Network* NetworkBuilder::ParseM(const StaticShape& input_shape, char** str) {
295 int y = 0, x = 0;
296 if ((*str)[1] != 'p' || (y = strtol(*str + 2, str, 10)) <= 0 ||
297 **str != ',' || (x = strtol(*str + 1, str, 10)) <= 0) {
298 tprintf("Invalid Mp spec!:%s\n", *str);
299 return nullptr;
300 }
301 return new Maxpool("Maxpool", input_shape.depth(), x, y);
302}
303
304// Parses an LSTM network, either individual, bi- or quad-directional.
305Network* NetworkBuilder::ParseLSTM(const StaticShape& input_shape, char** str) {
306 bool two_d = false;
308 char* spec_start = *str;
309 int chars_consumed = 1;
310 int num_outputs = 0;
311 char key = (*str)[chars_consumed], dir = 'f', dim = 'x';
312 if (key == 'S') {
314 num_outputs = num_softmax_outputs_;
315 ++chars_consumed;
316 } else if (key == 'E') {
318 num_outputs = num_softmax_outputs_;
319 ++chars_consumed;
320 } else if (key == '2' && (((*str)[2] == 'x' && (*str)[3] == 'y') ||
321 ((*str)[2] == 'y' && (*str)[3] == 'x'))) {
322 chars_consumed = 4;
323 dim = (*str)[3];
324 two_d = true;
325 } else if (key == 'f' || key == 'r' || key == 'b') {
326 dir = key;
327 dim = (*str)[2];
328 if (dim != 'x' && dim != 'y') {
329 tprintf("Invalid dimension (x|y) in L Spec!:%s\n", *str);
330 return nullptr;
331 }
332 chars_consumed = 3;
333 if ((*str)[chars_consumed] == 's') {
334 ++chars_consumed;
336 }
337 } else {
338 tprintf("Invalid direction (f|r|b) in L Spec!:%s\n", *str);
339 return nullptr;
340 }
341 int num_states = strtol(*str + chars_consumed, str, 10);
342 if (num_states <= 0) {
343 tprintf("Invalid number of states in L Spec!:%s\n", *str);
344 return nullptr;
345 }
346 Network* lstm = nullptr;
347 if (two_d) {
348 lstm = BuildLSTMXYQuad(input_shape.depth(), num_states);
349 } else {
350 if (num_outputs == 0) num_outputs = num_states;
351 STRING name(spec_start, *str - spec_start);
352 lstm = new LSTM(name, input_shape.depth(), num_states, num_outputs, false,
353 type);
354 if (dir != 'f') {
355 Reversed* rev = new Reversed("RevLSTM", NT_XREVERSED);
356 rev->SetNetwork(lstm);
357 lstm = rev;
358 }
359 if (dir == 'b') {
360 name += "LTR";
361 Parallel* parallel = new Parallel("BidiLSTM", NT_PAR_RL_LSTM);
362 parallel->AddToStack(new LSTM(name, input_shape.depth(), num_states,
363 num_outputs, false, type));
364 parallel->AddToStack(lstm);
365 lstm = parallel;
366 }
367 }
368 if (dim == 'y') {
369 Reversed* rev = new Reversed("XYTransLSTM", NT_XYTRANSPOSE);
370 rev->SetNetwork(lstm);
371 lstm = rev;
372 }
373 return lstm;
374}
375
376// Builds a set of 4 lstms with x and y reversal, running in true parallel.
377Network* NetworkBuilder::BuildLSTMXYQuad(int num_inputs, int num_states) {
378 Parallel* parallel = new Parallel("2DLSTMQuad", NT_PAR_2D_LSTM);
379 parallel->AddToStack(new LSTM("L2DLTRDown", num_inputs, num_states,
380 num_states, true, NT_LSTM));
381 Reversed* rev = new Reversed("L2DLTRXRev", NT_XREVERSED);
382 rev->SetNetwork(new LSTM("L2DRTLDown", num_inputs, num_states, num_states,
383 true, NT_LSTM));
384 parallel->AddToStack(rev);
385 rev = new Reversed("L2DRTLYRev", NT_YREVERSED);
386 rev->SetNetwork(
387 new LSTM("L2DRTLUp", num_inputs, num_states, num_states, true, NT_LSTM));
388 Reversed* rev2 = new Reversed("L2DXRevU", NT_XREVERSED);
389 rev2->SetNetwork(rev);
390 parallel->AddToStack(rev2);
391 rev = new Reversed("L2DXRevY", NT_YREVERSED);
392 rev->SetNetwork(new LSTM("L2DLTRDown", num_inputs, num_states, num_states,
393 true, NT_LSTM));
394 parallel->AddToStack(rev);
395 return parallel;
396}
397
398// Helper builds a truly (0-d) fully connected layer of the given type.
399static Network* BuildFullyConnected(const StaticShape& input_shape,
400 NetworkType type, const STRING& name,
401 int depth) {
402 if (input_shape.height() == 0 || input_shape.width() == 0) {
403 tprintf("Fully connected requires positive height and width, had %d,%d\n",
404 input_shape.height(), input_shape.width());
405 return nullptr;
406 }
407 int input_size = input_shape.height() * input_shape.width();
408 int input_depth = input_size * input_shape.depth();
409 Network* fc = new FullyConnected(name, input_depth, depth, type);
410 if (input_size > 1) {
411 Series* series = new Series("FCSeries");
412 series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(),
413 input_shape.width(), input_shape.height()));
414 series->AddToStack(fc);
415 fc = series;
416 }
417 return fc;
418}
419
420// Parses a Fully connected network.
421Network* NetworkBuilder::ParseFullyConnected(const StaticShape& input_shape,
422 char** str) {
423 char* spec_start = *str;
424 NetworkType type = NonLinearity((*str)[1]);
425 if (type == NT_NONE) {
426 tprintf("Invalid nonlinearity on F-spec!: %s\n", *str);
427 return nullptr;
428 }
429 int depth = strtol(*str + 2, str, 10);
430 if (depth <= 0) {
431 tprintf("Invalid F spec!:%s\n", *str);
432 return nullptr;
433 }
434 STRING name(spec_start, *str - spec_start);
435 return BuildFullyConnected(input_shape, type, name, depth);
436}
437
438// Parses an Output spec.
439Network* NetworkBuilder::ParseOutput(const StaticShape& input_shape,
440 char** str) {
441 char dims_ch = (*str)[1];
442 if (dims_ch != '0' && dims_ch != '1' && dims_ch != '2') {
443 tprintf("Invalid dims (2|1|0) in output spec!:%s\n", *str);
444 return nullptr;
445 }
446 char type_ch = (*str)[2];
447 if (type_ch != 'l' && type_ch != 's' && type_ch != 'c') {
448 tprintf("Invalid output type (l|s|c) in output spec!:%s\n", *str);
449 return nullptr;
450 }
451 int depth = strtol(*str + 3, str, 10);
452 if (depth != num_softmax_outputs_) {
453 tprintf("Warning: given outputs %d not equal to unicharset of %d.\n", depth,
454 num_softmax_outputs_);
455 depth = num_softmax_outputs_;
456 }
458 if (type_ch == 'l')
460 else if (type_ch == 's')
462 if (dims_ch == '0') {
463 // Same as standard fully connected.
464 return BuildFullyConnected(input_shape, type, "Output", depth);
465 } else if (dims_ch == '2') {
466 // We don't care if x and/or y are variable.
467 return new FullyConnected("Output2d", input_shape.depth(), depth, type);
468 }
469 // For 1-d y has to be fixed, and if not 1, moved to depth.
470 if (input_shape.height() == 0) {
471 tprintf("Fully connected requires fixed height!\n");
472 return nullptr;
473 }
474 int input_size = input_shape.height();
475 int input_depth = input_size * input_shape.depth();
476 Network* fc = new FullyConnected("Output", input_depth, depth, type);
477 if (input_size > 1) {
478 Series* series = new Series("FCSeries");
479 series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(), 1,
480 input_shape.height()));
481 series->AddToStack(fc);
482 fc = series;
483 }
484 return fc;
485}
486
487} // namespace tesseract.
#define ASSERT_HOST(x)
Definition: errcode.h:88
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35
NetworkType
Definition: network.h:43
@ NT_LINEAR
Definition: network.h:67
@ NT_RELU
Definition: network.h:66
@ NT_XREVERSED
Definition: network.h:56
@ NT_LSTM
Definition: network.h:60
@ NT_SOFTMAX
Definition: network.h:68
@ NT_NONE
Definition: network.h:44
@ NT_LOGISTIC
Definition: network.h:62
@ NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:76
@ NT_PARALLEL
Definition: network.h:49
@ NT_SYMCLIP
Definition: network.h:64
@ NT_PAR_2D_LSTM
Definition: network.h:53
@ NT_LSTM_SUMMARY
Definition: network.h:61
@ NT_YREVERSED
Definition: network.h:57
@ NT_POSCLIP
Definition: network.h:63
@ NT_LSTM_SOFTMAX
Definition: network.h:75
@ NT_XYTRANSPOSE
Definition: network.h:58
@ NT_SERIES
Definition: network.h:54
@ NT_SOFTMAX_NO_CTC
Definition: network.h:69
@ NT_TANH
Definition: network.h:65
@ NT_PAR_RL_LSTM
Definition: network.h:51
@ NT_REPLICATED
Definition: network.h:50
Definition: strngs.h:45
virtual void SetNetworkFlags(uint32_t flags)
Definition: network.cpp:124
NetworkType type() const
Definition: network.h:112
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
Network * BuildFromString(const StaticShape &input_shape, char **str)
void AppendSeries(Network *src)
Definition: series.cpp:190
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: series.cpp:35
void CacheXScaleFactor(int factor) override
Definition: series.cpp:101
void SplitAt(int last_start, Series **start, Series **end)
Definition: series.cpp:160
void SetShape(int batch, int height, int width, int depth)
Definition: static_shape.h:52