tesseract 4.1.1
Loading...
Searching...
No Matches
functions.h
Go to the documentation of this file.
1
2// File: functions.h
3// Description: Collection of function-objects used by the network layers.
4// Author: Ray Smith
5//
6// (C) Copyright 2014, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18#ifndef TESSERACT_LSTM_FUNCTIONS_H_
19#define TESSERACT_LSTM_FUNCTIONS_H_
20
21#include "helpers.h"
22
23// Setting this to 1 or more causes massive dumps of debug data: weights,
24// updates, internal calculations etc, and reduces the number of test iterations
25// to a small number, so outputs can be diffed.
26#define DEBUG_DETAIL 0
27#if DEBUG_DETAIL > 0
28#undef _OPENMP // Disable open mp to get the outputs in sync.
29#endif
30
31namespace tesseract {
32
33// Size of static tables.
34constexpr int kTableSize = 4096;
35// Scale factor for float arg to int index.
36constexpr double kScaleFactor = 256.0;
37
38// Generated lookup tables.
39extern const double TanhTable[];
40extern const double LogisticTable[];
41
42// Non-linearity (sigmoid) functions with cache tables and clipping.
43inline double Tanh(double x) {
44 if (x < 0.0) return -Tanh(-x);
45 x *= kScaleFactor;
46 unsigned index = static_cast<unsigned>(x);
47 if (index >= (kTableSize - 1)) return 1.0;
48 double tanh_i0 = TanhTable[index];
49 double tanh_i1 = TanhTable[index + 1];
50 // Linear interpolation.
51 return tanh_i0 + (tanh_i1 - tanh_i0) * (x - index);
52}
53
54inline double Logistic(double x) {
55 if (x < 0.0) return 1.0 - Logistic(-x);
56 x *= kScaleFactor;
57 unsigned index = static_cast<unsigned>(x);
58 if (index >= (kTableSize - 1)) return 1.0;
59 double l0 = LogisticTable[index];
60 double l1 = LogisticTable[index + 1];
61 // Linear interpolation.
62 return l0 + (l1 - l0) * (x - index);
63}
64
65// Non-linearity (sigmoid) functions and their derivatives.
66struct FFunc {
67 inline double operator()(double x) const { return Logistic(x); }
68};
69struct FPrime {
70 inline double operator()(double y) const { return y * (1.0 - y); }
71};
72struct ClipFFunc {
73 inline double operator()(double x) const {
74 if (x <= 0.0) return 0.0;
75 if (x >= 1.0) return 1.0;
76 return x;
77 }
78};
79struct ClipFPrime {
80 inline double operator()(double y) const {
81 return 0.0 < y && y < 1.0 ? 1.0 : 0.0;
82 }
83};
84struct Relu {
85 inline double operator()(double x) const {
86 if (x <= 0.0) return 0.0;
87 return x;
88 }
89};
90struct ReluPrime {
91 inline double operator()(double y) const { return 0.0 < y ? 1.0 : 0.0; }
92};
93struct GFunc {
94 inline double operator()(double x) const { return Tanh(x); }
95};
96struct GPrime {
97 inline double operator()(double y) const { return 1.0 - y * y; }
98};
99struct ClipGFunc {
100 inline double operator()(double x) const {
101 if (x <= -1.0) return -1.0;
102 if (x >= 1.0) return 1.0;
103 return x;
104 }
105};
107 inline double operator()(double y) const {
108 return -1.0 < y && y < 1.0 ? 1.0 : 0.0;
109 }
110};
111struct HFunc {
112 inline double operator()(double x) const { return Tanh(x); }
113};
114struct HPrime {
115 inline double operator()(double y) const {
116 double u = Tanh(y);
117 return 1.0 - u * u;
118 }
119};
120struct UnityFunc {
121 inline double operator()(double /*x*/) const { return 1.0; }
122};
124 inline double operator()(double x) const { return x; }
125};
126
127// Applies Func in-place to inout, of size n.
128template <class Func>
129inline void FuncInplace(int n, double* inout) {
130 Func f;
131 for (int i = 0; i < n; ++i) {
132 inout[i] = f(inout[i]);
133 }
134}
135// Applies Func to u and multiplies the result by v component-wise,
136// putting the product in out, all of size n.
137template <class Func>
138inline void FuncMultiply(const double* u, const double* v, int n, double* out) {
139 Func f;
140 for (int i = 0; i < n; ++i) {
141 out[i] = f(u[i]) * v[i];
142 }
143}
144// Applies the Softmax function in-place to inout, of size n.
145template <typename T>
146inline void SoftmaxInPlace(int n, T* inout) {
147 if (n <= 0) return;
148 // A limit on the negative range input to exp to guarantee non-zero output.
149 const T kMaxSoftmaxActivation = 86.0f;
150
151 T max_output = inout[0];
152 for (int i = 1; i < n; i++) {
153 T output = inout[i];
154 if (output > max_output) max_output = output;
155 }
156 T prob_total = 0.0;
157 for (int i = 0; i < n; i++) {
158 T prob = inout[i] - max_output;
159 prob = exp(ClipToRange(prob, -kMaxSoftmaxActivation, static_cast<T>(0)));
160 prob_total += prob;
161 inout[i] = prob;
162 }
163 if (prob_total > 0.0) {
164 for (int i = 0; i < n; i++) inout[i] /= prob_total;
165 }
166}
167
168// Copies n values of the given src vector to dest.
169inline void CopyVector(int n, const double* src, double* dest) {
170 memcpy(dest, src, n * sizeof(dest[0]));
171}
172
173// Adds n values of the given src vector to dest.
174inline void AccumulateVector(int n, const double* src, double* dest) {
175 for (int i = 0; i < n; ++i) dest[i] += src[i];
176}
177
178// Multiplies n values of inout in-place element-wise by the given src vector.
179inline void MultiplyVectorsInPlace(int n, const double* src, double* inout) {
180 for (int i = 0; i < n; ++i) inout[i] *= src[i];
181}
182
183// Multiplies n values of u by v, element-wise, accumulating to out.
184inline void MultiplyAccumulate(int n, const double* u, const double* v,
185 double* out) {
186 for (int i = 0; i < n; i++) {
187 out[i] += u[i] * v[i];
188 }
189}
190
191// Sums the given 5 n-vectors putting the result into sum.
192inline void SumVectors(int n, const double* v1, const double* v2,
193 const double* v3, const double* v4, const double* v5,
194 double* sum) {
195 for (int i = 0; i < n; ++i) {
196 sum[i] = v1[i] + v2[i] + v3[i] + v4[i] + v5[i];
197 }
198}
199
200// Sets the given n-vector vec to 0.
201template <typename T>
202inline void ZeroVector(int n, T* vec) {
203 memset(vec, 0, n * sizeof(*vec));
204}
205
206// Clips the given vector vec, of size n to [lower, upper].
207template <typename T>
208inline void ClipVector(int n, T lower, T upper, T* vec) {
209 for (int i = 0; i < n; ++i) vec[i] = ClipToRange(vec[i], lower, upper);
210}
211
212// Converts the given n-vector to a binary encoding of the maximum value,
213// encoded as vector of nf binary values.
214inline void CodeInBinary(int n, int nf, double* vec) {
215 if (nf <= 0 || n < nf) return;
216 int index = 0;
217 double best_score = vec[0];
218 for (int i = 1; i < n; ++i) {
219 if (vec[i] > best_score) {
220 best_score = vec[i];
221 index = i;
222 }
223 }
224 int mask = 1;
225 for (int i = 0; i < nf; ++i, mask *= 2) {
226 vec[i] = (index & mask) ? 1.0 : 0.0;
227 }
228}
229
230} // namespace tesseract.
231
232#endif // TESSERACT_LSTM_FUNCTIONS_H_
T ClipToRange(const T &x, const T &lower_bound, const T &upper_bound)
Definition: helpers.h:108
const double TanhTable[]
Definition: functions.cpp:4
constexpr double kScaleFactor
Definition: functions.h:36
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
Definition: functions.h:192
void FuncMultiply(const double *u, const double *v, int n, double *out)
Definition: functions.h:138
void AccumulateVector(int n, const double *src, double *dest)
Definition: functions.h:174
void SoftmaxInPlace(int n, T *inout)
Definition: functions.h:146
void CodeInBinary(int n, int nf, double *vec)
Definition: functions.h:214
double Logistic(double x)
Definition: functions.h:54
void FuncInplace(int n, double *inout)
Definition: functions.h:129
const double LogisticTable[]
Definition: functions.cpp:4102
void MultiplyVectorsInPlace(int n, const double *src, double *inout)
Definition: functions.h:179
double Tanh(double x)
Definition: functions.h:43
void CopyVector(int n, const double *src, double *dest)
Definition: functions.h:169
constexpr int kTableSize
Definition: functions.h:34
void ClipVector(int n, T lower, T upper, T *vec)
Definition: functions.h:208
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
Definition: functions.h:184
void ZeroVector(int n, T *vec)
Definition: functions.h:202
double operator()(double x) const
Definition: functions.h:67
double operator()(double y) const
Definition: functions.h:70
double operator()(double x) const
Definition: functions.h:73
double operator()(double y) const
Definition: functions.h:80
double operator()(double x) const
Definition: functions.h:85
double operator()(double y) const
Definition: functions.h:91
double operator()(double x) const
Definition: functions.h:94
double operator()(double y) const
Definition: functions.h:97
double operator()(double x) const
Definition: functions.h:100
double operator()(double y) const
Definition: functions.h:107
double operator()(double x) const
Definition: functions.h:112
double operator()(double y) const
Definition: functions.h:115
double operator()(double) const
Definition: functions.h:121
double operator()(double x) const
Definition: functions.h:124