tesseract 4.1.1
Loading...
Searching...
No Matches
series.h
Go to the documentation of this file.
1
2// File: series.h
3// Description: Runs networks in series on the same input.
4// Author: Ray Smith
5// Created: Thu May 02 08:20:06 PST 2013
6//
7// (C) Copyright 2013, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18
19#ifndef TESSERACT_LSTM_SERIES_H_
20#define TESSERACT_LSTM_SERIES_H_
21
22#include "plumbing.h"
23
24namespace tesseract {
25
26// Runs two or more networks in series (layers) on the same input.
27class Series : public Plumbing {
28 public:
29 // ni_ and no_ will be set by AddToStack.
30 explicit Series(const STRING& name);
31 ~Series() override = default;
32
33 // Returns the shape output from the network given an input shape (which may
34 // be partially unknown ie zero).
35 StaticShape OutputShape(const StaticShape& input_shape) const override;
36
37 STRING spec() const override {
38 STRING spec("[");
39 for (int i = 0; i < stack_.size(); ++i)
40 spec += stack_[i]->spec();
41 spec += "]";
42 return spec;
43 }
44
45 // Sets up the network for training. Initializes weights using weights of
46 // scale `range` picked according to the random number generator `randomizer`.
47 // Returns the number of weights initialized.
48 int InitWeights(float range, TRand* randomizer) override;
49 // Recursively searches the network for softmaxes with old_no outputs,
50 // and remaps their outputs according to code_map. See network.h for details.
51 int RemapOutputs(int old_no, const std::vector<int>& code_map) override;
52
53 // Sets needs_to_backprop_ to needs_backprop and returns true if
54 // needs_backprop || any weights in this network so the next layer forward
55 // can be told to produce backprop for this layer if needed.
56 bool SetupNeedsBackprop(bool needs_backprop) override;
57
58 // Returns an integer reduction factor that the network applies to the
59 // time sequence. Assumes that any 2-d is already eliminated. Used for
60 // scaling bounding boxes of truth data.
61 // WARNING: if GlobalMinimax is used to vary the scale, this will return
62 // the last used scale factor. Call it before any forward, and it will return
63 // the minimum scale factor of the paths through the GlobalMinimax.
64 int XScaleFactor() const override;
65
66 // Provides the (minimum) x scale factor to the network (of interest only to
67 // input units) so they can determine how to scale bounding boxes.
68 void CacheXScaleFactor(int factor) override;
69
70 // Runs forward propagation of activations on the input line.
71 // See Network for a detailed discussion of the arguments.
72 void Forward(bool debug, const NetworkIO& input,
73 const TransposedArray* input_transpose, NetworkScratch* scratch,
74 NetworkIO* output) override;
75
76 // Runs backward propagation of errors on the deltas line.
77 // See Network for a detailed discussion of the arguments.
78 bool Backward(bool debug, const NetworkIO& fwd_deltas,
79 NetworkScratch* scratch, NetworkIO* back_deltas) override;
80
81 // Splits the series after the given index, returning the two parts and
82 // deletes itself. The first part, up to network with index last_start, goes
83 // into start, and the rest goes into end.
84 void SplitAt(int last_start, Series** start, Series** end);
85
86 // Appends the elements of the src series to this, removing from src and
87 // deleting it.
88 void AppendSeries(Network* src);
89};
90
91} // namespace tesseract.
92
93#endif // TESSERACT_LSTM_SERIES_H_
Definition: strngs.h:45
const STRING & name() const
Definition: network.h:138
PointerVector< Network > stack_
Definition: plumbing.h:136
STRING spec() const override
Definition: series.h:37
bool SetupNeedsBackprop(bool needs_backprop) override
Definition: series.cpp:79
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: series.cpp:129
void AppendSeries(Network *src)
Definition: series.cpp:190
int XScaleFactor() const override
Definition: series.cpp:92
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: series.cpp:35
void CacheXScaleFactor(int factor) override
Definition: series.cpp:101
int InitWeights(float range, TRand *randomizer) override
Definition: series.cpp:47
void SplitAt(int last_start, Series **start, Series **end)
Definition: series.cpp:160
~Series() override=default
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: series.cpp:62
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: series.cpp:107