tesseract 4.1.1
Loading...
Searching...
No Matches
ctc.cpp
Go to the documentation of this file.
1
2// File: ctc.cpp
3// Description: Slightly improved standard CTC to compute the targets.
4// Author: Ray Smith
5// Created: Wed Jul 13 15:50:06 PDT 2016
6//
7// (C) Copyright 2016, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18#include "ctc.h"
19
20#include <algorithm>
21#include <cfloat> // for FLT_MAX
22#include <memory>
23
24#include "genericvector.h"
25#include "matrix.h"
26#include "networkio.h"
27
28#include "network.h"
29#include "scrollview.h"
30
31namespace tesseract {
32
33// Magic constants that keep CTC stable.
34// Minimum probability limit for softmax input to ctc_loss.
35const float CTC::kMinProb_ = 1e-12;
36// Maximum absolute argument to exp().
37const double CTC::kMaxExpArg_ = 80.0;
38// Minimum probability for total prob in time normalization.
39const double CTC::kMinTotalTimeProb_ = 1e-8;
40// Minimum probability for total prob in final normalization.
41const double CTC::kMinTotalFinalProb_ = 1e-6;
42
43// Builds a target using CTC. Slightly improved as follows:
44// Includes normalizations and clipping for stability.
45// labels should be pre-padded with nulls everywhere.
46// labels can be longer than the time sequence, but the total number of
47// essential labels (non-null plus nulls between equal labels) must not exceed
48// the number of timesteps in outputs.
49// outputs is the output of the network, and should have already been
50// normalized with NormalizeProbs.
51// On return targets is filled with the computed targets.
52// Returns false if there is insufficient time for the labels.
53/* static */
54bool CTC::ComputeCTCTargets(const GenericVector<int>& labels, int null_char,
55 const GENERIC_2D_ARRAY<float>& outputs,
56 NetworkIO* targets) {
57 std::unique_ptr<CTC> ctc(new CTC(labels, null_char, outputs));
58 if (!ctc->ComputeLabelLimits()) {
59 return false; // Not enough time.
60 }
61 // Generate simple targets purely from the truth labels by spreading them
62 // evenly over time.
63 GENERIC_2D_ARRAY<float> simple_targets;
64 ctc->ComputeSimpleTargets(&simple_targets);
65 // Add the simple targets as a starter bias to the network outputs.
66 float bias_fraction = ctc->CalculateBiasFraction();
67 simple_targets *= bias_fraction;
68 ctc->outputs_ += simple_targets;
69 NormalizeProbs(&ctc->outputs_);
70 // Run regular CTC on the biased outputs.
71 // Run forward and backward
72 GENERIC_2D_ARRAY<double> log_alphas, log_betas;
73 ctc->Forward(&log_alphas);
74 ctc->Backward(&log_betas);
75 // Normalize and come out of log space with a clipped softmax over time.
76 log_alphas += log_betas;
77 ctc->NormalizeSequence(&log_alphas);
78 ctc->LabelsToClasses(log_alphas, targets);
79 NormalizeProbs(targets);
80 return true;
81}
82
83CTC::CTC(const GenericVector<int>& labels, int null_char,
84 const GENERIC_2D_ARRAY<float>& outputs)
85 : labels_(labels), outputs_(outputs), null_char_(null_char) {
86 num_timesteps_ = outputs.dim1();
87 num_classes_ = outputs.dim2();
88 num_labels_ = labels_.size();
89}
90
91// Computes vectors of min and max label index for each timestep, based on
92// whether skippability of nulls makes it possible to complete a valid path.
93bool CTC::ComputeLabelLimits() {
94 min_labels_.init_to_size(num_timesteps_, 0);
95 max_labels_.init_to_size(num_timesteps_, 0);
96 int min_u = num_labels_ - 1;
97 if (labels_[min_u] == null_char_) --min_u;
98 for (int t = num_timesteps_ - 1; t >= 0; --t) {
99 min_labels_[t] = min_u;
100 if (min_u > 0) {
101 --min_u;
102 if (labels_[min_u] == null_char_ && min_u > 0 &&
103 labels_[min_u + 1] != labels_[min_u - 1]) {
104 --min_u;
105 }
106 }
107 }
108 int max_u = labels_[0] == null_char_;
109 for (int t = 0; t < num_timesteps_; ++t) {
110 max_labels_[t] = max_u;
111 if (max_labels_[t] < min_labels_[t]) return false; // Not enough room.
112 if (max_u + 1 < num_labels_) {
113 ++max_u;
114 if (labels_[max_u] == null_char_ && max_u + 1 < num_labels_ &&
115 labels_[max_u + 1] != labels_[max_u - 1]) {
116 ++max_u;
117 }
118 }
119 }
120 return true;
121}
122
123// Computes targets based purely on the labels by spreading the labels evenly
124// over the available timesteps.
125void CTC::ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const {
126 // Initialize all targets to zero.
127 targets->Resize(num_timesteps_, num_classes_, 0.0f);
128 GenericVector<float> half_widths;
129 GenericVector<int> means;
130 ComputeWidthsAndMeans(&half_widths, &means);
131 for (int l = 0; l < num_labels_; ++l) {
132 int label = labels_[l];
133 float left_half_width = half_widths[l];
134 float right_half_width = left_half_width;
135 int mean = means[l];
136 if (label == null_char_) {
137 if (!NeededNull(l)) {
138 if ((l > 0 && mean == means[l - 1]) ||
139 (l + 1 < num_labels_ && mean == means[l + 1])) {
140 continue; // Drop overlapping null.
141 }
142 }
143 // Make sure that no space is left unoccupied and that non-nulls always
144 // peak at 1 by stretching nulls to meet their neighbors.
145 if (l > 0) left_half_width = mean - means[l - 1];
146 if (l + 1 < num_labels_) right_half_width = means[l + 1] - mean;
147 }
148 if (mean >= 0 && mean < num_timesteps_) targets->put(mean, label, 1.0f);
149 for (int offset = 1; offset < left_half_width && mean >= offset; ++offset) {
150 float prob = 1.0f - offset / left_half_width;
151 if (mean - offset < num_timesteps_ &&
152 prob > targets->get(mean - offset, label)) {
153 targets->put(mean - offset, label, prob);
154 }
155 }
156 for (int offset = 1;
157 offset < right_half_width && mean + offset < num_timesteps_;
158 ++offset) {
159 float prob = 1.0f - offset / right_half_width;
160 if (mean + offset >= 0 && prob > targets->get(mean + offset, label)) {
161 targets->put(mean + offset, label, prob);
162 }
163 }
164 }
165}
166
167// Computes mean positions and half widths of the simple targets by spreading
168// the labels evenly over the available timesteps.
169void CTC::ComputeWidthsAndMeans(GenericVector<float>* half_widths,
170 GenericVector<int>* means) const {
171 // Count the number of labels of each type, in regexp terms, counts plus
172 // (non-null or necessary null, which must occur at least once) and star
173 // (optional null).
174 int num_plus = 0, num_star = 0;
175 for (int i = 0; i < num_labels_; ++i) {
176 if (labels_[i] != null_char_ || NeededNull(i))
177 ++num_plus;
178 else
179 ++num_star;
180 }
181 // Compute the size for each type. If there is enough space for everything
182 // to have size>=1, then all are equal, otherwise plus_size=1 and star gets
183 // whatever is left-over.
184 float plus_size = 1.0f, star_size = 0.0f;
185 float total_floating = num_plus + num_star;
186 if (total_floating <= num_timesteps_) {
187 plus_size = star_size = num_timesteps_ / total_floating;
188 } else if (num_star > 0) {
189 star_size = static_cast<float>(num_timesteps_ - num_plus) / num_star;
190 }
191 // Set the width and compute the mean of each.
192 float mean_pos = 0.0f;
193 for (int i = 0; i < num_labels_; ++i) {
194 float half_width;
195 if (labels_[i] != null_char_ || NeededNull(i)) {
196 half_width = plus_size / 2.0f;
197 } else {
198 half_width = star_size / 2.0f;
199 }
200 mean_pos += half_width;
201 means->push_back(static_cast<int>(mean_pos));
202 mean_pos += half_width;
203 half_widths->push_back(half_width);
204 }
205}
206
207// Helper returns the index of the highest probability label at timestep t.
208static int BestLabel(const GENERIC_2D_ARRAY<float>& outputs, int t) {
209 int result = 0;
210 int num_classes = outputs.dim2();
211 const float* outputs_t = outputs[t];
212 for (int c = 1; c < num_classes; ++c) {
213 if (outputs_t[c] > outputs_t[result]) result = c;
214 }
215 return result;
216}
217
218// Calculates and returns a suitable fraction of the simple targets to add
219// to the network outputs.
220float CTC::CalculateBiasFraction() {
221 // Compute output labels via basic decoding.
222 GenericVector<int> output_labels;
223 for (int t = 0; t < num_timesteps_; ++t) {
224 int label = BestLabel(outputs_, t);
225 while (t + 1 < num_timesteps_ && BestLabel(outputs_, t + 1) == label) ++t;
226 if (label != null_char_) output_labels.push_back(label);
227 }
228 // Simple bag of labels error calculation.
229 GenericVector<int> truth_counts(num_classes_, 0);
230 GenericVector<int> output_counts(num_classes_, 0);
231 for (int l = 0; l < num_labels_; ++l) {
232 ++truth_counts[labels_[l]];
233 }
234 for (int l = 0; l < output_labels.size(); ++l) {
235 ++output_counts[output_labels[l]];
236 }
237 // Count the number of true and false positive non-nulls and truth labels.
238 int true_pos = 0, false_pos = 0, total_labels = 0;
239 for (int c = 0; c < num_classes_; ++c) {
240 if (c == null_char_) continue;
241 int truth_count = truth_counts[c];
242 int ocr_count = output_counts[c];
243 if (truth_count > 0) {
244 total_labels += truth_count;
245 if (ocr_count > truth_count) {
246 true_pos += truth_count;
247 false_pos += ocr_count - truth_count;
248 } else {
249 true_pos += ocr_count;
250 }
251 }
252 // We don't need to count classes that don't exist in the truth as
253 // false positives, because they don't affect CTC at all.
254 }
255 if (total_labels == 0) return 0.0f;
256 return exp(std::max(true_pos - false_pos, 1) * log(kMinProb_) / total_labels);
257}
258
259// Given ln(x) and ln(y), returns ln(x + y), using:
260// ln(x + y) = ln(y) + ln(1 + exp(ln(y) - ln(x)), ensuring that ln(x) is the
261// bigger number to maximize precision.
262static double LogSumExp(double ln_x, double ln_y) {
263 if (ln_x >= ln_y) {
264 return ln_x + log1p(exp(ln_y - ln_x));
265 } else {
266 return ln_y + log1p(exp(ln_x - ln_y));
267 }
268}
269
270// Runs the forward CTC pass, filling in log_probs.
271void CTC::Forward(GENERIC_2D_ARRAY<double>* log_probs) const {
272 log_probs->Resize(num_timesteps_, num_labels_, -FLT_MAX);
273 log_probs->put(0, 0, log(outputs_(0, labels_[0])));
274 if (labels_[0] == null_char_)
275 log_probs->put(0, 1, log(outputs_(0, labels_[1])));
276 for (int t = 1; t < num_timesteps_; ++t) {
277 const float* outputs_t = outputs_[t];
278 for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
279 // Continuing the same label.
280 double log_sum = log_probs->get(t - 1, u);
281 // Change from previous label.
282 if (u > 0) {
283 log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 1));
284 }
285 // Skip the null if allowed.
286 if (u >= 2 && labels_[u - 1] == null_char_ &&
287 labels_[u] != labels_[u - 2]) {
288 log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 2));
289 }
290 // Add in the log prob of the current label.
291 double label_prob = outputs_t[labels_[u]];
292 log_sum += log(label_prob);
293 log_probs->put(t, u, log_sum);
294 }
295 }
296}
297
298// Runs the backward CTC pass, filling in log_probs.
299void CTC::Backward(GENERIC_2D_ARRAY<double>* log_probs) const {
300 log_probs->Resize(num_timesteps_, num_labels_, -FLT_MAX);
301 log_probs->put(num_timesteps_ - 1, num_labels_ - 1, 0.0);
302 if (labels_[num_labels_ - 1] == null_char_)
303 log_probs->put(num_timesteps_ - 1, num_labels_ - 2, 0.0);
304 for (int t = num_timesteps_ - 2; t >= 0; --t) {
305 const float* outputs_tp1 = outputs_[t + 1];
306 for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
307 // Continuing the same label.
308 double log_sum = log_probs->get(t + 1, u) + log(outputs_tp1[labels_[u]]);
309 // Change from previous label.
310 if (u + 1 < num_labels_) {
311 double prev_prob = outputs_tp1[labels_[u + 1]];
312 log_sum =
313 LogSumExp(log_sum, log_probs->get(t + 1, u + 1) + log(prev_prob));
314 }
315 // Skip the null if allowed.
316 if (u + 2 < num_labels_ && labels_[u + 1] == null_char_ &&
317 labels_[u] != labels_[u + 2]) {
318 double skip_prob = outputs_tp1[labels_[u + 2]];
319 log_sum =
320 LogSumExp(log_sum, log_probs->get(t + 1, u + 2) + log(skip_prob));
321 }
322 log_probs->put(t, u, log_sum);
323 }
324 }
325}
326
327// Normalizes and brings probs out of log space with a softmax over time.
328void CTC::NormalizeSequence(GENERIC_2D_ARRAY<double>* probs) const {
329 double max_logprob = probs->Max();
330 for (int u = 0; u < num_labels_; ++u) {
331 double total = 0.0;
332 for (int t = 0; t < num_timesteps_; ++t) {
333 // Separate impossible path from unlikely probs.
334 double prob = probs->get(t, u);
335 if (prob > -FLT_MAX)
336 prob = ClippedExp(prob - max_logprob);
337 else
338 prob = 0.0;
339 total += prob;
340 probs->put(t, u, prob);
341 }
342 // Note that although this is a probability distribution over time and
343 // therefore should sum to 1, it is important to allow some labels to be
344 // all zero, (or at least tiny) as it is necessary to skip some blanks.
345 if (total < kMinTotalTimeProb_) total = kMinTotalTimeProb_;
346 for (int t = 0; t < num_timesteps_; ++t)
347 probs->put(t, u, probs->get(t, u) / total);
348 }
349}
350
351// For each timestep computes the max prob for each class over all
352// instances of the class in the labels_, and sets the targets to
353// the max observed prob.
354void CTC::LabelsToClasses(const GENERIC_2D_ARRAY<double>& probs,
355 NetworkIO* targets) const {
356 // For each timestep compute the max prob for each class over all
357 // instances of the class in the labels_.
358 GenericVector<double> class_probs;
359 for (int t = 0; t < num_timesteps_; ++t) {
360 float* targets_t = targets->f(t);
361 class_probs.init_to_size(num_classes_, 0.0);
362 for (int u = 0; u < num_labels_; ++u) {
363 double prob = probs(t, u);
364 // Note that although Graves specifies sum over all labels of the same
365 // class, we need to allow skipped blanks to go to zero, so they don't
366 // interfere with the non-blanks, so max is better than sum.
367 if (prob > class_probs[labels_[u]]) class_probs[labels_[u]] = prob;
368 // class_probs[labels_[u]] += prob;
369 }
370 int best_class = 0;
371 for (int c = 0; c < num_classes_; ++c) {
372 targets_t[c] = class_probs[c];
373 if (class_probs[c] > class_probs[best_class]) best_class = c;
374 }
375 }
376}
377
378// Normalizes the probabilities such that no target has a prob below min_prob,
379// and, provided that the initial total is at least min_total_prob, then all
380// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
381// probability is thus 1 - (num_classes-1)*min_prob.
382/* static */
384 int num_timesteps = probs->dim1();
385 int num_classes = probs->dim2();
386 for (int t = 0; t < num_timesteps; ++t) {
387 float* probs_t = (*probs)[t];
388 // Compute the total and clip that to prevent amplification of noise.
389 double total = 0.0;
390 for (int c = 0; c < num_classes; ++c) total += probs_t[c];
391 if (total < kMinTotalFinalProb_) total = kMinTotalFinalProb_;
392 // Compute the increased total as a result of clipping.
393 double increment = 0.0;
394 for (int c = 0; c < num_classes; ++c) {
395 double prob = probs_t[c] / total;
396 if (prob < kMinProb_) increment += kMinProb_ - prob;
397 }
398 // Now normalize with clipping. Any additional clipping is negligible.
399 total += increment;
400 for (int c = 0; c < num_classes; ++c) {
401 float prob = probs_t[c] / total;
402 probs_t[c] = std::max(prob, kMinProb_);
403 }
404 }
405}
406
407// Returns true if the label at index is a needed null.
408bool CTC::NeededNull(int index) const {
409 return labels_[index] == null_char_ && index > 0 && index + 1 < num_labels_ &&
410 labels_[index + 1] == labels_[index - 1];
411}
412
413} // namespace tesseract
void init_to_size(int size, const T &t)
int push_back(T object)
int size() const
Definition: genericvector.h:72
int dim2() const
Definition: matrix.h:210
T Max() const
Definition: matrix.h:345
T get(ICOORD pos) const
Definition: matrix.h:231
void put(ICOORD pos, const T &thing)
Definition: matrix.h:223
void Resize(int size1, int size2, const T &empty)
Definition: matrix.h:108
int dim1() const
Definition: matrix.h:209
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition: ctc.cpp:54
static void NormalizeProbs(NetworkIO *probs)
Definition: ctc.h:36