tesseract 4.1.1
Loading...
Searching...
No Matches
lstmtester.cpp
Go to the documentation of this file.
1
2// File: lstmtester.cpp
3// Description: Top-level line evaluation class for LSTM-based networks.
4// Author: Ray Smith
5// Created: Wed Nov 23 11:18:06 PST 2016
6//
7// (C) Copyright 2016, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18
19#include <thread> // for std::thread
20#include "fileio.h" // for LoadFileLinesToStrings
21#include "lstmtester.h"
22#include "genericvector.h"
23
24namespace tesseract {
25
26LSTMTester::LSTMTester(int64_t max_memory)
27 : test_data_(max_memory), total_pages_(0), async_running_(false) {}
28
29// Loads a set of lstmf files that were created using the lstm.train config to
30// tesseract into memory ready for testing. Returns false if nothing was
31// loaded. The arg is a filename of a file that lists the filenames.
32bool LSTMTester::LoadAllEvalData(const STRING& filenames_file) {
33 GenericVector<STRING> filenames;
34 if (!LoadFileLinesToStrings(filenames_file.c_str(), &filenames)) {
35 tprintf("Failed to load list of eval filenames from %s\n",
36 filenames_file.string());
37 return false;
38 }
39 return LoadAllEvalData(filenames);
40}
41
42// Loads a set of lstmf files that were created using the lstm.train config to
43// tesseract into memory ready for testing. Returns false if nothing was
44// loaded.
46 test_data_.Clear();
47 bool result = test_data_.LoadDocuments(filenames, CS_SEQUENTIAL, nullptr);
48 total_pages_ = test_data_.TotalPages();
49 return result;
50}
51
52// Runs an evaluation asynchronously on the stored data and returns a string
53// describing the results of the previous test.
54STRING LSTMTester::RunEvalAsync(int iteration, const double* training_errors,
55 const TessdataManager& model_mgr,
56 int training_stage) {
57 STRING result;
58 if (total_pages_ == 0) {
59 result.add_str_int("No test data at iteration", iteration);
60 return result;
61 }
62 if (!LockIfNotRunning()) {
63 result.add_str_int("Previous test incomplete, skipping test at iteration",
64 iteration);
65 return result;
66 }
67 // Save the args.
68 STRING prev_result = test_result_;
69 test_result_ = "";
70 if (training_errors != nullptr) {
71 test_iteration_ = iteration;
72 test_training_errors_ = training_errors;
73 test_model_mgr_ = model_mgr;
74 test_training_stage_ = training_stage;
75 SVSync::StartThread(&LSTMTester::ThreadFunc, this);
76 } else {
77 UnlockRunning();
78 }
79 return prev_result;
80}
81
82// Runs an evaluation synchronously on the stored data and returns a string
83// describing the results.
84STRING LSTMTester::RunEvalSync(int iteration, const double* training_errors,
85 const TessdataManager& model_mgr,
86 int training_stage, int verbosity) {
87 LSTMTrainer trainer;
88 trainer.InitCharSet(model_mgr);
89 TFile fp;
90 if (!model_mgr.GetComponent(TESSDATA_LSTM, &fp) ||
91 !trainer.DeSerialize(&model_mgr, &fp)) {
92 return "Deserialize failed";
93 }
94 int eval_iteration = 0;
95 double char_error = 0.0;
96 double word_error = 0.0;
97 int error_count = 0;
98 while (error_count < total_pages_) {
99 const ImageData* trainingdata = test_data_.GetPageBySerial(eval_iteration);
100 trainer.SetIteration(++eval_iteration);
101 NetworkIO fwd_outputs, targets;
102 Trainability result =
103 trainer.PrepareForBackward(trainingdata, &fwd_outputs, &targets);
104 if (result != UNENCODABLE) {
105 char_error += trainer.NewSingleError(tesseract::ET_CHAR_ERROR);
106 word_error += trainer.NewSingleError(tesseract::ET_WORD_RECERR);
107 ++error_count;
108 if (verbosity > 1 || (verbosity > 0 && result != PERFECT)) {
109 tprintf("Truth:%s\n", trainingdata->transcription().string());
110 GenericVector<int> ocr_labels;
111 GenericVector<int> xcoords;
112 trainer.LabelsFromOutputs(fwd_outputs, &ocr_labels, &xcoords);
113 STRING ocr_text = trainer.DecodeLabels(ocr_labels);
114 tprintf("OCR :%s\n", ocr_text.string());
115 }
116 }
117 }
118 char_error *= 100.0 / total_pages_;
119 word_error *= 100.0 / total_pages_;
120 STRING result;
121 result.add_str_int("At iteration ", iteration);
122 result.add_str_int(", stage ", training_stage);
123 result.add_str_double(", Eval Char error rate=", char_error);
124 result.add_str_double(", Word error rate=", word_error);
125 return result;
126}
127
128// Static helper thread function for RunEvalAsync, with a specific signature
129// required by SVSync::StartThread. Actually a member function pretending to
130// be static, its arg is a this pointer that it will cast back to LSTMTester*
131// to call RunEvalSync using the stored args that RunEvalAsync saves in *this.
132// LockIfNotRunning must have returned true before calling ThreadFunc, and
133// it will call UnlockRunning to release the lock after RunEvalSync completes.
134/* static */
135void* LSTMTester::ThreadFunc(void* lstmtester_void) {
136 LSTMTester* lstmtester = static_cast<LSTMTester*>(lstmtester_void);
137 lstmtester->test_result_ = lstmtester->RunEvalSync(
138 lstmtester->test_iteration_, lstmtester->test_training_errors_,
139 lstmtester->test_model_mgr_, lstmtester->test_training_stage_,
140 /*verbosity*/ 0);
141 lstmtester->UnlockRunning();
142 return lstmtester_void;
143}
144
145// Returns true if there is currently nothing running, and takes the lock
146// if there is nothing running.
147bool LSTMTester::LockIfNotRunning() {
148 SVAutoLock lock(&running_mutex_);
149 if (async_running_) return false;
150 async_running_ = true;
151 return true;
152}
153
154// Releases the running lock.
155void LSTMTester::UnlockRunning() {
156 SVAutoLock lock(&running_mutex_);
157 async_running_ = false;
158}
159
160} // namespace tesseract
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35
@ ET_WORD_RECERR
Definition: lstmtrainer.h:40
@ ET_CHAR_ERROR
Definition: lstmtrainer.h:41
@ CS_SEQUENTIAL
Definition: imagedata.h:49
bool LoadFileLinesToStrings(const char *filename, GenericVector< STRING > *lines)
Definition: fileio.h:31
const STRING & transcription() const
Definition: imagedata.h:147
const ImageData * GetPageBySerial(int serial)
Definition: imagedata.h:344
bool LoadDocuments(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, FileReader reader)
Definition: imagedata.cpp:580
Definition: strngs.h:45
const char * c_str() const
Definition: strngs.cpp:205
void add_str_int(const char *str, int number)
Definition: strngs.cpp:377
void add_str_double(const char *str, double number)
Definition: strngs.cpp:387
const char * string() const
Definition: strngs.cpp:194
bool GetComponent(TessdataType type, TFile *fp)
STRING DecodeLabels(const GenericVector< int > &labels)
void SetIteration(int iteration)
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:154
void InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:109
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
LSTMTester(int64_t max_memory)
Definition: lstmtester.cpp:26
bool LoadAllEvalData(const STRING &filenames_file)
Definition: lstmtester.cpp:32
STRING RunEvalSync(int iteration, const double *training_errors, const TessdataManager &model_mgr, int training_stage, int verbosity)
Definition: lstmtester.cpp:84
STRING RunEvalAsync(int iteration, const double *training_errors, const TessdataManager &model_mgr, int training_stage)
Definition: lstmtester.cpp:54
static void StartThread(void *(*func)(void *), void *arg)
Create new thread.
Definition: svutil.cpp:81