tesseract 4.1.1
Loading...
Searching...
No Matches
plumbing.cpp
Go to the documentation of this file.
1
2// File: plumbing.cpp
3// Description: Base class for networks that organize other networks
4// eg series or parallel.
5// Author: Ray Smith
6// Created: Mon May 12 08:17:34 PST 2014
7//
8// (C) Copyright 2014, Google Inc.
9// Licensed under the Apache License, Version 2.0 (the "License");
10// you may not use this file except in compliance with the License.
11// You may obtain a copy of the License at
12// http://www.apache.org/licenses/LICENSE-2.0
13// Unless required by applicable law or agreed to in writing, software
14// distributed under the License is distributed on an "AS IS" BASIS,
15// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16// See the License for the specific language governing permissions and
17// limitations under the License.
19
20#include "plumbing.h"
21
22namespace tesseract {
23
24// ni_ and no_ will be set by AddToStack.
26 : Network(NT_PARALLEL, name, 0, 0) {
27}
28
29// Suspends/Enables training by setting the training_ flag. Serialize and
30// DeSerialize only operate on the run-time data if state is false.
33 for (int i = 0; i < stack_.size(); ++i)
34 stack_[i]->SetEnableTraining(state);
35}
36
37// Sets flags that control the action of the network. See NetworkFlags enum
38// for bit values.
39void Plumbing::SetNetworkFlags(uint32_t flags) {
41 for (int i = 0; i < stack_.size(); ++i)
42 stack_[i]->SetNetworkFlags(flags);
43}
44
45// Sets up the network for training. Initializes weights using weights of
46// scale `range` picked according to the random number generator `randomizer`.
47// Note that randomizer is a borrowed pointer that should outlive the network
48// and should not be deleted by any of the networks.
49// Returns the number of weights initialized.
50int Plumbing::InitWeights(float range, TRand* randomizer) {
51 num_weights_ = 0;
52 for (int i = 0; i < stack_.size(); ++i)
53 num_weights_ += stack_[i]->InitWeights(range, randomizer);
54 return num_weights_;
55}
56
57// Recursively searches the network for softmaxes with old_no outputs,
58// and remaps their outputs according to code_map. See network.h for details.
59int Plumbing::RemapOutputs(int old_no, const std::vector<int>& code_map) {
60 num_weights_ = 0;
61 for (int i = 0; i < stack_.size(); ++i) {
62 num_weights_ += stack_[i]->RemapOutputs(old_no, code_map);
63 }
64 return num_weights_;
65}
66
67// Converts a float network to an int network.
69 for (int i = 0; i < stack_.size(); ++i)
70 stack_[i]->ConvertToInt();
71}
72
73// Provides a pointer to a TRand for any networks that care to use it.
74// Note that randomizer is a borrowed pointer that should outlive the network
75// and should not be deleted by any of the networks.
76void Plumbing::SetRandomizer(TRand* randomizer) {
77 for (int i = 0; i < stack_.size(); ++i)
78 stack_[i]->SetRandomizer(randomizer);
79}
80
81// Adds the given network to the stack.
83 if (stack_.empty()) {
84 ni_ = network->NumInputs();
85 no_ = network->NumOutputs();
86 } else if (type_ == NT_SERIES) {
87 // ni is input of first, no output of last, others match output to input.
88 ASSERT_HOST(no_ == network->NumInputs());
89 no_ = network->NumOutputs();
90 } else {
91 // All parallel types. Output is sum of outputs, inputs all match.
92 ASSERT_HOST(ni_ == network->NumInputs());
93 no_ += network->NumOutputs();
94 }
95 stack_.push_back(network);
96}
97
98// Sets needs_to_backprop_ to needs_backprop and calls on sub-network
99// according to needs_backprop || any weights in this network.
100bool Plumbing::SetupNeedsBackprop(bool needs_backprop) {
101 if (IsTraining()) {
102 needs_to_backprop_ = needs_backprop;
103 bool retval = needs_backprop;
104 for (int i = 0; i < stack_.size(); ++i) {
105 if (stack_[i]->SetupNeedsBackprop(needs_backprop)) retval = true;
106 }
107 return retval;
108 }
109 // Frozen networks don't do backprop.
110 needs_to_backprop_ = false;
111 return false;
112}
113
114// Returns an integer reduction factor that the network applies to the
115// time sequence. Assumes that any 2-d is already eliminated. Used for
116// scaling bounding boxes of truth data.
117// WARNING: if GlobalMinimax is used to vary the scale, this will return
118// the last used scale factor. Call it before any forward, and it will return
119// the minimum scale factor of the paths through the GlobalMinimax.
121 return stack_[0]->XScaleFactor();
122}
123
124// Provides the (minimum) x scale factor to the network (of interest only to
125// input units) so they can determine how to scale bounding boxes.
127 for (int i = 0; i < stack_.size(); ++i) {
128 stack_[i]->CacheXScaleFactor(factor);
129 }
130}
131
132// Provides debug output on the weights.
134 for (int i = 0; i < stack_.size(); ++i)
135 stack_[i]->DebugWeights();
136}
137
138// Returns a set of strings representing the layer-ids of all layers below.
140 GenericVector<STRING>* layers) const {
141 for (int i = 0; i < stack_.size(); ++i) {
142 STRING layer_name;
143 if (prefix) layer_name = *prefix;
144 layer_name.add_str_int(":", i);
145 if (stack_[i]->IsPlumbingType()) {
146 auto* plumbing = static_cast<Plumbing*>(stack_[i]);
147 plumbing->EnumerateLayers(&layer_name, layers);
148 } else {
149 layers->push_back(layer_name);
150 }
151 }
152}
153
154// Returns a pointer to the network layer corresponding to the given id.
155Network* Plumbing::GetLayer(const char* id) const {
156 char* next_id;
157 int index = strtol(id, &next_id, 10);
158 if (index < 0 || index >= stack_.size()) return nullptr;
159 if (stack_[index]->IsPlumbingType()) {
160 auto* plumbing = static_cast<Plumbing*>(stack_[index]);
161 ASSERT_HOST(*next_id == ':');
162 return plumbing->GetLayer(next_id + 1);
163 }
164 return stack_[index];
165}
166
167// Returns a pointer to the learning rate for the given layer id.
168float* Plumbing::LayerLearningRatePtr(const char* id) const {
169 char* next_id;
170 int index = strtol(id, &next_id, 10);
171 if (index < 0 || index >= stack_.size()) return nullptr;
172 if (stack_[index]->IsPlumbingType()) {
173 auto* plumbing = static_cast<Plumbing*>(stack_[index]);
174 ASSERT_HOST(*next_id == ':');
175 return plumbing->LayerLearningRatePtr(next_id + 1);
176 }
177 if (index >= learning_rates_.size()) return nullptr;
178 return &learning_rates_[index];
179}
180
181// Writes to the given file. Returns false in case of error.
182bool Plumbing::Serialize(TFile* fp) const {
183 if (!Network::Serialize(fp)) return false;
184 uint32_t size = stack_.size();
185 // Can't use PointerVector::Serialize here as we need a special DeSerialize.
186 if (!fp->Serialize(&size)) return false;
187 for (uint32_t i = 0; i < size; ++i)
188 if (!stack_[i]->Serialize(fp)) return false;
191 return false;
192 }
193 return true;
194}
195
196// Reads from the given file. Returns false in case of error.
198 stack_.truncate(0);
199 no_ = 0; // We will be modifying this as we AddToStack.
200 uint32_t size;
201 if (!fp->DeSerialize(&size)) return false;
202 for (uint32_t i = 0; i < size; ++i) {
203 Network* network = CreateFromFile(fp);
204 if (network == nullptr) return false;
205 AddToStack(network);
206 }
209 return false;
210 }
211 return true;
212}
213
214// Updates the weights using the given learning rate, momentum and adam_beta.
215// num_samples is used in the adam computation iff use_adam_ is true.
216void Plumbing::Update(float learning_rate, float momentum, float adam_beta,
217 int num_samples) {
218 for (int i = 0; i < stack_.size(); ++i) {
220 if (i < learning_rates_.size())
221 learning_rate = learning_rates_[i];
222 else
223 learning_rates_.push_back(learning_rate);
224 }
225 if (stack_[i]->IsTraining()) {
226 stack_[i]->Update(learning_rate, momentum, adam_beta, num_samples);
227 }
228 }
229}
230
231// Sums the products of weight updates in *this and other, splitting into
232// positive (same direction) in *same and negative (different direction) in
233// *changed.
234void Plumbing::CountAlternators(const Network& other, double* same,
235 double* changed) const {
236 ASSERT_HOST(other.type() == type_);
237 const auto* plumbing = static_cast<const Plumbing*>(&other);
238 ASSERT_HOST(plumbing->stack_.size() == stack_.size());
239 for (int i = 0; i < stack_.size(); ++i)
240 stack_[i]->CountAlternators(*plumbing->stack_[i], same, changed);
241}
242
243} // namespace tesseract.
#define ASSERT_HOST(x)
Definition: errcode.h:88
TrainingState
Definition: network.h:92
@ NT_PARALLEL
Definition: network.h:49
@ NT_SERIES
Definition: network.h:54
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:87
int push_back(T object)
bool Serialize(FILE *fp) const
int size() const
Definition: genericvector.h:72
bool DeSerialize(bool swap, FILE *fp)
bool Serialize(const char *data, size_t count=1)
Definition: serialis.cpp:148
bool DeSerialize(char *data, size_t count=1)
Definition: serialis.cpp:104
Definition: strngs.h:45
void add_str_int(const char *str, int number)
Definition: strngs.cpp:377
int32_t network_flags_
Definition: network.h:296
NetworkType type_
Definition: network.h:293
int NumOutputs() const
Definition: network.h:123
bool needs_to_backprop_
Definition: network.h:295
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:110
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:187
bool IsTraining() const
Definition: network.h:115
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
int NumInputs() const
Definition: network.h:120
int32_t num_weights_
Definition: network.h:299
virtual void SetNetworkFlags(uint32_t flags)
Definition: network.cpp:124
NetworkType type() const
Definition: network.h:112
PointerVector< Network > stack_
Definition: plumbing.h:136
void SetEnableTraining(TrainingState state) override
Definition: plumbing.cpp:31
bool DeSerialize(TFile *fp) override
Definition: plumbing.cpp:197
void CacheXScaleFactor(int factor) override
Definition: plumbing.cpp:126
int XScaleFactor() const override
Definition: plumbing.cpp:120
void ConvertToInt() override
Definition: plumbing.cpp:68
bool SetupNeedsBackprop(bool needs_backprop) override
Definition: plumbing.cpp:100
Plumbing(const STRING &name)
Definition: plumbing.cpp:25
int InitWeights(float range, TRand *randomizer) override
Definition: plumbing.cpp:50
void SetRandomizer(TRand *randomizer) override
Definition: plumbing.cpp:76
virtual void AddToStack(Network *network)
Definition: plumbing.cpp:82
float * LayerLearningRatePtr(const char *id) const
Definition: plumbing.cpp:168
GenericVector< float > learning_rates_
Definition: plumbing.h:139
void EnumerateLayers(const STRING *prefix, GenericVector< STRING > *layers) const
Definition: plumbing.cpp:139
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: plumbing.cpp:59
void SetNetworkFlags(uint32_t flags) override
Definition: plumbing.cpp:39
void CountAlternators(const Network &other, double *same, double *changed) const override
Definition: plumbing.cpp:234
void DebugWeights() override
Definition: plumbing.cpp:133
Network * GetLayer(const char *id) const
Definition: plumbing.cpp:155
bool Serialize(TFile *fp) const override
Definition: plumbing.cpp:182
bool IsPlumbingType() const override
Definition: plumbing.h:44
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: plumbing.cpp:216