tesseract 4.1.1
Loading...
Searching...
No Matches
tesseract::NetworkBuilder Class Reference

#include <networkbuilder.h>

Public Member Functions

 NetworkBuilder (int num_softmax_outputs)
 
NetworkBuildFromString (const StaticShape &input_shape, char **str)
 

Static Public Member Functions

static bool InitNetwork (int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
 

Detailed Description

Definition at line 36 of file networkbuilder.h.

Constructor & Destructor Documentation

◆ NetworkBuilder()

tesseract::NetworkBuilder::NetworkBuilder ( int  num_softmax_outputs)
inlineexplicit

Definition at line 38 of file networkbuilder.h.

39 : num_softmax_outputs_(num_softmax_outputs) {}

Member Function Documentation

◆ BuildFromString()

Network * tesseract::NetworkBuilder::BuildFromString ( const StaticShape input_shape,
char **  str 
)

Definition at line 86 of file networkbuilder.cpp.

87 {
88 SkipWhitespace(str);
89 char code_ch = **str;
90 if (code_ch == '[') {
91 return ParseSeries(input_shape, nullptr, str);
92 }
93 if (input_shape.depth() == 0) {
94 // There must be an input at this point.
95 return ParseInput(str);
96 }
97 switch (code_ch) {
98 case '(':
99 return ParseParallel(input_shape, str);
100 case 'R':
101 return ParseR(input_shape, str);
102 case 'S':
103 return ParseS(input_shape, str);
104 case 'C':
105 return ParseC(input_shape, str);
106 case 'M':
107 return ParseM(input_shape, str);
108 case 'L':
109 return ParseLSTM(input_shape, str);
110 case 'F':
111 return ParseFullyConnected(input_shape, str);
112 case 'O':
113 return ParseOutput(input_shape, str);
114 default:
115 tprintf("Invalid network spec:%s\n", *str);
116 return nullptr;
117 }
118 return nullptr;
119}
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:35

◆ InitNetwork()

bool tesseract::NetworkBuilder::InitNetwork ( int  num_outputs,
STRING  network_spec,
int  append_index,
int  net_flags,
float  weight_range,
TRand randomizer,
Network **  network 
)
static

Definition at line 45 of file networkbuilder.cpp.

48 {
49 NetworkBuilder builder(num_outputs);
50 Series* bottom_series = nullptr;
51 StaticShape input_shape;
52 if (append_index >= 0) {
53 // Split the current network after the given append_index.
54 ASSERT_HOST(*network != nullptr && (*network)->type() == NT_SERIES);
55 auto* series = static_cast<Series*>(*network);
56 Series* top_series = nullptr;
57 series->SplitAt(append_index, &bottom_series, &top_series);
58 if (bottom_series == nullptr || top_series == nullptr) {
59 tprintf("Yikes! Splitting current network failed!!\n");
60 return false;
61 }
62 input_shape = bottom_series->OutputShape(input_shape);
63 delete top_series;
64 }
65 char* str_ptr = &network_spec[0];
66 *network = builder.BuildFromString(input_shape, &str_ptr);
67 if (*network == nullptr) return false;
68 (*network)->SetNetworkFlags(net_flags);
69 (*network)->InitWeights(weight_range, randomizer);
70 (*network)->SetupNeedsBackprop(false);
71 if (bottom_series != nullptr) {
72 bottom_series->AppendSeries(*network);
73 *network = bottom_series;
74 }
75 (*network)->CacheXScaleFactor((*network)->XScaleFactor());
76 return true;
77}
#define ASSERT_HOST(x)
Definition: errcode.h:88
@ NT_SERIES
Definition: network.h:54
NetworkBuilder(int num_softmax_outputs)

The documentation for this class was generated from the following files: