ctc.h 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: ctc.h
  3. // Description: Slightly improved standard CTC to compute the targets.
  4. // Author: Ray Smith
  5. // Created: Wed Jul 13 15:17:06 PDT 2016
  6. //
  7. // (C) Copyright 2016, Google Inc.
  8. // Licensed under the Apache License, Version 2.0 (the "License");
  9. // you may not use this file except in compliance with the License.
  10. // You may obtain a copy of the License at
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. // Unless required by applicable law or agreed to in writing, software
  13. // distributed under the License is distributed on an "AS IS" BASIS,
  14. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. // See the License for the specific language governing permissions and
  16. // limitations under the License.
  17. ///////////////////////////////////////////////////////////////////////
  18. #ifndef TESSERACT_LSTM_CTC_H_
  19. #define TESSERACT_LSTM_CTC_H_
  20. #include "genericvector.h"
  21. #include "network.h"
  22. #include "networkio.h"
  23. #include "scrollview.h"
  24. namespace tesseract {
  25. // Class to encapsulate CTC and simple target generation.
  26. class CTC {
  27. public:
  28. // Normalizes the probabilities such that no target has a prob below min_prob,
  29. // and, provided that the initial total is at least min_total_prob, then all
  30. // probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
  31. // probability is thus 1 - (num_classes-1)*min_prob.
  32. static void NormalizeProbs(NetworkIO* probs) {
  33. NormalizeProbs(probs->mutable_float_array());
  34. }
  35. // Builds a target using CTC. Slightly improved as follows:
  36. // Includes normalizations and clipping for stability.
  37. // labels should be pre-padded with nulls wherever desired, but they don't
  38. // have to be between all labels. Allows for multi-label codes with no
  39. // nulls between.
  40. // labels can be longer than the time sequence, but the total number of
  41. // essential labels (non-null plus nulls between equal labels) must not exceed
  42. // the number of timesteps in outputs.
  43. // outputs is the output of the network, and should have already been
  44. // normalized with NormalizeProbs.
  45. // On return targets is filled with the computed targets.
  46. // Returns false if there is insufficient time for the labels.
  47. static bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
  48. int null_char,
  49. const GENERIC_2D_ARRAY<float>& outputs,
  50. NetworkIO* targets);
  51. private:
  52. // Constructor is private as the instance only holds information specific to
  53. // the current labels, outputs etc, and is built by the static function.
  54. CTC(const GenericVector<int>& labels, int null_char,
  55. const GENERIC_2D_ARRAY<float>& outputs);
  56. // Computes vectors of min and max label index for each timestep, based on
  57. // whether skippability of nulls makes it possible to complete a valid path.
  58. bool ComputeLabelLimits();
  59. // Computes targets based purely on the labels by spreading the labels evenly
  60. // over the available timesteps.
  61. void ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const;
  62. // Computes mean positions and half widths of the simple targets by spreading
  63. // the labels even over the available timesteps.
  64. void ComputeWidthsAndMeans(GenericVector<float>* half_widths,
  65. GenericVector<int>* means) const;
  66. // Calculates and returns a suitable fraction of the simple targets to add
  67. // to the network outputs.
  68. float CalculateBiasFraction();
  69. // Runs the forward CTC pass, filling in log_probs.
  70. void Forward(GENERIC_2D_ARRAY<double>* log_probs) const;
  71. // Runs the backward CTC pass, filling in log_probs.
  72. void Backward(GENERIC_2D_ARRAY<double>* log_probs) const;
  73. // Normalizes and brings probs out of log space with a softmax over time.
  74. void NormalizeSequence(GENERIC_2D_ARRAY<double>* probs) const;
  75. // For each timestep computes the max prob for each class over all
  76. // instances of the class in the labels_, and sets the targets to
  77. // the max observed prob.
  78. void LabelsToClasses(const GENERIC_2D_ARRAY<double>& probs,
  79. NetworkIO* targets) const;
  80. // Normalizes the probabilities such that no target has a prob below min_prob,
  81. // and, provided that the initial total is at least min_total_prob, then all
  82. // probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
  83. // probability is thus 1 - (num_classes-1)*min_prob.
  84. static void NormalizeProbs(GENERIC_2D_ARRAY<float>* probs);
  85. // Returns true if the label at index is a needed null.
  86. bool NeededNull(int index) const;
  87. // Returns exp(clipped(x)), clipping x to a reasonable range to prevent over/
  88. // underflow.
  89. static double ClippedExp(double x) {
  90. if (x < -kMaxExpArg_) return exp(-kMaxExpArg_);
  91. if (x > kMaxExpArg_) return exp(kMaxExpArg_);
  92. return exp(x);
  93. }
  94. // Minimum probability limit for softmax input to ctc_loss.
  95. static const float kMinProb_;
  96. // Maximum absolute argument to exp().
  97. static const double kMaxExpArg_;
  98. // Minimum probability for total prob in time normalization.
  99. static const double kMinTotalTimeProb_;
  100. // Minimum probability for total prob in final normalization.
  101. static const double kMinTotalFinalProb_;
  102. // The truth label indices that are to be matched to outputs_.
  103. const GenericVector<int>& labels_;
  104. // The network outputs.
  105. GENERIC_2D_ARRAY<float> outputs_;
  106. // The null or "blank" label.
  107. int null_char_;
  108. // Number of timesteps in outputs_.
  109. int num_timesteps_;
  110. // Number of classes in outputs_.
  111. int num_classes_;
  112. // Number of labels in labels_.
  113. int num_labels_;
  114. // Min and max valid label indices for each timestep.
  115. GenericVector<int> min_labels_;
  116. GenericVector<int> max_labels_;
  117. };
  118. } // namespace tesseract
  119. #endif // TESSERACT_LSTM_CTC_H_