stridemap.h 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: stridemap.h
  3. // Description: Indexing into a 4-d tensor held in a 2-d Array.
  4. // Author: Ray Smith
  5. //
  6. // (C) Copyright 2016, Google Inc.
  7. // Licensed under the Apache License, Version 2.0 (the "License");
  8. // you may not use this file except in compliance with the License.
  9. // You may obtain a copy of the License at
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. ///////////////////////////////////////////////////////////////////////
  17. #ifndef TESSERACT_LSTM_STRIDEMAP_H_
  18. #define TESSERACT_LSTM_STRIDEMAP_H_
  19. #include <cstring>
  20. #include <vector>
  21. namespace tesseract {
  22. // Enum describing the dimensions of the 'Tensor' in a NetworkIO.
  23. // A NetworkIO is analogous to a TF Tensor, except that the number of dimensions
  24. // is fixed (4), and they always have the same meaning. The underlying
  25. // representation is a 2-D array, for which the product batch*height*width
  26. // is always dim1 and depth is always dim2. FlexDimensions is used only for
  27. // batch, height, width with the StrideMap, and therefore represents the runtime
  28. // shape. The build-time shape is defined by StaticShape.
  29. enum FlexDimensions {
  30. FD_BATCH, // Index of multiple images.
  31. FD_HEIGHT, // y-coordinate in image.
  32. FD_WIDTH, // x-coordinate in image.
  33. FD_DIMSIZE, // Number of flexible non-depth dimensions.
  34. };
  35. // Encapsulation of information relating to the mapping from [batch][y][x] to
  36. // the first index into the 2-d array underlying a NetworkIO.
  37. class StrideMap {
  38. public:
  39. // Class holding the non-depth indices.
  40. class Index {
  41. public:
  42. explicit Index(const StrideMap& stride_map) : stride_map_(&stride_map) {
  43. InitToFirst();
  44. }
  45. Index(const StrideMap& stride_map, int batch, int y, int x)
  46. : stride_map_(&stride_map) {
  47. indices_[FD_BATCH] = batch;
  48. indices_[FD_HEIGHT] = y;
  49. indices_[FD_WIDTH] = x;
  50. SetTFromIndices();
  51. }
  52. // Accesses the index to the underlying array.
  53. int t() const { return t_; }
  54. int index(FlexDimensions dimension) const { return indices_[dimension]; }
  55. // Initializes the indices to the first valid location.
  56. void InitToFirst() {
  57. memset(indices_, 0, sizeof(indices_));
  58. t_ = 0;
  59. }
  60. // Initializes the indices to the last valid location.
  61. void InitToLast() { InitToLastOfBatch(MaxIndexOfDim(FD_BATCH)); }
  62. // Returns true if *this is a valid index.
  63. bool IsValid() const;
  64. // Returns true if the index of the given dimension is the last.
  65. bool IsLast(FlexDimensions dimension) const;
  66. // Given that the dimensions up to and including dim-1 are valid, returns
  67. // the maximum index for dimension dim.
  68. int MaxIndexOfDim(FlexDimensions dim) const;
  69. // Adds the given offset to the given dimension. Returns true if the result
  70. // makes a valid index.
  71. bool AddOffset(int offset, FlexDimensions dimension);
  72. // Increments the index in some encapsulated way that guarantees to remain
  73. // valid until it returns false, meaning that the iteration is complete.
  74. bool Increment();
  75. // Decrements the index in some encapsulated way that guarantees to remain
  76. // valid until it returns false, meaning that the iteration (that started
  77. // with InitToLast()) is complete.
  78. bool Decrement();
  79. private:
  80. // Initializes the indices to the last valid location in the given batch
  81. // index.
  82. void InitToLastOfBatch(int batch);
  83. // Computes and sets t_ from the current indices_.
  84. void SetTFromIndices();
  85. // Map into which *this is an index.
  86. const StrideMap* stride_map_;
  87. // Index to the first dimension of the underlying array.
  88. int t_;
  89. // Indices into the individual dimensions.
  90. int indices_[FD_DIMSIZE];
  91. };
  92. StrideMap() {
  93. memset(shape_, 0, sizeof(shape_));
  94. memset(t_increments_, 0, sizeof(t_increments_));
  95. }
  96. // Default copy constructor and operator= are OK to use here!
  97. // Sets up the stride for the given array of height, width pairs.
  98. void SetStride(const std::vector<std::pair<int, int>>& h_w_pairs);
  99. // Scales width and height dimensions by the given factors.
  100. void ScaleXY(int x_factor, int y_factor);
  101. // Reduces width to 1, across the batch, whatever the input size.
  102. void ReduceWidthTo1();
  103. // Transposes the width and height dimensions.
  104. void TransposeXY();
  105. // Returns the size of the given dimension.
  106. int Size(FlexDimensions dimension) const { return shape_[dimension]; }
  107. // Returns the total width required.
  108. int Width() const { return t_increments_[FD_BATCH] * shape_[FD_BATCH]; }
  109. private:
  110. // Computes t_increments_ from shape_.
  111. void ComputeTIncrements();
  112. // The size of each non-depth dimension.
  113. int shape_[FD_DIMSIZE];
  114. // Precomputed 't' increments for each dimension. This is the value of
  115. // the given dimension in the packed 3-d array that the shape_ represents.
  116. int t_increments_[FD_DIMSIZE];
  117. // Vector of size shape_[FD_BATCH] holds the height of each image in a batch.
  118. std::vector<int> heights_;
  119. // Vector of size shape_[FD_BATCH] holds the width of each image in a batch.
  120. std::vector<int> widths_;
  121. };
  122. } // namespace tesseract
  123. #endif // TESSERACT_LSTM_STRIDEMAP_H_