lstmtrainer.h 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: lstmtrainer.h
  3. // Description: Top-level line trainer class for LSTM-based networks.
  4. // Author: Ray Smith
  5. // Created: Fri May 03 09:07:06 PST 2013
  6. //
  7. // (C) Copyright 2013, Google Inc.
  8. // Licensed under the Apache License, Version 2.0 (the "License");
  9. // you may not use this file except in compliance with the License.
  10. // You may obtain a copy of the License at
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. // Unless required by applicable law or agreed to in writing, software
  13. // distributed under the License is distributed on an "AS IS" BASIS,
  14. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. // See the License for the specific language governing permissions and
  16. // limitations under the License.
  17. ///////////////////////////////////////////////////////////////////////
  18. #ifndef TESSERACT_LSTM_LSTMTRAINER_H_
  19. #define TESSERACT_LSTM_LSTMTRAINER_H_
  20. #include "imagedata.h"
  21. #include "lstmrecognizer.h"
  22. #include "rect.h"
  23. #include "tesscallback.h"
  24. namespace tesseract {
  25. class LSTM;
  26. class LSTMTrainer;
  27. class Parallel;
  28. class Reversed;
  29. class Softmax;
  30. class Series;
  31. // Enum for the types of errors that are counted.
  32. enum ErrorTypes {
  33. ET_RMS, // RMS activation error.
  34. ET_DELTA, // Number of big errors in deltas.
  35. ET_WORD_RECERR, // Output text string word recall error.
  36. ET_CHAR_ERROR, // Output text string total char error.
  37. ET_SKIP_RATIO, // Fraction of samples skipped.
  38. ET_COUNT // For array sizing.
  39. };
  40. // Enum for the trainability_ flags.
  41. enum Trainability {
  42. TRAINABLE, // Non-zero delta error.
  43. PERFECT, // Zero delta error.
  44. UNENCODABLE, // Not trainable due to coding/alignment trouble.
  45. HI_PRECISION_ERR, // Hi confidence disagreement.
  46. NOT_BOXED, // Early in training and has no character boxes.
  47. };
  48. // Enum to define the amount of data to get serialized.
  49. enum SerializeAmount {
  50. LIGHT, // Minimal data for remote training.
  51. NO_BEST_TRAINER, // Save an empty vector in place of best_trainer_.
  52. FULL, // All data including best_trainer_.
  53. };
  54. // Enum to indicate how the sub_trainer_ training went.
  55. enum SubTrainerResult {
  56. STR_NONE, // Did nothing as not good enough.
  57. STR_UPDATED, // Subtrainer was updated, but didn't replace *this.
  58. STR_REPLACED // Subtrainer replaced *this.
  59. };
  60. class LSTMTrainer;
  61. // Function to restore the trainer state from a given checkpoint.
  62. // Returns false on failure.
  63. typedef TessResultCallback2<bool, const GenericVector<char>&, LSTMTrainer*>*
  64. CheckPointReader;
  65. // Function to save a checkpoint of the current trainer state.
  66. // Returns false on failure. SerializeAmount determines the amount of the
  67. // trainer to serialize, typically used for saving the best state.
  68. typedef TessResultCallback3<bool, SerializeAmount, const LSTMTrainer*,
  69. GenericVector<char>*>* CheckPointWriter;
  70. // Function to compute and record error rates on some external test set(s).
  71. // Args are: iteration, mean errors, model, training stage.
  72. // Returns a STRING containing logging information about the tests.
  73. typedef TessResultCallback4<STRING, int, const double*, const TessdataManager&,
  74. int>* TestCallback;
  75. // Trainer class for LSTM networks. Most of the effort is in creating the
  76. // ideal target outputs from the transcription. A box file is used if it is
  77. // available, otherwise estimates of the char widths from the unicharset are
  78. // used to guide a DP search for the best fit to the transcription.
  79. class LSTMTrainer : public LSTMRecognizer {
  80. public:
  81. LSTMTrainer();
  82. // Callbacks may be null, in which case defaults are used.
  83. LSTMTrainer(FileReader file_reader, FileWriter file_writer,
  84. CheckPointReader checkpoint_reader,
  85. CheckPointWriter checkpoint_writer,
  86. const char* model_base, const char* checkpoint_name,
  87. int debug_interval, int64_t max_memory);
  88. virtual ~LSTMTrainer();
  89. // Tries to deserialize a trainer from the given file and silently returns
  90. // false in case of failure. If old_traineddata is not null, then it is
  91. // assumed that the character set is to be re-mapped from old_traineddata to
  92. // the new, with consequent change in weight matrices etc.
  93. bool TryLoadingCheckpoint(const char* filename, const char* old_traineddata);
  94. // Initializes the character set encode/decode mechanism directly from a
  95. // previously setup traineddata containing dawgs, UNICHARSET and
  96. // UnicharCompress. Note: Call before InitNetwork!
  97. void InitCharSet(const std::string& traineddata_path) {
  98. ASSERT_HOST(mgr_.Init(traineddata_path.c_str()));
  99. InitCharSet();
  100. }
  101. void InitCharSet(const TessdataManager& mgr) {
  102. mgr_ = mgr;
  103. InitCharSet();
  104. }
  105. // Initializes the trainer with a network_spec in the network description
  106. // net_flags control network behavior according to the NetworkFlags enum.
  107. // There isn't really much difference between them - only where the effects
  108. // are implemented.
  109. // For other args see NetworkBuilder::InitNetwork.
  110. // Note: Be sure to call InitCharSet before InitNetwork!
  111. bool InitNetwork(const STRING& network_spec, int append_index, int net_flags,
  112. float weight_range, float learning_rate, float momentum,
  113. float adam_beta);
  114. // Initializes a trainer from a serialized TFNetworkModel proto.
  115. // Returns the global step of TensorFlow graph or 0 if failed.
  116. // Building a compatible TF graph: See tfnetwork.proto.
  117. int InitTensorFlowNetwork(const std::string& tf_proto);
  118. // Resets all the iteration counters for fine tuning or training a head,
  119. // where we want the error reporting to reset.
  120. void InitIterations();
  121. // Accessors.
  122. double ActivationError() const {
  123. return error_rates_[ET_DELTA];
  124. }
  125. double CharError() const { return error_rates_[ET_CHAR_ERROR]; }
  126. const double* error_rates() const {
  127. return error_rates_;
  128. }
  129. double best_error_rate() const {
  130. return best_error_rate_;
  131. }
  132. int best_iteration() const {
  133. return best_iteration_;
  134. }
  135. int learning_iteration() const { return learning_iteration_; }
  136. int32_t improvement_steps() const { return improvement_steps_; }
  137. void set_perfect_delay(int delay) { perfect_delay_ = delay; }
  138. const GenericVector<char>& best_trainer() const { return best_trainer_; }
  139. // Returns the error that was just calculated by PrepareForBackward.
  140. double NewSingleError(ErrorTypes type) const {
  141. return error_buffers_[type][training_iteration() % kRollingBufferSize_];
  142. }
  143. // Returns the error that was just calculated by TrainOnLine. Since
  144. // TrainOnLine rolls the error buffers, this is one further back than
  145. // NewSingleError.
  146. double LastSingleError(ErrorTypes type) const {
  147. return error_buffers_[type]
  148. [(training_iteration() + kRollingBufferSize_ - 1) %
  149. kRollingBufferSize_];
  150. }
  151. const DocumentCache& training_data() const {
  152. return training_data_;
  153. }
  154. DocumentCache* mutable_training_data() { return &training_data_; }
  155. // If the training sample is usable, grid searches for the optimal
  156. // dict_ratio/cert_offset, and returns the results in a string of space-
  157. // separated triplets of ratio,offset=worderr.
  158. Trainability GridSearchDictParams(
  159. const ImageData* trainingdata, int iteration, double min_dict_ratio,
  160. double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
  161. double cert_offset_step, double max_cert_offset, STRING* results);
  162. // Provides output on the distribution of weight values.
  163. void DebugNetwork();
  164. // Loads a set of lstmf files that were created using the lstm.train config to
  165. // tesseract into memory ready for training. Returns false if nothing was
  166. // loaded.
  167. bool LoadAllTrainingData(const GenericVector<STRING>& filenames,
  168. CachingStrategy cache_strategy,
  169. bool randomly_rotate);
  170. // Keeps track of best and locally worst error rate, using internally computed
  171. // values. See MaintainCheckpointsSpecific for more detail.
  172. bool MaintainCheckpoints(TestCallback tester, STRING* log_msg);
  173. // Keeps track of best and locally worst error_rate (whatever it is) and
  174. // launches tests using rec_model, when a new min or max is reached.
  175. // Writes checkpoints using train_model at appropriate times and builds and
  176. // returns a log message to indicate progress. Returns false if nothing
  177. // interesting happened.
  178. bool MaintainCheckpointsSpecific(int iteration,
  179. const GenericVector<char>* train_model,
  180. const GenericVector<char>* rec_model,
  181. TestCallback tester, STRING* log_msg);
  182. // Builds a string containing a progress message with current error rates.
  183. void PrepareLogMsg(STRING* log_msg) const;
  184. // Appends <intro_str> iteration learning_iteration()/training_iteration()/
  185. // sample_iteration() to the log_msg.
  186. void LogIterations(const char* intro_str, STRING* log_msg) const;
  187. // TODO(rays) Add curriculum learning.
  188. // Returns true and increments the training_stage_ if the error rate has just
  189. // passed through the given threshold for the first time.
  190. bool TransitionTrainingStage(float error_threshold);
  191. // Returns the current training stage.
  192. int CurrentTrainingStage() const { return training_stage_; }
  193. // Writes to the given file. Returns false in case of error.
  194. bool Serialize(SerializeAmount serialize_amount,
  195. const TessdataManager* mgr, TFile* fp) const;
  196. // Reads from the given file. Returns false in case of error.
  197. bool DeSerialize(const TessdataManager* mgr, TFile* fp);
  198. // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
  199. // learning rates (by scaling reduction, or layer specific, according to
  200. // NF_LAYER_SPECIFIC_LR).
  201. void StartSubtrainer(STRING* log_msg);
  202. // While the sub_trainer_ is behind the current training iteration and its
  203. // training error is at least kSubTrainerMarginFraction better than the
  204. // current training error, trains the sub_trainer_, and returns STR_UPDATED if
  205. // it did anything. If it catches up, and has a better error rate than the
  206. // current best, as well as a margin over the current error rate, then the
  207. // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
  208. // returned. STR_NONE is returned if the subtrainer wasn't good enough to
  209. // receive any training iterations.
  210. SubTrainerResult UpdateSubtrainer(STRING* log_msg);
  211. // Reduces network learning rates, either for everything, or for layers
  212. // independently, according to NF_LAYER_SPECIFIC_LR.
  213. void ReduceLearningRates(LSTMTrainer* samples_trainer, STRING* log_msg);
  214. // Considers reducing the learning rate independently for each layer down by
  215. // factor(<1), or leaving it the same, by double-training the given number of
  216. // samples and minimizing the amount of changing of sign of weight updates.
  217. // Even if it looks like all weights should remain the same, an adjustment
  218. // will be made to guarantee a different result when reverting to an old best.
  219. // Returns the number of layer learning rates that were reduced.
  220. int ReduceLayerLearningRates(double factor, int num_samples,
  221. LSTMTrainer* samples_trainer);
  222. // Converts the string to integer class labels, with appropriate null_char_s
  223. // in between if not in SimpleTextOutput mode. Returns false on failure.
  224. bool EncodeString(const STRING& str, GenericVector<int>* labels) const {
  225. return EncodeString(str, GetUnicharset(), IsRecoding() ? &recoder_ : nullptr,
  226. SimpleTextOutput(), null_char_, labels);
  227. }
  228. // Static version operates on supplied unicharset, encoder, simple_text.
  229. static bool EncodeString(const STRING& str, const UNICHARSET& unicharset,
  230. const UnicharCompress* recoder, bool simple_text,
  231. int null_char, GenericVector<int>* labels);
  232. // Performs forward-backward on the given trainingdata.
  233. // Returns the sample that was used or nullptr if the next sample was deemed
  234. // unusable. samples_trainer could be this or an alternative trainer that
  235. // holds the training samples.
  236. const ImageData* TrainOnLine(LSTMTrainer* samples_trainer, bool batch) {
  237. int sample_index = sample_iteration();
  238. const ImageData* image =
  239. samples_trainer->training_data_.GetPageBySerial(sample_index);
  240. if (image != nullptr) {
  241. Trainability trainable = TrainOnLine(image, batch);
  242. if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
  243. return nullptr; // Sample was unusable.
  244. }
  245. } else {
  246. ++sample_iteration_;
  247. }
  248. return image;
  249. }
  250. Trainability TrainOnLine(const ImageData* trainingdata, bool batch);
  251. // Prepares the ground truth, runs forward, and prepares the targets.
  252. // Returns a Trainability enum to indicate the suitability of the sample.
  253. Trainability PrepareForBackward(const ImageData* trainingdata,
  254. NetworkIO* fwd_outputs, NetworkIO* targets);
  255. // Writes the trainer to memory, so that the current training state can be
  256. // restored. *this must always be the master trainer that retains the only
  257. // copy of the training data and language model. trainer is the model that is
  258. // actually serialized.
  259. bool SaveTrainingDump(SerializeAmount serialize_amount,
  260. const LSTMTrainer* trainer,
  261. GenericVector<char>* data) const;
  262. // Reads previously saved trainer from memory. *this must always be the
  263. // master trainer that retains the only copy of the training data and
  264. // language model. trainer is the model that is restored.
  265. bool ReadTrainingDump(const GenericVector<char>& data,
  266. LSTMTrainer* trainer) const {
  267. if (data.empty()) return false;
  268. return ReadSizedTrainingDump(&data[0], data.size(), trainer);
  269. }
  270. bool ReadSizedTrainingDump(const char* data, int size,
  271. LSTMTrainer* trainer) const {
  272. return trainer->ReadLocalTrainingDump(&mgr_, data, size);
  273. }
  274. // Restores the model to *this.
  275. bool ReadLocalTrainingDump(const TessdataManager* mgr, const char* data,
  276. int size);
  277. // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
  278. void SetupCheckpointInfo();
  279. // Writes the full recognition traineddata to the given filename.
  280. bool SaveTraineddata(const STRING& filename);
  281. // Writes the recognizer to memory, so that it can be used for testing later.
  282. void SaveRecognitionDump(GenericVector<char>* data) const;
  283. // Returns a suitable filename for a training dump, based on the model_base_,
  284. // the iteration and the error rates.
  285. STRING DumpFilename() const;
  286. // Fills the whole error buffer of the given type with the given value.
  287. void FillErrorBuffer(double new_error, ErrorTypes type);
  288. // Helper generates a map from each current recoder_ code (ie softmax index)
  289. // to the corresponding old_recoder code, or -1 if there isn't one.
  290. std::vector<int> MapRecoder(const UNICHARSET& old_chset,
  291. const UnicharCompress& old_recoder) const;
  292. protected:
  293. // Private version of InitCharSet above finishes the job after initializing
  294. // the mgr_ data member.
  295. void InitCharSet();
  296. // Helper computes and sets the null_char_.
  297. void SetNullChar();
  298. // Factored sub-constructor sets up reasonable default values.
  299. void EmptyConstructor();
  300. // Outputs the string and periodically displays the given network inputs
  301. // as an image in the given window, and the corresponding labels at the
  302. // corresponding x_starts.
  303. // Returns false if the truth string is empty.
  304. bool DebugLSTMTraining(const NetworkIO& inputs,
  305. const ImageData& trainingdata,
  306. const NetworkIO& fwd_outputs,
  307. const GenericVector<int>& truth_labels,
  308. const NetworkIO& outputs);
  309. // Displays the network targets as line a line graph.
  310. void DisplayTargets(const NetworkIO& targets, const char* window_name,
  311. ScrollView** window);
  312. // Builds a no-compromises target where the first positions should be the
  313. // truth labels and the rest is padded with the null_char_.
  314. bool ComputeTextTargets(const NetworkIO& outputs,
  315. const GenericVector<int>& truth_labels,
  316. NetworkIO* targets);
  317. // Builds a target using standard CTC. truth_labels should be pre-padded with
  318. // nulls wherever desired. They don't have to be between all labels.
  319. // outputs is input-output, as it gets clipped to minimum probability.
  320. bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
  321. NetworkIO* outputs, NetworkIO* targets);
  322. // Computes network errors, and stores the results in the rolling buffers,
  323. // along with the supplied text_error.
  324. // Returns the delta error of the current sample (not running average.)
  325. double ComputeErrorRates(const NetworkIO& deltas, double char_error,
  326. double word_error);
  327. // Computes the network activation RMS error rate.
  328. double ComputeRMSError(const NetworkIO& deltas);
  329. // Computes network activation winner error rate. (Number of values that are
  330. // in error by >= 0.5 divided by number of time-steps.) More closely related
  331. // to final character error than RMS, but still directly calculable from
  332. // just the deltas. Because of the binary nature of the targets, zero winner
  333. // error is a sufficient but not necessary condition for zero char error.
  334. double ComputeWinnerError(const NetworkIO& deltas);
  335. // Computes a very simple bag of chars char error rate.
  336. double ComputeCharError(const GenericVector<int>& truth_str,
  337. const GenericVector<int>& ocr_str);
  338. // Computes a very simple bag of words word recall error rate.
  339. // NOTE that this is destructive on both input strings.
  340. double ComputeWordError(STRING* truth_str, STRING* ocr_str);
  341. // Updates the error buffer and corresponding mean of the given type with
  342. // the new_error.
  343. void UpdateErrorBuffer(double new_error, ErrorTypes type);
  344. // Rolls error buffers and reports the current means.
  345. void RollErrorBuffers();
  346. // Given that error_rate is either a new min or max, updates the best/worst
  347. // error rates, and record of progress.
  348. STRING UpdateErrorGraph(int iteration, double error_rate,
  349. const GenericVector<char>& model_data,
  350. TestCallback tester);
  351. protected:
  352. // Alignment display window.
  353. ScrollView* align_win_;
  354. // CTC target display window.
  355. ScrollView* target_win_;
  356. // CTC output display window.
  357. ScrollView* ctc_win_;
  358. // Reconstructed image window.
  359. ScrollView* recon_win_;
  360. // How often to display a debug image.
  361. int debug_interval_;
  362. // Iteration at which the last checkpoint was dumped.
  363. int checkpoint_iteration_;
  364. // Basename of files to save best models to.
  365. STRING model_base_;
  366. // Checkpoint filename.
  367. STRING checkpoint_name_;
  368. // Training data.
  369. bool randomly_rotate_;
  370. DocumentCache training_data_;
  371. // Name to use when saving best_trainer_.
  372. STRING best_model_name_;
  373. // Number of available training stages.
  374. int num_training_stages_;
  375. // Checkpointing callbacks.
  376. FileReader file_reader_;
  377. FileWriter file_writer_;
  378. // TODO(rays) These are pointers, and must be deleted. Switch to unique_ptr
  379. // when we can commit to c++11.
  380. CheckPointReader checkpoint_reader_;
  381. CheckPointWriter checkpoint_writer_;
  382. // ===Serialized data to ensure that a restart produces the same results.===
  383. // These members are only serialized when serialize_amount != LIGHT.
  384. // Best error rate so far.
  385. double best_error_rate_;
  386. // Snapshot of all error rates at best_iteration_.
  387. double best_error_rates_[ET_COUNT];
  388. // Iteration of best_error_rate_.
  389. int best_iteration_;
  390. // Worst error rate since best_error_rate_.
  391. double worst_error_rate_;
  392. // Snapshot of all error rates at worst_iteration_.
  393. double worst_error_rates_[ET_COUNT];
  394. // Iteration of worst_error_rate_.
  395. int worst_iteration_;
  396. // Iteration at which the process will be thought stalled.
  397. int stall_iteration_;
  398. // Saved recognition models for computing test error for graph points.
  399. GenericVector<char> best_model_data_;
  400. GenericVector<char> worst_model_data_;
  401. // Saved trainer for reverting back to last known best.
  402. GenericVector<char> best_trainer_;
  403. // A subsidiary trainer running with a different learning rate until either
  404. // *this or sub_trainer_ hits a new best.
  405. LSTMTrainer* sub_trainer_;
  406. // Error rate at which last best model was dumped.
  407. float error_rate_of_last_saved_best_;
  408. // Current stage of training.
  409. int training_stage_;
  410. // History of best error rate against iteration. Used for computing the
  411. // number of steps to each 2% improvement.
  412. GenericVector<double> best_error_history_;
  413. GenericVector<int> best_error_iterations_;
  414. // Number of iterations since the best_error_rate_ was 2% more than it is now.
  415. int32_t improvement_steps_;
  416. // Number of iterations that yielded a non-zero delta error and thus provided
  417. // significant learning. learning_iteration_ <= training_iteration_.
  418. // learning_iteration_ is used to measure rate of learning progress.
  419. int learning_iteration_;
  420. // Saved value of sample_iteration_ before looking for the the next sample.
  421. int prev_sample_iteration_;
  422. // How often to include a PERFECT training sample in backprop.
  423. // A PERFECT training sample is used if the current
  424. // training_iteration_ > last_perfect_training_iteration_ + perfect_delay_,
  425. // so with perfect_delay_ == 0, all samples are used, and with
  426. // perfect_delay_ == 4, at most 1 in 5 samples will be perfect.
  427. int perfect_delay_;
  428. // Value of training_iteration_ at which the last PERFECT training sample
  429. // was used in back prop.
  430. int last_perfect_training_iteration_;
  431. // Rolling buffers storing recent training errors are indexed by
  432. // training_iteration % kRollingBufferSize_.
  433. static const int kRollingBufferSize_ = 1000;
  434. GenericVector<double> error_buffers_[ET_COUNT];
  435. // Rounded mean percent trailing training errors in the buffers.
  436. double error_rates_[ET_COUNT]; // RMS training error.
  437. // Traineddata file with optional dawgs + UNICHARSET and recoder.
  438. TessdataManager mgr_;
  439. };
  440. } // namespace tesseract.
  441. #endif // TESSERACT_LSTM_LSTMTRAINER_H_