tesseract 4.1.1
Loading...
Searching...
No Matches
lstmtraining.cpp
Go to the documentation of this file.
1
2// File: lstmtraining.cpp
3// Description: Training program for LSTM-based networks.
4// Author: Ray Smith
5//
6// (C) Copyright 2013, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18#ifdef GOOGLE_TESSERACT
19#include "base/commandlineflags.h"
20#endif
21#include <cerrno>
22#include "commontraining.h"
23#include "fileio.h" // for LoadFileLinesToStrings
24#include "lstmtester.h"
25#include "lstmtrainer.h"
26#include "params.h"
27#include "strngs.h"
28#include "tprintf.h"
30
31static INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment.");
32static STRING_PARAM_FLAG(net_spec, "", "Network specification");
33static INT_PARAM_FLAG(net_mode, 192, "Controls network behavior.");
34static INT_PARAM_FLAG(perfect_sample_delay, 0,
35 "How many imperfect samples between perfect ones.");
36static DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent.");
37static DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights.");
38static DOUBLE_PARAM_FLAG(learning_rate, 10.0e-4, "Weight factor for new deltas.");
39static DOUBLE_PARAM_FLAG(momentum, 0.5, "Decay factor for repeating deltas.");
40static DOUBLE_PARAM_FLAG(adam_beta, 0.999, "Decay factor for repeating deltas.");
41static INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images.");
42static STRING_PARAM_FLAG(continue_from, "", "Existing model to extend");
43static STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models");
44static STRING_PARAM_FLAG(train_listfile, "",
45 "File listing training files in lstmf training format.");
46static STRING_PARAM_FLAG(eval_listfile, "",
47 "File listing eval files in lstmf training format.");
48static BOOL_PARAM_FLAG(stop_training, false,
49 "Just convert the training model to a runtime model.");
50static BOOL_PARAM_FLAG(convert_to_int, false,
51 "Convert the recognition model to an integer model.");
52static BOOL_PARAM_FLAG(sequential_training, false,
53 "Use the training files sequentially instead of round-robin.");
54static INT_PARAM_FLAG(append_index, -1, "Index in continue_from Network at which to"
55 " attach the new network defined by net_spec");
56static BOOL_PARAM_FLAG(debug_network, false,
57 "Get info on distribution of weight values");
58static INT_PARAM_FLAG(max_iterations, 0, "If set, exit after this many iterations");
59static STRING_PARAM_FLAG(traineddata, "",
60 "Combined Dawgs/Unicharset/Recoder for language model");
61static STRING_PARAM_FLAG(old_traineddata, "",
62 "When changing the character set, this specifies the old"
63 " character set that is to be replaced");
64static BOOL_PARAM_FLAG(randomly_rotate, false,
65 "Train OSD and randomly turn training samples upside-down");
66
67// Number of training images to train between calls to MaintainCheckpoints.
68const int kNumPagesPerBatch = 100;
69
70// Apart from command-line flags, input is a collection of lstmf files, that
71// were previously created using tesseract with the lstm.train config file.
72// The program iterates over the inputs, feeding the data to the network,
73// until the error rate reaches a specified target or max_iterations is reached.
74int main(int argc, char **argv) {
75 tesseract::CheckSharedLibraryVersion();
76 ParseArguments(&argc, &argv);
77 if (FLAGS_model_output.empty()) {
78 tprintf("Must provide a --model_output!\n");
79 return EXIT_FAILURE;
80 }
81 if (FLAGS_traineddata.empty()) {
82 tprintf("Must provide a --traineddata see training wiki\n");
83 return EXIT_FAILURE;
84 }
85
86 // Check write permissions.
87 STRING test_file = FLAGS_model_output.c_str();
88 test_file += "_wtest";
89 FILE* f = fopen(test_file.c_str(), "wb");
90 if (f != nullptr) {
91 fclose(f);
92 if (remove(test_file.c_str()) != 0) {
93 tprintf("Error, failed to remove %s: %s\n",
94 test_file.c_str(), strerror(errno));
95 return EXIT_FAILURE;
96 }
97 } else {
98 tprintf("Error, model output cannot be written: %s\n", strerror(errno));
99 return EXIT_FAILURE;
100 }
101
102 // Setup the trainer.
103 STRING checkpoint_file = FLAGS_model_output.c_str();
104 checkpoint_file += "_checkpoint";
105 STRING checkpoint_bak = checkpoint_file + ".bak";
107 nullptr, nullptr, nullptr, nullptr, FLAGS_model_output.c_str(),
108 checkpoint_file.c_str(), FLAGS_debug_interval,
109 static_cast<int64_t>(FLAGS_max_image_MB) * 1048576);
110 trainer.InitCharSet(FLAGS_traineddata.c_str());
111
112 // Reading something from an existing model doesn't require many flags,
113 // so do it now and exit.
114 if (FLAGS_stop_training || FLAGS_debug_network) {
115 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(), nullptr)) {
116 tprintf("Failed to read continue from: %s\n",
117 FLAGS_continue_from.c_str());
118 return EXIT_FAILURE;
119 }
120 if (FLAGS_debug_network) {
121 trainer.DebugNetwork();
122 } else {
123 if (FLAGS_convert_to_int) trainer.ConvertToInt();
124 if (!trainer.SaveTraineddata(FLAGS_model_output.c_str())) {
125 tprintf("Failed to write recognition model : %s\n",
126 FLAGS_model_output.c_str());
127 }
128 }
129 return EXIT_SUCCESS;
130 }
131
132 // Get the list of files to process.
133 if (FLAGS_train_listfile.empty()) {
134 tprintf("Must supply a list of training filenames! --train_listfile\n");
135 return EXIT_FAILURE;
136 }
137 GenericVector<STRING> filenames;
138 if (!tesseract::LoadFileLinesToStrings(FLAGS_train_listfile.c_str(),
139 &filenames)) {
140 tprintf("Failed to load list of training filenames from %s\n",
141 FLAGS_train_listfile.c_str());
142 return EXIT_FAILURE;
143 }
144
145 // Checkpoints always take priority if they are available.
146 if (trainer.TryLoadingCheckpoint(checkpoint_file.string(), nullptr) ||
147 trainer.TryLoadingCheckpoint(checkpoint_bak.string(), nullptr)) {
148 tprintf("Successfully restored trainer from %s\n",
149 checkpoint_file.string());
150 } else {
151 if (!FLAGS_continue_from.empty()) {
152 // Load a past model file to improve upon.
153 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(),
154 FLAGS_append_index >= 0
155 ? FLAGS_continue_from.c_str()
156 : FLAGS_old_traineddata.c_str())) {
157 tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str());
158 return EXIT_FAILURE;
159 }
160 tprintf("Continuing from %s\n", FLAGS_continue_from.c_str());
161 trainer.InitIterations();
162 }
163 if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
164 if (FLAGS_append_index >= 0) {
165 tprintf("Appending a new network to an old one!!");
166 if (FLAGS_continue_from.empty()) {
167 tprintf("Must set --continue_from for appending!\n");
168 return EXIT_FAILURE;
169 }
170 }
171 // We are initializing from scratch.
172 if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index,
173 FLAGS_net_mode, FLAGS_weight_range,
174 FLAGS_learning_rate, FLAGS_momentum,
175 FLAGS_adam_beta)) {
176 tprintf("Failed to create network from spec: %s\n",
177 FLAGS_net_spec.c_str());
178 return EXIT_FAILURE;
179 }
180 trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
181 }
182 }
183 if (!trainer.LoadAllTrainingData(filenames,
184 FLAGS_sequential_training
187 FLAGS_randomly_rotate)) {
188 tprintf("Load of images failed!!\n");
189 return EXIT_FAILURE;
190 }
191
192 tesseract::LSTMTester tester(static_cast<int64_t>(FLAGS_max_image_MB) *
193 1048576);
194 tesseract::TestCallback tester_callback = nullptr;
195 if (!FLAGS_eval_listfile.empty()) {
196 if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) {
197 tprintf("Failed to load eval data from: %s\n",
198 FLAGS_eval_listfile.c_str());
199 return EXIT_FAILURE;
200 }
201 tester_callback =
203 }
204 do {
205 // Train a few.
206 int iteration = trainer.training_iteration();
207 for (int target_iteration = iteration + kNumPagesPerBatch;
208 iteration < target_iteration &&
209 (iteration < FLAGS_max_iterations || FLAGS_max_iterations == 0);
210 iteration = trainer.training_iteration()) {
211 trainer.TrainOnLine(&trainer, false);
212 }
213 STRING log_str;
214 trainer.MaintainCheckpoints(tester_callback, &log_str);
215 tprintf("%s\n", log_str.string());
216 } while (trainer.best_error_rate() > FLAGS_target_error_rate &&
217 (trainer.training_iteration() < FLAGS_max_iterations ||
218 FLAGS_max_iterations == 0));
219 delete tester_callback;
220 tprintf("Finished! Error rate = %g\n", trainer.best_error_rate());
221 return EXIT_SUCCESS;
222} /* main */
_ConstTessMemberResultCallback_5_0< false, R, T1, P1, P2, P3, P4, P5 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)(P1, P2, P3, P4, P5) const, typename Identity< P1 >::type p1, typename Identity< P2 >::type p2, typename Identity< P3 >::type p3, typename Identity< P4 >::type p4, typename Identity< P5 >::type p5)
Definition: tesscallback.h:258
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35
#define DOUBLE_PARAM_FLAG(name, val, comment)
#define BOOL_PARAM_FLAG(name, val, comment)
#define INT_PARAM_FLAG(name, val, comment)
#define STRING_PARAM_FLAG(name, val, comment)
void ParseArguments(int *argc, char ***argv)
int main(int argc, char **argv)
const int kNumPagesPerBatch
@ CS_SEQUENTIAL
Definition: imagedata.h:49
@ CS_ROUND_ROBIN
Definition: imagedata.h:54
bool LoadFileLinesToStrings(const char *filename, GenericVector< STRING > *lines)
Definition: fileio.h:31
Definition: strngs.h:45
const char * c_str() const
Definition: strngs.cpp:205
const char * string() const
Definition: strngs.cpp:194
bool SaveTraineddata(const STRING &filename)
double best_error_rate() const
Definition: lstmtrainer.h:143
void InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:109
void set_perfect_delay(int delay)
Definition: lstmtrainer.h:151
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:259
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
bool LoadAllEvalData(const STRING &filenames_file)
Definition: lstmtester.cpp:32
STRING RunEvalAsync(int iteration, const double *training_errors, const TessdataManager &model_mgr, int training_stage)
Definition: lstmtester.cpp:54