tesseract 4.1.1
Loading...
Searching...
No Matches
tesseract::FullyConnected Class Reference

#include <fullyconnected.h>

Inheritance diagram for tesseract::FullyConnected:
tesseract::Network

Public Member Functions

 FullyConnected (const STRING &name, int ni, int no, NetworkType type)
 
 ~FullyConnected () override=default
 
StaticShape OutputShape (const StaticShape &input_shape) const override
 
STRING spec () const override
 
void ChangeType (NetworkType type)
 
void SetEnableTraining (TrainingState state) override
 
int InitWeights (float range, TRand *randomizer) override
 
int RemapOutputs (int old_no, const std::vector< int > &code_map) override
 
void ConvertToInt () override
 
void DebugWeights () override
 
bool Serialize (TFile *fp) const override
 
bool DeSerialize (TFile *fp) override
 
void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
 
void SetupForward (const NetworkIO &input, const TransposedArray *input_transpose)
 
void ForwardTimeStep (int t, double *output_line)
 
void ForwardTimeStep (const double *d_input, int t, double *output_line)
 
void ForwardTimeStep (const int8_t *i_input, int t, double *output_line)
 
bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
 
void BackwardTimeStep (const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
 
void FinishBackward (const TransposedArray &errors_t)
 
void Update (float learning_rate, float momentum, float adam_beta, int num_samples) override
 
void CountAlternators (const Network &other, double *same, double *changed) const override
 
- Public Member Functions inherited from tesseract::Network
 Network ()
 
 Network (NetworkType type, const STRING &name, int ni, int no)
 
virtual ~Network ()=default
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
virtual StaticShape InputShape () const
 
virtual StaticShape OutputShape (const StaticShape &input_shape) const
 
const STRINGname () const
 
virtual STRING spec () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetEnableTraining (TrainingState state)
 
virtual void SetNetworkFlags (uint32_t flags)
 
virtual int InitWeights (float range, TRand *randomizer)
 
virtual int RemapOutputs (int old_no, const std::vector< int > &code_map)
 
virtual void ConvertToInt ()
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual int XScaleFactor () const
 
virtual void CacheXScaleFactor (int factor)
 
virtual void DebugWeights ()=0
 
virtual bool Serialize (TFile *fp) const
 
virtual bool DeSerialize (TFile *fp)=0
 
virtual void Update (float learning_rate, float momentum, float adam_beta, int num_samples)
 
virtual void CountAlternators (const Network &other, double *same, double *changed) const
 
virtual void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)=0
 
virtual bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Protected Attributes

WeightMatrix weights_
 
TransposedArray source_t_
 
const TransposedArrayexternal_source_
 
NetworkIO acts_
 
bool int_mode_
 
- Protected Attributes inherited from tesseract::Network
NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
int32_t network_flags_
 
int32_t ni_
 
int32_t no_
 
int32_t num_weights_
 
STRING name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 

Additional Inherited Members

- Static Public Member Functions inherited from tesseract::Network
static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Pix *pix, ScrollView *window)
 
- Protected Member Functions inherited from tesseract::Network
double Random (double range)
 

Detailed Description

Definition at line 28 of file fullyconnected.h.

Constructor & Destructor Documentation

◆ FullyConnected()

tesseract::FullyConnected::FullyConnected ( const STRING name,
int  ni,
int  no,
NetworkType  type 
)

Definition at line 39 of file fullyconnected.cpp.

41 : Network(type, name, ni, no), external_source_(nullptr), int_mode_(false) {
42}
const TransposedArray * external_source_
const STRING & name() const
Definition: network.h:138
NetworkType type() const
Definition: network.h:112

◆ ~FullyConnected()

tesseract::FullyConnected::~FullyConnected ( )
overridedefault

Member Function Documentation

◆ Backward()

bool tesseract::FullyConnected::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
overridevirtual

Implements tesseract::Network.

Definition at line 221 of file fullyconnected.cpp.

