tesseract 4.1.1
Loading...
Searching...
No Matches
input.h
Go to the documentation of this file.
1
2// File: input.h
3// Description: Input layer class for neural network implementations.
4// Author: Ray Smith
5//
6// (C) Copyright 2014, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18#ifndef TESSERACT_LSTM_INPUT_H_
19#define TESSERACT_LSTM_INPUT_H_
20
21#include "network.h"
22
23class ScrollView;
24
25namespace tesseract {
26
27class Input : public Network {
28 public:
29 Input(const STRING& name, int ni, int no);
30 Input(const STRING& name, const StaticShape& shape);
31 ~Input() override = default;
32
33 STRING spec() const override {
35 spec.add_str_int("", shape_.batch());
36 spec.add_str_int(",", shape_.height());
37 spec.add_str_int(",", shape_.width());
38 spec.add_str_int(",", shape_.depth());
39 return spec;
40 }
41
42 // Returns the required shape input to the network.
43 StaticShape InputShape() const override { return shape_; }
44 // Returns the shape output from the network given an input shape (which may
45 // be partially unknown ie zero).
46 StaticShape OutputShape(const StaticShape& input_shape) const override {
47 return shape_;
48 }
49 // Writes to the given file. Returns false in case of error.
50 // Should be overridden by subclasses, but called by their Serialize.
51 bool Serialize(TFile* fp) const override;
52 // Reads from the given file. Returns false in case of error.
53 bool DeSerialize(TFile* fp) override;
54
55 // Returns an integer reduction factor that the network applies to the
56 // time sequence. Assumes that any 2-d is already eliminated. Used for
57 // scaling bounding boxes of truth data.
58 // WARNING: if GlobalMinimax is used to vary the scale, this will return
59 // the last used scale factor. Call it before any forward, and it will return
60 // the minimum scale factor of the paths through the GlobalMinimax.
61 int XScaleFactor() const override;
62
63 // Provides the (minimum) x scale factor to the network (of interest only to
64 // input units) so they can determine how to scale bounding boxes.
65 void CacheXScaleFactor(int factor) override;
66
67 // Runs forward propagation of activations on the input line.
68 // See Network for a detailed discussion of the arguments.
69 void Forward(bool debug, const NetworkIO& input,
70 const TransposedArray* input_transpose,
71 NetworkScratch* scratch, NetworkIO* output) override;
72
73 // Runs backward propagation of errors on the deltas line.
74 // See Network for a detailed discussion of the arguments.
75 bool Backward(bool debug, const NetworkIO& fwd_deltas,
76 NetworkScratch* scratch,
77 NetworkIO* back_deltas) override;
78 // Creates and returns a Pix of appropriate size for the network from the
79 // image_data. If non-null, *image_scale returns the image scale factor used.
80 // Returns nullptr on error.
81 /* static */
82 static Pix* PrepareLSTMInputs(const ImageData& image_data,
83 const Network* network, int min_width,
84 TRand* randomizer, float* image_scale);
85 // Converts the given pix to a NetworkIO of height and depth appropriate to
86 // the given StaticShape:
87 // If depth == 3, convert to 24 bit color, otherwise normalized grey.
88 // Scale to target height, if the shape's height is > 1, or its depth if the
89 // height == 1. If height == 0 then no scaling.
90 // NOTE: It isn't safe for multiple threads to call this on the same pix.
91 static void PreparePixInput(const StaticShape& shape, const Pix* pix,
92 TRand* randomizer, NetworkIO* input);
93
94 private:
95 void DebugWeights() override {
96 tprintf("Must override Network::DebugWeights for type %d\n", type_);
97 }
98
99 // Input shape determines how images are dealt with.
100 StaticShape shape_;
101 // Cached total network x scale factor for scaling bounding boxes.
102 int cached_x_scale_;
103};
104
105} // namespace tesseract.
106
107#endif // TESSERACT_LSTM_INPUT_H_
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35
Definition: strngs.h:45
void add_str_int(const char *str, int number)
Definition: strngs.cpp:377
StaticShape InputShape() const override
Definition: input.h:43
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: input.cpp:64
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: input.h:46
static void PreparePixInput(const StaticShape &shape, const Pix *pix, TRand *randomizer, NetworkIO *input)
Definition: input.cpp:111
int XScaleFactor() const override
Definition: input.cpp:52
static Pix * PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
Definition: input.cpp:83
void CacheXScaleFactor(int factor) override
Definition: input.cpp:58
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: input.cpp:72
~Input() override=default
bool Serialize(TFile *fp) const override
Definition: input.cpp:40
bool DeSerialize(TFile *fp) override
Definition: input.cpp:45
STRING spec() const override
Definition: input.h:33
NetworkType type_
Definition: network.h:293
const STRING & name() const
Definition: network.h:138