networkbuilder.h 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: networkbuilder.h
  3. // Description: Class to parse the network description language and
  4. // build a corresponding network.
  5. // Author: Ray Smith
  6. // Created: Wed Jul 16 18:35:38 PST 2014
  7. //
  8. // (C) Copyright 2014, Google Inc.
  9. // Licensed under the Apache License, Version 2.0 (the "License");
  10. // you may not use this file except in compliance with the License.
  11. // You may obtain a copy of the License at
  12. // http://www.apache.org/licenses/LICENSE-2.0
  13. // Unless required by applicable law or agreed to in writing, software
  14. // distributed under the License is distributed on an "AS IS" BASIS,
  15. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. // See the License for the specific language governing permissions and
  17. // limitations under the License.
  18. ///////////////////////////////////////////////////////////////////////
  19. #ifndef TESSERACT_LSTM_NETWORKBUILDER_H_
  20. #define TESSERACT_LSTM_NETWORKBUILDER_H_
  21. #include "static_shape.h"
  22. #include "stridemap.h"
  23. class STRING;
  24. class UNICHARSET;
  25. namespace tesseract {
  26. class Input;
  27. class Network;
  28. class Parallel;
  29. class TRand;
  30. class NetworkBuilder {
  31. public:
  32. explicit NetworkBuilder(int num_softmax_outputs)
  33. : num_softmax_outputs_(num_softmax_outputs) {}
  34. // Builds a network with a network_spec in the network description
  35. // language, to recognize a character set of num_outputs size.
  36. // If append_index is non-negative, then *network must be non-null and the
  37. // given network_spec will be appended to *network AFTER append_index, with
  38. // the top of the input *network discarded.
  39. // Note that network_spec is call by value to allow a non-const char* pointer
  40. // into the string for BuildFromString.
  41. // net_flags control network behavior according to the NetworkFlags enum.
  42. // The resulting network is returned via **network.
  43. // Returns false if something failed.
  44. static bool InitNetwork(int num_outputs, STRING network_spec,
  45. int append_index, int net_flags, float weight_range,
  46. TRand* randomizer, Network** network);
  47. // Parses the given string and returns a network according to the following
  48. // language:
  49. // ============ Syntax of description below: ============
  50. // <d> represents a number.
  51. // <net> represents any single network element, including (recursively) a
  52. // [...] series or (...) parallel construct.
  53. // (s|t|r|l|m) (regex notation) represents a single required letter.
  54. // NOTE THAT THROUGHOUT, x and y are REVERSED from conventional mathematics,
  55. // to use the same convention as Tensor Flow. The reason TF adopts this
  56. // convention is to eliminate the need to transpose images on input, since
  57. // adjacent memory locations in images increase x and then y, while adjacent
  58. // memory locations in tensors in TF, and NetworkIO in tesseract increase the
  59. // rightmost index first, then the next-left and so-on, like C arrays.
  60. // ============ INPUTS ============
  61. // <b>,<h>,<w>,<d> A batch of b images with height h, width w, and depth d.
  62. // b, h and/or w may be zero, to indicate variable size. Some network layer
  63. // (summarizing LSTM) must be used to make a variable h known.
  64. // d may be 1 for greyscale, 3 for color.
  65. // NOTE that throughout the constructed network, the inputs/outputs are all of
  66. // the same [batch,height,width,depth] dimensions, even if a different size.
  67. // ============ PLUMBING ============
  68. // [...] Execute ... networks in series (layers).
  69. // (...) Execute ... networks in parallel, with their output depths added.
  70. // R<d><net> Execute d replicas of net in parallel, with their output depths
  71. // added.
  72. // Rx<net> Execute <net> with x-dimension reversal.
  73. // Ry<net> Execute <net> with y-dimension reversal.
  74. // S<y>,<x> Rescale 2-D input by shrink factor x,y, rearranging the data by
  75. // increasing the depth of the input by factor xy.
  76. // Mp<y>,<x> Maxpool the input, reducing the size by an (x,y) rectangle.
  77. // ============ FUNCTIONAL UNITS ============
  78. // C(s|t|r|l|m)<y>,<x>,<d> Convolves using a (x,y) window, with no shrinkage,
  79. // random infill, producing d outputs, then applies a non-linearity:
  80. // s: Sigmoid, t: Tanh, r: Relu, l: Linear, m: Softmax.
  81. // F(s|t|r|l|m)<d> Truly fully-connected with s|t|r|l|m non-linearity and d
  82. // outputs. Connects to every x,y,depth position of the input, reducing
  83. // height, width to 1, producing a single <d> vector as the output.
  84. // Input height and width must be constant.
  85. // For a sliding-window linear or non-linear map that connects just to the
  86. // input depth, and leaves the input image size as-is, use a 1x1 convolution
  87. // eg. Cr1,1,64 instead of Fr64.
  88. // L(f|r|b)(x|y)[s]<n> LSTM cell with n states/outputs.
  89. // The LSTM must have one of:
  90. // f runs the LSTM forward only.
  91. // r runs the LSTM reversed only.
  92. // b runs the LSTM bidirectionally.
  93. // It will operate on either the x- or y-dimension, treating the other
  94. // dimension independently (as if part of the batch).
  95. // s (optional) summarizes the output in the requested dimension,
  96. // outputting only the final step, collapsing the dimension to a
  97. // single element.
  98. // LS<n> Forward-only LSTM cell in the x-direction, with built-in Softmax.
  99. // LE<n> Forward-only LSTM cell in the x-direction, with built-in softmax,
  100. // with binary Encoding.
  101. // L2xy<n> Full 2-d LSTM operating in quad-directions (bidi in x and y) and
  102. // all the output depths added.
  103. // ============ OUTPUTS ============
  104. // The network description must finish with an output specification:
  105. // O(2|1|0)(l|s|c)<n> output layer with n classes
  106. // 2 (heatmap) Output is a 2-d vector map of the input (possibly at
  107. // different scale).
  108. // 1 (sequence) Output is a 1-d sequence of vector values.
  109. // 0 (category) Output is a 0-d single vector value.
  110. // l uses a logistic non-linearity on the output, allowing multiple
  111. // hot elements in any output vector value.
  112. // s uses a softmax non-linearity, with one-hot output in each value.
  113. // c uses a softmax with CTC. Can only be used with s (sequence).
  114. // NOTE1: Only O1s and O1c are currently supported.
  115. // NOTE2: n is totally ignored, and for compatibility purposes only. The
  116. // output number of classes is obtained automatically from the
  117. // unicharset.
  118. Network* BuildFromString(const StaticShape& input_shape, char** str);
  119. private:
  120. // Parses an input specification and returns the result, which may include a
  121. // series.
  122. Network* ParseInput(char** str);
  123. // Parses a sequential series of networks, defined by [<net><net>...].
  124. Network* ParseSeries(const StaticShape& input_shape, Input* input_layer,
  125. char** str);
  126. // Parses a parallel set of networks, defined by (<net><net>...).
  127. Network* ParseParallel(const StaticShape& input_shape, char** str);
  128. // Parses a network that begins with 'R'.
  129. Network* ParseR(const StaticShape& input_shape, char** str);
  130. // Parses a network that begins with 'S'.
  131. Network* ParseS(const StaticShape& input_shape, char** str);
  132. // Parses a network that begins with 'C'.
  133. Network* ParseC(const StaticShape& input_shape, char** str);
  134. // Parses a network that begins with 'M'.
  135. Network* ParseM(const StaticShape& input_shape, char** str);
  136. // Parses an LSTM network, either individual, bi- or quad-directional.
  137. Network* ParseLSTM(const StaticShape& input_shape, char** str);
  138. // Builds a set of 4 lstms with t and y reversal, running in true parallel.
  139. static Network* BuildLSTMXYQuad(int num_inputs, int num_states);
  140. // Parses a Fully connected network.
  141. Network* ParseFullyConnected(const StaticShape& input_shape, char** str);
  142. // Parses an Output spec.
  143. Network* ParseOutput(const StaticShape& input_shape, char** str);
  144. private:
  145. int num_softmax_outputs_;
  146. };
  147. } // namespace tesseract.
  148. #endif // TESSERACT_LSTM_NETWORKBUILDER_H_