tesseract 4.1.1
Loading...
Searching...
No Matches
static_shape.h
Go to the documentation of this file.
1
2// File: static_shape.h
3// Description: Defines the size of the 4-d tensor input/output from a network.
4// Author: Ray Smith
5// Created: Fri Oct 14 09:07:31 PST 2016
6//
7// (C) Copyright 2016, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18
19#ifndef TESSERACT_LSTM_STATIC_SHAPE_H_
20#define TESSERACT_LSTM_STATIC_SHAPE_H_
21
22#include "serialis.h" // for TFile
23#include "tprintf.h" // for tprintf
24
25namespace tesseract {
26
27// Enum describing the loss function to apply during training and/or the
28// decoding method to apply at runtime.
30 LT_NONE, // Undefined.
31 LT_CTC, // Softmax with standard CTC for training/decoding.
32 LT_SOFTMAX, // Outputs sum to 1 in fixed positions.
33 LT_LOGISTIC, // Logistic outputs with independent values.
34};
35
36// Simple class to hold the tensor shape that is known at network build time
37// and the LossType of the loss function.
39 public:
41 : batch_(0), height_(0), width_(0), depth_(0), loss_type_(LT_NONE) {}
42 int batch() const { return batch_; }
43 void set_batch(int value) { batch_ = value; }
44 int height() const { return height_; }
45 void set_height(int value) { height_ = value; }
46 int width() const { return width_; }
47 void set_width(int value) { width_ = value; }
48 int depth() const { return depth_; }
49 void set_depth(int value) { depth_ = value; }
50 LossType loss_type() const { return loss_type_; }
51 void set_loss_type(LossType value) { loss_type_ = value; }
52 void SetShape(int batch, int height, int width, int depth) {
53 batch_ = batch;
54 height_ = height;
55 width_ = width;
56 depth_ = depth;
57 }
58
59 void Print() const {
60 tprintf("Batch=%d, Height=%d, Width=%d, Depth=%d, loss=%d\n", batch_,
61 height_, width_, depth_, loss_type_);
62 }
63
64 bool DeSerialize(TFile *fp) {
65 int32_t tmp = LT_NONE;
66 bool result =
67 fp->DeSerialize(&batch_) &&
68 fp->DeSerialize(&height_) &&
69 fp->DeSerialize(&width_) &&
70 fp->DeSerialize(&depth_) &&
71 fp->DeSerialize(&tmp);
72 loss_type_ = static_cast<LossType>(tmp);
73 return result;
74 }
75
76 bool Serialize(TFile *fp) const {
77 int32_t tmp = loss_type_;
78 return
79 fp->Serialize(&batch_) &&
80 fp->Serialize(&height_) &&
81 fp->Serialize(&width_) &&
82 fp->Serialize(&depth_) &&
83 fp->Serialize(&tmp);
84 }
85
86 private:
87 // Size of the 4-D tensor input/output to a network. A value of zero is
88 // allowed for all except depth_ and means to be determined at runtime, and
89 // regarded as variable.
90 // Number of elements in a batch, or number of frames in a video stream.
91 int32_t batch_;
92 // Height of the image.
93 int32_t height_;
94 // Width of the image.
95 int32_t width_;
96 // Depth of the image. (Number of "nodes").
97 int32_t depth_;
98 // How to train/interpret the output.
99 LossType loss_type_;
100};
101
102} // namespace tesseract
103
104#endif // TESSERACT_LSTM_STATIC_SHAPE_H_
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35
bool Serialize(const char *data, size_t count=1)
Definition: serialis.cpp:148
bool DeSerialize(char *data, size_t count=1)
Definition: serialis.cpp:104
void set_batch(int value)
Definition: static_shape.h:43
void set_loss_type(LossType value)
Definition: static_shape.h:51
void SetShape(int batch, int height, int width, int depth)
Definition: static_shape.h:52
void set_depth(int value)
Definition: static_shape.h:49
LossType loss_type() const
Definition: static_shape.h:50
bool Serialize(TFile *fp) const
Definition: static_shape.h:76
void set_width(int value)
Definition: static_shape.h:47
void set_height(int value)
Definition: static_shape.h:45
bool DeSerialize(TFile *fp)
Definition: static_shape.h:64