tesseract 4.1.1
Loading...
Searching...
No Matches
mastertrainer.h
Go to the documentation of this file.
1// Copyright 2010 Google Inc. All Rights Reserved.
2// Author: rays@google.com (Ray Smith)
4// File: mastertrainer.h
5// Description: Trainer to build the MasterClassifier.
6// Author: Ray Smith
7// Created: Wed Nov 03 18:07:01 PDT 2010
8//
9// (C) Copyright 2010, Google Inc.
10// Licensed under the Apache License, Version 2.0 (the "License");
11// you may not use this file except in compliance with the License.
12// You may obtain a copy of the License at
13// http://www.apache.org/licenses/LICENSE-2.0
14// Unless required by applicable law or agreed to in writing, software
15// distributed under the License is distributed on an "AS IS" BASIS,
16// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17// See the License for the specific language governing permissions and
18// limitations under the License.
19//
21
22#ifndef TESSERACT_TRAINING_MASTERTRAINER_H_
23#define TESSERACT_TRAINING_MASTERTRAINER_H_
24
28#include "classify.h"
29#include "cluster.h"
30#include "intfx.h"
31#include "elst.h"
32#include "errorcounter.h"
33#include "featdefs.h"
34#include "fontinfo.h"
35#include "indexmapbidi.h"
36#include "intfeaturespace.h"
37#include "intfeaturemap.h"
38#include "intmatcher.h"
39#include "params.h"
40#include "shapetable.h"
41#include "trainingsample.h"
42#include "trainingsampleset.h"
43#include "unicharset.h"
44
45namespace tesseract {
46
47class ShapeClassifier;
48
49// Simple struct to hold the distance between two shapes during clustering.
50struct ShapeDist {
51 ShapeDist() : shape1(0), shape2(0), distance(0.0f) {}
52 ShapeDist(int s1, int s2, float dist)
53 : shape1(s1), shape2(s2), distance(dist) {}
54
55 // Sort operator to sort in ascending order of distance.
56 bool operator<(const ShapeDist& other) const {
57 return distance < other.distance;
58 }
59
60 int shape1;
61 int shape2;
62 float distance;
63};
64
65// Class to encapsulate training processes that use the TrainingSampleSet.
66// Initially supports shape clustering and mftrainining.
67// Other important features of the MasterTrainer are conditioning the data
68// by outlier elimination, replication with perturbation, and serialization.
70 public:
71 MasterTrainer(NormalizationMode norm_mode, bool shape_analysis,
72 bool replicate_samples, int debug_level);
74
75 // Writes to the given file. Returns false in case of error.
76 bool Serialize(FILE* fp) const;
77
78 // Loads an initial unicharset, or sets one up if the file cannot be read.
79 void LoadUnicharset(const char* filename);
80
81 // Sets the feature space definition.
83 feature_space_ = fs;
84 feature_map_.Init(fs);
85 }
86
87 // Reads the samples and their features from the given file,
88 // adding them to the trainer with the font_id from the content of the file.
89 // If verification, then these are verification samples, not training.
90 void ReadTrainingSamples(const char* page_name,
92 bool verification);
93
94 // Adds the given single sample to the trainer, setting the classid
95 // appropriately from the given unichar_str.
96 void AddSample(bool verification, const char* unichar_str,
98
99 // Loads all pages from the given tif filename and append to page_images_.
100 // Must be called after ReadTrainingSamples, as the current number of images
101 // is used as an offset for page numbers in the samples.
102 void LoadPageImages(const char* filename);
103
104 // Cleans up the samples after initial load from the tr files, and prior to
105 // saving the MasterTrainer:
106 // Remaps fragmented chars if running shape analysis.
107 // Sets up the samples appropriately for class/fontwise access.
108 // Deletes outlier samples.
109 void PostLoadCleanup();
110
111 // Gets the samples ready for training. Use after both
112 // ReadTrainingSamples+PostLoadCleanup or DeSerialize.
113 // Re-indexes the features and computes canonical and cloud features.
114 void PreTrainingSetup();
115
116 // Sets up the master_shapes_ table, which tells which fonts should stay
117 // together until they get to a leaf node classifier.
118 void SetupMasterShapes();
119
120 // Adds the junk_samples_ to the main samples_ set. Junk samples are initially
121 // fragments and n-grams (all incorrectly segmented characters).
122 // Various training functions may result in incorrectly segmented characters
123 // being added to the unicharset of the main samples, perhaps because they
124 // form a "radical" decomposition of some (Indic) grapheme, or because they
125 // just look the same as a real character (like rn/m)
126 // This function moves all the junk samples, to the main samples_ set, but
127 // desirable junk, being any sample for which the unichar already exists in
128 // the samples_ unicharset gets the unichar-ids re-indexed to match, but
129 // anything else gets re-marked as unichar_id 0 (space character) to identify
130 // it as junk to the error counter.
131 void IncludeJunk();
132
133 // Replicates the samples and perturbs them if the enable_replication_ flag
134 // is set. MUST be used after the last call to OrganizeByFontAndClass on
135 // the training samples, ie after IncludeJunk if it is going to be used, as
136 // OrganizeByFontAndClass will eat the replicated samples into the regular
137 // samples.
139
140 // Loads the basic font properties file into fontinfo_table_.
141 // Returns false on failure.
142 bool LoadFontInfo(const char* filename);
143
144 // Loads the xheight font properties file into xheights_.
145 // Returns false on failure.
146 bool LoadXHeights(const char* filename);
147
148 // Reads spacing stats from filename and adds them to fontinfo_table.
149 // Returns false on failure.
150 bool AddSpacingInfo(const char *filename);
151
152 // Returns the font id corresponding to the given font name.
153 // Returns -1 if the font cannot be found.
154 int GetFontInfoId(const char* font_name);
155 // Returns the font_id of the closest matching font name to the given
156 // filename. It is assumed that a substring of the filename will match
157 // one of the fonts. If more than one is matched, the longest is returned.
158 int GetBestMatchingFontInfoId(const char* filename);
159
160 // Returns the filename of the tr file corresponding to the command-line
161 // argument with the given index.
162 const STRING& GetTRFileName(int index) const {
163 return tr_filenames_[index];
164 }
165
166 // Sets up a flat shapetable with one shape per class/font combination.
167 void SetupFlatShapeTable(ShapeTable* shape_table);
168
169 // Sets up a Clusterer for mftraining on a single shape_id.
170 // Call FreeClusterer on the return value after use.
171 CLUSTERER* SetupForClustering(const ShapeTable& shape_table,
173 int shape_id, int* num_samples);
174
175 // Writes the given float_classes (produced by SetupForFloat2Int) as inttemp
176 // to the given inttemp_file, and the corresponding pffmtable.
177 // The unicharset is the original encoding of graphemes, and shape_set should
178 // match the size of the shape_table, and may possibly be totally fake.
180 const UNICHARSET& shape_set,
181 const ShapeTable& shape_table,
182 CLASS_STRUCT* float_classes,
183 const char* inttemp_file,
184 const char* pffmtable_file);
185
186 const UNICHARSET& unicharset() const {
187 return samples_.unicharset();
188 }
190 return &samples_;
191 }
192 const ShapeTable& master_shapes() const {
193 return master_shapes_;
194 }
195
196 // Generates debug output relating to the canonical distance between the
197 // two given UTF8 grapheme strings.
198 void DebugCanonical(const char* unichar_str1, const char* unichar_str2);
199 #ifndef GRAPHICS_DISABLED
200 // Debugging for cloud/canonical features.
201 // Displays a Features window containing:
202 // If unichar_str2 is in the unicharset, and canonical_font is non-negative,
203 // displays the canonical features of the char/font combination in red.
204 // If unichar_str1 is in the unicharset, and cloud_font is non-negative,
205 // displays the cloud feature of the char/font combination in green.
206 // The canonical features are drawn first to show which ones have no
207 // matches in the cloud features.
208 // Until the features window is destroyed, each click in the features window
209 // will display the samples that have that feature in a separate window.
210 void DisplaySamples(const char* unichar_str1, int cloud_font,
211 const char* unichar_str2, int canonical_font);
212 #endif // GRAPHICS_DISABLED
213
214 void TestClassifierVOld(bool replicate_samples,
215 ShapeClassifier* test_classifier,
216 ShapeClassifier* old_classifier);
217
218 // Tests the given test_classifier on the internal samples.
219 // See TestClassifier for details.
220 void TestClassifierOnSamples(CountTypes error_mode,
221 int report_level,
222 bool replicate_samples,
223 ShapeClassifier* test_classifier,
224 STRING* report_string);
225 // Tests the given test_classifier on the given samples
226 // error_mode indicates what counts as an error.
227 // report_levels:
228 // 0 = no output.
229 // 1 = bottom-line error rate.
230 // 2 = bottom-line error rate + time.
231 // 3 = font-level error rate + time.
232 // 4 = list of all errors + short classifier debug output on 16 errors.
233 // 5 = list of all errors + short classifier debug output on 25 errors.
234 // If replicate_samples is true, then the test is run on an extended test
235 // sample including replicated and systematically perturbed samples.
236 // If report_string is non-nullptr, a summary of the results for each font
237 // is appended to the report_string.
238 double TestClassifier(CountTypes error_mode,
239 int report_level,
240 bool replicate_samples,
241 TrainingSampleSet* samples,
242 ShapeClassifier* test_classifier,
243 STRING* report_string);
244
245 // Returns the average (in some sense) distance between the two given
246 // shapes, which may contain multiple fonts and/or unichars.
247 // This function is public to facilitate testing.
248 float ShapeDistance(const ShapeTable& shapes, int s1, int s2);
249
250 private:
251 // Replaces samples that are always fragmented with the corresponding
252 // fragment samples.
253 void ReplaceFragmentedSamples();
254
255 // Runs a hierarchical agglomerative clustering to merge shapes in the given
256 // shape_table, while satisfying the given constraints:
257 // * End with at least min_shapes left in shape_table,
258 // * No shape shall have more than max_shape_unichars in it,
259 // * Don't merge shapes where the distance between them exceeds max_dist.
260 void ClusterShapes(int min_shapes, int max_shape_unichars,
261 float max_dist, ShapeTable* shape_table);
262
263 private:
264 NormalizationMode norm_mode_;
265 // Character set we are training for.
266 UNICHARSET unicharset_;
267 // Original feature space. Subspace mapping is contained in feature_map_.
268 IntFeatureSpace feature_space_;
269 TrainingSampleSet samples_;
270 TrainingSampleSet junk_samples_;
271 TrainingSampleSet verify_samples_;
272 // Master shape table defines what fonts stay together until the leaves.
273 ShapeTable master_shapes_;
274 // Flat shape table has each unichar/font id pair in a separate shape.
275 ShapeTable flat_shapes_;
276 // Font metrics gathered from multiple files.
277 FontInfoTable fontinfo_table_;
278 // Array of xheights indexed by font ids in fontinfo_table_;
279 GenericVector<int32_t> xheights_;
280
281 // Non-serialized data initialized by other means or used temporarily
282 // during loading of training samples.
283 // Number of different class labels in unicharset_.
284 int charsetsize_;
285 // Flag to indicate that we are running shape analysis and need fragments
286 // fixing.
287 bool enable_shape_analysis_;
288 // Flag to indicate that sample replication is required.
289 bool enable_replication_;
290 // Array of classids of fragments that replace the correctly segmented chars.
291 int* fragments_;
292 // Classid of previous correctly segmented sample that was added.
293 int prev_unichar_id_;
294 // Debug output control.
295 int debug_level_;
296 // Feature map used to construct reduced feature spaces for compact
297 // classifiers.
298 IntFeatureMap feature_map_;
299 // Vector of Pix pointers used for classifiers that need the image.
300 // Indexed by page_num_ in the samples.
301 // These images are owned by the trainer and need to be pixDestroyed.
302 GenericVector<Pix*> page_images_;
303 // Vector of filenames of loaded tr files.
304 GenericVector<STRING> tr_filenames_;
305};
306
307} // namespace tesseract.
308
309#endif // TESSERACT_TRAINING_MASTERTRAINER_H_
FEATURE_DEFS_STRUCT feature_defs
NormalizationMode
Definition: normalis.h:42
Definition: strngs.h:45
Definition: cluster.h:32
void Init(const IntFeatureSpace &feature_space)
const UNICHARSET & unicharset() const
bool operator<(const ShapeDist &other) const
Definition: mastertrainer.h:56
ShapeDist(int s1, int s2, float dist)
Definition: mastertrainer.h:52
bool LoadFontInfo(const char *filename)
int GetBestMatchingFontInfoId(const char *filename)
TrainingSampleSet * GetSamples()
void DisplaySamples(const char *unichar_str1, int cloud_font, const char *unichar_str2, int canonical_font)
void LoadUnicharset(const char *filename)
void ReplicateAndRandomizeSamplesIfRequired()
void LoadPageImages(const char *filename)
int GetFontInfoId(const char *font_name)
bool Serialize(FILE *fp) const
float ShapeDistance(const ShapeTable &shapes, int s1, int s2)
void AddSample(bool verification, const char *unichar_str, TrainingSample *sample)
void SetFeatureSpace(const IntFeatureSpace &fs)
Definition: mastertrainer.h:82
const STRING & GetTRFileName(int index) const
bool LoadXHeights(const char *filename)
void WriteInttempAndPFFMTable(const UNICHARSET &unicharset, const UNICHARSET &shape_set, const ShapeTable &shape_table, CLASS_STRUCT *float_classes, const char *inttemp_file, const char *pffmtable_file)
void TestClassifierVOld(bool replicate_samples, ShapeClassifier *test_classifier, ShapeClassifier *old_classifier)
void DebugCanonical(const char *unichar_str1, const char *unichar_str2)
const UNICHARSET & unicharset() const
void SetupFlatShapeTable(ShapeTable *shape_table)
void TestClassifierOnSamples(CountTypes error_mode, int report_level, bool replicate_samples, ShapeClassifier *test_classifier, STRING *report_string)
void ReadTrainingSamples(const char *page_name, const FEATURE_DEFS_STRUCT &feature_defs, bool verification)
double TestClassifier(CountTypes error_mode, int report_level, bool replicate_samples, TrainingSampleSet *samples, ShapeClassifier *test_classifier, STRING *report_string)
const ShapeTable & master_shapes() const
bool AddSpacingInfo(const char *filename)
CLUSTERER * SetupForClustering(const ShapeTable &shape_table, const FEATURE_DEFS_STRUCT &feature_defs, int shape_id, int *num_samples)