network.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: network.h
  3. // Description: Base class for neural network implementations.
  4. // Author: Ray Smith
  5. //
  6. // (C) Copyright 2013, Google Inc.
  7. // Licensed under the Apache License, Version 2.0 (the "License");
  8. // you may not use this file except in compliance with the License.
  9. // You may obtain a copy of the License at
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. ///////////////////////////////////////////////////////////////////////
  17. #ifndef TESSERACT_LSTM_NETWORK_H_
  18. #define TESSERACT_LSTM_NETWORK_H_
  19. #include <cstdio>
  20. #include <cmath>
  21. #include "genericvector.h"
  22. #include "helpers.h"
  23. #include "matrix.h"
  24. #include "networkio.h"
  25. #include "serialis.h"
  26. #include "static_shape.h"
  27. #include "tprintf.h"
  28. struct Pix;
  29. class ScrollView;
  30. class TBOX;
  31. namespace tesseract {
  32. class ImageData;
  33. class NetworkScratch;
  34. // Enum to store the run-time type of a Network. Keep in sync with kTypeNames.
  35. enum NetworkType {
  36. NT_NONE, // The naked base class.
  37. NT_INPUT, // Inputs from an image.
  38. // Plumbing networks combine other networks or rearrange the inputs.
  39. NT_CONVOLVE, // Duplicates inputs in a sliding window neighborhood.
  40. NT_MAXPOOL, // Chooses the max result from a rectangle.
  41. NT_PARALLEL, // Runs networks in parallel.
  42. NT_REPLICATED, // Runs identical networks in parallel.
  43. NT_PAR_RL_LSTM, // Runs LTR and RTL LSTMs in parallel.
  44. NT_PAR_UD_LSTM, // Runs Up and Down LSTMs in parallel.
  45. NT_PAR_2D_LSTM, // Runs 4 LSTMs in parallel.
  46. NT_SERIES, // Executes a sequence of layers.
  47. NT_RECONFIG, // Scales the time/y size but makes the output deeper.
  48. NT_XREVERSED, // Reverses the x direction of the inputs/outputs.
  49. NT_YREVERSED, // Reverses the y-direction of the inputs/outputs.
  50. NT_XYTRANSPOSE, // Transposes x and y (for just a single op).
  51. // Functional networks actually calculate stuff.
  52. NT_LSTM, // Long-Short-Term-Memory block.
  53. NT_LSTM_SUMMARY, // LSTM that only keeps its last output.
  54. NT_LOGISTIC, // Fully connected logistic nonlinearity.
  55. NT_POSCLIP, // Fully connected rect lin version of logistic.
  56. NT_SYMCLIP, // Fully connected rect lin version of tanh.
  57. NT_TANH, // Fully connected with tanh nonlinearity.
  58. NT_RELU, // Fully connected with rectifier nonlinearity.
  59. NT_LINEAR, // Fully connected with no nonlinearity.
  60. NT_SOFTMAX, // Softmax uses exponential normalization, with CTC.
  61. NT_SOFTMAX_NO_CTC, // Softmax uses exponential normalization, no CTC.
  62. // The SOFTMAX LSTMs both have an extra softmax layer on top, but inside, with
  63. // the outputs fed back to the input of the LSTM at the next timestep.
  64. // The ENCODED version binary encodes the softmax outputs, providing log2 of
  65. // the number of outputs as additional inputs, and the other version just
  66. // provides all the softmax outputs as additional inputs.
  67. NT_LSTM_SOFTMAX, // 1-d LSTM with built-in fully connected softmax.
  68. NT_LSTM_SOFTMAX_ENCODED, // 1-d LSTM with built-in binary encoded softmax.
  69. // A TensorFlow graph encapsulated as a Tesseract network.
  70. NT_TENSORFLOW,
  71. NT_COUNT // Array size.
  72. };
  73. // Enum of Network behavior flags. Can in theory be set for each individual
  74. // network element.
  75. enum NetworkFlags {
  76. // Network forward/backprop behavior.
  77. NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer.
  78. NF_ADAM = 128, // Weight-specific learning rate.
  79. };
  80. // State of training and desired state used in SetEnableTraining.
  81. enum TrainingState {
  82. // Valid states of training_.
  83. TS_DISABLED, // Disabled permanently.
  84. TS_ENABLED, // Enabled for backprop and to write a training dump.
  85. // Re-enable from ANY disabled state.
  86. TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump.
  87. // Valid only for SetEnableTraining.
  88. TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED.
  89. };
  90. // Base class for network types. Not quite an abstract base class, but almost.
  91. // Most of the time no isolated Network exists, except prior to
  92. // deserialization.
  93. class Network {
  94. public:
  95. Network();
  96. Network(NetworkType type, const STRING& name, int ni, int no);
  97. virtual ~Network() = default;
  98. // Accessors.
  99. NetworkType type() const {
  100. return type_;
  101. }
  102. bool IsTraining() const { return training_ == TS_ENABLED; }
  103. bool needs_to_backprop() const {
  104. return needs_to_backprop_;
  105. }
  106. int num_weights() const { return num_weights_; }
  107. int NumInputs() const {
  108. return ni_;
  109. }
  110. int NumOutputs() const {
  111. return no_;
  112. }
  113. // Returns the required shape input to the network.
  114. virtual StaticShape InputShape() const {
  115. StaticShape result;
  116. return result;
  117. }
  118. // Returns the shape output from the network given an input shape (which may
  119. // be partially unknown ie zero).
  120. virtual StaticShape OutputShape(const StaticShape& input_shape) const {
  121. StaticShape result(input_shape);
  122. result.set_depth(no_);
  123. return result;
  124. }
  125. const STRING& name() const {
  126. return name_;
  127. }
  128. virtual STRING spec() const {
  129. return "?";
  130. }
  131. bool TestFlag(NetworkFlags flag) const {
  132. return (network_flags_ & flag) != 0;
  133. }
  134. // Initialization and administrative functions that are mostly provided
  135. // by Plumbing.
  136. // Returns true if the given type is derived from Plumbing, and thus contains
  137. // multiple sub-networks that can have their own learning rate.
  138. virtual bool IsPlumbingType() const { return false; }
  139. // Suspends/Enables/Permanently disables training by setting the training_
  140. // flag. Serialize and DeSerialize only operate on the run-time data if state
  141. // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
  142. // temporarily disable layers in state TS_ENABLED, allowing a trainer to
  143. // serialize as if it were a recognizer.
  144. // TS_RE_ENABLE will re-enable layers that were previously in any disabled
  145. // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
  146. // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
  147. // recognizer can be converted back to a trainer.
  148. virtual void SetEnableTraining(TrainingState state);
  149. // Sets flags that control the action of the network. See NetworkFlags enum
  150. // for bit values.
  151. virtual void SetNetworkFlags(uint32_t flags);
  152. // Sets up the network for training. Initializes weights using weights of
  153. // scale `range` picked according to the random number generator `randomizer`.
  154. // Note that randomizer is a borrowed pointer that should outlive the network
  155. // and should not be deleted by any of the networks.
  156. // Returns the number of weights initialized.
  157. virtual int InitWeights(float range, TRand* randomizer);
  158. // Changes the number of outputs to the outside world to the size of the given
  159. // code_map. Recursively searches the entire network for Softmax layers that
  160. // have exactly old_no outputs, and operates only on those, leaving all others
  161. // unchanged. This enables networks with multiple output layers to get all
  162. // their softmaxes updated, but if an internal layer, uses one of those
  163. // softmaxes for input, then the inputs will effectively be scrambled.
  164. // TODO(rays) Fix this before any such network is implemented.
  165. // The softmaxes are resized by copying the old weight matrix entries for each
  166. // output from code_map[output] where non-negative, and uses the mean (over
  167. // all outputs) of the existing weights for all outputs with negative code_map
  168. // entries. Returns the new number of weights.
  169. virtual int RemapOutputs(int old_no, const std::vector<int>& code_map) {
  170. return 0;
  171. }
  172. // Converts a float network to an int network.
  173. virtual void ConvertToInt() {}
  174. // Provides a pointer to a TRand for any networks that care to use it.
  175. // Note that randomizer is a borrowed pointer that should outlive the network
  176. // and should not be deleted by any of the networks.
  177. virtual void SetRandomizer(TRand* randomizer);
  178. // Sets needs_to_backprop_ to needs_backprop and returns true if
  179. // needs_backprop || any weights in this network so the next layer forward
  180. // can be told to produce backprop for this layer if needed.
  181. virtual bool SetupNeedsBackprop(bool needs_backprop);
  182. // Returns the most recent reduction factor that the network applied to the
  183. // time sequence. Assumes that any 2-d is already eliminated. Used for
  184. // scaling bounding boxes of truth data and calculating result bounding boxes.
  185. // WARNING: if GlobalMinimax is used to vary the scale, this will return
  186. // the last used scale factor. Call it before any forward, and it will return
  187. // the minimum scale factor of the paths through the GlobalMinimax.
  188. virtual int XScaleFactor() const {
  189. return 1;
  190. }
  191. // Provides the (minimum) x scale factor to the network (of interest only to
  192. // input units) so they can determine how to scale bounding boxes.
  193. virtual void CacheXScaleFactor(int factor) {}
  194. // Provides debug output on the weights.
  195. virtual void DebugWeights() = 0;
  196. // Writes to the given file. Returns false in case of error.
  197. // Should be overridden by subclasses, but called by their Serialize.
  198. virtual bool Serialize(TFile* fp) const;
  199. // Reads from the given file. Returns false in case of error.
  200. // Should be overridden by subclasses, but NOT called by their DeSerialize.
  201. virtual bool DeSerialize(TFile* fp) = 0;
  202. public:
  203. // Updates the weights using the given learning rate, momentum and adam_beta.
  204. // num_samples is used in the adam computation iff use_adam_ is true.
  205. virtual void Update(float learning_rate, float momentum, float adam_beta,
  206. int num_samples) {}
  207. // Sums the products of weight updates in *this and other, splitting into
  208. // positive (same direction) in *same and negative (different direction) in
  209. // *changed.
  210. virtual void CountAlternators(const Network& other, double* same,
  211. double* changed) const {}
  212. // Reads from the given file. Returns nullptr in case of error.
  213. // Determines the type of the serialized class and calls its DeSerialize
  214. // on a new object of the appropriate type, which is returned.
  215. static Network* CreateFromFile(TFile* fp);
  216. // Runs forward propagation of activations on the input line.
  217. // Note that input and output are both 2-d arrays.
  218. // The 1st index is the time element. In a 1-d network, it might be the pixel
  219. // position on the textline. In a 2-d network, the linearization is defined
  220. // by the stride_map. (See networkio.h).
  221. // The 2nd index of input is the network inputs/outputs, and the dimension
  222. // of the input must match NumInputs() of this network.
  223. // The output array will be resized as needed so that its 1st dimension is
  224. // always equal to the number of output values, and its second dimension is
  225. // always NumOutputs(). Note that all this detail is encapsulated away inside
  226. // NetworkIO, as are the internals of the scratch memory space used by the
  227. // network. See networkscratch.h for that.
  228. // If input_transpose is not nullptr, then it contains the transpose of input,
  229. // and the caller guarantees that it will still be valid on the next call to
  230. // backward. The callee is therefore at liberty to save the pointer and
  231. // reference it on a call to backward. This is a bit ugly, but it makes it
  232. // possible for a replicating parallel to calculate the input transpose once
  233. // instead of all the replicated networks having to do it.
  234. virtual void Forward(bool debug, const NetworkIO& input,
  235. const TransposedArray* input_transpose,
  236. NetworkScratch* scratch, NetworkIO* output) = 0;
  237. // Runs backward propagation of errors on fwdX_deltas.
  238. // Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward.
  239. // Returns false if back_deltas was not set, due to there being no point in
  240. // propagating further backwards. Thus most complete networks will always
  241. // return false from Backward!
  242. virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
  243. NetworkScratch* scratch,
  244. NetworkIO* back_deltas) = 0;
  245. // === Debug image display methods. ===
  246. // Displays the image of the matrix to the forward window.
  247. void DisplayForward(const NetworkIO& matrix);
  248. // Displays the image of the matrix to the backward window.
  249. void DisplayBackward(const NetworkIO& matrix);
  250. // Creates the window if needed, otherwise clears it.
  251. static void ClearWindow(bool tess_coords, const char* window_name,
  252. int width, int height, ScrollView** window);
  253. // Displays the pix in the given window. and returns the height of the pix.
  254. // The pix is pixDestroyed.
  255. static int DisplayImage(Pix* pix, ScrollView* window);
  256. protected:
  257. // Returns a random number in [-range, range].
  258. double Random(double range);
  259. protected:
  260. NetworkType type_; // Type of the derived network class.
  261. TrainingState training_; // Are we currently training?
  262. bool needs_to_backprop_; // This network needs to output back_deltas.
  263. int32_t network_flags_; // Behavior control flags in NetworkFlags.
  264. int32_t ni_; // Number of input values.
  265. int32_t no_; // Number of output values.
  266. int32_t num_weights_; // Number of weights in this and sub-network.
  267. STRING name_; // A unique name for this layer.
  268. // NOT-serialized debug data.
  269. ScrollView* forward_win_; // Recognition debug display window.
  270. ScrollView* backward_win_; // Training debug display window.
  271. TRand* randomizer_; // Random number generator.
  272. };
  273. } // namespace tesseract.
  274. #endif // TESSERACT_LSTM_NETWORK_H_