tesseract 4.1.1
Loading...
Searching...
No Matches
lstm.cpp
Go to the documentation of this file.
1
2// File: lstm.cpp
3// Description: Long-term-short-term-memory Recurrent neural network.
4// Author: Ray Smith
5// Created: Wed May 01 17:43:06 PST 2013
6//
7// (C) Copyright 2013, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18
19#include "lstm.h"
20
21#ifdef _OPENMP
22#include <omp.h>
23#endif
24#include <cstdio>
25#include <cstdlib>
26
27#if !defined(__GNUC__) && defined(_MSC_VER)
28#include <intrin.h> // _BitScanReverse
29#endif
30
31#include "fullyconnected.h"
32#include "functions.h"
33#include "networkscratch.h"
34#include "tprintf.h"
35
36// Macros for openmp code if it is available, otherwise empty macros.
37#ifdef _OPENMP
38#define PARALLEL_IF_OPENMP(__num_threads) \
39 PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \
40 PRAGMA(omp sections nowait) { \
41 PRAGMA(omp section) {
42#define SECTION_IF_OPENMP \
43 } \
44 PRAGMA(omp section) \
45 {
46
47#define END_PARALLEL_IF_OPENMP \
48 } \
49 } /* end of sections */ \
50 } /* end of parallel section */
51
52// Define the portable PRAGMA macro.
53#ifdef _MSC_VER // Different _Pragma
54#define PRAGMA(x) __pragma(x)
55#else
56#define PRAGMA(x) _Pragma(#x)
57#endif // _MSC_VER
58
59#else // _OPENMP
60#define PARALLEL_IF_OPENMP(__num_threads)
61#define SECTION_IF_OPENMP
62#define END_PARALLEL_IF_OPENMP
63#endif // _OPENMP
64
65
66namespace tesseract {
67
68// Max absolute value of state_. It is reasonably high to enable the state
69// to count things.
70const double kStateClip = 100.0;
71// Max absolute value of gate_errors (the gradients).
72const double kErrClip = 1.0f;
73
74// Calculate ceil(log2(n)).
75static inline uint32_t ceil_log2(uint32_t n)
76{
77 // l2 = (unsigned)log2(n).
78#if defined(__GNUC__)
79 // Use fast inline assembler code for gcc or clang.
80 uint32_t l2 = 31 - __builtin_clz(n);
81#elif defined(_MSC_VER)
82 // Use fast intrinsic function for MS compiler.
83 unsigned long l2 = 0;
84 _BitScanReverse(&l2, n);
85#else
86 if (n == 0) return UINT_MAX;
87 if (n == 1) return 0;
88 uint32_t val = n;
89 uint32_t l2 = 0;
90 while (val > 1) {
91 val >>= 1;
92 l2++;
93 }
94#endif
95 // Round up if n is not a power of 2.
96 return (n == (1u << l2)) ? l2 : l2 + 1;
97}
98
99LSTM::LSTM(const STRING& name, int ni, int ns, int no, bool two_dimensional,
100 NetworkType type)
101 : Network(type, name, ni, no),
102 na_(ni + ns),
103 ns_(ns),
104 nf_(0),
105 is_2d_(two_dimensional),
106 softmax_(nullptr),
107 input_width_(0) {
108 if (two_dimensional) na_ += ns_;
109 if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
110 nf_ = 0;
111 // networkbuilder ensures this is always true.
112 ASSERT_HOST(no == ns);
114 nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : ceil_log2(no_);
115 softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
116 } else {
117 tprintf("%d is invalid type of LSTM!\n", type);
118 ASSERT_HOST(false);
119 }
120 na_ += nf_;
121}
122
123LSTM::~LSTM() { delete softmax_; }
124
125// Returns the shape output from the network given an input shape (which may
126// be partially unknown ie zero).
127StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
128 StaticShape result = input_shape;
129 result.set_depth(no_);
130 if (type_ == NT_LSTM_SUMMARY) result.set_width(1);
131 if (softmax_ != nullptr) return softmax_->OutputShape(result);
132 return result;
133}
134
135// Suspends/Enables training by setting the training_ flag. Serialize and
136// DeSerialize only operate on the run-time data if state is false.
138 if (state == TS_RE_ENABLE) {
139 // Enable only from temp disabled.
141 } else if (state == TS_TEMP_DISABLE) {
142 // Temp disable only from enabled.
143 if (training_ == TS_ENABLED) training_ = state;
144 } else {
145 if (state == TS_ENABLED && training_ != TS_ENABLED) {
146 for (int w = 0; w < WT_COUNT; ++w) {
147 if (w == GFS && !Is2D()) continue;
148 gate_weights_[w].InitBackward();
149 }
150 }
151 training_ = state;
152 }
153 if (softmax_ != nullptr) softmax_->SetEnableTraining(state);
154}
155
156// Sets up the network for training. Initializes weights using weights of
157// scale `range` picked according to the random number generator `randomizer`.
158int LSTM::InitWeights(float range, TRand* randomizer) {
159 Network::SetRandomizer(randomizer);
160 num_weights_ = 0;
161 for (int w = 0; w < WT_COUNT; ++w) {
162 if (w == GFS && !Is2D()) continue;
163 num_weights_ += gate_weights_[w].InitWeightsFloat(
164 ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer);
165 }
166 if (softmax_ != nullptr) {
167 num_weights_ += softmax_->InitWeights(range, randomizer);
168 }
169 return num_weights_;
170}
171
172// Recursively searches the network for softmaxes with old_no outputs,
173// and remaps their outputs according to code_map. See network.h for details.
174int LSTM::RemapOutputs(int old_no, const std::vector<int>& code_map) {
175 if (softmax_ != nullptr) {
176 num_weights_ -= softmax_->num_weights();
177 num_weights_ += softmax_->RemapOutputs(old_no, code_map);
178 }
179 return num_weights_;
180}
181
182// Converts a float network to an int network.
184 for (int w = 0; w < WT_COUNT; ++w) {
185 if (w == GFS && !Is2D()) continue;
186 gate_weights_[w].ConvertToInt();
187 }
188 if (softmax_ != nullptr) {
189 softmax_->ConvertToInt();
190 }
191}
192
193// Sets up the network for training using the given weight_range.
195 for (int w = 0; w < WT_COUNT; ++w) {
196 if (w == GFS && !Is2D()) continue;
197 STRING msg = name_;
198 msg.add_str_int(" Gate weights ", w);
199 gate_weights_[w].Debug2D(msg.string());
200 }
201 if (softmax_ != nullptr) {
202 softmax_->DebugWeights();
203 }
204}
205
206// Writes to the given file. Returns false in case of error.
207bool LSTM::Serialize(TFile* fp) const {
208 if (!Network::Serialize(fp)) return false;
209 if (!fp->Serialize(&na_)) return false;
210 for (int w = 0; w < WT_COUNT; ++w) {
211 if (w == GFS && !Is2D()) continue;
212 if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
213 }
214 if (softmax_ != nullptr && !softmax_->Serialize(fp)) return false;
215 return true;
216}
217
218// Reads from the given file. Returns false in case of error.
219
221 if (!fp->DeSerialize(&na_)) return false;
222 if (type_ == NT_LSTM_SOFTMAX) {
223 nf_ = no_;
224 } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
225 nf_ = ceil_log2(no_);
226 } else {
227 nf_ = 0;
228 }
229 is_2d_ = false;
230 for (int w = 0; w < WT_COUNT; ++w) {
231 if (w == GFS && !Is2D()) continue;
232 if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) return false;
233 if (w == CI) {
234 ns_ = gate_weights_[CI].NumOutputs();
235 is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
236 }
237 }
238 delete softmax_;
240 softmax_ = static_cast<FullyConnected*>(Network::CreateFromFile(fp));
241 if (softmax_ == nullptr) return false;
242 } else {
243 softmax_ = nullptr;
244 }
245 return true;
246}
247
248// Runs forward propagation of activations on the input line.
249// See NetworkCpp for a detailed discussion of the arguments.
250void LSTM::Forward(bool debug, const NetworkIO& input,
251 const TransposedArray* input_transpose,
252 NetworkScratch* scratch, NetworkIO* output) {
253 input_map_ = input.stride_map();
254 input_width_ = input.Width();
255 if (softmax_ != nullptr)
256 output->ResizeFloat(input, no_);
257 else if (type_ == NT_LSTM_SUMMARY)
258 output->ResizeXTo1(input, no_);
259 else
260 output->Resize(input, no_);
261 ResizeForward(input);
262 // Temporary storage of forward computation for each gate.
264 for (auto & temp_line : temp_lines) temp_line.Init(ns_, scratch);
265 // Single timestep buffers for the current/recurrent output and state.
266 NetworkScratch::FloatVec curr_state, curr_output;
267 curr_state.Init(ns_, scratch);
268 ZeroVector<double>(ns_, curr_state);
269 curr_output.Init(ns_, scratch);
270 ZeroVector<double>(ns_, curr_output);
271 // Rotating buffers of width buf_width allow storage of the state and output
272 // for the other dimension, used only when working in true 2D mode. The width
273 // is enough to hold an entire strip of the major direction.
274 int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
276 if (Is2D()) {
277 states.init_to_size(buf_width, NetworkScratch::FloatVec());
278 outputs.init_to_size(buf_width, NetworkScratch::FloatVec());
279 for (int i = 0; i < buf_width; ++i) {
280 states[i].Init(ns_, scratch);
281 ZeroVector<double>(ns_, states[i]);
282 outputs[i].Init(ns_, scratch);
283 ZeroVector<double>(ns_, outputs[i]);
284 }
285 }
286 // Used only if a softmax LSTM.
287 NetworkScratch::FloatVec softmax_output;
288 NetworkScratch::IO int_output;
289 if (softmax_ != nullptr) {
290 softmax_output.Init(no_, scratch);
291 ZeroVector<double>(no_, softmax_output);
292 int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_);
293 if (input.int_mode())
294 int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch);
295 softmax_->SetupForward(input, nullptr);
296 }
297 NetworkScratch::FloatVec curr_input;
298 curr_input.Init(na_, scratch);
299 StrideMap::Index src_index(input_map_);
300 // Used only by NT_LSTM_SUMMARY.
301 StrideMap::Index dest_index(output->stride_map());
302 do {
303 int t = src_index.t();
304 // True if there is a valid old state for the 2nd dimension.
305 bool valid_2d = Is2D();
306 if (valid_2d) {
307 StrideMap::Index dim_index(src_index);
308 if (!dim_index.AddOffset(-1, FD_HEIGHT)) valid_2d = false;
309 }
310 // Index of the 2-D revolving buffers (outputs, states).
311 int mod_t = Modulo(t, buf_width); // Current timestep.
312 // Setup the padded input in source.
313 source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
314 if (softmax_ != nullptr) {
315 source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
316 }
317 source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
318 if (Is2D())
319 source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
320 if (!source_.int_mode()) source_.ReadTimeStep(t, curr_input);
321 // Matrix multiply the inputs with the source.
323 // It looks inefficient to create the threads on each t iteration, but the
324 // alternative of putting the parallel outside the t loop, a single around
325 // the t-loop and then tasks in place of the sections is a *lot* slower.
326 // Cell inputs.
327 if (source_.int_mode())
328 gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
329 else
330 gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
331 FuncInplace<GFunc>(ns_, temp_lines[CI]);
332
334 // Input Gates.
335 if (source_.int_mode())
336 gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
337 else
338 gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
339 FuncInplace<FFunc>(ns_, temp_lines[GI]);
340
342 // 1-D forget gates.
343 if (source_.int_mode())
344 gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
345 else
346 gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
347 FuncInplace<FFunc>(ns_, temp_lines[GF1]);
348
349 // 2-D forget gates.
350 if (Is2D()) {
351 if (source_.int_mode())
352 gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
353 else
354 gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
355 FuncInplace<FFunc>(ns_, temp_lines[GFS]);
356 }
357
359 // Output gates.
360 if (source_.int_mode())
361 gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
362 else
363 gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
364 FuncInplace<FFunc>(ns_, temp_lines[GO]);
366
367 // Apply forget gate to state.
368 MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
369 if (Is2D()) {
370 // Max-pool the forget gates (in 2-d) instead of blindly adding.
371 int8_t* which_fg_col = which_fg_[t];
372 memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
373 if (valid_2d) {
374 const double* stepped_state = states[mod_t];
375 for (int i = 0; i < ns_; ++i) {
376 if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
377 curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
378 which_fg_col[i] = 2;
379 }
380 }
381 }
382 }
383 MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
384 // Clip curr_state to a sane range.
385 ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
386 if (IsTraining()) {
387 // Save the gate node values.
388 node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
389 node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
390 node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
391 node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
392 if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
393 }
394 FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
395 if (IsTraining()) state_.WriteTimeStep(t, curr_state);
396 if (softmax_ != nullptr) {
397 if (input.int_mode()) {
398 int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
399 softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
400 } else {
401 softmax_->ForwardTimeStep(curr_output, t, softmax_output);
402 }
403 output->WriteTimeStep(t, softmax_output);
405 CodeInBinary(no_, nf_, softmax_output);
406 }
407 } else if (type_ == NT_LSTM_SUMMARY) {
408 // Output only at the end of a row.
409 if (src_index.IsLast(FD_WIDTH)) {
410 output->WriteTimeStep(dest_index.t(), curr_output);
411 dest_index.Increment();
412 }
413 } else {
414 output->WriteTimeStep(t, curr_output);
415 }
416 // Save states for use by the 2nd dimension only if needed.
417 if (Is2D()) {
418 CopyVector(ns_, curr_state, states[mod_t]);
419 CopyVector(ns_, curr_output, outputs[mod_t]);
420 }
421 // Always zero the states at the end of every row, but only for the major
422 // direction. The 2-D state remains intact.
423 if (src_index.IsLast(FD_WIDTH)) {
424 ZeroVector<double>(ns_, curr_state);
425 ZeroVector<double>(ns_, curr_output);
426 }
427 } while (src_index.Increment());
428#if DEBUG_DETAIL > 0
429 tprintf("Source:%s\n", name_.string());
430 source_.Print(10);
431 tprintf("State:%s\n", name_.string());
432 state_.Print(10);
433 tprintf("Output:%s\n", name_.string());
434 output->Print(10);
435#endif
436 if (debug) DisplayForward(*output);
437}
438
439// Runs backward propagation of errors on the deltas line.
440// See NetworkCpp for a detailed discussion of the arguments.
441bool LSTM::Backward(bool debug, const NetworkIO& fwd_deltas,
442 NetworkScratch* scratch,
443 NetworkIO* back_deltas) {
444 if (debug) DisplayBackward(fwd_deltas);
445 back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
446 // ======Scratch space.======
447 // Output errors from deltas with recurrence from sourceerr.
448 NetworkScratch::FloatVec outputerr;
449 outputerr.Init(ns_, scratch);
450 // Recurrent error in the state/source.
451 NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
452 curr_stateerr.Init(ns_, scratch);
453 curr_sourceerr.Init(na_, scratch);
454 ZeroVector<double>(ns_, curr_stateerr);
455 ZeroVector<double>(na_, curr_sourceerr);
456 // Errors in the gates.
458 for (auto & gate_error : gate_errors) gate_error.Init(ns_, scratch);
459 // Rotating buffers of width buf_width allow storage of the recurrent time-
460 // steps used only for true 2-D. Stores one full strip of the major direction.
461 int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
462 GenericVector<NetworkScratch::FloatVec> stateerr, sourceerr;
463 if (Is2D()) {
464 stateerr.init_to_size(buf_width, NetworkScratch::FloatVec());
465 sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec());
466 for (int t = 0; t < buf_width; ++t) {
467 stateerr[t].Init(ns_, scratch);
468 sourceerr[t].Init(na_, scratch);
469 ZeroVector<double>(ns_, stateerr[t]);
470 ZeroVector<double>(na_, sourceerr[t]);
471 }
472 }
473 // Parallel-generated sourceerr from each of the gates.
474 NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
475 for (auto & sourceerr_temp : sourceerr_temps)
476 sourceerr_temp.Init(na_, scratch);
477 int width = input_width_;
478 // Transposed gate errors stored over all timesteps for sum outer.
480 for (auto & w : gate_errors_t) {
481 w.Init(ns_, width, scratch);
482 }
483 // Used only if softmax_ != nullptr.
484 NetworkScratch::FloatVec softmax_errors;
485 NetworkScratch::GradientStore softmax_errors_t;
486 if (softmax_ != nullptr) {
487 softmax_errors.Init(no_, scratch);
488 softmax_errors_t.Init(no_, width, scratch);
489 }
490 double state_clip = Is2D() ? 9.0 : 4.0;
491#if DEBUG_DETAIL > 1
492 tprintf("fwd_deltas:%s\n", name_.string());
493 fwd_deltas.Print(10);
494#endif
495 StrideMap::Index dest_index(input_map_);
496 dest_index.InitToLast();
497 // Used only by NT_LSTM_SUMMARY.
498 StrideMap::Index src_index(fwd_deltas.stride_map());
499 src_index.InitToLast();
500 do {
501 int t = dest_index.t();
502 bool at_last_x = dest_index.IsLast(FD_WIDTH);
503 // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
504 // valid if >= 0, which is true if 2d and not on the top/bottom.
505 int up_pos = -1;
506 int down_pos = -1;
507 if (Is2D()) {
508 if (dest_index.index(FD_HEIGHT) > 0) {
509 StrideMap::Index up_index(dest_index);
510 if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t();
511 }
512 if (!dest_index.IsLast(FD_HEIGHT)) {
513 StrideMap::Index down_index(dest_index);
514 if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t();
515 }
516 }
517 // Index of the 2-D revolving buffers (sourceerr, stateerr).
518 int mod_t = Modulo(t, buf_width); // Current timestep.
519 // Zero the state in the major direction only at the end of every row.
520 if (at_last_x) {
521 ZeroVector<double>(na_, curr_sourceerr);
522 ZeroVector<double>(ns_, curr_stateerr);
523 }
524 // Setup the outputerr.
525 if (type_ == NT_LSTM_SUMMARY) {
526 if (dest_index.IsLast(FD_WIDTH)) {
527 fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
528 src_index.Decrement();
529 } else {
530 ZeroVector<double>(ns_, outputerr);
531 }
532 } else if (softmax_ == nullptr) {
533 fwd_deltas.ReadTimeStep(t, outputerr);
534 } else {
535 softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors,
536 softmax_errors_t.get(), outputerr);
537 }
538 if (!at_last_x)
539 AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
540 if (down_pos >= 0)
541 AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
542 // Apply the 1-d forget gates.
543 if (!at_last_x) {
544 const float* next_node_gf1 = node_values_[GF1].f(t + 1);
545 for (int i = 0; i < ns_; ++i) {
546 curr_stateerr[i] *= next_node_gf1[i];
547 }
548 }
549 if (Is2D() && t + 1 < width) {
550 for (int i = 0; i < ns_; ++i) {
551 if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
552 }
553 if (down_pos >= 0) {
554 const float* right_node_gfs = node_values_[GFS].f(down_pos);
555 const double* right_stateerr = stateerr[mod_t];
556 for (int i = 0; i < ns_; ++i) {
557 if (which_fg_[down_pos][i] == 2) {
558 curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
559 }
560 }
561 }
562 }
563 state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr,
564 curr_stateerr);
565 // Clip stateerr_ to a sane range.
566 ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
567#if DEBUG_DETAIL > 1
568 if (t + 10 > width) {
569 tprintf("t=%d, stateerr=", t);
570 for (int i = 0; i < ns_; ++i)
571 tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i],
572 curr_sourceerr[ni_ + nf_ + i]);
573 tprintf("\n");
574 }
575#endif
576 // Matrix multiply to get the source errors.
578
579 // Cell inputs.
580 node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t,
581 curr_stateerr, gate_errors[CI]);
582 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
583 gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
584 gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
585
587 // Input Gates.
588 node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t,
589 curr_stateerr, gate_errors[GI]);
590 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
591 gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
592 gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
593
595 // 1-D forget Gates.
596 if (t > 0) {
597 node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
598 gate_errors[GF1]);
599 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
600 gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1],
601 sourceerr_temps[GF1]);
602 } else {
603 memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
604 memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
605 }
606 gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
607
608 // 2-D forget Gates.
609 if (up_pos >= 0) {
610 node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
611 gate_errors[GFS]);
612 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
613 gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS],
614 sourceerr_temps[GFS]);
615 } else {
616 memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
617 memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
618 }
619 if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
620
622 // Output gates.
623 state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr,
624 gate_errors[GO]);
625 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
626 gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
627 gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
629
630 SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
631 sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
632 curr_sourceerr);
633 back_deltas->WriteTimeStep(t, curr_sourceerr);
634 // Save states for use by the 2nd dimension only if needed.
635 if (Is2D()) {
636 CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
637 CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
638 }
639 } while (dest_index.Decrement());
640#if DEBUG_DETAIL > 2
641 for (int w = 0; w < WT_COUNT; ++w) {
642 tprintf("%s gate errors[%d]\n", name_.string(), w);
643 gate_errors_t[w].get()->PrintUnTransposed(10);
644 }
645#endif
646 // Transposed source_ used to speed-up SumOuter.
647 NetworkScratch::GradientStore source_t, state_t;
648 source_t.Init(na_, width, scratch);
649 source_.Transpose(source_t.get());
650 state_t.Init(ns_, width, scratch);
651 state_.Transpose(state_t.get());
652#ifdef _OPENMP
653#pragma omp parallel for num_threads(GFS) if (!Is2D())
654#endif
655 for (int w = 0; w < WT_COUNT; ++w) {
656 if (w == GFS && !Is2D()) continue;
657 gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
658 }
659 if (softmax_ != nullptr) {
660 softmax_->FinishBackward(*softmax_errors_t);
661 }
662 return needs_to_backprop_;
663}
664
665// Updates the weights using the given learning rate, momentum and adam_beta.
666// num_samples is used in the adam computation iff use_adam_ is true.
667void LSTM::Update(float learning_rate, float momentum, float adam_beta,
668 int num_samples) {
669#if DEBUG_DETAIL > 3
670 PrintW();
671#endif
672 for (int w = 0; w < WT_COUNT; ++w) {
673 if (w == GFS && !Is2D()) continue;
674 gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples);
675 }
676 if (softmax_ != nullptr) {
677 softmax_->Update(learning_rate, momentum, adam_beta, num_samples);
678 }
679#if DEBUG_DETAIL > 3
680 PrintDW();
681#endif
682}
683
684// Sums the products of weight updates in *this and other, splitting into
685// positive (same direction) in *same and negative (different direction) in
686// *changed.
687void LSTM::CountAlternators(const Network& other, double* same,
688 double* changed) const {
689 ASSERT_HOST(other.type() == type_);
690 const LSTM* lstm = static_cast<const LSTM*>(&other);
691 for (int w = 0; w < WT_COUNT; ++w) {
692 if (w == GFS && !Is2D()) continue;
693 gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
694 }
695 if (softmax_ != nullptr) {
696 softmax_->CountAlternators(*lstm->softmax_, same, changed);
697 }
698}
699
700// Prints the weights for debug purposes.
702 tprintf("Weight state:%s\n", name_.string());
703 for (int w = 0; w < WT_COUNT; ++w) {
704 if (w == GFS && !Is2D()) continue;
705 tprintf("Gate %d, inputs\n", w);
706 for (int i = 0; i < ni_; ++i) {
707 tprintf("Row %d:", i);
708 for (int s = 0; s < ns_; ++s)
709 tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
710 tprintf("\n");
711 }
712 tprintf("Gate %d, outputs\n", w);
713 for (int i = ni_; i < ni_ + ns_; ++i) {
714 tprintf("Row %d:", i - ni_);
715 for (int s = 0; s < ns_; ++s)
716 tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
717 tprintf("\n");
718 }
719 tprintf("Gate %d, bias\n", w);
720 for (int s = 0; s < ns_; ++s)
721 tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
722 tprintf("\n");
723 }
724}
725
726// Prints the weight deltas for debug purposes.
728 tprintf("Delta state:%s\n", name_.string());
729 for (int w = 0; w < WT_COUNT; ++w) {
730 if (w == GFS && !Is2D()) continue;
731 tprintf("Gate %d, inputs\n", w);
732 for (int i = 0; i < ni_; ++i) {
733 tprintf("Row %d:", i);
734 for (int s = 0; s < ns_; ++s)
735 tprintf(" %g", gate_weights_[w].GetDW(s, i));
736 tprintf("\n");
737 }
738 tprintf("Gate %d, outputs\n", w);
739 for (int i = ni_; i < ni_ + ns_; ++i) {
740 tprintf("Row %d:", i - ni_);
741 for (int s = 0; s < ns_; ++s)
742 tprintf(" %g", gate_weights_[w].GetDW(s, i));
743 tprintf("\n");
744 }
745 tprintf("Gate %d, bias\n", w);
746 for (int s = 0; s < ns_; ++s)
747 tprintf(" %g", gate_weights_[w].GetDW(s, na_));
748 tprintf("\n");
749 }
750}
751
752// Resizes forward data to cope with an input image of the given width.
753void LSTM::ResizeForward(const NetworkIO& input) {
754 int rounded_inputs = gate_weights_[CI].RoundInputs(na_);
755 source_.Resize(input, rounded_inputs);
756 which_fg_.ResizeNoInit(input.Width(), ns_);
757 if (IsTraining()) {
758 state_.ResizeFloat(input, ns_);
759 for (int w = 0; w < WT_COUNT; ++w) {
760 if (w == GFS && !Is2D()) continue;
761 node_values_[w].ResizeFloat(input, ns_);
762 }
763 }
764}
765
766
767} // namespace tesseract.
#define ASSERT_HOST(x)
Definition: errcode.h:88
int Modulo(int a, int b)
Definition: helpers.h:158
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:62
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:60
#define SECTION_IF_OPENMP
Definition: lstm.cpp:61
const double kErrClip
Definition: lstm.cpp:72
const double kStateClip
Definition: lstm.cpp:70
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
Definition: functions.h:192
void AccumulateVector(int n, const double *src, double *dest)
Definition: functions.h:174
void CodeInBinary(int n, int nf, double *vec)
Definition: functions.h:214
TrainingState
Definition: network.h:92
@ TS_TEMP_DISABLE
Definition: network.h:97
@ TS_ENABLED
Definition: network.h:95
@ TS_RE_ENABLE
Definition: network.h:99
NetworkType
Definition: network.h:43
@ NT_LSTM
Definition: network.h:60
@ NT_SOFTMAX
Definition: network.h:68
@ NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:76
@ NT_LSTM_SUMMARY
Definition: network.h:61
@ NT_LSTM_SOFTMAX
Definition: network.h:75
@ FD_WIDTH
Definition: stridemap.h:35
@ FD_HEIGHT
Definition: stridemap.h:34
@ NF_ADAM
Definition: network.h:88
void MultiplyVectorsInPlace(int n, const double *src, double *inout)
Definition: functions.h:179
void CopyVector(int n, const double *src, double *dest)
Definition: functions.h:169
void ClipVector(int n, T lower, T upper, T *vec)
Definition: functions.h:208
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
Definition: functions.h:184
void init_to_size(int size, const T &t)
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:94
bool Serialize(const char *data, size_t count=1)
Definition: serialis.cpp:148
bool DeSerialize(char *data, size_t count=1)
Definition: serialis.cpp:104
Definition: strngs.h:45
void add_str_int(const char *str, int number)
Definition: strngs.cpp:377
const char * string() const
Definition: strngs.cpp:194
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void FinishBackward(const TransposedArray &errors_t)
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
void CountAlternators(const Network &other, double *same, double *changed) const override
void SetEnableTraining(TrainingState state) override
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
int InitWeights(float range, TRand *randomizer) override
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
void ForwardTimeStep(int t, double *output_line)
void ConvertToInt() override
StaticShape OutputShape(const StaticShape &input_shape) const override
bool Serialize(TFile *fp) const override
void PrintDW()
Definition: lstm.cpp:727
bool Is2D() const
Definition: lstm.h:119
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: lstm.cpp:441
~LSTM() override
Definition: lstm.cpp:123
int InitWeights(float range, TRand *randomizer) override
Definition: lstm.cpp:158
void PrintW()
Definition: lstm.cpp:701
void DebugWeights() override
Definition: lstm.cpp:194
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: lstm.cpp:174
bool DeSerialize(TFile *fp) override
Definition: lstm.cpp:220
bool Serialize(TFile *fp) const override
Definition: lstm.cpp:207
LSTM(const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
Definition: lstm.cpp:99
void ConvertToInt() override
Definition: lstm.cpp:183
void SetEnableTraining(TrainingState state) override
Definition: lstm.cpp:137
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: lstm.cpp:250
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: lstm.cpp:667
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: lstm.cpp:127
void CountAlternators(const Network &other, double *same, double *changed) const override
Definition: lstm.cpp:687
NetworkType type_
Definition: network.h:293
bool needs_to_backprop_
Definition: network.h:295
int num_weights() const
Definition: network.h:119
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:288
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:299
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:187
bool IsTraining() const
Definition: network.h:115
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
int32_t num_weights_
Definition: network.h:299
TrainingState training_
Definition: network.h:294
NetworkType type() const
Definition: network.h:112
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
void ResizeXTo1(const NetworkIO &src, int num_features)
Definition: networkio.cpp:70
bool int_mode() const
Definition: networkio.h:127
void ResizeFloat(const NetworkIO &src, int num_features)
Definition: networkio.h:52
void WriteTimeStepPart(int t, int offset, int num_features, const double *input)
Definition: networkio.cpp:651
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
Definition: networkio.cpp:393
void Print(int num) const
Definition: networkio.cpp:366
void WriteTimeStep(int t, const double *input)
Definition: networkio.cpp:645
float * f(int t)
Definition: networkio.h:115
int Width() const
Definition: networkio.h:107
void Func2Multiply3(const NetworkIO &v_io, int t, const double *w, double *product) const
Definition: networkio.h:315
void Transpose(TransposedArray *dest) const
Definition: networkio.cpp:964
void ReadTimeStep(int t, double *output) const
Definition: networkio.cpp:598
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
Definition: networkio.h:299
const StrideMap & stride_map() const
Definition: networkio.h:133
void ResizeToMap(bool int_mode, const StrideMap &stride_map, int num_features)
Definition: networkio.cpp:46
const int8_t * i(int t) const
Definition: networkio.h:123
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
void Init(int size, NetworkScratch *scratch)
void Init(int size1, int size2, NetworkScratch *scratch)
void set_depth(int value)
Definition: static_shape.h:49
void set_width(int value)
Definition: static_shape.h:47
int Size(FlexDimensions dimension) const
Definition: stridemap.h:114
int index(FlexDimensions dimension) const
Definition: stridemap.h:58
bool AddOffset(int offset, FlexDimensions dimension)
Definition: stridemap.cpp:62
bool IsLast(FlexDimensions dimension) const
Definition: stridemap.cpp:37
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:39
void PrintUnTransposed(int num)
Definition: weightmatrix.h:48
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
void Debug2D(const char *msg)
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
void MatrixDotVector(const double *u, double *v) const
int RoundInputs(int size) const
Definition: weightmatrix.h:92
void VectorDotMatrix(const double *u, double *v) const