223 {
224 if (debug) DisplayBackward(fwd_deltas);
225 back_deltas->Resize(fwd_deltas, ni_);
227 errors.init_to_size(kNumThreads, NetworkScratch::FloatVec());
228 for (int i = 0; i < kNumThreads; ++i) errors[i].Init(no_, scratch);
230 if (needs_to_backprop_) {
231 temp_backprops.init_to_size(kNumThreads, NetworkScratch::FloatVec());
232 for (int i = 0; i < kNumThreads; ++i) temp_backprops[i].Init(ni_, scratch);
233 }
234 int width = fwd_deltas.Width();
235 NetworkScratch::GradientStore errors_t;
236 errors_t.Init(no_, width, scratch);
237#ifdef _OPENMP
238#pragma omp parallel for num_threads(kNumThreads)
239 for (int t = 0; t < width; ++t) {
240 int thread_id = omp_get_thread_num();
241#else
242 for (int t = 0; t < width; ++t) {
243 int thread_id = 0;
244#endif
245 double* backprop = nullptr;
246 if (needs_to_backprop_) backprop = temp_backprops[thread_id];
247 double* curr_errors = errors[thread_id];
248 BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
249 if (backprop != nullptr) {
250 back_deltas->WriteTimeStep(t, backprop);
251 }
252 }
253 FinishBackward(*errors_t.get());
254 if (needs_to_backprop_) {
255 back_deltas->ZeroInvalidElements();
256#if DEBUG_DETAIL > 0
257 tprintf("F Backprop:%s\n", name_.string());
258 back_deltas->Print(10);
259#endif
260 return true;
261 }
262 return false; // No point going further back.
263}
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35
const int kNumThreads
void init_to_size(int size, const T &t)
const char * string() const
Definition: strngs.cpp:194
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void FinishBackward(const TransposedArray &errors_t)
bool needs_to_backprop_
Definition: network.h:295
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:299

◆ BackwardTimeStep()

void tesseract::FullyConnected::BackwardTimeStep ( const NetworkIO fwd_deltas,
int  t,
double *  curr_errors,
TransposedArray errors_t,
double *  backprop 
)

Definition at line 265 of file fullyconnected.cpp.

268 {
269 if (type_ == NT_TANH)
270 acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
271 else if (type_ == NT_LOGISTIC)
272 acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors);
273 else if (type_ == NT_POSCLIP)
274 acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors);
275 else if (type_ == NT_SYMCLIP)
276 acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors);
277 else if (type_ == NT_RELU)
278 acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors);
279 else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC ||
280 type_ == NT_LINEAR)
281 fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors.
282 else
283 ASSERT_HOST("Invalid fully-connected type!" == nullptr);
284 // Generate backprop only if needed by the lower layer.
285 if (backprop != nullptr) weights_.VectorDotMatrix(curr_errors, backprop);
286 errors_t->WriteStrided(t, curr_errors);
287}
#define ASSERT_HOST(x)
Definition: errcode.h:88
@ NT_LINEAR
Definition: network.h:67
@ NT_RELU
Definition: network.h:66
@ NT_SOFTMAX
Definition: network.h:68
@ NT_LOGISTIC
Definition: network.h:62
@ NT_SYMCLIP
Definition: network.h:64
@ NT_POSCLIP
Definition: network.h:63
@ NT_SOFTMAX_NO_CTC
Definition: network.h:69
@ NT_TANH
Definition: network.h:65
NetworkType type_
Definition: network.h:293
void FuncMultiply(const NetworkIO &v_io, int t, double *product)
Definition: networkio.h:259
void VectorDotMatrix(const double *u, double *v) const

◆ ChangeType()

void tesseract::FullyConnected::ChangeType ( NetworkType  type)
inline

Definition at line 60 of file fullyconnected.h.

60 {
61 type_ = type;
62 }

◆ ConvertToInt()

void tesseract::FullyConnected::ConvertToInt ( )
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 96 of file fullyconnected.cpp.

◆ CountAlternators()

void tesseract::FullyConnected::CountAlternators ( const Network other,
double *  same,
double *  changed 
) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 306 of file fullyconnected.cpp.

307 {
308 ASSERT_HOST(other.type() == type_);
309 const auto* fc = static_cast<const FullyConnected*>(&other);
310 weights_.CountAlternators(fc->weights_, same, changed);
311}
FullyConnected(const STRING &name, int ni, int no, NetworkType type)
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const

◆ DebugWeights()

void tesseract::FullyConnected::DebugWeights ( )
overridevirtual

Implements tesseract::Network.

Definition at line 101 of file fullyconnected.cpp.

101 {
103}
void Debug2D(const char *msg)

◆ DeSerialize()

bool tesseract::FullyConnected::DeSerialize ( TFile fp)
overridevirtual

Implements tesseract::Network.

Definition at line 113 of file fullyconnected.cpp.

113 {
114 return weights_.DeSerialize(IsTraining(), fp);
115}
bool IsTraining() const
Definition: network.h:115
bool DeSerialize(bool training, TFile *fp)

◆ FinishBackward()

void tesseract::FullyConnected::FinishBackward ( const TransposedArray errors_t)

Definition at line 289 of file fullyconnected.cpp.

289 {
290 if (external_source_ == nullptr)
291 weights_.SumOuterTransposed(errors_t, source_t_, true);
292 else
294}
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)

◆ Forward()

void tesseract::FullyConnected::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
overridevirtual

Implements tesseract::Network.

Definition at line 119 of file fullyconnected.cpp.

