lstmrecognizer.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: lstmrecognizer.h
  3. // Description: Top-level line recognizer class for LSTM-based networks.
  4. // Author: Ray Smith
  5. // Created: Thu May 02 08:57:06 PST 2013
  6. //
  7. // (C) Copyright 2013, Google Inc.
  8. // Licensed under the Apache License, Version 2.0 (the "License");
  9. // you may not use this file except in compliance with the License.
  10. // You may obtain a copy of the License at
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. // Unless required by applicable law or agreed to in writing, software
  13. // distributed under the License is distributed on an "AS IS" BASIS,
  14. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. // See the License for the specific language governing permissions and
  16. // limitations under the License.
  17. ///////////////////////////////////////////////////////////////////////
  18. #ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
  19. #define TESSERACT_LSTM_LSTMRECOGNIZER_H_
  20. #include "ccutil.h"
  21. #include "helpers.h"
  22. #include "imagedata.h"
  23. #include "matrix.h"
  24. #include "network.h"
  25. #include "networkscratch.h"
  26. #include "params.h"
  27. #include "recodebeam.h"
  28. #include "series.h"
  29. #include "strngs.h"
  30. #include "unicharcompress.h"
  31. class BLOB_CHOICE_IT;
  32. struct Pix;
  33. class ROW_RES;
  34. class ScrollView;
  35. class TBOX;
  36. class WERD_RES;
  37. namespace tesseract {
  38. class Dict;
  39. class ImageData;
  40. // Enum indicating training mode control flags.
  41. enum TrainingFlags {
  42. TF_INT_MODE = 1,
  43. TF_COMPRESS_UNICHARSET = 64,
  44. };
  45. // Top-level line recognizer class for LSTM-based networks.
  46. // Note that a sub-class, LSTMTrainer is used for training.
  47. class LSTMRecognizer {
  48. public:
  49. LSTMRecognizer();
  50. ~LSTMRecognizer();
  51. int NumOutputs() const { return network_->NumOutputs(); }
  52. int training_iteration() const { return training_iteration_; }
  53. int sample_iteration() const { return sample_iteration_; }
  54. double learning_rate() const { return learning_rate_; }
  55. LossType OutputLossType() const {
  56. if (network_ == nullptr) return LT_NONE;
  57. StaticShape shape;
  58. shape = network_->OutputShape(shape);
  59. return shape.loss_type();
  60. }
  61. bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; }
  62. bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; }
  63. // True if recoder_ is active to re-encode text to a smaller space.
  64. bool IsRecoding() const {
  65. return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0;
  66. }
  67. // Returns true if the network is a TensorFlow network.
  68. bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
  69. // Returns a vector of layer ids that can be passed to other layer functions
  70. // to access a specific layer.
  71. GenericVector<STRING> EnumerateLayers() const {
  72. ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
  73. auto* series = static_cast<Series*>(network_);
  74. GenericVector<STRING> layers;
  75. series->EnumerateLayers(nullptr, &layers);
  76. return layers;
  77. }
  78. // Returns a specific layer from its id (from EnumerateLayers).
  79. Network* GetLayer(const STRING& id) const {
  80. ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
  81. ASSERT_HOST(id.length() > 1 && id[0] == ':');
  82. auto* series = static_cast<Series*>(network_);
  83. return series->GetLayer(&id[1]);
  84. }
  85. // Returns the learning rate of the layer from its id.
  86. float GetLayerLearningRate(const STRING& id) const {
  87. ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
  88. if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
  89. ASSERT_HOST(id.length() > 1 && id[0] == ':');
  90. auto* series = static_cast<Series*>(network_);
  91. return series->LayerLearningRate(&id[1]);
  92. } else {
  93. return learning_rate_;
  94. }
  95. }
  96. // Multiplies the all the learning rate(s) by the given factor.
  97. void ScaleLearningRate(double factor) {
  98. ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
  99. learning_rate_ *= factor;
  100. if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
  101. GenericVector<STRING> layers = EnumerateLayers();
  102. for (int i = 0; i < layers.size(); ++i) {
  103. ScaleLayerLearningRate(layers[i], factor);
  104. }
  105. }
  106. }
  107. // Multiplies the learning rate of the layer with id, by the given factor.
  108. void ScaleLayerLearningRate(const STRING& id, double factor) {
  109. ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
  110. ASSERT_HOST(id.length() > 1 && id[0] == ':');
  111. auto* series = static_cast<Series*>(network_);
  112. series->ScaleLayerLearningRate(&id[1], factor);
  113. }
  114. // Converts the network to int if not already.
  115. void ConvertToInt() {
  116. if ((training_flags_ & TF_INT_MODE) == 0) {
  117. network_->ConvertToInt();
  118. training_flags_ |= TF_INT_MODE;
  119. }
  120. }
  121. // Provides access to the UNICHARSET that this classifier works with.
  122. const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
  123. // Provides access to the UnicharCompress that this classifier works with.
  124. const UnicharCompress& GetRecoder() const { return recoder_; }
  125. // Provides access to the Dict that this classifier works with.
  126. const Dict* GetDict() const { return dict_; }
  127. // Sets the sample iteration to the given value. The sample_iteration_
  128. // determines the seed for the random number generator. The training
  129. // iteration is incremented only by a successful training iteration.
  130. void SetIteration(int iteration) { sample_iteration_ = iteration; }
  131. // Accessors for textline image normalization.
  132. int NumInputs() const { return network_->NumInputs(); }
  133. int null_char() const { return null_char_; }
  134. // Loads a model from mgr, including the dictionary only if lang is not null.
  135. bool Load(const ParamsVectors* params, const char* lang,
  136. TessdataManager* mgr);
  137. // Writes to the given file. Returns false in case of error.
  138. // If mgr contains a unicharset and recoder, then they are not encoded to fp.
  139. bool Serialize(const TessdataManager* mgr, TFile* fp) const;
  140. // Reads from the given file. Returns false in case of error.
  141. // If mgr contains a unicharset and recoder, then they are taken from there,
  142. // otherwise, they are part of the serialization in fp.
  143. bool DeSerialize(const TessdataManager* mgr, TFile* fp);
  144. // Loads the charsets from mgr.
  145. bool LoadCharsets(const TessdataManager* mgr);
  146. // Loads the Recoder.
  147. bool LoadRecoder(TFile* fp);
  148. // Loads the dictionary if possible from the traineddata file.
  149. // Prints a warning message, and returns false but otherwise fails silently
  150. // and continues to work without it if loading fails.
  151. // Note that dictionary load is independent from DeSerialize, but dependent
  152. // on the unicharset matching. This enables training to deserialize a model
  153. // from checkpoint or restore without having to go back and reload the
  154. // dictionary.
  155. bool LoadDictionary(const ParamsVectors* params, const char* lang,
  156. TessdataManager* mgr);
  157. // Recognizes the line image, contained within image_data, returning the
  158. // recognized tesseract WERD_RES for the words.
  159. // If invert, tries inverted as well if the normal interpretation doesn't
  160. // produce a good enough result. The line_box is used for computing the
  161. // box_word in the output words. worst_dict_cert is the worst certainty that
  162. // will be used in a dictionary word.
  163. void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
  164. double worst_dict_cert, const TBOX& line_box,
  165. PointerVector<WERD_RES>* words, int lstm_choice_mode = 0);
  166. // Helper computes min and mean best results in the output.
  167. void OutputStats(const NetworkIO& outputs, float* min_output,
  168. float* mean_output, float* sd);
  169. // Recognizes the image_data, returning the labels,
  170. // scores, and corresponding pairs of start, end x-coords in coords.
  171. // Returned in scale_factor is the reduction factor
  172. // between the image and the output coords, for computing bounding boxes.
  173. // If re_invert is true, the input is inverted back to its original
  174. // photometric interpretation if inversion is attempted but fails to
  175. // improve the results. This ensures that outputs contains the correct
  176. // forward outputs for the best photometric interpretation.
  177. // inputs is filled with the used inputs to the network.
  178. bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
  179. bool re_invert, bool upside_down, float* scale_factor,
  180. NetworkIO* inputs, NetworkIO* outputs);
  181. // Converts an array of labels to utf-8, whether or not the labels are
  182. // augmented with character boundaries.
  183. STRING DecodeLabels(const GenericVector<int>& labels);
  184. // Displays the forward results in a window with the characters and
  185. // boundaries as determined by the labels and label_coords.
  186. void DisplayForward(const NetworkIO& inputs, const GenericVector<int>& labels,
  187. const GenericVector<int>& label_coords,
  188. const char* window_name, ScrollView** window);
  189. // Converts the network output to a sequence of labels. Outputs labels, scores
  190. // and start xcoords of each char, and each null_char_, with an additional
  191. // final xcoord for the end of the output.
  192. // The conversion method is determined by internal state.
  193. void LabelsFromOutputs(const NetworkIO& outputs, GenericVector<int>* labels,
  194. GenericVector<int>* xcoords);
  195. protected:
  196. // Sets the random seed from the sample_iteration_;
  197. void SetRandomSeed() {
  198. int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001;
  199. randomizer_.set_seed(seed);
  200. randomizer_.IntRand();
  201. }
  202. // Displays the labels and cuts at the corresponding xcoords.
  203. // Size of labels should match xcoords.
  204. void DisplayLSTMOutput(const GenericVector<int>& labels,
  205. const GenericVector<int>& xcoords, int height,
  206. ScrollView* window);
  207. // Prints debug output detailing the activation path that is implied by the
  208. // xcoords.
  209. void DebugActivationPath(const NetworkIO& outputs,
  210. const GenericVector<int>& labels,
  211. const GenericVector<int>& xcoords);
  212. // Prints debug output detailing activations and 2nd choice over a range
  213. // of positions.
  214. void DebugActivationRange(const NetworkIO& outputs, const char* label,
  215. int best_choice, int x_start, int x_end);
  216. // As LabelsViaCTC except that this function constructs the best path that
  217. // contains only legal sequences of subcodes for recoder_.
  218. void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
  219. GenericVector<int>* xcoords);
  220. // Converts the network output to a sequence of labels, with scores, using
  221. // the simple character model (each position is a char, and the null_char_ is
  222. // mainly intended for tail padding.)
  223. void LabelsViaSimpleText(const NetworkIO& output, GenericVector<int>* labels,
  224. GenericVector<int>* xcoords);
  225. // Returns a string corresponding to the label starting at start. Sets *end
  226. // to the next start and if non-null, *decoded to the unichar id.
  227. const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
  228. int* decoded);
  229. // Returns a string corresponding to a given single label id, falling back to
  230. // a default of ".." for part of a multi-label unichar-id.
  231. const char* DecodeSingleLabel(int label);
  232. protected:
  233. // The network hierarchy.
  234. Network* network_;
  235. // The unicharset. Only the unicharset element is serialized.
  236. // Has to be a CCUtil, so Dict can point to it.
  237. CCUtil ccutil_;
  238. // For backward compatibility, recoder_ is serialized iff
  239. // training_flags_ & TF_COMPRESS_UNICHARSET.
  240. // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
  241. UnicharCompress recoder_;
  242. // ==Training parameters that are serialized to provide a record of them.==
  243. STRING network_str_;
  244. // Flags used to determine the training method of the network.
  245. // See enum TrainingFlags above.
  246. int32_t training_flags_;
  247. // Number of actual backward training steps used.
  248. int32_t training_iteration_;
  249. // Index into training sample set. sample_iteration >= training_iteration_.
  250. int32_t sample_iteration_;
  251. // Index in softmax of null character. May take the value UNICHAR_BROKEN or
  252. // ccutil_.unicharset.size().
  253. int32_t null_char_;
  254. // Learning rate and momentum multipliers of deltas in backprop.
  255. float learning_rate_;
  256. float momentum_;
  257. // Smoothing factor for 2nd moment of gradients.
  258. float adam_beta_;
  259. // === NOT SERIALIZED.
  260. TRand randomizer_;
  261. NetworkScratch scratch_space_;
  262. // Language model (optional) to use with the beam search.
  263. Dict* dict_;
  264. // Beam search held between uses to optimize memory allocation/use.
  265. RecodeBeamSearch* search_;
  266. // == Debugging parameters.==
  267. // Recognition debug display window.
  268. ScrollView* debug_win_;
  269. };
  270. } // namespace tesseract.
  271. #endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_