tesseract 4.1.1
Loading...
Searching...
No Matches
tfnetwork.h
Go to the documentation of this file.
1
2// File: tfnetwork.h
3// Description: Encapsulation of an entire tensorflow graph as a
4// Tesseract Network.
5// Author: Ray Smith
6//
7// (C) Copyright 2016, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18
19#ifndef TESSERACT_LSTM_TFNETWORK_H_
20#define TESSERACT_LSTM_TFNETWORK_H_
21
22#ifdef INCLUDE_TENSORFLOW
23
24#include <memory>
25#include <string>
26
27#include "network.h"
28#include "static_shape.h"
29#include "tfnetwork.pb.h"
30#include "tensorflow/core/framework/graph.pb.h"
31#include "tensorflow/core/public/session.h"
32
33namespace tesseract {
34
35class TFNetwork : public Network {
36 public:
37 explicit TFNetwork(const STRING& name);
38 virtual ~TFNetwork() = default;
39
40 // Returns the required shape input to the network.
41 StaticShape InputShape() const override { return input_shape_; }
42 // Returns the shape output from the network given an input shape (which may
43 // be partially unknown ie zero).
44 StaticShape OutputShape(const StaticShape& input_shape) const override {
45 return output_shape_;
46 }
47
48 STRING spec() const override { return spec_.c_str(); }
49
50 // Deserializes *this from a serialized TFNetwork proto. Returns 0 if failed,
51 // otherwise the global step of the serialized graph.
52 int InitFromProtoStr(const std::string& proto_str);
53 // The number of classes in this network should be equal to those in the
54 // recoder_ in LSTMRecognizer.
55 int num_classes() const { return output_shape_.depth(); }
56
57 // Writes to the given file. Returns false in case of error.
58 // Should be overridden by subclasses, but called by their Serialize.
59 bool Serialize(TFile* fp) const override;
60 // Reads from the given file. Returns false in case of error.
61 // Should be overridden by subclasses, but NOT called by their DeSerialize.
62 bool DeSerialize(TFile* fp) override;
63
64 // Runs forward propagation of activations on the input line.
65 // See Network for a detailed discussion of the arguments.
66 void Forward(bool debug, const NetworkIO& input,
67 const TransposedArray* input_transpose,
68 NetworkScratch* scratch, NetworkIO* output) override;
69
70 private:
71 // Runs backward propagation of errors on the deltas line.
72 // See Network for a detailed discussion of the arguments.
73 bool Backward(bool debug, const NetworkIO& fwd_deltas,
74 NetworkScratch* scratch,
75 NetworkIO* back_deltas) override {
76 tprintf("Must override Network::Backward for type %d\n", type_);
77 return false;
78 }
79
80 void DebugWeights() override {
81 tprintf("Must override Network::DebugWeights for type %d\n", type_);
82 }
83
84 int InitFromProto();
85
86 // The original network definition for reference.
87 std::string spec_;
88 // Input tensor parameters.
89 StaticShape input_shape_;
90 // Output tensor parameters.
91 StaticShape output_shape_;
92 // The tensor flow graph is contained in here.
93 std::unique_ptr<tensorflow::Session> session_;
94 // The serialized graph is also contained in here.
95 TFNetworkModel model_proto_;
96};
97
98} // namespace tesseract.
99
100#endif // ifdef INCLUDE_TENSORFLOW
101
102#endif // TESSERACT_TENSORFLOW_TFNETWORK_H_
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35
bool DeSerialize(FILE *fp, char *data, size_t n)
Definition: serialis.cpp:28
bool Serialize(FILE *fp, const char *data, size_t n)
Definition: serialis.cpp:60
Definition: strngs.h:45
const char * c_str() const
Definition: strngs.cpp:205