121 {
122 int width = input.Width();
123 if (type_ == NT_SOFTMAX)
124 output->ResizeFloat(input, no_);
125 else
126 output->Resize(input, no_);
127 SetupForward(input, input_transpose);
129 temp_lines.init_to_size(kNumThreads, NetworkScratch::FloatVec());
131 curr_input.init_to_size(kNumThreads, NetworkScratch::FloatVec());
132 for (int i = 0; i < kNumThreads; ++i) {
133 temp_lines[i].Init(no_, scratch);
134 curr_input[i].Init(ni_, scratch);
135 }
136#ifdef _OPENMP
137#pragma omp parallel for num_threads(kNumThreads)
138 for (int t = 0; t < width; ++t) {
139 // Thread-local pointer to temporary storage.
140 int thread_id = omp_get_thread_num();
141#else
142 for (int t = 0; t < width; ++t) {
143 // Thread-local pointer to temporary storage.
144 int thread_id = 0;
145#endif
146 double* temp_line = temp_lines[thread_id];
147 if (input.int_mode()) {
148 ForwardTimeStep(input.i(t), t, temp_line);
149 } else {
150 input.ReadTimeStep(t, curr_input[thread_id]);
151 ForwardTimeStep(curr_input[thread_id], t, temp_line);
152 }
153 output->WriteTimeStep(t, temp_line);
154 if (IsTraining() && type_ != NT_SOFTMAX) {
155 acts_.CopyTimeStepFrom(t, *output, t);
156 }
157 }
158 // Zero all the elements that are in the padding around images that allows
159 // multiple different-sized images to exist in a single array.
160 // acts_ is only used if this is not a softmax op.
161 if (IsTraining() && type_ != NT_SOFTMAX) {
163 }
164 output->ZeroInvalidElements();
165#if DEBUG_DETAIL > 0
166 tprintf("F Output:%s\n", name_.string());
167 output->Print(10);
168#endif
169 if (debug) DisplayForward(*output);
170}
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
void ForwardTimeStep(int t, double *output_line)
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:288
void ZeroInvalidElements()
Definition: networkio.cpp:88
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
Definition: networkio.cpp:383

◆ ForwardTimeStep() [1/3]

void tesseract::FullyConnected::ForwardTimeStep ( const double *  d_input,
int  t,
double *  output_line 
)

Definition at line 203 of file fullyconnected.cpp.

204 {
205 // input is copied to source_ line-by-line for cache coherency.
206 if (IsTraining() && external_source_ == nullptr)
207 source_t_.WriteStrided(t, d_input);
208 weights_.MatrixDotVector(d_input, output_line);
209 ForwardTimeStep(t, output_line);
210}
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:39
void MatrixDotVector(const double *u, double *v) const

◆ ForwardTimeStep() [2/3]

void tesseract::FullyConnected::ForwardTimeStep ( const int8_t *  i_input,
int  t,
double *  output_line 
)

Definition at line 212 of file fullyconnected.cpp.

213 {
214 // input is copied to source_ line-by-line for cache coherency.
215 weights_.MatrixDotVector(i_input, output_line);
216 ForwardTimeStep(t, output_line);
217}

◆ ForwardTimeStep() [3/3]

void tesseract::FullyConnected::ForwardTimeStep ( int  t,
double *  output_line 
)

Definition at line 185 of file fullyconnected.cpp.

185 {
186 if (type_ == NT_TANH) {
187 FuncInplace<GFunc>(no_, output_line);
188 } else if (type_ == NT_LOGISTIC) {
189 FuncInplace<FFunc>(no_, output_line);
190 } else if (type_ == NT_POSCLIP) {
191 FuncInplace<ClipFFunc>(no_, output_line);
192 } else if (type_ == NT_SYMCLIP) {
193 FuncInplace<ClipGFunc>(no_, output_line);
194 } else if (type_ == NT_RELU) {
195 FuncInplace<Relu>(no_, output_line);
196 } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) {
197 SoftmaxInPlace(no_, output_line);
198 } else if (type_ != NT_LINEAR) {
199 ASSERT_HOST("Invalid fully-connected type!" == nullptr);
200 }
201}
void SoftmaxInPlace(int n, T *inout)
Definition: functions.h:146

◆ InitWeights()

int tesseract::FullyConnected::InitWeights ( float  range,
TRand randomizer 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 77 of file fullyconnected.cpp.

77 {
78 Network::SetRandomizer(randomizer);
80 range, randomizer);
81 return num_weights_;
82}
@ NF_ADAM
Definition: network.h:88
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
int32_t num_weights_
Definition: network.h:299
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)

◆ OutputShape()

StaticShape tesseract::FullyConnected::OutputShape ( const StaticShape input_shape) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 46 of file fullyconnected.cpp.

