tesseract 4.1.1
Loading...
Searching...
No Matches
parallel.h
Go to the documentation of this file.
1
2// File: parallel.h
3// Description: Runs networks in parallel on the same input.
4// Author: Ray Smith
5// Created: Thu May 02 08:02:06 PST 2013
6//
7// (C) Copyright 2013, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18
19#ifndef TESSERACT_LSTM_PARALLEL_H_
20#define TESSERACT_LSTM_PARALLEL_H_
21
22#include "plumbing.h"
23
24namespace tesseract {
25
26// Runs multiple networks in parallel, interlacing their outputs.
27class Parallel : public Plumbing {
28 public:
29 // ni_ and no_ will be set by AddToStack.
31 ~Parallel() override = default;
32
33 // Returns the shape output from the network given an input shape (which may
34 // be partially unknown ie zero).
35 StaticShape OutputShape(const StaticShape& input_shape) const override;
36
37 STRING spec() const override {
39 if (type_ == NT_PAR_2D_LSTM) {
40 // We have 4 LSTMs operating in parallel here, so the size of each is
41 // the number of outputs/4.
42 spec.add_str_int("L2xy", no_ / 4);
43 } else if (type_ == NT_PAR_RL_LSTM) {
44 // We have 2 LSTMs operating in parallel here, so the size of each is
45 // the number of outputs/2.
46 if (stack_[0]->type() == NT_LSTM_SUMMARY)
47 spec.add_str_int("Lbxs", no_ / 2);
48 else
49 spec.add_str_int("Lbx", no_ / 2);
50 } else {
51 if (type_ == NT_REPLICATED) {
52 spec.add_str_int("R", stack_.size());
53 spec += "(";
54 spec += stack_[0]->spec();
55 } else {
56 spec = "(";
57 for (int i = 0; i < stack_.size(); ++i) spec += stack_[i]->spec();
58 }
59 spec += ")";
60 }
61 return spec;
62 }
63
64 // Runs forward propagation of activations on the input line.
65 // See Network for a detailed discussion of the arguments.
66 void Forward(bool debug, const NetworkIO& input,
67 const TransposedArray* input_transpose,
68 NetworkScratch* scratch, NetworkIO* output) override;
69
70 // Runs backward propagation of errors on the deltas line.
71 // See Network for a detailed discussion of the arguments.
72 bool Backward(bool debug, const NetworkIO& fwd_deltas,
73 NetworkScratch* scratch,
74 NetworkIO* back_deltas) override;
75
76 private:
77 // If *this is a NT_REPLICATED, then it feeds a replicated network with
78 // identical inputs, and it would be extremely wasteful for them to each
79 // calculate and store the same transpose of the inputs, so Parallel does it
80 // and passes a pointer to the replicated network, allowing it to use the
81 // transpose on the next call to Backward.
82 TransposedArray transposed_input_;
83};
84
85} // namespace tesseract.
86
87#endif // TESSERACT_LSTM_PARALLEL_H_
NetworkType
Definition: network.h:43
@ NT_PAR_2D_LSTM
Definition: network.h:53
@ NT_LSTM_SUMMARY
Definition: network.h:61
@ NT_PAR_RL_LSTM
Definition: network.h:51
@ NT_REPLICATED
Definition: network.h:50
Definition: strngs.h:45
void add_str_int(const char *str, int number)
Definition: strngs.cpp:377
NetworkType type_
Definition: network.h:293
const STRING & name() const
Definition: network.h:138
NetworkType type() const
Definition: network.h:112
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: parallel.cpp:49
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: parallel.cpp:110
STRING spec() const override
Definition: parallel.h:37
~Parallel() override=default
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: parallel.cpp:37
PointerVector< Network > stack_
Definition: plumbing.h:136