recodebeam.h 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: recodebeam.h
  3. // Description: Beam search to decode from the re-encoded CJK as a sequence of
  4. // smaller numbers in place of a single large code.
  5. // Author: Ray Smith
  6. //
  7. // (C) Copyright 2015, Google Inc.
  8. // Licensed under the Apache License, Version 2.0 (the "License");
  9. // you may not use this file except in compliance with the License.
  10. // You may obtain a copy of the License at
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. // Unless required by applicable law or agreed to in writing, software
  13. // distributed under the License is distributed on an "AS IS" BASIS,
  14. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. // See the License for the specific language governing permissions and
  16. // limitations under the License.
  17. //
  18. ///////////////////////////////////////////////////////////////////////
  19. #ifndef THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
  20. #define THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
  21. #include "dawg.h"
  22. #include "dict.h"
  23. #include "genericheap.h"
  24. #include "kdpair.h"
  25. #include "networkio.h"
  26. #include "ratngs.h"
  27. #include "unicharcompress.h"
  28. #include <deque>
  29. #include <set>
  30. #include <tuple>
  31. #include <vector>
  32. namespace tesseract {
  33. // Enum describing what can follow the current node.
  34. // Consider the following softmax outputs:
  35. // Timestep 0 1 2 3 4 5 6 7 8
  36. // X-score 0.01 0.55 0.98 0.42 0.01 0.01 0.40 0.95 0.01
  37. // Y-score 0.00 0.01 0.01 0.01 0.01 0.97 0.59 0.04 0.01
  38. // Null-score 0.99 0.44 0.01 0.57 0.98 0.02 0.01 0.01 0.98
  39. // Then the correct CTC decoding (in which adjacent equal classes are folded,
  40. // and then all nulls are dropped) is clearly XYX, but simple decoding (taking
  41. // the max at each timestep) leads to:
  42. // Null@0.99 X@0.55 X@0.98 Null@0.57 Null@0.98 Y@0.97 Y@0.59 X@0.95 Null@0.98,
  43. // which folds to the correct XYX. The conversion to Tesseract rating and
  44. // certainty uses the sum of the log probs (log of the product of probabilities)
  45. // for the Rating and the minimum log prob for the certainty, but that yields a
  46. // minimum certainty of log(0.55), which is poor for such an obvious case.
  47. // CTC says that the probability of the result is the SUM of the products of the
  48. // probabilities over ALL PATHS that decode to the same result, which includes:
  49. // NXXNNYYXN, NNXNNYYN, NXXXNYYXN, NNXXNYXXN, and others including XXXXXYYXX.
  50. // That is intractable, so some compromise between simple and ideal is needed.
  51. // Observing that evenly split timesteps rarely happen next to each other, we
  52. // allow scores at a transition between classes to be added for decoding thus:
  53. // N@0.99 (N+X)@0.99 X@0.98 (N+X)@0.99 N@0.98 Y@0.97 (X+Y+N)@1.00 X@0.95 N@0.98.
  54. // This works because NNX and NXX both decode to X, so in the middle we can use
  55. // N+X. Note that the classes either side of a sum must stand alone, i.e. use a
  56. // single score, to force all paths to pass through them and decode to the same
  57. // result. Also in the special case of a transition from X to Y, with only one
  58. // timestep between, it is possible to add X+Y+N, since XXY, XYY, and XNY all
  59. // decode to XY.
  60. // An important condition is that we cannot combine X and Null between two
  61. // stand-alone Xs, since that can decode as XNX->XX or XXX->X, so the scores for
  62. // X and Null have to go in separate paths. Combining scores in this way
  63. // provides a much better minimum certainty of log(0.95).
  64. // In the implementation of the beam search, we have to place the possibilities
  65. // X, X+N and X+Y+N in the beam under appropriate conditions of the previous
  66. // node, and constrain what can follow, to enforce the rules explained above.
  67. // We therefore have 3 different types of node determined by what can follow:
  68. enum NodeContinuation {
  69. NC_ANYTHING, // This node used just its own score, so anything can follow.
  70. NC_ONLY_DUP, // The current node combined another score with the score for
  71. // itself, without a stand-alone duplicate before, so must be
  72. // followed by a stand-alone duplicate.
  73. NC_NO_DUP, // The current node combined another score with the score for
  74. // itself, after a stand-alone, so can only be followed by
  75. // something other than a duplicate of the current node.
  76. NC_COUNT
  77. };
  78. // Enum describing the top-n status of a code.
  79. enum TopNState {
  80. TN_TOP2, // Winner or 2nd.
  81. TN_TOPN, // Runner up in top-n, but not 1st or 2nd.
  82. TN_ALSO_RAN, // Not in the top-n.
  83. TN_COUNT
  84. };
  85. // Lattice element for Re-encode beam search.
  86. struct RecodeNode {
  87. RecodeNode()
  88. : code(-1),
  89. unichar_id(INVALID_UNICHAR_ID),
  90. permuter(TOP_CHOICE_PERM),
  91. start_of_dawg(false),
  92. start_of_word(false),
  93. end_of_word(false),
  94. duplicate(false),
  95. certainty(0.0f),
  96. score(0.0f),
  97. prev(nullptr),
  98. dawgs(nullptr),
  99. code_hash(0) {}
  100. RecodeNode(int c, int uni_id, PermuterType perm, bool dawg_start,
  101. bool word_start, bool end, bool dup, float cert, float s,
  102. const RecodeNode* p, DawgPositionVector* d, uint64_t hash)
  103. : code(c),
  104. unichar_id(uni_id),
  105. permuter(perm),
  106. start_of_dawg(dawg_start),
  107. start_of_word(word_start),
  108. end_of_word(end),
  109. duplicate(dup),
  110. certainty(cert),
  111. score(s),
  112. prev(p),
  113. dawgs(d),
  114. code_hash(hash) {}
  115. // NOTE: If we could use C++11, then this would be a move constructor.
  116. // Instead we have copy constructor that does a move!! This is because we
  117. // don't want to copy the whole DawgPositionVector each time, and true
  118. // copying isn't necessary for this struct. It does get moved around a lot
  119. // though inside the heap and during heap push, hence the move semantics.
  120. RecodeNode(RecodeNode& src) : dawgs(nullptr) {
  121. *this = src;
  122. ASSERT_HOST(src.dawgs == nullptr);
  123. }
  124. RecodeNode& operator=(RecodeNode& src) {
  125. delete dawgs;
  126. memcpy(this, &src, sizeof(src));
  127. src.dawgs = nullptr;
  128. return *this;
  129. }
  130. ~RecodeNode() { delete dawgs; }
  131. // Prints details of the node.
  132. void Print(int null_char, const UNICHARSET& unicharset, int depth) const;
  133. // The re-encoded code here = index to network output.
  134. int code;
  135. // The decoded unichar_id is only valid for the final code of a sequence.
  136. int unichar_id;
  137. // The type of permuter active at this point. Intervals between start_of_word
  138. // and end_of_word make valid words of type given by permuter where
  139. // end_of_word is true. These aren't necessarily delimited by spaces.
  140. PermuterType permuter;
  141. // True if this is the initial dawg state. May be attached to a space or,
  142. // in a non-space-delimited lang, the end of the previous word.
  143. bool start_of_dawg;
  144. // True if this is the first node in a dictionary word.
  145. bool start_of_word;
  146. // True if this represents a valid candidate end of word position. Does not
  147. // necessarily mark the end of a word, since a word can be extended beyond a
  148. // candidate end by a continuation, eg 'the' continues to 'these'.
  149. bool end_of_word;
  150. // True if this->code is a duplicate of prev->code. Some training modes
  151. // allow the network to output duplicate characters and crush them with CTC,
  152. // but that would mess up the dictionary search, so we just smash them
  153. // together on the fly using the duplicate flag.
  154. bool duplicate;
  155. // Certainty (log prob) of (just) this position.
  156. float certainty;
  157. // Total certainty of the path to this position.
  158. float score;
  159. // The previous node in this chain. Borrowed pointer.
  160. const RecodeNode* prev;
  161. // The currently active dawgs at this position. Owned pointer.
  162. DawgPositionVector* dawgs;
  163. // A hash of all codes in the prefix and this->code as well. Used for
  164. // duplicate path removal.
  165. uint64_t code_hash;
  166. };
  167. using RecodePair = KDPairInc<double, RecodeNode>;
  168. using RecodeHeap = GenericHeap<RecodePair>;
  169. // Class that holds the entire beam search for recognition of a text line.
  170. class RecodeBeamSearch {
  171. public:
  172. // Borrows the pointer, which is expected to survive until *this is deleted.
  173. RecodeBeamSearch(const UnicharCompress& recoder, int null_char,
  174. bool simple_text, Dict* dict);
  175. // Decodes the set of network outputs, storing the lattice internally.
  176. // If charset is not null, it enables detailed debugging of the beam search.
  177. void Decode(const NetworkIO& output, double dict_ratio, double cert_offset,
  178. double worst_dict_cert, const UNICHARSET* charset,
  179. int lstm_choice_mode = 0);
  180. void Decode(const GENERIC_2D_ARRAY<float>& output, double dict_ratio,
  181. double cert_offset, double worst_dict_cert,
  182. const UNICHARSET* charset);
  183. // Returns the best path as labels/scores/xcoords similar to simple CTC.
  184. void ExtractBestPathAsLabels(GenericVector<int>* labels,
  185. GenericVector<int>* xcoords) const;
  186. // Returns the best path as unichar-ids/certs/ratings/xcoords skipping
  187. // duplicates, nulls and intermediate parts.
  188. void ExtractBestPathAsUnicharIds(bool debug, const UNICHARSET* unicharset,
  189. GenericVector<int>* unichar_ids,
  190. GenericVector<float>* certs,
  191. GenericVector<float>* ratings,
  192. GenericVector<int>* xcoords) const;
  193. // Returns the best path as a set of WERD_RES.
  194. void ExtractBestPathAsWords(const TBOX& line_box, float scale_factor,
  195. bool debug, const UNICHARSET* unicharset,
  196. PointerVector<WERD_RES>* words,
  197. int lstm_choice_mode = 0);
  198. // Generates debug output of the content of the beams after a Decode.
  199. void DebugBeams(const UNICHARSET& unicharset) const;
  200. // Stores the alternative characters of every timestep together with their
  201. // probability.
  202. std::vector< std::vector<std::pair<const char*, float>>> timesteps;
  203. // Clipping value for certainty inside Tesseract. Reflects the minimum value
  204. // of certainty that will be returned by ExtractBestPathAsUnicharIds.
  205. // Supposedly on a uniform scale that can be compared across languages and
  206. // engines.
  207. static const float kMinCertainty;
  208. // Number of different code lengths for which we have a separate beam.
  209. static const int kNumLengths = RecodedCharID::kMaxCodeLen + 1;
  210. // Total number of beams: dawg/nodawg * number of NodeContinuation * number
  211. // of different lengths.
  212. static const int kNumBeams = 2 * NC_COUNT * kNumLengths;
  213. // Returns the relevant factor in the beams_ index.
  214. static int LengthFromBeamsIndex(int index) { return index % kNumLengths; }
  215. static NodeContinuation ContinuationFromBeamsIndex(int index) {
  216. return static_cast<NodeContinuation>((index / kNumLengths) % NC_COUNT);
  217. }
  218. static bool IsDawgFromBeamsIndex(int index) {
  219. return index / (kNumLengths * NC_COUNT) > 0;
  220. }
  221. // Computes a beams_ index from the given factors.
  222. static int BeamIndex(bool is_dawg, NodeContinuation cont, int length) {
  223. return (is_dawg * NC_COUNT + cont) * kNumLengths + length;
  224. }
  225. private:
  226. // Struct for the Re-encode beam search. This struct holds the data for
  227. // a single time-step position of the output. Use a PointerVector<RecodeBeam>
  228. // to hold all the timesteps and prevent reallocation of the individual heaps.
  229. struct RecodeBeam {
  230. // Resets to the initial state without deleting all the memory.
  231. void Clear() {
  232. for (auto & beam : beams_) {
  233. beam.clear();
  234. }
  235. RecodeNode empty;
  236. for (auto & best_initial_dawg : best_initial_dawgs_) {
  237. best_initial_dawg = empty;
  238. }
  239. }
  240. // A separate beam for each combination of code length,
  241. // NodeContinuation, and dictionary flag. Separating out all these types
  242. // allows the beam to be quite narrow, and yet still have a low chance of
  243. // losing the best path.
  244. // We have to keep all these beams separate, since the highest scoring paths
  245. // come from the paths that are most likely to dead-end at any time, like
  246. // dawg paths, NC_ONLY_DUP etc.
  247. // Each heap is stored with the WORST result at the top, so we can quickly
  248. // get the top-n values.
  249. RecodeHeap beams_[kNumBeams];
  250. // While the language model is only a single word dictionary, we can use
  251. // word starts as a choke point in the beam, and keep only a single dict
  252. // start node at each step (for each NodeContinuation type), so we find the
  253. // best one here and push it on the heap, if it qualifies, after processing
  254. // all of the step.
  255. RecodeNode best_initial_dawgs_[NC_COUNT];
  256. };
  257. using TopPair = KDPairInc<float, int>;
  258. // Generates debug output of the content of a single beam position.
  259. void DebugBeamPos(const UNICHARSET& unicharset, const RecodeHeap& heap) const;
  260. // Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping
  261. // duplicates, nulls and intermediate parts.
  262. static void ExtractPathAsUnicharIds(
  263. const GenericVector<const RecodeNode*>& best_nodes,
  264. GenericVector<int>* unichar_ids, GenericVector<float>* certs,
  265. GenericVector<float>* ratings, GenericVector<int>* xcoords,
  266. std::deque<std::tuple<int, int>>* best_choices = nullptr);
  267. // Sets up a word with the ratings matrix and fake blobs with boxes in the
  268. // right places.
  269. WERD_RES* InitializeWord(bool leading_space, const TBOX& line_box,
  270. int word_start, int word_end, float space_certainty,
  271. const UNICHARSET* unicharset,
  272. const GenericVector<int>& xcoords,
  273. float scale_factor);
  274. // Fills top_n_flags_ with bools that are true iff the corresponding output
  275. // is one of the top_n.
  276. void ComputeTopN(const float* outputs, int num_outputs, int top_n);
  277. // Adds the computation for the current time-step to the beam. Call at each
  278. // time-step in sequence from left to right. outputs is the activation vector
  279. // for the current timestep.
  280. void DecodeStep(const float* outputs, int t, double dict_ratio,
  281. double cert_offset, double worst_dict_cert,
  282. const UNICHARSET* charset, bool debug = false);
  283. //Saves the most certain choices for the current time-step
  284. void SaveMostCertainChoices(const float* outputs, int num_outputs, const UNICHARSET* charset, int xCoord);
  285. // Adds to the appropriate beams the legal (according to recoder)
  286. // continuations of context prev, which is from the given index to beams_,
  287. // using the given network outputs to provide scores to the choices. Uses only
  288. // those choices for which top_n_flags[code] == top_n_flag.
  289. void ContinueContext(const RecodeNode* prev, int index, const float* outputs,
  290. TopNState top_n_flag, const UNICHARSET* unicharset,
  291. double dict_ratio, double cert_offset,
  292. double worst_dict_cert, RecodeBeam* step);
  293. // Continues for a new unichar, using dawg or non-dawg as per flag.
  294. void ContinueUnichar(int code, int unichar_id, float cert,
  295. float worst_dict_cert, float dict_ratio, bool use_dawgs,
  296. NodeContinuation cont, const RecodeNode* prev,
  297. RecodeBeam* step);
  298. // Adds a RecodeNode composed of the args to the correct heap in step if
  299. // unichar_id is a valid dictionary continuation of whatever is in prev.
  300. void ContinueDawg(int code, int unichar_id, float cert, NodeContinuation cont,
  301. const RecodeNode* prev, RecodeBeam* step);
  302. // Sets the correct best_initial_dawgs_ with a RecodeNode composed of the args
  303. // if better than what is already there.
  304. void PushInitialDawgIfBetter(int code, int unichar_id, PermuterType permuter,
  305. bool start, bool end, float cert,
  306. NodeContinuation cont, const RecodeNode* prev,
  307. RecodeBeam* step);
  308. // Adds a RecodeNode composed of the args to the correct heap in step for
  309. // partial unichar or duplicate if there is room or if better than the
  310. // current worst element if already full.
  311. void PushDupOrNoDawgIfBetter(int length, bool dup, int code, int unichar_id,
  312. float cert, float worst_dict_cert,
  313. float dict_ratio, bool use_dawgs,
  314. NodeContinuation cont, const RecodeNode* prev,
  315. RecodeBeam* step);
  316. // Adds a RecodeNode composed of the args to the correct heap in step if there
  317. // is room or if better than the current worst element if already full.
  318. void PushHeapIfBetter(int max_size, int code, int unichar_id,
  319. PermuterType permuter, bool dawg_start, bool word_start,
  320. bool end, bool dup, float cert, const RecodeNode* prev,
  321. DawgPositionVector* d, RecodeHeap* heap);
  322. // Adds a RecodeNode to heap if there is room
  323. // or if better than the current worst element if already full.
  324. void PushHeapIfBetter(int max_size, RecodeNode* node, RecodeHeap* heap);
  325. // Searches the heap for an entry matching new_node, and updates the entry
  326. // with reshuffle if needed. Returns true if there was a match.
  327. bool UpdateHeapIfMatched(RecodeNode* new_node, RecodeHeap* heap);
  328. // Computes and returns the code-hash for the given code and prev.
  329. uint64_t ComputeCodeHash(int code, bool dup, const RecodeNode* prev) const;
  330. // Backtracks to extract the best path through the lattice that was built
  331. // during Decode. On return the best_nodes vector essentially contains the set
  332. // of code, score pairs that make the optimal path with the constraint that
  333. // the recoder can decode the code sequence back to a sequence of unichar-ids.
  334. void ExtractBestPaths(GenericVector<const RecodeNode*>* best_nodes,
  335. GenericVector<const RecodeNode*>* second_nodes) const;
  336. // Helper backtracks through the lattice from the given node, storing the
  337. // path and reversing it.
  338. void ExtractPath(const RecodeNode* node,
  339. GenericVector<const RecodeNode*>* path) const;
  340. // Helper prints debug information on the given lattice path.
  341. void DebugPath(const UNICHARSET* unicharset,
  342. const GenericVector<const RecodeNode*>& path) const;
  343. // Helper prints debug information on the given unichar path.
  344. void DebugUnicharPath(const UNICHARSET* unicharset,
  345. const GenericVector<const RecodeNode*>& path,
  346. const GenericVector<int>& unichar_ids,
  347. const GenericVector<float>& certs,
  348. const GenericVector<float>& ratings,
  349. const GenericVector<int>& xcoords) const;
  350. static const int kBeamWidths[RecodedCharID::kMaxCodeLen + 1];
  351. // The encoder/decoder that we will be using.
  352. const UnicharCompress& recoder_;
  353. // The beam for each timestep in the output.
  354. PointerVector<RecodeBeam> beam_;
  355. // The number of timesteps valid in beam_;
  356. int beam_size_;
  357. // A flag to indicate which outputs are the top-n choices. Current timestep
  358. // only.
  359. GenericVector<TopNState> top_n_flags_;
  360. // A record of the highest and second scoring codes.
  361. int top_code_;
  362. int second_code_;
  363. // Heap used to compute the top_n_flags_.
  364. GenericHeap<TopPair> top_heap_;
  365. // Borrowed pointer to the dictionary to use in the search.
  366. Dict* dict_;
  367. // True if the language is space-delimited, which is true for most languages
  368. // except chi*, jpn, tha.
  369. bool space_delimited_;
  370. // True if the input is simple text, ie adjacent equal chars are not to be
  371. // eliminated.
  372. bool is_simple_text_;
  373. // The encoded (class label) of the null/reject character.
  374. int null_char_;
  375. };
  376. } // namespace tesseract.
  377. #endif // THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_