46 {
47 LossType loss_type = LT_NONE;
48 if (type_ == NT_SOFTMAX)
49 loss_type = LT_CTC;
50 else if (type_ == NT_SOFTMAX_NO_CTC)
51 loss_type = LT_SOFTMAX;
52 else if (type_ == NT_LOGISTIC)
53 loss_type = LT_LOGISTIC;
54 StaticShape result(input_shape);
55 result.set_depth(no_);
56 result.set_loss_type(loss_type);
57 return result;
58}

◆ RemapOutputs()

int tesseract::FullyConnected::RemapOutputs ( int  old_no,
const std::vector< int > &  code_map 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 87 of file fullyconnected.cpp.

87 {
88 if (type_ == NT_SOFTMAX && no_ == old_no) {
90 no_ = code_map.size();
91 }
92 return num_weights_;
93}
int RemapOutputs(const std::vector< int > &code_map)

◆ Serialize()

bool tesseract::FullyConnected::Serialize ( TFile fp) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 106 of file fullyconnected.cpp.

106 {
107 if (!Network::Serialize(fp)) return false;
108 if (!weights_.Serialize(IsTraining(), fp)) return false;
109 return true;
110}
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
bool Serialize(bool training, TFile *fp) const

◆ SetEnableTraining()

void tesseract::FullyConnected::SetEnableTraining ( TrainingState  state)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 61 of file fullyconnected.cpp.

61 {
62 if (state == TS_RE_ENABLE) {
63 // Enable only from temp disabled.
65 } else if (state == TS_TEMP_DISABLE) {
66 // Temp disable only from enabled.
67 if (training_ == TS_ENABLED) training_ = state;
68 } else {
69 if (state == TS_ENABLED && training_ != TS_ENABLED)
71 training_ = state;
72 }
73}
@ TS_TEMP_DISABLE
Definition: network.h:97
@ TS_ENABLED
Definition: network.h:95
@ TS_RE_ENABLE
Definition: network.h:99
TrainingState training_
Definition: network.h:294

◆ SetupForward()

void tesseract::FullyConnected::SetupForward ( const NetworkIO input,
const TransposedArray input_transpose 
)

Definition at line 173 of file fullyconnected.cpp.

174 {
175 // Softmax output is always float, so save the input type.
176 int_mode_ = input.int_mode();
177 if (IsTraining()) {
178 acts_.Resize(input, no_);
179 // Source_ is a transposed copy of input. It isn't needed if provided.
180 external_source_ = input_transpose;
181 if (external_source_ == nullptr) source_t_.ResizeNoInit(ni_, input.Width());
182 }
183}
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:94
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45

◆ spec()

STRING tesseract::FullyConnected::spec ( ) const
inlineoverridevirtual

Reimplemented from tesseract::Network.

Definition at line 37 of file fullyconnected.h.

37 {
39 if (type_ == NT_TANH)
40 spec.add_str_int("Ft", no_);
41 else if (type_ == NT_LOGISTIC)
42 spec.add_str_int("Fs", no_);
43 else if (type_ == NT_RELU)
44 spec.add_str_int("Fr", no_);
45 else if (type_ == NT_LINEAR)
46 spec.add_str_int("Fl", no_);
47 else if (type_ == NT_POSCLIP)
48 spec.add_str_int("Fp", no_);
49 else if (type_ == NT_SYMCLIP)
50 spec.add_str_int("Fs", no_);
51 else if (type_ == NT_SOFTMAX)
52 spec.add_str_int("Fc", no_);
53 else
54 spec.add_str_int("Fm", no_);
55 return spec;
56 }
Definition: strngs.h:45
void add_str_int(const char *str, int number)
Definition: strngs.cpp:377
STRING spec() const override

◆ Update()

void tesseract::FullyConnected::Update ( float  learning_rate,
float  momentum,
float  adam_beta,
int  num_samples 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 298 of file fullyconnected.cpp.

299 {
300 weights_.Update(learning_rate, momentum, adam_beta, num_samples);
301}
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)

Member Data Documentation

◆ acts_

NetworkIO tesseract::FullyConnected::acts_
protected

Definition at line 126 of file fullyconnected.h.

◆ external_source_

const TransposedArray* tesseract::FullyConnected::external_source_
protected

Definition at line 124 of file fullyconnected.h.

◆ int_mode_

bool tesseract::FullyConnected::int_mode_
protected

Definition at line 129 of file fullyconnected.h.

◆ source_t_

TransposedArray tesseract::FullyConnected::source_t_
protected

Definition at line 121 of file fullyconnected.h.

◆ weights_

WeightMatrix tesseract::FullyConnected::weights_
protected

Definition at line 119 of file fullyconnected.h.


The documentation for this class was generated from the following files: