intsimdmatrix.h 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: intsimdmatrix.h
  3. // Description: Base class for 8-bit int SIMD matrix multipliers.
  4. // Author: Ray Smith
  5. // Created: Tue Aug 15 07:37:20 PST 2017
  6. //
  7. // (C) Copyright 2017, Google Inc.
  8. // Licensed under the Apache License, Version 2.0 (the "License");
  9. // you may not use this file except in compliance with the License.
  10. // You may obtain a copy of the License at
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. // Unless required by applicable law or agreed to in writing, software
  13. // distributed under the License is distributed on an "AS IS" BASIS,
  14. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. // See the License for the specific language governing permissions and
  16. // limitations under the License.
  17. ///////////////////////////////////////////////////////////////////////
  18. #ifndef TESSERACT_ARCH_INTSIMDMATRIX_H_
  19. #define TESSERACT_ARCH_INTSIMDMATRIX_H_
  20. #include <cstdint>
  21. #include <vector>
  22. template <class T>
  23. class GENERIC_2D_ARRAY;
  24. template <typename T>
  25. class GenericVector;
  26. namespace tesseract {
  27. // Base class for a SIMD function to multiply a matrix by a vector, with sources
  28. // of 8-bit signed integer, and result in a double, after appropriate scaling.
  29. // Assumes a specific method of multiplication that can be applied to any size
  30. // and number of SIMD registers as follows:
  31. // int32_t results are computed with num_outputs_per_register_ in each of
  32. // max_output_registers_ result registers, repeatedly until it would make too
  33. // many results, then the number of registers is halved, and so-on down to a
  34. // single result register. The last calculation only outputs the required number
  35. // of results instead of writing beyond the bounds. Eg: matrix has 75 outputs,
  36. // num_outputs_per_register_ = 4, and max_output_registers_ = 8,
  37. // Step 1: 8x4=32 results are computed,
  38. // Step 2: 8x4=32 again, total 64,
  39. // Step 3: 2x4=8 (since 8x4 is too many, so is 4x4), total 72,
  40. // Step 4: 1x3, total 75.
  41. // Each step above is computed using a PartialFunc, which runs over the input
  42. // vector once. The input is read one registerful of num_inputs_per_register_
  43. // at a time (presumably 4x num_outputs_per_register_ since they are int8_t)
  44. // so the inputs MUST BE PADDED to a multiple of num_inputs_per_register_.
  45. // Since it is slow (on Intel at least) to horizontally add in a register,
  46. // provision is made to process num_inputs_per_group_ inputs at a time, with
  47. // the group being replicated num_input_groups_ times and multiplied by a
  48. // num_inputs_per_group_ by num_input_groups_ rectangle of the weights matrix.
  49. // This is most convenient if num_inputs_per_group_ is 4, and the product
  50. // sign-extends and sums 8x8=16 bit results to 32 bits, adding 4 adjacent
  51. // results in the process, but it doesn't have to be implemented that way.
  52. // The weights are re-ordered by Init() to be used sequentially by the above
  53. // algorithm, followed by the biases, so they can be added at the end.
  54. // The base class computes the base C++ implementation.
  55. // NOTE that, although the subclasses execute on different SIMD hardware, no
  56. // virtual methods are needed, as the constructor sets up everything that
  57. // is required to allow the base class implementation to do all the work.
  58. struct IntSimdMatrix {
  59. // Computes a reshaped copy of the weight matrix w.
  60. void Init(const GENERIC_2D_ARRAY<int8_t>& w,
  61. std::vector<int8_t>& shaped_w) const;
  62. // Rounds the size up to a multiple of the input register size (in int8_t).
  63. int RoundInputs(int size) const {
  64. return Roundup(size, num_inputs_per_register_);
  65. }
  66. // Rounds the size up to a multiple of the output register size (in int32_t).
  67. int RoundOutputs(int size) const {
  68. return Roundup(size, num_outputs_per_register_);
  69. }
  70. // Computes matrix.vector v = Wu.
  71. // u is of size W.dim2() - 1 and the output v is of size W.dim1().
  72. // u is imagined to have an extra element at the end with value 1, to
  73. // implement the bias, but it doesn't actually have it.
  74. // Computes the base C++ implementation.
  75. static void MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w,
  76. const GenericVector<double>& scales,
  77. const int8_t* u, double* v);
  78. // Rounds the input up to a multiple of the given factor.
  79. static int Roundup(int input, int factor) {
  80. return (input + factor - 1) / factor * factor;
  81. }
  82. // Computes matrix.vector v = Wu.
  83. // u is of size W.dim2() - 1 and the output v is of size W.dim1().
  84. // u is imagined to have an extra element at the end with value 1, to
  85. // implement the bias, but it doesn't actually have it.
  86. // Uses an optimized implementation with partial funcs.
  87. // NOTE: The size of the input vector (u) must be padded using
  88. // RoundInputs above.
  89. // The input will be over-read to the extent of the padding. There are no
  90. // alignment requirements.
  91. using MatrixDotVectorFunction = void (*)(int, int, const int8_t*,
  92. const double*, const int8_t*,
  93. double*);
  94. MatrixDotVectorFunction matrixDotVectorFunction;
  95. // Number of 32 bit outputs held in each register.
  96. int num_outputs_per_register_;
  97. // Maximum number of registers that we will use to hold outputs.
  98. int max_output_registers_;
  99. // Number of 8 bit inputs in the inputs register.
  100. int num_inputs_per_register_;
  101. // Number of inputs in each weight group.
  102. int num_inputs_per_group_;
  103. // Number of groups of inputs to be broadcast.
  104. // num_input_groups_ = num_inputs_per_register_ / num_inputs_per_group_
  105. static const IntSimdMatrix* intSimdMatrix;
  106. static const IntSimdMatrix intSimdMatrixAVX2;
  107. static const IntSimdMatrix intSimdMatrixSSE;
  108. };
  109. } // namespace tesseract
  110. #endif // TESSERACT_ARCH_INTSIMDMATRIX_H_