weightmatrix.h 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: weightmatrix.h
  3. // Description: Hides distinction between float/int implementations.
  4. // Author: Ray Smith
  5. //
  6. // (C) Copyright 2014, Google Inc.
  7. // Licensed under the Apache License, Version 2.0 (the "License");
  8. // you may not use this file except in compliance with the License.
  9. // You may obtain a copy of the License at
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. ///////////////////////////////////////////////////////////////////////
  17. #ifndef TESSERACT_LSTM_WEIGHTMATRIX_H_
  18. #define TESSERACT_LSTM_WEIGHTMATRIX_H_
  19. #include <memory>
  20. #include "genericvector.h"
  21. #include "intsimdmatrix.h"
  22. #include "matrix.h"
  23. #include "tprintf.h"
  24. namespace tesseract {
  25. // Convenience instantiation of GENERIC_2D_ARRAY<double> with additional
  26. // operations to write a strided vector, so the transposed form of the input
  27. // is memory-contiguous.
  28. class TransposedArray : public GENERIC_2D_ARRAY<double> {
  29. public:
  30. // Copies the whole input transposed, converted to double, into *this.
  31. void Transpose(const GENERIC_2D_ARRAY<double>& input);
  32. // Writes a vector of data representing a timestep (gradients or sources).
  33. // The data is assumed to be of size1 in size (the strided dimension).
  34. ~TransposedArray() override;
  35. void WriteStrided(int t, const float* data) {
  36. int size1 = dim1();
  37. for (int i = 0; i < size1; ++i) put(i, t, data[i]);
  38. }
  39. void WriteStrided(int t, const double* data) {
  40. int size1 = dim1();
  41. for (int i = 0; i < size1; ++i) put(i, t, data[i]);
  42. }
  43. // Prints the first and last num elements of the un-transposed array.
  44. void PrintUnTransposed(int num) {
  45. int num_features = dim1();
  46. int width = dim2();
  47. for (int y = 0; y < num_features; ++y) {
  48. for (int t = 0; t < width; ++t) {
  49. if (num == 0 || t < num || t + num >= width) {
  50. tprintf(" %g", (*this)(y, t));
  51. }
  52. }
  53. tprintf("\n");
  54. }
  55. }
  56. }; // class TransposedArray
  57. // Generic weight matrix for network layers. Can store the matrix as either
  58. // an array of floats or int8_t. Provides functions to compute the forward and
  59. // backward steps with the matrix and updates to the weights.
  60. class WeightMatrix {
  61. public:
  62. WeightMatrix() : int_mode_(false), use_adam_(false) {}
  63. // Sets up the network for training. Initializes weights using weights of
  64. // scale `range` picked according to the random number generator `randomizer`.
  65. // Note the order is outputs, inputs, as this is the order of indices to
  66. // the matrix, so the adjacent elements are multiplied by the input during
  67. // a forward operation.
  68. int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range,
  69. TRand* randomizer);
  70. // Changes the number of outputs to the size of the given code_map, copying
  71. // the old weight matrix entries for each output from code_map[output] where
  72. // non-negative, and uses the mean (over all outputs) of the existing weights
  73. // for all outputs with negative code_map entries. Returns the new number of
  74. // weights.
  75. int RemapOutputs(const std::vector<int>& code_map);
  76. // Converts a float network to an int network. Each set of input weights that
  77. // corresponds to a single output weight is converted independently:
  78. // Compute the max absolute value of the weight set.
  79. // Scale so the max absolute value becomes INT8_MAX.
  80. // Round to integer.
  81. // Store a multiplicative scale factor (as a float) that will reproduce
  82. // the original value, subject to rounding errors.
  83. void ConvertToInt();
  84. // Returns the size rounded up to an internal factor used by the SIMD
  85. // implementation for its input.
  86. int RoundInputs(int size) const {
  87. if (!int_mode_ || !IntSimdMatrix::intSimdMatrix) return size;
  88. return IntSimdMatrix::intSimdMatrix->RoundInputs(size);
  89. }
  90. // Accessors.
  91. bool is_int_mode() const {
  92. return int_mode_;
  93. }
  94. int NumOutputs() const { return int_mode_ ? wi_.dim1() : wf_.dim1(); }
  95. // Provides one set of weights. Only used by peep weight maxpool.
  96. const double* GetWeights(int index) const { return wf_[index]; }
  97. // Provides access to the deltas (dw_).
  98. double GetDW(int i, int j) const { return dw_(i, j); }
  99. // Allocates any needed memory for running Backward, and zeroes the deltas,
  100. // thus eliminating any existing momentum.
  101. void InitBackward();
  102. // Writes to the given file. Returns false in case of error.
  103. bool Serialize(bool training, TFile* fp) const;
  104. // Reads from the given file. Returns false in case of error.
  105. bool DeSerialize(bool training, TFile* fp);
  106. // As DeSerialize, but reads an old (float) format WeightMatrix for
  107. // backward compatibility.
  108. bool DeSerializeOld(bool training, TFile* fp);
  109. // Computes matrix.vector v = Wu.
  110. // u is of size W.dim2() - 1 and the output v is of size W.dim1().
  111. // u is imagined to have an extra element at the end with value 1, to
  112. // implement the bias, but it doesn't actually have it.
  113. // Asserts that the call matches what we have.
  114. void MatrixDotVector(const double* u, double* v) const;
  115. void MatrixDotVector(const int8_t* u, double* v) const;
  116. // MatrixDotVector for peep weights, MultiplyAccumulate adds the
  117. // component-wise products of *this[0] and v to inout.
  118. void MultiplyAccumulate(const double* v, double* inout);
  119. // Computes vector.matrix v = uW.
  120. // u is of size W.dim1() and the output v is of size W.dim2() - 1.
  121. // The last result is discarded, as v is assumed to have an imaginary
  122. // last value of 1, as with MatrixDotVector.
  123. void VectorDotMatrix(const double* u, double* v) const;
  124. // Fills dw_[i][j] with the dot product u[i][] . v[j][], using elements
  125. // from u and v, starting with u[i][offset] and v[j][offset].
  126. // Note that (matching MatrixDotVector) v[last][] is missing, presumed 1.0.
  127. // Runs parallel if requested. Note that inputs must be transposed.
  128. void SumOuterTransposed(const TransposedArray& u, const TransposedArray& v,
  129. bool parallel);
  130. // Updates the weights using the given learning rate, momentum and adam_beta.
  131. // num_samples is used in the Adam correction factor.
  132. void Update(double learning_rate, double momentum, double adam_beta,
  133. int num_samples);
  134. // Adds the dw_ in other to the dw_ is *this.
  135. void AddDeltas(const WeightMatrix& other);
  136. // Sums the products of weight updates in *this and other, splitting into
  137. // positive (same direction) in *same and negative (different direction) in
  138. // *changed.
  139. void CountAlternators(const WeightMatrix& other, double* same,
  140. double* changed) const;
  141. void Debug2D(const char* msg);
  142. // Utility function converts an array of float to the corresponding array
  143. // of double.
  144. static void FloatToDouble(const GENERIC_2D_ARRAY<float>& wf,
  145. GENERIC_2D_ARRAY<double>* wd);
  146. private:
  147. // Choice between float and 8 bit int implementations.
  148. GENERIC_2D_ARRAY<double> wf_;
  149. GENERIC_2D_ARRAY<int8_t> wi_;
  150. // Transposed copy of wf_, used only for Backward, and set with each Update.
  151. TransposedArray wf_t_;
  152. // Which of wf_ and wi_ are we actually using.
  153. bool int_mode_;
  154. // True if we are running adam in this weight matrix.
  155. bool use_adam_;
  156. // If we are using wi_, then scales_ is a factor to restore the row product
  157. // with a vector to the correct range.
  158. GenericVector<double> scales_;
  159. // Weight deltas. dw_ is the new delta, and updates_ the momentum-decaying
  160. // amount to be added to wf_/wi_.
  161. GENERIC_2D_ARRAY<double> dw_;
  162. GENERIC_2D_ARRAY<double> updates_;
  163. // Iff use_adam_, the sum of squares of dw_. The number of samples is
  164. // given to Update(). Serialized iff use_adam_.
  165. GENERIC_2D_ARRAY<double> dw_sq_sum_;
  166. // The weights matrix reorganized in whatever way suits this instance.
  167. std::vector<int8_t> shaped_w_;
  168. };
  169. } // namespace tesseract.
  170. #endif // TESSERACT_LSTM_WEIGHTMATRIX_H_