tesseract 4.1.1
Loading...
Searching...
No Matches
networkscratch.h
Go to the documentation of this file.
1
2// File: networkscratch.h
3// Description: Scratch space for Network layers that hides distinction
4// between float/int implementations.
5// Author: Ray Smith
6//
7// (C) Copyright 2014, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18
19#ifndef TESSERACT_LSTM_NETWORKSCRATCH_H_
20#define TESSERACT_LSTM_NETWORKSCRATCH_H_
21
22#include "genericvector.h"
23#include "matrix.h"
24#include "networkio.h"
25#include "svutil.h"
26
27namespace tesseract {
28
29// Generic scratch space for network layers. Provides NetworkIO that can store
30// a complete set (over time) of intermediates, and GenericVector<float>
31// scratch space that auto-frees after use. The aim here is to provide a set
32// of temporary buffers to network layers that can be reused between layers
33// and don't have to be reallocated on each call.
35 public:
36 NetworkScratch() : int_mode_(false) {}
37 ~NetworkScratch() = default;
38
39 // Sets the network representation. If the representation is integer, then
40 // default (integer) NetworkIOs are separated from the always-float variety.
41 // This saves memory by having separate int-specific and float-specific
42 // stacks. If the network representation is float, then all NetworkIOs go
43 // to the float stack.
44 void set_int_mode(bool int_mode) {
45 int_mode_ = int_mode;
46 }
47
48 // Class that acts like a NetworkIO (by having an implicit cast operator),
49 // yet actually holds a pointer to NetworkIOs in the source NetworkScratch,
50 // and knows how to unstack the borrowed pointers on destruction.
51 class IO {
52 public:
53 // The NetworkIO should be sized after construction.
54 IO(const NetworkIO& src, NetworkScratch* scratch)
55 : int_mode_(scratch->int_mode_ && src.int_mode()),
56 scratch_space_(scratch) {
57 network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
58 : scratch_space_->float_stack_.Borrow();
59 }
60 // Default constructor for arrays. Use one of the Resize functions
61 // below to initialize and size.
62 IO() : int_mode_(false), network_io_(nullptr), scratch_space_(nullptr) {}
63
64 ~IO() {
65 if (scratch_space_ == nullptr) {
66 ASSERT_HOST(network_io_ == nullptr);
67 } else if (int_mode_) {
68 scratch_space_->int_stack_.Return(network_io_);
69 } else {
70 scratch_space_->float_stack_.Return(network_io_);
71 }
72 }
73 // Resizes the array (and stride), avoiding realloc if possible, to the
74 // size from various size specs:
75 // Same time size, given number of features.
76 void Resize(const NetworkIO& src, int num_features,
77 NetworkScratch* scratch) {
78 if (scratch_space_ == nullptr) {
79 int_mode_ = scratch->int_mode_ && src.int_mode();
80 scratch_space_ = scratch;
81 network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
82 : scratch_space_->float_stack_.Borrow();
83 }
84 network_io_->Resize(src, num_features);
85 }
86 // Resizes to a specific size as a temp buffer. No batches, no y-dim.
87 void Resize2d(bool int_mode, int width, int num_features,
88 NetworkScratch* scratch) {
89 if (scratch_space_ == nullptr) {
90 int_mode_ = scratch->int_mode_ && int_mode;
91 scratch_space_ = scratch;
92 network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
93 : scratch_space_->float_stack_.Borrow();
94 }
95 network_io_->Resize2d(int_mode, width, num_features);
96 }
97 // Resize forcing a float representation with the width of src and the given
98 // number of features.
99 void ResizeFloat(const NetworkIO& src, int num_features,
100 NetworkScratch* scratch) {
101 if (scratch_space_ == nullptr) {
102 int_mode_ = false;
103 scratch_space_ = scratch;
104 network_io_ = scratch_space_->float_stack_.Borrow();
105 }
106 network_io_->ResizeFloat(src, num_features);
107 }
108
109 // Returns a ref to a NetworkIO that enables *this to be treated as if
110 // it were just a NetworkIO*.
112 return *network_io_;
113 }
115 return network_io_;
116 }
117 operator NetworkIO*() {
118 return network_io_;
119 }
120
121 private:
122 // True if this is from the always-float stack, otherwise the default stack.
123 bool int_mode_;
124 // The NetworkIO that we have borrowed from the scratch_space_.
125 NetworkIO* network_io_;
126 // The source scratch_space_. Borrowed pointer, used to free the
127 // NetworkIO. Don't delete!
128 NetworkScratch* scratch_space_;
129 }; // class IO.
130
131 // Class that acts like a fixed array of float, yet actually uses space
132 // from a GenericVector<float> in the source NetworkScratch, and knows how
133 // to unstack the borrowed vector on destruction.
134 class FloatVec {
135 public:
136 // The array will have size elements in it, uninitialized.
137 FloatVec(int size, NetworkScratch* scratch)
138 : vec_(nullptr), scratch_space_(scratch) {
139 Init(size, scratch);
140 }
141 // Default constructor is for arrays. Use Init to setup.
142 FloatVec() : vec_(nullptr), data_(nullptr), scratch_space_(nullptr) {}
144 if (scratch_space_ != nullptr) scratch_space_->vec_stack_.Return(vec_);
145 }
146
147 void Init(int size, NetworkScratch* scratch) {
148 if (scratch_space_ != nullptr && vec_ != nullptr)
149 scratch_space_->vec_stack_.Return(vec_);
150 scratch_space_ = scratch;
151 vec_ = scratch_space_->vec_stack_.Borrow();
152 vec_->resize_no_init(size);
153 data_ = &(*vec_)[0];
154 }
155
156 // Use the cast operator instead of operator[] so the FloatVec can be used
157 // as a double* argument to a function call.
158 operator double*() const { return data_; }
159 double* get() { return data_; }
160
161 private:
162 // Vector borrowed from the scratch space. Use Return to free it.
164 // Short-cut pointer to the underlying array.
165 double* data_;
166 // The source scratch_space_. Borrowed pointer, used to free the
167 // vector. Don't delete!
168 NetworkScratch* scratch_space_;
169 }; // class FloatVec
170
171 // Class that acts like a 2-D array of double, yet actually uses space
172 // from the source NetworkScratch, and knows how to unstack the borrowed
173 // array on destruction.
175 public:
176 // Default constructor is for arrays. Use Init to setup.
177 GradientStore() : array_(nullptr), scratch_space_(nullptr) {}
179 if (scratch_space_ != nullptr) scratch_space_->array_stack_.Return(array_);
180 }
181
182 void Init(int size1, int size2, NetworkScratch* scratch) {
183 if (scratch_space_ != nullptr && array_ != nullptr)
184 scratch_space_->array_stack_.Return(array_);
185 scratch_space_ = scratch;
186 array_ = scratch_space_->array_stack_.Borrow();
187 array_->Resize(size1, size2, 0.0);
188 }
189
190 // Accessors to get to the underlying TransposedArray.
191 TransposedArray* get() const { return array_; }
192 const TransposedArray& operator*() const { return *array_; }
193
194 private:
195 // Array borrowed from the scratch space. Use Return to free it.
196 TransposedArray* array_;
197 // The source scratch_space_. Borrowed pointer, used to free the
198 // vector. Don't delete!
199 NetworkScratch* scratch_space_;
200 }; // class GradientStore
201
202 // Class that does the work of holding a stack of objects, a stack pointer
203 // and a vector of in-use flags, so objects can be returned out of order.
204 // It is safe to attempt to Borrow/Return in multiple threads.
205 template<typename T> class Stack {
206 public:
207 Stack() : stack_top_(0) {
208 }
209
210 // Lends out the next free item, creating one if none available, sets
211 // the used flags and increments the stack top.
212 T* Borrow() {
213 SVAutoLock lock(&mutex_);
214 if (stack_top_ == stack_.size()) {
215 stack_.push_back(new T);
216 flags_.push_back(false);
217 }
218 flags_[stack_top_] = true;
219 return stack_[stack_top_++];
220 }
221 // Takes back the given item, and marks it free. Item does not have to be
222 // the most recently lent out, but free slots don't get re-used until the
223 // blocking item is returned. The assumption is that there will only be
224 // small, temporary variations from true stack use. (Determined by the order
225 // of destructors within a local scope.)
226 void Return(T* item) {
227 SVAutoLock lock(&mutex_);
228 // Linear search will do.
229 int index = stack_top_ - 1;
230 while (index >= 0 && stack_[index] != item) --index;
231 if (index >= 0) flags_[index] = false;
232 while (stack_top_ > 0 && !flags_[stack_top_ - 1]) --stack_top_;
233 }
234
235 private:
236 PointerVector<T> stack_;
237 GenericVector<bool> flags_;
238 int stack_top_;
239 SVMutex mutex_;
240 }; // class Stack.
241
242 private:
243 // If true, the network weights are int8_t, if false, float.
244 bool int_mode_;
245 // Stacks of NetworkIO and GenericVector<float>. Once allocated, they are not
246 // deleted until the NetworkScratch is deleted.
247 Stack<NetworkIO> int_stack_;
248 Stack<NetworkIO> float_stack_;
249 Stack<GenericVector<double> > vec_stack_;
250 Stack<TransposedArray> array_stack_;
251};
252
253} // namespace tesseract.
254
255#endif // TESSERACT_LSTM_NETWORKSCRATCH_H_
#define ASSERT_HOST(x)
Definition: errcode.h:88
void resize_no_init(int size)
Definition: genericvector.h:66
int push_back(T object)
void Resize(int size1, int size2, const T &empty)
Definition: matrix.h:108
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
bool int_mode() const
Definition: networkio.h:127
void ResizeFloat(const NetworkIO &src, int num_features)
Definition: networkio.h:52
void Resize2d(bool int_mode, int width, int num_features)
Definition: networkio.cpp:35
void set_int_mode(bool int_mode)
void Resize(const NetworkIO &src, int num_features, NetworkScratch *scratch)
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
void ResizeFloat(const NetworkIO &src, int num_features, NetworkScratch *scratch)
IO(const NetworkIO &src, NetworkScratch *scratch)
void Init(int size, NetworkScratch *scratch)
FloatVec(int size, NetworkScratch *scratch)
const TransposedArray & operator*() const
void Init(int size1, int size2, NetworkScratch *scratch)
Definition: svutil.h:68