tesseract 4.1.1
Loading...
Searching...
No Matches
lstmtrainer.h
Go to the documentation of this file.
1
2// File: lstmtrainer.h
3// Description: Top-level line trainer class for LSTM-based networks.
4// Author: Ray Smith
5// Created: Fri May 03 09:07:06 PST 2013
6//
7// (C) Copyright 2013, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18
19#ifndef TESSERACT_LSTM_LSTMTRAINER_H_
20#define TESSERACT_LSTM_LSTMTRAINER_H_
21
22#include "imagedata.h"
23#include "lstmrecognizer.h"
24#include "rect.h"
25#include "tesscallback.h"
26
27namespace tesseract {
28
29class LSTM;
30class LSTMTrainer;
31class Parallel;
32class Reversed;
33class Softmax;
34class Series;
35
36// Enum for the types of errors that are counted.
38 ET_RMS, // RMS activation error.
39 ET_DELTA, // Number of big errors in deltas.
40 ET_WORD_RECERR, // Output text string word recall error.
41 ET_CHAR_ERROR, // Output text string total char error.
42 ET_SKIP_RATIO, // Fraction of samples skipped.
43 ET_COUNT // For array sizing.
44};
45
46// Enum for the trainability_ flags.
48 TRAINABLE, // Non-zero delta error.
49 PERFECT, // Zero delta error.
50 UNENCODABLE, // Not trainable due to coding/alignment trouble.
51 HI_PRECISION_ERR, // Hi confidence disagreement.
52 NOT_BOXED, // Early in training and has no character boxes.
53};
54
55// Enum to define the amount of data to get serialized.
57 LIGHT, // Minimal data for remote training.
58 NO_BEST_TRAINER, // Save an empty vector in place of best_trainer_.
59 FULL, // All data including best_trainer_.
60};
61
62// Enum to indicate how the sub_trainer_ training went.
64 STR_NONE, // Did nothing as not good enough.
65 STR_UPDATED, // Subtrainer was updated, but didn't replace *this.
66 STR_REPLACED // Subtrainer replaced *this.
67};
68
69class LSTMTrainer;
70// Function to restore the trainer state from a given checkpoint.
71// Returns false on failure.
74// Function to save a checkpoint of the current trainer state.
75// Returns false on failure. SerializeAmount determines the amount of the
76// trainer to serialize, typically used for saving the best state.
79// Function to compute and record error rates on some external test set(s).
80// Args are: iteration, mean errors, model, training stage.
81// Returns a STRING containing logging information about the tests.
82typedef TessResultCallback4<STRING, int, const double*, const TessdataManager&,
84
85// Trainer class for LSTM networks. Most of the effort is in creating the
86// ideal target outputs from the transcription. A box file is used if it is
87// available, otherwise estimates of the char widths from the unicharset are
88// used to guide a DP search for the best fit to the transcription.
90 public:
92 // Callbacks may be null, in which case defaults are used.
93 LSTMTrainer(FileReader file_reader, FileWriter file_writer,
94 CheckPointReader checkpoint_reader,
95 CheckPointWriter checkpoint_writer,
96 const char* model_base, const char* checkpoint_name,
97 int debug_interval, int64_t max_memory);
98 virtual ~LSTMTrainer();
99
100 // Tries to deserialize a trainer from the given file and silently returns
101 // false in case of failure. If old_traineddata is not null, then it is
102 // assumed that the character set is to be re-mapped from old_traineddata to
103 // the new, with consequent change in weight matrices etc.
104 bool TryLoadingCheckpoint(const char* filename, const char* old_traineddata);
105
106 // Initializes the character set encode/decode mechanism directly from a
107 // previously setup traineddata containing dawgs, UNICHARSET and
108 // UnicharCompress. Note: Call before InitNetwork!
109 void InitCharSet(const std::string& traineddata_path) {
110 ASSERT_HOST(mgr_.Init(traineddata_path.c_str()));
111 InitCharSet();
112 }
113 void InitCharSet(const TessdataManager& mgr) {
114 mgr_ = mgr;
115 InitCharSet();
116 }
117
118 // Initializes the trainer with a network_spec in the network description
119 // net_flags control network behavior according to the NetworkFlags enum.
120 // There isn't really much difference between them - only where the effects
121 // are implemented.
122 // For other args see NetworkBuilder::InitNetwork.
123 // Note: Be sure to call InitCharSet before InitNetwork!
124 bool InitNetwork(const STRING& network_spec, int append_index, int net_flags,
125 float weight_range, float learning_rate, float momentum,
126 float adam_beta);
127 // Initializes a trainer from a serialized TFNetworkModel proto.
128 // Returns the global step of TensorFlow graph or 0 if failed.
129 // Building a compatible TF graph: See tfnetwork.proto.
130 int InitTensorFlowNetwork(const std::string& tf_proto);
131 // Resets all the iteration counters for fine tuning or training a head,
132 // where we want the error reporting to reset.
133 void InitIterations();
134
135 // Accessors.
136 double ActivationError() const {
137 return error_rates_[ET_DELTA];
138 }
139 double CharError() const { return error_rates_[ET_CHAR_ERROR]; }
140 const double* error_rates() const {
141 return error_rates_;
142 }
143 double best_error_rate() const {
144 return best_error_rate_;
145 }
146 int best_iteration() const {
147 return best_iteration_;
148 }
150 int32_t improvement_steps() const { return improvement_steps_; }
151 void set_perfect_delay(int delay) { perfect_delay_ = delay; }
153 // Returns the error that was just calculated by PrepareForBackward.
154 double NewSingleError(ErrorTypes type) const {
156 }
157 // Returns the error that was just calculated by TrainOnLine. Since
158 // TrainOnLine rolls the error buffers, this is one further back than
159 // NewSingleError.
160 double LastSingleError(ErrorTypes type) const {
161 return error_buffers_[type]
164 }
166 return training_data_;
167 }
169
170 // If the training sample is usable, grid searches for the optimal
171 // dict_ratio/cert_offset, and returns the results in a string of space-
172 // separated triplets of ratio,offset=worderr.
174 const ImageData* trainingdata, int iteration, double min_dict_ratio,
175 double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
176 double cert_offset_step, double max_cert_offset, STRING* results);
177
178 // Provides output on the distribution of weight values.
179 void DebugNetwork();
180
181 // Loads a set of lstmf files that were created using the lstm.train config to
182 // tesseract into memory ready for training. Returns false if nothing was
183 // loaded.
184 bool LoadAllTrainingData(const GenericVector<STRING>& filenames,
185 CachingStrategy cache_strategy,
186 bool randomly_rotate);
187
188 // Keeps track of best and locally worst error rate, using internally computed
189 // values. See MaintainCheckpointsSpecific for more detail.
190 bool MaintainCheckpoints(TestCallback tester, STRING* log_msg);
191 // Keeps track of best and locally worst error_rate (whatever it is) and
192 // launches tests using rec_model, when a new min or max is reached.
193 // Writes checkpoints using train_model at appropriate times and builds and
194 // returns a log message to indicate progress. Returns false if nothing
195 // interesting happened.
196 bool MaintainCheckpointsSpecific(int iteration,
197 const GenericVector<char>* train_model,
198 const GenericVector<char>* rec_model,
199 TestCallback tester, STRING* log_msg);
200 // Builds a string containing a progress message with current error rates.
201 void PrepareLogMsg(STRING* log_msg) const;
202 // Appends <intro_str> iteration learning_iteration()/training_iteration()/
203 // sample_iteration() to the log_msg.
204 void LogIterations(const char* intro_str, STRING* log_msg) const;
205
206 // TODO(rays) Add curriculum learning.
207 // Returns true and increments the training_stage_ if the error rate has just
208 // passed through the given threshold for the first time.
209 bool TransitionTrainingStage(float error_threshold);
210 // Returns the current training stage.
211 int CurrentTrainingStage() const { return training_stage_; }
212
213 // Writes to the given file. Returns false in case of error.
214 bool Serialize(SerializeAmount serialize_amount,
215 const TessdataManager* mgr, TFile* fp) const;
216 // Reads from the given file. Returns false in case of error.
217 bool DeSerialize(const TessdataManager* mgr, TFile* fp);
218
219 // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
220 // learning rates (by scaling reduction, or layer specific, according to
221 // NF_LAYER_SPECIFIC_LR).
222 void StartSubtrainer(STRING* log_msg);
223 // While the sub_trainer_ is behind the current training iteration and its
224 // training error is at least kSubTrainerMarginFraction better than the
225 // current training error, trains the sub_trainer_, and returns STR_UPDATED if
226 // it did anything. If it catches up, and has a better error rate than the
227 // current best, as well as a margin over the current error rate, then the
228 // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
229 // returned. STR_NONE is returned if the subtrainer wasn't good enough to
230 // receive any training iterations.
232 // Reduces network learning rates, either for everything, or for layers
233 // independently, according to NF_LAYER_SPECIFIC_LR.
234 void ReduceLearningRates(LSTMTrainer* samples_trainer, STRING* log_msg);
235 // Considers reducing the learning rate independently for each layer down by
236 // factor(<1), or leaving it the same, by double-training the given number of
237 // samples and minimizing the amount of changing of sign of weight updates.
238 // Even if it looks like all weights should remain the same, an adjustment
239 // will be made to guarantee a different result when reverting to an old best.
240 // Returns the number of layer learning rates that were reduced.
241 int ReduceLayerLearningRates(double factor, int num_samples,
242 LSTMTrainer* samples_trainer);
243
244 // Converts the string to integer class labels, with appropriate null_char_s
245 // in between if not in SimpleTextOutput mode. Returns false on failure.
246 bool EncodeString(const STRING& str, GenericVector<int>* labels) const {
247 return EncodeString(str, GetUnicharset(), IsRecoding() ? &recoder_ : nullptr,
248 SimpleTextOutput(), null_char_, labels);
249 }
250 // Static version operates on supplied unicharset, encoder, simple_text.
251 static bool EncodeString(const STRING& str, const UNICHARSET& unicharset,
252 const UnicharCompress* recoder, bool simple_text,
253 int null_char, GenericVector<int>* labels);
254
255 // Performs forward-backward on the given trainingdata.
256 // Returns the sample that was used or nullptr if the next sample was deemed
257 // unusable. samples_trainer could be this or an alternative trainer that
258 // holds the training samples.
259 const ImageData* TrainOnLine(LSTMTrainer* samples_trainer, bool batch) {
260 int sample_index = sample_iteration();
261 const ImageData* image =
262 samples_trainer->training_data_.GetPageBySerial(sample_index);
263 if (image != nullptr) {
264 Trainability trainable = TrainOnLine(image, batch);
265 if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
266 return nullptr; // Sample was unusable.
267 }
268 } else {
270 }
271 return image;
272 }
273 Trainability TrainOnLine(const ImageData* trainingdata, bool batch);
274
275 // Prepares the ground truth, runs forward, and prepares the targets.
276 // Returns a Trainability enum to indicate the suitability of the sample.
277 Trainability PrepareForBackward(const ImageData* trainingdata,
278 NetworkIO* fwd_outputs, NetworkIO* targets);
279
280 // Writes the trainer to memory, so that the current training state can be
281 // restored. *this must always be the master trainer that retains the only
282 // copy of the training data and language model. trainer is the model that is
283 // actually serialized.
284 bool SaveTrainingDump(SerializeAmount serialize_amount,
285 const LSTMTrainer* trainer,
286 GenericVector<char>* data) const;
287
288 // Reads previously saved trainer from memory. *this must always be the
289 // master trainer that retains the only copy of the training data and
290 // language model. trainer is the model that is restored.
292 LSTMTrainer* trainer) const {
293 if (data.empty()) return false;
294 return ReadSizedTrainingDump(&data[0], data.size(), trainer);
295 }
296 bool ReadSizedTrainingDump(const char* data, int size,
297 LSTMTrainer* trainer) const {
298 return trainer->ReadLocalTrainingDump(&mgr_, data, size);
299 }
300 // Restores the model to *this.
301 bool ReadLocalTrainingDump(const TessdataManager* mgr, const char* data,
302 int size);
303
304 // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
306
307 // Writes the full recognition traineddata to the given filename.
308 bool SaveTraineddata(const STRING& filename);
309
310 // Writes the recognizer to memory, so that it can be used for testing later.
312
313 // Returns a suitable filename for a training dump, based on the model_base_,
314 // the iteration and the error rates.
315 STRING DumpFilename() const;
316
317 // Fills the whole error buffer of the given type with the given value.
318 void FillErrorBuffer(double new_error, ErrorTypes type);
319 // Helper generates a map from each current recoder_ code (ie softmax index)
320 // to the corresponding old_recoder code, or -1 if there isn't one.
321 std::vector<int> MapRecoder(const UNICHARSET& old_chset,
322 const UnicharCompress& old_recoder) const;
323
324 protected:
325 // Private version of InitCharSet above finishes the job after initializing
326 // the mgr_ data member.
327 void InitCharSet();
328 // Helper computes and sets the null_char_.
329 void SetNullChar();
330
331 // Factored sub-constructor sets up reasonable default values.
332 void EmptyConstructor();
333
334 // Outputs the string and periodically displays the given network inputs
335 // as an image in the given window, and the corresponding labels at the
336 // corresponding x_starts.
337 // Returns false if the truth string is empty.
338 bool DebugLSTMTraining(const NetworkIO& inputs,
339 const ImageData& trainingdata,
340 const NetworkIO& fwd_outputs,
341 const GenericVector<int>& truth_labels,
342 const NetworkIO& outputs);
343 // Displays the network targets as line a line graph.
344 void DisplayTargets(const NetworkIO& targets, const char* window_name,
345 ScrollView** window);
346
347 // Builds a no-compromises target where the first positions should be the
348 // truth labels and the rest is padded with the null_char_.
349 bool ComputeTextTargets(const NetworkIO& outputs,
350 const GenericVector<int>& truth_labels,
351 NetworkIO* targets);
352
353 // Builds a target using standard CTC. truth_labels should be pre-padded with
354 // nulls wherever desired. They don't have to be between all labels.
355 // outputs is input-output, as it gets clipped to minimum probability.
356 bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
357 NetworkIO* outputs, NetworkIO* targets);
358
359 // Computes network errors, and stores the results in the rolling buffers,
360 // along with the supplied text_error.
361 // Returns the delta error of the current sample (not running average.)
362 double ComputeErrorRates(const NetworkIO& deltas, double char_error,
363 double word_error);
364
365 // Computes the network activation RMS error rate.
366 double ComputeRMSError(const NetworkIO& deltas);
367
368 // Computes network activation winner error rate. (Number of values that are
369 // in error by >= 0.5 divided by number of time-steps.) More closely related
370 // to final character error than RMS, but still directly calculable from
371 // just the deltas. Because of the binary nature of the targets, zero winner
372 // error is a sufficient but not necessary condition for zero char error.
373 double ComputeWinnerError(const NetworkIO& deltas);
374
375 // Computes a very simple bag of chars char error rate.
376 double ComputeCharError(const GenericVector<int>& truth_str,
377 const GenericVector<int>& ocr_str);
378 // Computes a very simple bag of words word recall error rate.
379 // NOTE that this is destructive on both input strings.
380 double ComputeWordError(STRING* truth_str, STRING* ocr_str);
381
382 // Updates the error buffer and corresponding mean of the given type with
383 // the new_error.
384 void UpdateErrorBuffer(double new_error, ErrorTypes type);
385
386 // Rolls error buffers and reports the current means.
387 void RollErrorBuffers();
388
389 // Given that error_rate is either a new min or max, updates the best/worst
390 // error rates, and record of progress.
391 STRING UpdateErrorGraph(int iteration, double error_rate,
392 const GenericVector<char>& model_data,
393 TestCallback tester);
394
395 protected:
396 // Alignment display window.
398 // CTC target display window.
400 // CTC output display window.
402 // Reconstructed image window.
404 // How often to display a debug image.
406 // Iteration at which the last checkpoint was dumped.
408 // Basename of files to save best models to.
410 // Checkpoint filename.
412 // Training data.
415 // Name to use when saving best_trainer_.
417 // Number of available training stages.
419 // Checkpointing callbacks.
422 // TODO(rays) These are pointers, and must be deleted. Switch to unique_ptr
423 // when we can commit to c++11.
426
427 // ===Serialized data to ensure that a restart produces the same results.===
428 // These members are only serialized when serialize_amount != LIGHT.
429 // Best error rate so far.
431 // Snapshot of all error rates at best_iteration_.
433 // Iteration of best_error_rate_.
435 // Worst error rate since best_error_rate_.
437 // Snapshot of all error rates at worst_iteration_.
439 // Iteration of worst_error_rate_.
441 // Iteration at which the process will be thought stalled.
443 // Saved recognition models for computing test error for graph points.
446 // Saved trainer for reverting back to last known best.
448 // A subsidiary trainer running with a different learning rate until either
449 // *this or sub_trainer_ hits a new best.
451 // Error rate at which last best model was dumped.
453 // Current stage of training.
455 // History of best error rate against iteration. Used for computing the
456 // number of steps to each 2% improvement.
459 // Number of iterations since the best_error_rate_ was 2% more than it is now.
461 // Number of iterations that yielded a non-zero delta error and thus provided
462 // significant learning. learning_iteration_ <= training_iteration_.
463 // learning_iteration_ is used to measure rate of learning progress.
465 // Saved value of sample_iteration_ before looking for the the next sample.
467 // How often to include a PERFECT training sample in backprop.
468 // A PERFECT training sample is used if the current
469 // training_iteration_ > last_perfect_training_iteration_ + perfect_delay_,
470 // so with perfect_delay_ == 0, all samples are used, and with
471 // perfect_delay_ == 4, at most 1 in 5 samples will be perfect.
473 // Value of training_iteration_ at which the last PERFECT training sample
474 // was used in back prop.
476 // Rolling buffers storing recent training errors are indexed by
477 // training_iteration % kRollingBufferSize_.
478 static const int kRollingBufferSize_ = 1000;
480 // Rounded mean percent trailing training errors in the buffers.
481 double error_rates_[ET_COUNT]; // RMS training error.
482 // Traineddata file with optional dawgs + UNICHARSET and recoder.
484};
485
486} // namespace tesseract.
487
488#endif // TESSERACT_LSTM_LSTMTRAINER_H_
#define ASSERT_HOST(x)
Definition: errcode.h:88
TessResultCallback4< STRING, int, const double *, const TessdataManager &, int > * TestCallback
Definition: lstmtrainer.h:83
@ ET_WORD_RECERR
Definition: lstmtrainer.h:40
@ ET_SKIP_RATIO
Definition: lstmtrainer.h:42
@ ET_CHAR_ERROR
Definition: lstmtrainer.h:41
@ HI_PRECISION_ERR
Definition: lstmtrainer.h:51
TessResultCallback2< bool, const GenericVector< char > &, LSTMTrainer * > * CheckPointReader
Definition: lstmtrainer.h:73
bool(*)(const STRING &, GenericVector< char > *) FileReader
Definition: serialis.h:49
@ STR_REPLACED
Definition: lstmtrainer.h:66
TessResultCallback3< bool, SerializeAmount, const LSTMTrainer *, GenericVector< char > * > * CheckPointWriter
Definition: lstmtrainer.h:78
bool(*)(const GenericVector< char > &, const STRING &) FileWriter
Definition: serialis.h:52
CachingStrategy
Definition: imagedata.h:42
@ NO_BEST_TRAINER
Definition: lstmtrainer.h:58
bool empty() const
Definition: genericvector.h:91
int size() const
Definition: genericvector.h:72
const ImageData * GetPageBySerial(int serial)
Definition: imagedata.h:344
Definition: strngs.h:45
bool Init(const char *data_file_name)
double learning_rate() const
const UNICHARSET & GetUnicharset() const
bool SaveTraineddata(const STRING &filename)
bool TransitionTrainingStage(float error_threshold)
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:479
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
void PrepareLogMsg(STRING *log_msg) const
ScrollView * target_win_
Definition: lstmtrainer.h:399
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481
const double * error_rates() const
Definition: lstmtrainer.h:140
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:444
int InitTensorFlowNetwork(const std::string &tf_proto)
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:154
double CharError() const
Definition: lstmtrainer.h:139
void StartSubtrainer(STRING *log_msg)
bool MaintainCheckpointsSpecific(int iteration, const GenericVector< char > *train_model, const GenericVector< char > *rec_model, TestCallback tester, STRING *log_msg)
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:438
double best_error_rate() const
Definition: lstmtrainer.h:143
double LastSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:160
DocumentCache * mutable_training_data()
Definition: lstmtrainer.h:168
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:291
bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:424
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:452
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
ScrollView * recon_win_
Definition: lstmtrainer.h:403
void FillErrorBuffer(double new_error, ErrorTypes type)
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
int learning_iteration() const
Definition: lstmtrainer.h:149
void InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:109
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
double ComputeRMSError(const NetworkIO &deltas)
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:457
void set_perfect_delay(int delay)
Definition: lstmtrainer.h:151
const GenericVector< char > & best_trainer() const
Definition: lstmtrainer.h:152
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:450
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
double ComputeWinnerError(const NetworkIO &deltas)
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:447
void SaveRecognitionDump(GenericVector< char > *data) const
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
void UpdateErrorBuffer(double new_error, ErrorTypes type)
ScrollView * ctc_win_
Definition: lstmtrainer.h:401
int CurrentTrainingStage() const
Definition: lstmtrainer.h:211
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:458
double ActivationError() const
Definition: lstmtrainer.h:136
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:425
void InitCharSet(const TessdataManager &mgr)
Definition: lstmtrainer.h:113
DocumentCache training_data_
Definition: lstmtrainer.h:414
STRING DumpFilename() const
static const int kRollingBufferSize_
Definition: lstmtrainer.h:478
const DocumentCache & training_data() const
Definition: lstmtrainer.h:165
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:445
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
int32_t improvement_steps() const
Definition: lstmtrainer.h:150
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
void LogIterations(const char *intro_str, STRING *log_msg) const
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:246
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:259
int best_iteration() const
Definition: lstmtrainer.h:146
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
TessdataManager mgr_
Definition: lstmtrainer.h:483
ScrollView * align_win_
Definition: lstmtrainer.h:397
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:296
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)