tesseract 4.1.1
Loading...
Searching...
No Matches
tfnetwork.cpp
Go to the documentation of this file.
1
2// File: tfnetwork.cpp
3// Description: Encapsulation of an entire tensorflow graph as a
4// Tesseract Network.
5// Author: Ray Smith
6//
7// (C) Copyright 2016, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18#ifdef INCLUDE_TENSORFLOW
19
20#include "tfnetwork.h"
21
22#include "allheaders.h"
23#include "input.h"
24#include "networkscratch.h"
25
26using tensorflow::Status;
27using tensorflow::Tensor;
28using tensorflow::TensorShape;
29
30namespace tesseract {
31
32TFNetwork::TFNetwork(const STRING& name) : Network(NT_TENSORFLOW, name, 0, 0) {}
33
34int TFNetwork::InitFromProtoStr(const std::string& proto_str) {
35 if (!model_proto_.ParseFromString(proto_str)) return 0;
36 return InitFromProto();
37}
38
39// Writes to the given file. Returns false in case of error.
40// Should be overridden by subclasses, but called by their Serialize.
41bool TFNetwork::Serialize(TFile* fp) const {
42 if (!Network::Serialize(fp)) return false;
43 std::string proto_str;
44 model_proto_.SerializeToString(&proto_str);
46 data.resize_no_init(proto_str.size());
47 memcpy(&data[0], proto_str.data(), proto_str.size());
48 if (!data.Serialize(fp)) return false;
49 return true;
50}
51
52// Reads from the given file. Returns false in case of error.
53// Should be overridden by subclasses, but NOT called by their DeSerialize.
54bool TFNetwork::DeSerialize(TFile* fp) {
56 if (!data.DeSerialize(fp)) return false;
57 if (!model_proto_.ParseFromArray(&data[0], data.size())) {
58 return false;
59 }
60 return InitFromProto();
61}
62
63// Runs forward propagation of activations on the input line.
64// See Network for a detailed discussion of the arguments.
65void TFNetwork::Forward(bool debug, const NetworkIO& input,
66 const TransposedArray* input_transpose,
67 NetworkScratch* scratch, NetworkIO* output) {
68 std::vector<std::pair<std::string, Tensor>> tf_inputs;
69 int depth = input_shape_.depth();
70 ASSERT_HOST(depth == input.NumFeatures());
71 // TODO(rays) Allow batching. For now batch_size = 1.
72 const StrideMap& stride_map = input.stride_map();
73 // TF requires a tensor of shape float[batch, height, width, depth].
74 TensorShape shape{1, stride_map.Size(FD_HEIGHT), stride_map.Size(FD_WIDTH),
75 depth};
76 Tensor input_tensor(tensorflow::DT_FLOAT, shape);
77 // The flat() member gives a 1d array, with a data() member to get the data.
78 auto eigen_tensor = input_tensor.flat<float>();
79 memcpy(eigen_tensor.data(), input.f(0),
80 input.Width() * depth * sizeof(input.f(0)[0]));
81 // Add the tensor to the vector of inputs.
82 tf_inputs.emplace_back(model_proto_.image_input(), input_tensor);
83
84 // Provide tensors giving the width and/or height of the image if they are
85 // required. Some tf ops require a separate tensor with knowledge of the
86 // size of the input as they cannot obtain it from the input tensor. This is
87 // usually true in the case of ops that process a batch of variable-sized
88 // objects.
89 if (!model_proto_.image_widths().empty()) {
90 TensorShape size_shape{1};
91 Tensor width_tensor(tensorflow::DT_INT64, size_shape);
92 auto eigen_wtensor = width_tensor.flat<tensorflow::int64>();
93 *eigen_wtensor.data() = stride_map.Size(FD_WIDTH);
94 tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor);
95 }
96 if (!model_proto_.image_heights().empty()) {
97 TensorShape size_shape{1};
98 Tensor height_tensor(tensorflow::DT_INT64, size_shape);
99 auto eigen_htensor = height_tensor.flat<tensorflow::int64>();
100 *eigen_htensor.data() = stride_map.Size(FD_HEIGHT);
101 tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor);
102 }
103 std::vector<std::string> target_layers = {model_proto_.output_layer()};
104 std::vector<Tensor> outputs;
105 Status s = session_->Run(tf_inputs, target_layers, {}, &outputs);
106 if (!s.ok()) tprintf("session->Run failed:%s\n", s.error_message().c_str());
107 ASSERT_HOST(s.ok());
108 ASSERT_HOST(outputs.size() == 1);
109 const Tensor& output_tensor = outputs[0];
110 // Check the dimensions of the output.
111 ASSERT_HOST(output_tensor.shape().dims() == 3);
112 int output_batch = output_tensor.shape().dim_size(0);
113 int output_steps = output_tensor.shape().dim_size(1);
114 int output_depth = output_tensor.shape().dim_size(2);
115 ASSERT_HOST(output_batch == 1);
116 ASSERT_HOST(output_depth == output_shape_.depth());
117 output->Resize2d(false, output_steps, output_depth);
118 auto eigen_output = output_tensor.flat<float>();
119 memcpy(output->f(0), eigen_output.data(),
120 output_steps * output_depth * sizeof(output->f(0)[0]));
121}
122
123int TFNetwork::InitFromProto() {
124 spec_ = model_proto_.spec();
125 input_shape_.SetShape(
126 model_proto_.batch_size(), std::max(0, model_proto_.y_size()),
127 std::max(0, model_proto_.x_size()), model_proto_.depth());
128 output_shape_.SetShape(model_proto_.batch_size(), 1, 0,
129 model_proto_.num_classes());
130 output_shape_.set_loss_type(model_proto_.using_ctc() ? LT_CTC : LT_SOFTMAX);
131 ni_ = input_shape_.height();
132 no_ = output_shape_.depth();
133 // Initialize the session_ with the graph. Since we can't get the graph
134 // back from the session_, we have to keep the proto as well
135 tensorflow::SessionOptions options;
136 session_.reset(NewSession(options));
137 Status s = session_->Create(model_proto_.graph());
138 if (s.ok()) return model_proto_.global_step();
139 tprintf("Session_->Create returned '%s'\n", s.error_message().c_str());
140 return 0;
141}
142
143} // namespace tesseract
144
145#endif // ifdef INCLUDE_TENSORFLOW
#define ASSERT_HOST(x)
Definition: errcode.h:88
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35
@ NT_TENSORFLOW
Definition: network.h:78
void resize_no_init(int size)
Definition: genericvector.h:66
bool Serialize(FILE *fp) const
int size() const
Definition: genericvector.h:72
bool DeSerialize(bool swap, FILE *fp)
Definition: strngs.h:45