tesseract 4.1.1
Loading...
Searching...
No Matches
lstmtrainer.cpp
Go to the documentation of this file.
1
2// File: lstmtrainer.cpp
3// Description: Top-level line trainer class for LSTM-based networks.
4// Author: Ray Smith
5//
6// (C) Copyright 2013, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18#define _USE_MATH_DEFINES // needed to get definition of M_SQRT1_2
19
20// Include automatically generated configuration file if running autoconf.
21#ifdef HAVE_CONFIG_H
22#include "config_auto.h"
23#endif
24
25#include "lstmtrainer.h"
26#include <string>
27
28#include "allheaders.h"
29#include "boxread.h"
30#include "ctc.h"
31#include "imagedata.h"
32#include "input.h"
33#include "networkbuilder.h"
34#include "ratngs.h"
35#include "recodebeam.h"
36#ifdef INCLUDE_TENSORFLOW
37#include "tfnetwork.h"
38#endif
39#include "tprintf.h"
40
41#include "callcpp.h"
42
43namespace tesseract {
44
45// Min actual error rate increase to constitute divergence.
46const double kMinDivergenceRate = 50.0;
47// Min iterations since last best before acting on a stall.
48const int kMinStallIterations = 10000;
49// Fraction of current char error rate that sub_trainer_ has to be ahead
50// before we declare the sub_trainer_ a success and switch to it.
51const double kSubTrainerMarginFraction = 3.0 / 128;
52// Factor to reduce learning rate on divergence.
53const double kLearningRateDecay = M_SQRT1_2;
54// LR adjustment iterations.
56// How often to add data to the error_graph_.
57const int kErrorGraphInterval = 1000;
58// Number of training images to train between calls to MaintainCheckpoints.
59const int kNumPagesPerBatch = 100;
60// Min percent error rate to consider start-up phase over.
61const int kMinStartedErrorRate = 75;
62// Error rate at which to transition to stage 1.
63const double kStageTransitionThreshold = 10.0;
64// Confidence beyond which the truth is more likely wrong than the recognizer.
65const double kHighConfidence = 0.9375; // 15/16.
66// Fraction of weight sign-changing total to constitute a definite improvement.
67const double kImprovementFraction = 15.0 / 16.0;
68// Fraction of last written best to make it worth writing another.
69const double kBestCheckpointFraction = 31.0 / 32.0;
70// Scale factor for display of target activations of CTC.
71const int kTargetXScale = 5;
72const int kTargetYScale = 100;
73
75 : randomly_rotate_(false),
76 training_data_(0),
77 file_reader_(LoadDataFromFile),
78 file_writer_(SaveDataToFile),
79 checkpoint_reader_(
80 NewPermanentTessCallback(this, &LSTMTrainer::ReadTrainingDump)),
81 checkpoint_writer_(
82 NewPermanentTessCallback(this, &LSTMTrainer::SaveTrainingDump)),
83 sub_trainer_(nullptr) {
86}
87
89 CheckPointReader checkpoint_reader,
90 CheckPointWriter checkpoint_writer,
91 const char* model_base, const char* checkpoint_name,
92 int debug_interval, int64_t max_memory)
93 : randomly_rotate_(false),
94 training_data_(max_memory),
95 file_reader_(file_reader),
96 file_writer_(file_writer),
97 checkpoint_reader_(checkpoint_reader),
98 checkpoint_writer_(checkpoint_writer),
99 sub_trainer_(nullptr),
100 mgr_(file_reader) {
103 if (file_writer_ == nullptr) file_writer_ = SaveDataToFile;
104 if (checkpoint_reader_ == nullptr) {
107 }
108 if (checkpoint_writer_ == nullptr) {
111 }
112 debug_interval_ = debug_interval;
113 model_base_ = model_base;
114 checkpoint_name_ = checkpoint_name;
115}
116
118 delete align_win_;
119 delete target_win_;
120 delete ctc_win_;
121 delete recon_win_;
122 delete checkpoint_reader_;
123 delete checkpoint_writer_;
124 delete sub_trainer_;
125}
126
127// Tries to deserialize a trainer from the given file and silently returns
128// false in case of failure.
129bool LSTMTrainer::TryLoadingCheckpoint(const char* filename,
130 const char* old_traineddata) {
132 if (!(*file_reader_)(filename, &data)) return false;
133 tprintf("Loaded file %s, unpacking...\n", filename);
134 if (!checkpoint_reader_->Run(data, this)) return false;
136 if (((old_traineddata == nullptr || *old_traineddata == '\0') &&
138 filename == old_traineddata) {
139 return true; // Normal checkpoint load complete.
140 }
141 tprintf("Code range changed from %d to %d!\n", network_->NumOutputs(),
143 if (old_traineddata == nullptr || *old_traineddata == '\0') {
144 tprintf("Must supply the old traineddata for code conversion!\n");
145 return false;
146 }
147 TessdataManager old_mgr;
148 ASSERT_HOST(old_mgr.Init(old_traineddata));
149 TFile fp;
150 if (!old_mgr.GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) return false;
151 UNICHARSET old_chset;
152 if (!old_chset.load_from_file(&fp, false)) return false;
153 if (!old_mgr.GetComponent(TESSDATA_LSTM_RECODER, &fp)) return false;
154 UnicharCompress old_recoder;
155 if (!old_recoder.DeSerialize(&fp)) return false;
156 std::vector<int> code_map = MapRecoder(old_chset, old_recoder);
157 // Set the null_char_ to the new value.
158 int old_null_char = null_char_;
159 SetNullChar();
160 // Map the softmax(s) in the network.
161 network_->RemapOutputs(old_recoder.code_range(), code_map);
162 tprintf("Previous null char=%d mapped to %d\n", old_null_char, null_char_);
163 return true;
164}
165
166// Initializes the trainer with a network_spec in the network description
167// net_flags control network behavior according to the NetworkFlags enum.
168// There isn't really much difference between them - only where the effects
169// are implemented.
170// For other args see NetworkBuilder::InitNetwork.
171// Note: Be sure to call InitCharSet before InitNetwork!
172bool LSTMTrainer::InitNetwork(const STRING& network_spec, int append_index,
173 int net_flags, float weight_range,
174 float learning_rate, float momentum,
175 float adam_beta) {
176 mgr_.SetVersionString(mgr_.VersionString() + ":" + network_spec.string());
177 adam_beta_ = adam_beta;
179 momentum_ = momentum;
180 SetNullChar();
182 append_index, net_flags, weight_range,
183 &randomizer_, &network_)) {
184 return false;
185 }
186 network_str_ += network_spec;
187 tprintf("Built network:%s from request %s\n",
188 network_->spec().string(), network_spec.string());
189 tprintf(
190 "Training parameters:\n Debug interval = %d,"
191 " weights = %g, learning rate = %g, momentum=%g\n",
193 tprintf("null char=%d\n", null_char_);
194 return true;
195}
196
197// Initializes a trainer from a serialized TFNetworkModel proto.
198// Returns the global step of TensorFlow graph or 0 if failed.
199#ifdef INCLUDE_TENSORFLOW
200int LSTMTrainer::InitTensorFlowNetwork(const std::string& tf_proto) {
201 delete network_;
202 TFNetwork* tf_net = new TFNetwork("TensorFlow");
203 training_iteration_ = tf_net->InitFromProtoStr(tf_proto);
204 if (training_iteration_ == 0) {
205 tprintf("InitFromProtoStr failed!!\n");
206 return 0;
207 }
208 network_ = tf_net;
209 ASSERT_HOST(recoder_.code_range() == tf_net->num_classes());
210 return training_iteration_;
211}
212#endif
213
214// Resets all the iteration counters for fine tuning or traininng a head,
215// where we want the error reporting to reset.
221 best_error_rate_ = 100.0;
222 best_iteration_ = 0;
223 worst_error_rate_ = 0.0;
227 perfect_delay_ = 0;
229 for (int i = 0; i < ET_COUNT; ++i) {
230 best_error_rates_[i] = 100.0;
231 worst_error_rates_[i] = 0.0;
233 error_rates_[i] = 100.0;
234 }
236}
237
238// If the training sample is usable, grid searches for the optimal
239// dict_ratio/cert_offset, and returns the results in a string of space-
240// separated triplets of ratio,offset=worderr.
242 const ImageData* trainingdata, int iteration, double min_dict_ratio,
243 double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
244 double cert_offset_step, double max_cert_offset, STRING* results) {
245 sample_iteration_ = iteration;
246 NetworkIO fwd_outputs, targets;
247 Trainability result =
248 PrepareForBackward(trainingdata, &fwd_outputs, &targets);
249 if (result == UNENCODABLE || result == HI_PRECISION_ERR || dict_ == nullptr)
250 return result;
251
252 // Encode/decode the truth to get the normalization.
253 GenericVector<int> truth_labels, ocr_labels, xcoords;
254 ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels));
255 // NO-dict error.
256 RecodeBeamSearch base_search(recoder_, null_char_, SimpleTextOutput(), nullptr);
257 base_search.Decode(fwd_outputs, 1.0, 0.0, RecodeBeamSearch::kMinCertainty,
258 nullptr);
259 base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
260 STRING truth_text = DecodeLabels(truth_labels);
261 STRING ocr_text = DecodeLabels(ocr_labels);
262 double baseline_error = ComputeWordError(&truth_text, &ocr_text);
263 results->add_str_double("0,0=", baseline_error);
264
266 for (double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
267 for (double c = min_cert_offset; c < max_cert_offset;
268 c += cert_offset_step) {
269 search.Decode(fwd_outputs, r, c, RecodeBeamSearch::kMinCertainty, nullptr);
270 search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
271 truth_text = DecodeLabels(truth_labels);
272 ocr_text = DecodeLabels(ocr_labels);
273 // This is destructive on both strings.
274 double word_error = ComputeWordError(&truth_text, &ocr_text);
275 if ((r == min_dict_ratio && c == min_cert_offset) ||
276 !std::isfinite(word_error)) {
277 STRING t = DecodeLabels(truth_labels);
278 STRING o = DecodeLabels(ocr_labels);
279 tprintf("r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
280 t.string(), o.string(), word_error, truth_labels[0]);
281 }
282 results->add_str_double(" ", r);
283 results->add_str_double(",", c);
284 results->add_str_double("=", word_error);
285 }
286 }
287 return result;
288}
289
290// Provides output on the distribution of weight values.
293}
294
295// Loads a set of lstmf files that were created using the lstm.train config to
296// tesseract into memory ready for training. Returns false if nothing was
297// loaded.
299 CachingStrategy cache_strategy,
300 bool randomly_rotate) {
301 randomly_rotate_ = randomly_rotate;
303 return training_data_.LoadDocuments(filenames, cache_strategy, file_reader_);
304}
305
306// Keeps track of best and locally worst char error_rate and launches tests
307// using tester, when a new min or max is reached.
308// Writes checkpoints at appropriate times and builds and returns a log message
309// to indicate progress. Returns false if nothing interesting happened.
311 PrepareLogMsg(log_msg);
312 double error_rate = CharError();
313 int iteration = learning_iteration();
314 if (iteration >= stall_iteration_ &&
315 error_rate > best_error_rate_ * (1.0 + kSubTrainerMarginFraction) &&
317 // It hasn't got any better in a long while, and is a margin worse than the
318 // best, so go back to the best model and try a different learning rate.
319 StartSubtrainer(log_msg);
320 }
321 SubTrainerResult sub_trainer_result = STR_NONE;
322 if (sub_trainer_ != nullptr) {
323 sub_trainer_result = UpdateSubtrainer(log_msg);
324 if (sub_trainer_result == STR_REPLACED) {
325 // Reset the inputs, as we have overwritten *this.
326 error_rate = CharError();
327 iteration = learning_iteration();
328 PrepareLogMsg(log_msg);
329 }
330 }
331 bool result = true; // Something interesting happened.
332 GenericVector<char> rec_model_data;
333 if (error_rate < best_error_rate_) {
334 SaveRecognitionDump(&rec_model_data);
335 log_msg->add_str_double(" New best char error = ", error_rate);
336 *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
337 // If sub_trainer_ is not nullptr, either *this beat it to a new best, or it
338 // just overwrote *this. In either case, we have finished with it.
339 delete sub_trainer_;
340 sub_trainer_ = nullptr;
343 log_msg->add_str_int(" Transitioned to stage ", CurrentTrainingStage());
344 }
347 STRING best_model_name = DumpFilename();
348 if (!(*file_writer_)(best_trainer_, best_model_name.c_str())) {
349 *log_msg += " failed to write best model:";
350 } else {
351 *log_msg += " wrote best model:";
353 }
354 *log_msg += best_model_name;
355 }
356 } else if (error_rate > worst_error_rate_) {
357 SaveRecognitionDump(&rec_model_data);
358 log_msg->add_str_double(" New worst char error = ", error_rate);
359 *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
362 // Error rate has ballooned. Go back to the best model.
363 *log_msg += "\nDivergence! ";
364 // Copy best_trainer_ before reading it, as it will get overwritten.
366 if (checkpoint_reader_->Run(revert_data, this)) {
367 LogIterations("Reverted to", log_msg);
368 ReduceLearningRates(this, log_msg);
369 } else {
370 LogIterations("Failed to Revert at", log_msg);
371 }
372 // If it fails again, we will wait twice as long before reverting again.
373 stall_iteration_ = iteration + 2 * (iteration - learning_iteration());
374 // Re-save the best trainer with the new learning rates and stall
375 // iteration.
377 }
378 } else {
379 // Something interesting happened only if the sub_trainer_ was trained.
380 result = sub_trainer_result != STR_NONE;
381 }
382 if (checkpoint_writer_ != nullptr && file_writer_ != nullptr &&
383 checkpoint_name_.length() > 0) {
384 // Write a current checkpoint.
385 GenericVector<char> checkpoint;
386 if (!checkpoint_writer_->Run(FULL, this, &checkpoint) ||
387 !(*file_writer_)(checkpoint, checkpoint_name_.c_str())) {
388 *log_msg += " failed to write checkpoint.";
389 } else {
390 *log_msg += " wrote checkpoint.";
391 }
392 }
393 *log_msg += "\n";
394 return result;
395}
396
397// Builds a string containing a progress message with current error rates.
398void LSTMTrainer::PrepareLogMsg(STRING* log_msg) const {
399 LogIterations("At", log_msg);
400 log_msg->add_str_double(", Mean rms=", error_rates_[ET_RMS]);
401 log_msg->add_str_double("%, delta=", error_rates_[ET_DELTA]);
402 log_msg->add_str_double("%, char train=", error_rates_[ET_CHAR_ERROR]);
403 log_msg->add_str_double("%, word train=", error_rates_[ET_WORD_RECERR]);
404 log_msg->add_str_double("%, skip ratio=", error_rates_[ET_SKIP_RATIO]);
405 *log_msg += "%, ";
406}
407
408// Appends <intro_str> iteration learning_iteration()/training_iteration()/
409// sample_iteration() to the log_msg.
410void LSTMTrainer::LogIterations(const char* intro_str, STRING* log_msg) const {
411 *log_msg += intro_str;
412 log_msg->add_str_int(" iteration ", learning_iteration());
413 log_msg->add_str_int("/", training_iteration());
414 log_msg->add_str_int("/", sample_iteration());
415}
416
417// Returns true and increments the training_stage_ if the error rate has just
418// passed through the given threshold for the first time.
419bool LSTMTrainer::TransitionTrainingStage(float error_threshold) {
420 if (best_error_rate_ < error_threshold &&
423 return true;
424 }
425 return false;
426}
427
428// Writes to the given file. Returns false in case of error.
430 const TessdataManager* mgr, TFile* fp) const {
431 if (!LSTMRecognizer::Serialize(mgr, fp)) return false;
432 if (!fp->Serialize(&learning_iteration_)) return false;
433 if (!fp->Serialize(&prev_sample_iteration_)) return false;
434 if (!fp->Serialize(&perfect_delay_)) return false;
435 if (!fp->Serialize(&last_perfect_training_iteration_)) return false;
436 for (const auto & error_buffer : error_buffers_) {
437 if (!error_buffer.Serialize(fp)) return false;
438 }
439 if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) return false;
440 if (!fp->Serialize(&training_stage_)) return false;
441 uint8_t amount = serialize_amount;
442 if (!fp->Serialize(&amount)) return false;
443 if (serialize_amount == LIGHT) return true; // We are done.
444 if (!fp->Serialize(&best_error_rate_)) return false;
445 if (!fp->Serialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
446 if (!fp->Serialize(&best_iteration_)) return false;
447 if (!fp->Serialize(&worst_error_rate_)) return false;
448 if (!fp->Serialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
449 if (!fp->Serialize(&worst_iteration_)) return false;
450 if (!fp->Serialize(&stall_iteration_)) return false;
451 if (!best_model_data_.Serialize(fp)) return false;
452 if (!worst_model_data_.Serialize(fp)) return false;
453 if (serialize_amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp))
454 return false;
455 GenericVector<char> sub_data;
456 if (sub_trainer_ != nullptr && !SaveTrainingDump(LIGHT, sub_trainer_, &sub_data))
457 return false;
458 if (!sub_data.Serialize(fp)) return false;
459 if (!best_error_history_.Serialize(fp)) return false;
460 if (!best_error_iterations_.Serialize(fp)) return false;
461 return fp->Serialize(&improvement_steps_);
462}
463
464// Reads from the given file. Returns false in case of error.
465// NOTE: It is assumed that the trainer is never read cross-endian.
467 if (!LSTMRecognizer::DeSerialize(mgr, fp)) return false;
468 if (!fp->DeSerialize(&learning_iteration_)) {
469 // Special case. If we successfully decoded the recognizer, but fail here
470 // then it means we were just given a recognizer, so issue a warning and
471 // allow it.
472 tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
475 return true;
476 }
477 if (!fp->DeSerialize(&prev_sample_iteration_)) return false;
478 if (!fp->DeSerialize(&perfect_delay_)) return false;
479 if (!fp->DeSerialize(&last_perfect_training_iteration_)) return false;
480 for (auto & error_buffer : error_buffers_) {
481 if (!error_buffer.DeSerialize(fp)) return false;
482 }
483 if (!fp->DeSerialize(&error_rates_[0], countof(error_rates_))) return false;
484 if (!fp->DeSerialize(&training_stage_)) return false;
485 uint8_t amount;
486 if (!fp->DeSerialize(&amount)) return false;
487 if (amount == LIGHT) return true; // Don't read the rest.
488 if (!fp->DeSerialize(&best_error_rate_)) return false;
489 if (!fp->DeSerialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
490 if (!fp->DeSerialize(&best_iteration_)) return false;
491 if (!fp->DeSerialize(&worst_error_rate_)) return false;
492 if (!fp->DeSerialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
493 if (!fp->DeSerialize(&worst_iteration_)) return false;
494 if (!fp->DeSerialize(&stall_iteration_)) return false;
495 if (!best_model_data_.DeSerialize(fp)) return false;
496 if (!worst_model_data_.DeSerialize(fp)) return false;
497 if (amount != NO_BEST_TRAINER && !best_trainer_.DeSerialize(fp)) return false;
498 GenericVector<char> sub_data;
499 if (!sub_data.DeSerialize(fp)) return false;
500 delete sub_trainer_;
501 if (sub_data.empty()) {
502 sub_trainer_ = nullptr;
503 } else {
505 if (!ReadTrainingDump(sub_data, sub_trainer_)) return false;
506 }
507 if (!best_error_history_.DeSerialize(fp)) return false;
508 if (!best_error_iterations_.DeSerialize(fp)) return false;
509 return fp->DeSerialize(&improvement_steps_);
510}
511
512// De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
513// learning rates (by scaling reduction, or layer specific, according to
514// NF_LAYER_SPECIFIC_LR).
516 delete sub_trainer_;
519 *log_msg += " Failed to revert to previous best for trial!";
520 delete sub_trainer_;
521 sub_trainer_ = nullptr;
522 } else {
523 log_msg->add_str_int(" Trial sub_trainer_ from iteration ",
525 // Reduce learning rate so it doesn't diverge this time.
526 sub_trainer_->ReduceLearningRates(this, log_msg);
527 // If it fails again, we will wait twice as long before reverting again.
528 int stall_offset =
530 stall_iteration_ = learning_iteration() + 2 * stall_offset;
532 // Re-save the best trainer with the new learning rates and stall iteration.
534 }
535}
536
537// While the sub_trainer_ is behind the current training iteration and its
538// training error is at least kSubTrainerMarginFraction better than the
539// current training error, trains the sub_trainer_, and returns STR_UPDATED if
540// it did anything. If it catches up, and has a better error rate than the
541// current best, as well as a margin over the current error rate, then the
542// trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
543// returned. STR_NONE is returned if the subtrainer wasn't good enough to
544// receive any training iterations.
546 double training_error = CharError();
547 double sub_error = sub_trainer_->CharError();
548 double sub_margin = (training_error - sub_error) / sub_error;
549 if (sub_margin >= kSubTrainerMarginFraction) {
550 log_msg->add_str_double(" sub_trainer=", sub_error);
551 log_msg->add_str_double(" margin=", 100.0 * sub_margin);
552 *log_msg += "\n";
553 // Catch up to current iteration.
554 int end_iteration = training_iteration();
555 while (sub_trainer_->training_iteration() < end_iteration &&
556 sub_margin >= kSubTrainerMarginFraction) {
557 int target_iteration =
559 while (sub_trainer_->training_iteration() < target_iteration) {
560 sub_trainer_->TrainOnLine(this, false);
561 }
562 STRING batch_log = "Sub:";
563 sub_trainer_->PrepareLogMsg(&batch_log);
564 batch_log += "\n";
565 tprintf("UpdateSubtrainer:%s", batch_log.string());
566 *log_msg += batch_log;
567 sub_error = sub_trainer_->CharError();
568 sub_margin = (training_error - sub_error) / sub_error;
569 }
570 if (sub_error < best_error_rate_ &&
571 sub_margin >= kSubTrainerMarginFraction) {
572 // The sub_trainer_ has won the race to a new best. Switch to it.
573 GenericVector<char> updated_trainer;
574 SaveTrainingDump(LIGHT, sub_trainer_, &updated_trainer);
575 ReadTrainingDump(updated_trainer, this);
576 log_msg->add_str_int(" Sub trainer wins at iteration ",
578 *log_msg += "\n";
579 return STR_REPLACED;
580 }
581 return STR_UPDATED;
582 }
583 return STR_NONE;
584}
585
586// Reduces network learning rates, either for everything, or for layers
587// independently, according to NF_LAYER_SPECIFIC_LR.
589 STRING* log_msg) {
591 int num_reduced = ReduceLayerLearningRates(
593 log_msg->add_str_int("\nReduced learning rate on layers: ", num_reduced);
594 } else {
596 log_msg->add_str_double("\nReduced learning rate to :", learning_rate_);
597 }
598 *log_msg += "\n";
599}
600
601// Considers reducing the learning rate independently for each layer down by
602// factor(<1), or leaving it the same, by double-training the given number of
603// samples and minimizing the amount of changing of sign of weight updates.
604// Even if it looks like all weights should remain the same, an adjustment
605// will be made to guarantee a different result when reverting to an old best.
606// Returns the number of layer learning rates that were reduced.
607int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
608 LSTMTrainer* samples_trainer) {
609 enum WhichWay {
610 LR_DOWN, // Learning rate will go down by factor.
611 LR_SAME, // Learning rate will stay the same.
612 LR_COUNT // Size of arrays.
613 };
615 int num_layers = layers.size();
616 GenericVector<int> num_weights;
617 num_weights.init_to_size(num_layers, 0);
618 GenericVector<double> bad_sums[LR_COUNT];
619 GenericVector<double> ok_sums[LR_COUNT];
620 for (int i = 0; i < LR_COUNT; ++i) {
621 bad_sums[i].init_to_size(num_layers, 0.0);
622 ok_sums[i].init_to_size(num_layers, 0.0);
623 }
624 double momentum_factor = 1.0 / (1.0 - momentum_);
625 GenericVector<char> orig_trainer;
626 samples_trainer->SaveTrainingDump(LIGHT, this, &orig_trainer);
627 for (int i = 0; i < num_layers; ++i) {
628 Network* layer = GetLayer(layers[i]);
629 num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0;
630 }
631 int iteration = sample_iteration();
632 for (int s = 0; s < num_samples; ++s) {
633 // Which way will we modify the learning rate?
634 for (int ww = 0; ww < LR_COUNT; ++ww) {
635 // Transfer momentum to learning rate and adjust by the ww factor.
636 float ww_factor = momentum_factor;
637 if (ww == LR_DOWN) ww_factor *= factor;
638 // Make a copy of *this, so we can mess about without damaging anything.
639 LSTMTrainer copy_trainer;
640 samples_trainer->ReadTrainingDump(orig_trainer, &copy_trainer);
641 // Clear the updates, doing nothing else.
642 copy_trainer.network_->Update(0.0, 0.0, 0.0, 0);
643 // Adjust the learning rate in each layer.
644 for (int i = 0; i < num_layers; ++i) {
645 if (num_weights[i] == 0) continue;
646 copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor);
647 }
648 copy_trainer.SetIteration(iteration);
649 // Train on the sample, but keep the update in updates_ instead of
650 // applying to the weights.
651 const ImageData* trainingdata =
652 copy_trainer.TrainOnLine(samples_trainer, true);
653 if (trainingdata == nullptr) continue;
654 // We'll now use this trainer again for each layer.
655 GenericVector<char> updated_trainer;
656 samples_trainer->SaveTrainingDump(LIGHT, &copy_trainer, &updated_trainer);
657 for (int i = 0; i < num_layers; ++i) {
658 if (num_weights[i] == 0) continue;
659 LSTMTrainer layer_trainer;
660 samples_trainer->ReadTrainingDump(updated_trainer, &layer_trainer);
661 Network* layer = layer_trainer.GetLayer(layers[i]);
662 // Update the weights in just the layer, using Adam if enabled.
663 layer->Update(0.0, momentum_, adam_beta_,
664 layer_trainer.training_iteration_ + 1);
665 // Zero the updates matrix again.
666 layer->Update(0.0, 0.0, 0.0, 0);
667 // Train again on the same sample, again holding back the updates.
668 layer_trainer.TrainOnLine(trainingdata, true);
669 // Count the sign changes in the updates in layer vs in copy_trainer.
670 float before_bad = bad_sums[ww][i];
671 float before_ok = ok_sums[ww][i];
672 layer->CountAlternators(*copy_trainer.GetLayer(layers[i]),
673 &ok_sums[ww][i], &bad_sums[ww][i]);
674 float bad_frac =
675 bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
676 if (bad_frac > 0.0f)
677 bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
678 }
679 }
680 ++iteration;
681 }
682 int num_lowered = 0;
683 for (int i = 0; i < num_layers; ++i) {
684 if (num_weights[i] == 0) continue;
685 Network* layer = GetLayer(layers[i]);
686 float lr = GetLayerLearningRate(layers[i]);
687 double total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
688 double total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
689 double frac_down = bad_sums[LR_DOWN][i] / total_down;
690 double frac_same = bad_sums[LR_SAME][i] / total_same;
691 tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().string(),
692 lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
693 if (frac_down < frac_same * kImprovementFraction) {
694 tprintf(" REDUCED\n");
695 ScaleLayerLearningRate(layers[i], factor);
696 ++num_lowered;
697 } else {
698 tprintf(" SAME\n");
699 }
700 }
701 if (num_lowered == 0) {
702 // Just lower everything to make sure.
703 for (int i = 0; i < num_layers; ++i) {
704 if (num_weights[i] > 0) {
705 ScaleLayerLearningRate(layers[i], factor);
706 ++num_lowered;
707 }
708 }
709 }
710 return num_lowered;
711}
712
713// Converts the string to integer class labels, with appropriate null_char_s
714// in between if not in SimpleTextOutput mode. Returns false on failure.
715/* static */
716bool LSTMTrainer::EncodeString(const STRING& str, const UNICHARSET& unicharset,
717 const UnicharCompress* recoder, bool simple_text,
718 int null_char, GenericVector<int>* labels) {
719 if (str.string() == nullptr || str.length() <= 0) {
720 tprintf("Empty truth string!\n");
721 return false;
722 }
723 int err_index;
724 GenericVector<int> internal_labels;
725 labels->truncate(0);
726 if (!simple_text) labels->push_back(null_char);
727 std::string cleaned = unicharset.CleanupString(str.string());
728 if (unicharset.encode_string(cleaned.c_str(), true, &internal_labels, nullptr,
729 &err_index)) {
730 bool success = true;
731 for (int i = 0; i < internal_labels.size(); ++i) {
732 if (recoder != nullptr) {
733 // Re-encode labels via recoder.
734 RecodedCharID code;
735 int len = recoder->EncodeUnichar(internal_labels[i], &code);
736 if (len > 0) {
737 for (int j = 0; j < len; ++j) {
738 labels->push_back(code(j));
739 if (!simple_text) labels->push_back(null_char);
740 }
741 } else {
742 success = false;
743 err_index = 0;
744 break;
745 }
746 } else {
747 labels->push_back(internal_labels[i]);
748 if (!simple_text) labels->push_back(null_char);
749 }
750 }
751 if (success) return true;
752 }
753 tprintf("Encoding of string failed! Failure bytes:");
754 while (err_index < cleaned.size()) {
755 tprintf(" %x", cleaned[err_index++]);
756 }
757 tprintf("\n");
758 return false;
759}
760
761// Performs forward-backward on the given trainingdata.
762// Returns a Trainability enum to indicate the suitability of the sample.
764 bool batch) {
765 NetworkIO fwd_outputs, targets;
766 Trainability trainable =
767 PrepareForBackward(trainingdata, &fwd_outputs, &targets);
769 if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
770 return trainable; // Sample was unusable.
771 }
772 bool debug = debug_interval_ > 0 &&
774 // Run backprop on the output.
775 NetworkIO bp_deltas;
776 if (network_->IsTraining() &&
777 (trainable != PERFECT ||
780 network_->Backward(debug, targets, &scratch_space_, &bp_deltas);
783 }
784#ifndef GRAPHICS_DISABLED
785 if (debug_interval_ == 1 && debug_win_ != nullptr) {
787 }
788#endif // GRAPHICS_DISABLED
789 // Roll the memory of past means.
791 return trainable;
792}
793
794// Prepares the ground truth, runs forward, and prepares the targets.
795// Returns a Trainability enum to indicate the suitability of the sample.
797 NetworkIO* fwd_outputs,
798 NetworkIO* targets) {
799 if (trainingdata == nullptr) {
800 tprintf("Null trainingdata.\n");
801 return UNENCODABLE;
802 }
803 // Ensure repeatability of random elements even across checkpoints.
804 bool debug = debug_interval_ > 0 &&
806 GenericVector<int> truth_labels;
807 if (!EncodeString(trainingdata->transcription(), &truth_labels)) {
808 tprintf("Can't encode transcription: '%s' in language '%s'\n",
809 trainingdata->transcription().string(),
810 trainingdata->language().string());
811 return UNENCODABLE;
812 }
813 bool upside_down = false;
814 if (randomly_rotate_) {
815 // This ensures consistent training results.
817 upside_down = randomizer_.SignedRand(1.0) > 0.0;
818 if (upside_down) {
819 // Modify the truth labels to match the rotation:
820 // Apart from space and null, increment the label. This is changes the
821 // script-id to the same script-id but upside-down.
822 // The labels need to be reversed in order, as the first is now the last.
823 for (int c = 0; c < truth_labels.size(); ++c) {
824 if (truth_labels[c] != UNICHAR_SPACE && truth_labels[c] != null_char_)
825 ++truth_labels[c];
826 }
827 truth_labels.reverse();
828 }
829 }
830 int w = 0;
831 while (w < truth_labels.size() &&
832 (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_))
833 ++w;
834 if (w == truth_labels.size()) {
835 tprintf("Blank transcription: %s\n",
836 trainingdata->transcription().string());
837 return UNENCODABLE;
838 }
839 float image_scale;
840 NetworkIO inputs;
841 bool invert = trainingdata->boxes().empty();
842 if (!RecognizeLine(*trainingdata, invert, debug, invert, upside_down,
843 &image_scale, &inputs, fwd_outputs)) {
844 tprintf("Image not trainable\n");
845 return UNENCODABLE;
846 }
847 targets->Resize(*fwd_outputs, network_->NumOutputs());
848 LossType loss_type = OutputLossType();
849 if (loss_type == LT_SOFTMAX) {
850 if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) {
851 tprintf("Compute simple targets failed!\n");
852 return UNENCODABLE;
853 }
854 } else if (loss_type == LT_CTC) {
855 if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) {
856 tprintf("Compute CTC targets failed!\n");
857 return UNENCODABLE;
858 }
859 } else {
860 tprintf("Logistic outputs not implemented yet!\n");
861 return UNENCODABLE;
862 }
863 GenericVector<int> ocr_labels;
864 GenericVector<int> xcoords;
865 LabelsFromOutputs(*fwd_outputs, &ocr_labels, &xcoords);
866 // CTC does not produce correct target labels to begin with.
867 if (loss_type != LT_CTC) {
868 LabelsFromOutputs(*targets, &truth_labels, &xcoords);
869 }
870 if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels,
871 *targets)) {
872 tprintf("Input width was %d\n", inputs.Width());
873 return UNENCODABLE;
874 }
875 STRING ocr_text = DecodeLabels(ocr_labels);
876 STRING truth_text = DecodeLabels(truth_labels);
877 targets->SubtractAllFromFloat(*fwd_outputs);
878 if (debug_interval_ != 0) {
879 if (truth_text != ocr_text) {
880 tprintf("Iteration %d: BEST OCR TEXT : %s\n",
881 training_iteration(), ocr_text.string());
882 }
883 }
884 double char_error = ComputeCharError(truth_labels, ocr_labels);
885 double word_error = ComputeWordError(&truth_text, &ocr_text);
886 double delta_error = ComputeErrorRates(*targets, char_error, word_error);
887 if (debug_interval_ != 0) {
888 tprintf("File %s line %d %s:\n", trainingdata->imagefilename().string(),
889 trainingdata->page_number(), delta_error == 0.0 ? "(Perfect)" : "");
890 }
891 if (delta_error == 0.0) return PERFECT;
893 return TRAINABLE;
894}
895
896// Writes the trainer to memory, so that the current training state can be
897// restored. *this must always be the master trainer that retains the only
898// copy of the training data and language model. trainer is the model that is
899// actually serialized.
901 const LSTMTrainer* trainer,
902 GenericVector<char>* data) const {
903 TFile fp;
904 fp.OpenWrite(data);
905 return trainer->Serialize(serialize_amount, &mgr_, &fp);
906}
907
908// Restores the model to *this.
910 const char* data, int size) {
911 if (size == 0) {
912 tprintf("Warning: data size is 0 in LSTMTrainer::ReadLocalTrainingDump\n");
913 return false;
914 }
915 TFile fp;
916 fp.Open(data, size);
917 return DeSerialize(mgr, &fp);
918}
919
920// Writes the full recognition traineddata to the given filename.
922 GenericVector<char> recognizer_data;
923 SaveRecognitionDump(&recognizer_data);
924 mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
925 recognizer_data.size());
926 return mgr_.SaveFile(filename, file_writer_);
927}
928
929// Writes the recognizer to memory, so that it can be used for testing later.
931 TFile fp;
932 fp.OpenWrite(data);
936}
937
938// Returns a suitable filename for a training dump, based on the model_base_,
939// the iteration and the error rates.
941 STRING filename;
943 filename.add_str_int("_", best_iteration_);
944 filename += ".checkpoint";
945 return filename;
946}
947
948// Fills the whole error buffer of the given type with the given value.
949void LSTMTrainer::FillErrorBuffer(double new_error, ErrorTypes type) {
950 for (int i = 0; i < kRollingBufferSize_; ++i)
951 error_buffers_[type][i] = new_error;
952 error_rates_[type] = 100.0 * new_error;
953}
954
955// Helper generates a map from each current recoder_ code (ie softmax index)
956// to the corresponding old_recoder code, or -1 if there isn't one.
957std::vector<int> LSTMTrainer::MapRecoder(
958 const UNICHARSET& old_chset, const UnicharCompress& old_recoder) const {
959 int num_new_codes = recoder_.code_range();
960 int num_new_unichars = GetUnicharset().size();
961 std::vector<int> code_map(num_new_codes, -1);
962 for (int c = 0; c < num_new_codes; ++c) {
963 int old_code = -1;
964 // Find all new unichar_ids that recode to something that includes c.
965 // The <= is to include the null char, which may be beyond the unicharset.
966 for (int uid = 0; uid <= num_new_unichars; ++uid) {
967 RecodedCharID codes;
968 int length = recoder_.EncodeUnichar(uid, &codes);
969 int code_index = 0;
970 while (code_index < length && codes(code_index) != c) ++code_index;
971 if (code_index == length) continue;
972 // The old unicharset must have the same unichar.
973 int old_uid =
974 uid < num_new_unichars
975 ? old_chset.unichar_to_id(GetUnicharset().id_to_unichar(uid))
976 : old_chset.size() - 1;
977 if (old_uid == INVALID_UNICHAR_ID) continue;
978 // The encoding of old_uid at the same code_index is the old code.
979 RecodedCharID old_codes;
980 if (code_index < old_recoder.EncodeUnichar(old_uid, &old_codes)) {
981 old_code = old_codes(code_index);
982 break;
983 }
984 }
985 code_map[c] = old_code;
986 }
987 return code_map;
988}
989
990// Private version of InitCharSet above finishes the job after initializing
991// the mgr_ data member.
995 // Initialize the unicharset and recoder.
996 if (!LoadCharsets(&mgr_)) {
998 "Must provide a traineddata containing lstm_unicharset and"
999 " lstm_recoder!\n" != nullptr);
1000 }
1001 SetNullChar();
1002}
1003
1004// Helper computes and sets the null_char_.
1007 : GetUnicharset().size();
1008 RecodedCharID code;
1010 null_char_ = code(0);
1011}
1012
1013// Factored sub-constructor sets up reasonable default values.
1015 align_win_ = nullptr;
1016 target_win_ = nullptr;
1017 ctc_win_ = nullptr;
1018 recon_win_ = nullptr;
1020 training_stage_ = 0;
1023}
1024
1025// Outputs the string and periodically displays the given network inputs
1026// as an image in the given window, and the corresponding labels at the
1027// corresponding x_starts.
1028// Returns false if the truth string is empty.
1030 const ImageData& trainingdata,
1031 const NetworkIO& fwd_outputs,
1032 const GenericVector<int>& truth_labels,
1033 const NetworkIO& outputs) {
1034 const STRING& truth_text = DecodeLabels(truth_labels);
1035 if (truth_text.string() == nullptr || truth_text.length() <= 0) {
1036 tprintf("Empty truth string at decode time!\n");
1037 return false;
1038 }
1039 if (debug_interval_ != 0) {
1040 // Get class labels, xcoords and string.
1041 GenericVector<int> labels;
1042 GenericVector<int> xcoords;
1043 LabelsFromOutputs(outputs, &labels, &xcoords);
1044 STRING text = DecodeLabels(labels);
1045 tprintf("Iteration %d: GROUND TRUTH : %s\n",
1046 training_iteration(), truth_text.string());
1047 if (truth_text != text) {
1048 tprintf("Iteration %d: ALIGNED TRUTH : %s\n",
1049 training_iteration(), text.string());
1050 }
1052 tprintf("TRAINING activation path for truth string %s\n",
1053 truth_text.string());
1054 DebugActivationPath(outputs, labels, xcoords);
1055 DisplayForward(inputs, labels, xcoords, "LSTMTraining", &align_win_);
1056 if (OutputLossType() == LT_CTC) {
1057 DisplayTargets(fwd_outputs, "CTC Outputs", &ctc_win_);
1058 DisplayTargets(outputs, "CTC Targets", &target_win_);
1059 }
1060 }
1061 }
1062 return true;
1063}
1064
1065// Displays the network targets as line a line graph.
1067 const char* window_name, ScrollView** window) {
1068#ifndef GRAPHICS_DISABLED // do nothing if there's no graphics.
1069 int width = targets.Width();
1070 int num_features = targets.NumFeatures();
1071 Network::ClearWindow(true, window_name, width * kTargetXScale, kTargetYScale,
1072 window);
1073 for (int c = 0; c < num_features; ++c) {
1074 int color = c % (ScrollView::GREEN_YELLOW - 1) + 2;
1075 (*window)->Pen(static_cast<ScrollView::Color>(color));
1076 int start_t = -1;
1077 for (int t = 0; t < width; ++t) {
1078 double target = targets.f(t)[c];
1079 target *= kTargetYScale;
1080 if (target >= 1) {
1081 if (start_t < 0) {
1082 (*window)->SetCursor(t - 1, 0);
1083 start_t = t;
1084 }
1085 (*window)->DrawTo(t, target);
1086 } else if (start_t >= 0) {
1087 (*window)->DrawTo(t, 0);
1088 (*window)->DrawTo(start_t - 1, 0);
1089 start_t = -1;
1090 }
1091 }
1092 if (start_t >= 0) {
1093 (*window)->DrawTo(width, 0);
1094 (*window)->DrawTo(start_t - 1, 0);
1095 }
1096 }
1097 (*window)->Update();
1098#endif // GRAPHICS_DISABLED
1099}
1100
1101// Builds a no-compromises target where the first positions should be the
1102// truth labels and the rest is padded with the null_char_.
1104 const GenericVector<int>& truth_labels,
1105 NetworkIO* targets) {
1106 if (truth_labels.size() > targets->Width()) {
1107 tprintf("Error: transcription %s too long to fit into target of width %d\n",
1108 DecodeLabels(truth_labels).string(), targets->Width());
1109 return false;
1110 }
1111 for (int i = 0; i < truth_labels.size() && i < targets->Width(); ++i) {
1112 targets->SetActivations(i, truth_labels[i], 1.0);
1113 }
1114 for (int i = truth_labels.size(); i < targets->Width(); ++i) {
1115 targets->SetActivations(i, null_char_, 1.0);
1116 }
1117 return true;
1118}
1119
1120// Builds a target using standard CTC. truth_labels should be pre-padded with
1121// nulls wherever desired. They don't have to be between all labels.
1122// outputs is input-output, as it gets clipped to minimum probability.
1124 NetworkIO* outputs, NetworkIO* targets) {
1125 // Bottom-clip outputs to a minimum probability.
1126 CTC::NormalizeProbs(outputs);
1127 return CTC::ComputeCTCTargets(truth_labels, null_char_,
1128 outputs->float_array(), targets);
1129}
1130
1131// Computes network errors, and stores the results in the rolling buffers,
1132// along with the supplied text_error.
1133// Returns the delta error of the current sample (not running average.)
1135 double char_error, double word_error) {
1137 // Delta error is the fraction of timesteps with >0.5 error in the top choice
1138 // score. If zero, then the top choice characters are guaranteed correct,
1139 // even when there is residue in the RMS error.
1140 double delta_error = ComputeWinnerError(deltas);
1141 UpdateErrorBuffer(delta_error, ET_DELTA);
1142 UpdateErrorBuffer(word_error, ET_WORD_RECERR);
1143 UpdateErrorBuffer(char_error, ET_CHAR_ERROR);
1144 // Skip ratio measures the difference between sample_iteration_ and
1145 // training_iteration_, which reflects the number of unusable samples,
1146 // usually due to unencodable truth text, or the text not fitting in the
1147 // space for the output.
1148 double skip_count = sample_iteration_ - prev_sample_iteration_;
1149 UpdateErrorBuffer(skip_count, ET_SKIP_RATIO);
1150 return delta_error;
1151}
1152
1153// Computes the network activation RMS error rate.
1155 double total_error = 0.0;
1156 int width = deltas.Width();
1157 int num_classes = deltas.NumFeatures();
1158 for (int t = 0; t < width; ++t) {
1159 const float* class_errs = deltas.f(t);
1160 for (int c = 0; c < num_classes; ++c) {
1161 double error = class_errs[c];
1162 total_error += error * error;
1163 }
1164 }
1165 return sqrt(total_error / (width * num_classes));
1166}
1167
1168// Computes network activation winner error rate. (Number of values that are
1169// in error by >= 0.5 divided by number of time-steps.) More closely related
1170// to final character error than RMS, but still directly calculable from
1171// just the deltas. Because of the binary nature of the targets, zero winner
1172// error is a sufficient but not necessary condition for zero char error.
1174 int num_errors = 0;
1175 int width = deltas.Width();
1176 int num_classes = deltas.NumFeatures();
1177 for (int t = 0; t < width; ++t) {
1178 const float* class_errs = deltas.f(t);
1179 for (int c = 0; c < num_classes; ++c) {
1180 float abs_delta = fabs(class_errs[c]);
1181 // TODO(rays) Filtering cases where the delta is very large to cut out
1182 // GT errors doesn't work. Find a better way or get better truth.
1183 if (0.5 <= abs_delta)
1184 ++num_errors;
1185 }
1186 }
1187 return static_cast<double>(num_errors) / width;
1188}
1189
1190// Computes a very simple bag of chars char error rate.
1192 const GenericVector<int>& ocr_str) {
1193 GenericVector<int> label_counts;
1194 label_counts.init_to_size(NumOutputs(), 0);
1195 int truth_size = 0;
1196 for (int i = 0; i < truth_str.size(); ++i) {
1197 if (truth_str[i] != null_char_) {
1198 ++label_counts[truth_str[i]];
1199 ++truth_size;
1200 }
1201 }
1202 for (int i = 0; i < ocr_str.size(); ++i) {
1203 if (ocr_str[i] != null_char_) {
1204 --label_counts[ocr_str[i]];
1205 }
1206 }
1207 int char_errors = 0;
1208 for (int i = 0; i < label_counts.size(); ++i) {
1209 char_errors += abs(label_counts[i]);
1210 }
1211 if (truth_size == 0) {
1212 return (char_errors == 0) ? 0.0 : 1.0;
1213 }
1214 return static_cast<double>(char_errors) / truth_size;
1215}
1216
1217// Computes word recall error rate using a very simple bag of words algorithm.
1218// NOTE that this is destructive on both input strings.
1219double LSTMTrainer::ComputeWordError(STRING* truth_str, STRING* ocr_str) {
1220 using StrMap = std::unordered_map<std::string, int, std::hash<std::string>>;
1221 GenericVector<STRING> truth_words, ocr_words;
1222 truth_str->split(' ', &truth_words);
1223 if (truth_words.empty()) return 0.0;
1224 ocr_str->split(' ', &ocr_words);
1225 StrMap word_counts;
1226 for (int i = 0; i < truth_words.size(); ++i) {
1227 std::string truth_word(truth_words[i].string());
1228 auto it = word_counts.find(truth_word);
1229 if (it == word_counts.end())
1230 word_counts.insert(std::make_pair(truth_word, 1));
1231 else
1232 ++it->second;
1233 }
1234 for (int i = 0; i < ocr_words.size(); ++i) {
1235 std::string ocr_word(ocr_words[i].string());
1236 auto it = word_counts.find(ocr_word);
1237 if (it == word_counts.end())
1238 word_counts.insert(std::make_pair(ocr_word, -1));
1239 else
1240 --it->second;
1241 }
1242 int word_recall_errs = 0;
1243 for (StrMap::const_iterator it = word_counts.begin(); it != word_counts.end();
1244 ++it) {
1245 if (it->second > 0) word_recall_errs += it->second;
1246 }
1247 return static_cast<double>(word_recall_errs) / truth_words.size();
1248}
1249
1250// Updates the error buffer and corresponding mean of the given type with
1251// the new_error.
1252void LSTMTrainer::UpdateErrorBuffer(double new_error, ErrorTypes type) {
1254 error_buffers_[type][index] = new_error;
1255 // Compute the mean error.
1256 int mean_count = std::min(training_iteration_ + 1, error_buffers_[type].size());
1257 double buffer_sum = 0.0;
1258 for (int i = 0; i < mean_count; ++i) buffer_sum += error_buffers_[type][i];
1259 double mean = buffer_sum / mean_count;
1260 // Trim precision to 1/1000 of 1%.
1261 error_rates_[type] = IntCastRounded(100000.0 * mean) / 1000.0;
1262}
1263
1264// Rolls error buffers and reports the current means.
1267 if (NewSingleError(ET_DELTA) > 0.0)
1269 else
1272 if (debug_interval_ != 0) {
1273 tprintf("Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n",
1277 }
1278}
1279
1280// Given that error_rate is either a new min or max, updates the best/worst
1281// error rates, and record of progress.
1282// Tester is an externally supplied callback function that tests on some
1283// data set with a given model and records the error rates in a graph.
1284STRING LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate,
1285 const GenericVector<char>& model_data,
1286 TestCallback tester) {
1287 if (error_rate > best_error_rate_
1288 && iteration < best_iteration_ + kErrorGraphInterval) {
1289 // Too soon to record a new point.
1290 if (tester != nullptr && !worst_model_data_.empty()) {
1293 return tester->Run(worst_iteration_, nullptr, mgr_, CurrentTrainingStage());
1294 } else {
1295 return "";
1296 }
1297 }
1298 STRING result;
1299 // NOTE: there are 2 asymmetries here:
1300 // 1. We are computing the global minimum, but the local maximum in between.
1301 // 2. If the tester returns an empty string, indicating that it is busy,
1302 // call it repeatedly on new local maxima to test the previous min, but
1303 // not the other way around, as there is little point testing the maxima
1304 // between very frequent minima.
1305 if (error_rate < best_error_rate_) {
1306 // This is a new (global) minimum.
1307 if (tester != nullptr && !worst_model_data_.empty()) {
1310 result = tester->Run(worst_iteration_, worst_error_rates_, mgr_,
1313 best_model_data_ = model_data;
1314 }
1315 best_error_rate_ = error_rate;
1317 best_iteration_ = iteration;
1318 best_error_history_.push_back(error_rate);
1320 // Compute 2% decay time.
1321 double two_percent_more = error_rate + 2.0;
1322 int i;
1323 for (i = best_error_history_.size() - 1;
1324 i >= 0 && best_error_history_[i] < two_percent_more; --i) {
1325 }
1326 int old_iteration = i >= 0 ? best_error_iterations_[i] : 0;
1327 improvement_steps_ = iteration - old_iteration;
1328 tprintf("2 Percent improvement time=%d, best error was %g @ %d\n",
1329 improvement_steps_, i >= 0 ? best_error_history_[i] : 100.0,
1330 old_iteration);
1331 } else if (error_rate > best_error_rate_) {
1332 // This is a new (local) maximum.
1333 if (tester != nullptr) {
1334 if (!best_model_data_.empty()) {
1337 result = tester->Run(best_iteration_, best_error_rates_, mgr_,
1339 } else if (!worst_model_data_.empty()) {
1340 // Allow for multiple data points with "worst" error rate.
1343 result = tester->Run(worst_iteration_, worst_error_rates_, mgr_,
1345 }
1346 if (result.length() > 0)
1348 worst_model_data_ = model_data;
1349 }
1350 }
1351 worst_error_rate_ = error_rate;
1353 worst_iteration_ = iteration;
1354 return result;
1355}
1356
1357} // namespace tesseract.
#define ASSERT_HOST(x)
Definition: errcode.h:88
int IntCastRounded(double x)
Definition: helpers.h:175
_ConstTessMemberResultCallback_5_0< false, R, T1, P1, P2, P3, P4, P5 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)(P1, P2, P3, P4, P5) const, typename Identity< P1 >::type p1, typename Identity< P2 >::type p2, typename Identity< P3 >::type p3, typename Identity< P4 >::type p4, typename Identity< P5 >::type p5)
Definition: tesscallback.h:258
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35
@ UNICHAR_BROKEN
Definition: unicharset.h:36
@ UNICHAR_SPACE
Definition: unicharset.h:34
LIST search(LIST list, void *key, int_compare is_equal)
Definition: oldlist.cpp:258
@ SVET_CLICK
Definition: scrollview.h:48
@ TF_COMPRESS_UNICHARSET
@ ET_WORD_RECERR
Definition: lstmtrainer.h:40
@ ET_SKIP_RATIO
Definition: lstmtrainer.h:42
@ ET_CHAR_ERROR
Definition: lstmtrainer.h:41
@ HI_PRECISION_ERR
Definition: lstmtrainer.h:51
const double kLearningRateDecay
Definition: lstmtrainer.cpp:53
bool(*)(const STRING &, GenericVector< char > *) FileReader
Definition: serialis.h:49
const double kImprovementFraction
Definition: lstmtrainer.cpp:67
const int kTargetYScale
Definition: lstmtrainer.cpp:72
@ STR_REPLACED
Definition: lstmtrainer.h:66
const int kMinStartedErrorRate
Definition: lstmtrainer.cpp:61
@ TESSDATA_LSTM_UNICHARSET
@ TESSDATA_LSTM_RECODER
const double kSubTrainerMarginFraction
Definition: lstmtrainer.cpp:51
const int kErrorGraphInterval
Definition: lstmtrainer.cpp:57
constexpr size_t countof(T const (&)[N]) noexcept
Definition: serialis.h:43
@ TS_TEMP_DISABLE
Definition: network.h:97
@ TS_ENABLED
Definition: network.h:95
@ TS_RE_ENABLE
Definition: network.h:99
bool(*)(const GenericVector< char > &, const STRING &) FileWriter
Definition: serialis.h:52
bool SaveDataToFile(const GenericVector< char > &data, const STRING &filename)
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:87
const double kMinDivergenceRate
Definition: lstmtrainer.cpp:46
const int kNumAdjustmentIterations
Definition: lstmtrainer.cpp:55
const double kHighConfidence
Definition: lstmtrainer.cpp:65
CachingStrategy
Definition: imagedata.h:42
const double kBestCheckpointFraction
Definition: lstmtrainer.cpp:69
const int kNumPagesPerBatch
Definition: lstmtrainer.cpp:59
const int kTargetXScale
Definition: lstmtrainer.cpp:71
const int kMinStallIterations
Definition: lstmtrainer.cpp:48
const double kStageTransitionThreshold
Definition: lstmtrainer.cpp:63
@ NO_BEST_TRAINER
Definition: lstmtrainer.h:58
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
void init_to_size(int size, const T &t)
int push_back(T object)
bool empty() const
Definition: genericvector.h:91
bool Serialize(FILE *fp) const
int size() const
Definition: genericvector.h:72
void truncate(int size)
bool DeSerialize(bool swap, FILE *fp)
virtual R Run(A1, A2)=0
int page_number() const
Definition: imagedata.h:132
const STRING & imagefilename() const
Definition: imagedata.h:126
const STRING & transcription() const
Definition: imagedata.h:147
const GenericVector< TBOX > & boxes() const
Definition: imagedata.h:150
const STRING & language() const
Definition: imagedata.h:141
bool LoadDocuments(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, FileReader reader)
Definition: imagedata.cpp:580
double SignedRand(double range)
Definition: helpers.h:55
void OpenWrite(GenericVector< char > *data)
Definition: serialis.cpp:296
bool Open(const STRING &filename, FileReader reader)
Definition: serialis.cpp:197
bool Serialize(const char *data, size_t count=1)
Definition: serialis.cpp:148
bool DeSerialize(char *data, size_t count=1)
Definition: serialis.cpp:104
Definition: strngs.h:45
const char * c_str() const
Definition: strngs.cpp:205
void add_str_int(const char *str, int number)
Definition: strngs.cpp:377
int32_t length() const
Definition: strngs.cpp:189
void add_str_double(const char *str, double number)
Definition: strngs.cpp:387
const char * string() const
Definition: strngs.cpp:194
void split(char c, GenericVector< STRING > *splited)
Definition: strngs.cpp:282
virtual R Run(A1, A2, A3)=0
virtual R Run(A1, A2, A3, A4)=0
void OverwriteEntry(TessdataType type, const char *data, int size)
std::string VersionString() const
void SetVersionString(const std::string &v_str)
bool GetComponent(TessdataType type, TFile *fp)
bool SaveFile(const STRING &filename, FileWriter writer) const
bool Init(const char *data_file_name)
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
bool encode_string(const char *str, bool give_up_on_failure, GenericVector< UNICHAR_ID > *encoding, GenericVector< char > *lengths, int *encoded_length) const
Definition: unicharset.cpp:259
static std::string CleanupString(const char *utf8_str)
Definition: unicharset.h:246
int size() const
Definition: unicharset.h:341
bool has_special_codes() const
Definition: unicharset.h:722
UNICHAR_ID unichar_to_id(const char *const unichar_repr) const
Definition: unicharset.cpp:210
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:388
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition: ctc.cpp:54
static void NormalizeProbs(NetworkIO *probs)
Definition: ctc.h:36
LossType OutputLossType() const
NetworkScratch scratch_space_
double learning_rate() const
bool LoadCharsets(const TessdataManager *mgr)
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
STRING DecodeLabels(const GenericVector< int > &labels)
Network * GetLayer(const STRING &id) const
void SetIteration(int iteration)
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
float GetLayerLearningRate(const STRING &id) const
void ScaleLearningRate(double factor)
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
bool Serialize(const TessdataManager *mgr, TFile *fp) const
const UNICHARSET & GetUnicharset() const
void ScaleLayerLearningRate(const STRING &id, double factor)
GenericVector< STRING > EnumerateLayers() const
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
bool SaveTraineddata(const STRING &filename)
bool TransitionTrainingStage(float error_threshold)
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:479
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
void PrepareLogMsg(STRING *log_msg) const
ScrollView * target_win_
Definition: lstmtrainer.h:399
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:444
int InitTensorFlowNetwork(const std::string &tf_proto)
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:154
double CharError() const
Definition: lstmtrainer.h:139
void StartSubtrainer(STRING *log_msg)
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:438
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:291
bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:424
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:452
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
ScrollView * recon_win_
Definition: lstmtrainer.h:403
void FillErrorBuffer(double new_error, ErrorTypes type)
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
int learning_iteration() const
Definition: lstmtrainer.h:149
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
double ComputeRMSError(const NetworkIO &deltas)
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:457
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:450
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
double ComputeWinnerError(const NetworkIO &deltas)
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:447
void SaveRecognitionDump(GenericVector< char > *data) const
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
void UpdateErrorBuffer(double new_error, ErrorTypes type)
ScrollView * ctc_win_
Definition: lstmtrainer.h:401
int CurrentTrainingStage() const
Definition: lstmtrainer.h:211
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:458
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:425
DocumentCache training_data_
Definition: lstmtrainer.h:414
STRING DumpFilename() const
static const int kRollingBufferSize_
Definition: lstmtrainer.h:478
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:445
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
void LogIterations(const char *intro_str, STRING *log_msg) const
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:246
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:259
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
TessdataManager mgr_
Definition: lstmtrainer.h:483
ScrollView * align_win_
Definition: lstmtrainer.h:397
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
virtual int RemapOutputs(int old_no, const std::vector< int > &code_map)
Definition: network.h:186
int NumOutputs() const
Definition: network.h:123
int num_weights() const
Definition: network.h:119
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:312
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:110
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
bool IsTraining() const
Definition: network.h:115
virtual void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
Definition: network.h:230
virtual void DebugWeights()=0
virtual STRING spec() const
Definition: network.h:141
const STRING & name() const
Definition: network.h:138
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:133
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
virtual void CountAlternators(const Network &other, double *same, double *changed) const
Definition: network.h:235
virtual StaticShape InputShape() const
Definition: network.h:127
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
float * f(int t)
Definition: networkio.h:115
int Width() const
Definition: networkio.h:107
void SetActivations(int t, int label, float ok_score)
Definition: networkio.cpp:537
bool AnySuspiciousTruth(float confidence_thr) const
Definition: networkio.cpp:579
void SubtractAllFromFloat(const NetworkIO &src)
Definition: networkio.cpp:824
const GENERIC_2D_ARRAY< float > & float_array() const
Definition: networkio.h:139
int NumFeatures() const
Definition: networkio.h:111
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:76
void ExtractBestPathAsLabels(GenericVector< int > *labels, GenericVector< int > *xcoords) const
Definition: recodebeam.cpp:133
static constexpr float kMinCertainty
Definition: recodebeam.h:222
SVEvent * AwaitEvent(SVEventType type)
Definition: scrollview.cpp:443