networkscratch.h 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: networkscratch.h
  3. // Description: Scratch space for Network layers that hides distinction
  4. // between float/int implementations.
  5. // Author: Ray Smith
  6. //
  7. // (C) Copyright 2014, Google Inc.
  8. // Licensed under the Apache License, Version 2.0 (the "License");
  9. // you may not use this file except in compliance with the License.
  10. // You may obtain a copy of the License at
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. // Unless required by applicable law or agreed to in writing, software
  13. // distributed under the License is distributed on an "AS IS" BASIS,
  14. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. // See the License for the specific language governing permissions and
  16. // limitations under the License.
  17. ///////////////////////////////////////////////////////////////////////
  18. #ifndef TESSERACT_LSTM_NETWORKSCRATCH_H_
  19. #define TESSERACT_LSTM_NETWORKSCRATCH_H_
  20. #include "genericvector.h"
  21. #include "matrix.h"
  22. #include "networkio.h"
  23. #include "svutil.h"
  24. namespace tesseract {
  25. // Generic scratch space for network layers. Provides NetworkIO that can store
  26. // a complete set (over time) of intermediates, and GenericVector<float>
  27. // scratch space that auto-frees after use. The aim here is to provide a set
  28. // of temporary buffers to network layers that can be reused between layers
  29. // and don't have to be reallocated on each call.
  30. class NetworkScratch {
  31. public:
  32. NetworkScratch() : int_mode_(false) {}
  33. ~NetworkScratch() = default;
  34. // Sets the network representation. If the representation is integer, then
  35. // default (integer) NetworkIOs are separated from the always-float variety.
  36. // This saves memory by having separate int-specific and float-specific
  37. // stacks. If the network representation is float, then all NetworkIOs go
  38. // to the float stack.
  39. void set_int_mode(bool int_mode) {
  40. int_mode_ = int_mode;
  41. }
  42. // Class that acts like a NetworkIO (by having an implicit cast operator),
  43. // yet actually holds a pointer to NetworkIOs in the source NetworkScratch,
  44. // and knows how to unstack the borrowed pointers on destruction.
  45. class IO {
  46. public:
  47. // The NetworkIO should be sized after construction.
  48. IO(const NetworkIO& src, NetworkScratch* scratch)
  49. : int_mode_(scratch->int_mode_ && src.int_mode()),
  50. scratch_space_(scratch) {
  51. network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
  52. : scratch_space_->float_stack_.Borrow();
  53. }
  54. // Default constructor for arrays. Use one of the Resize functions
  55. // below to initialize and size.
  56. IO() : int_mode_(false), network_io_(nullptr), scratch_space_(nullptr) {}
  57. ~IO() {
  58. if (scratch_space_ == nullptr) {
  59. ASSERT_HOST(network_io_ == nullptr);
  60. } else if (int_mode_) {
  61. scratch_space_->int_stack_.Return(network_io_);
  62. } else {
  63. scratch_space_->float_stack_.Return(network_io_);
  64. }
  65. }
  66. // Resizes the array (and stride), avoiding realloc if possible, to the
  67. // size from various size specs:
  68. // Same time size, given number of features.
  69. void Resize(const NetworkIO& src, int num_features,
  70. NetworkScratch* scratch) {
  71. if (scratch_space_ == nullptr) {
  72. int_mode_ = scratch->int_mode_ && src.int_mode();
  73. scratch_space_ = scratch;
  74. network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
  75. : scratch_space_->float_stack_.Borrow();
  76. }
  77. network_io_->Resize(src, num_features);
  78. }
  79. // Resizes to a specific size as a temp buffer. No batches, no y-dim.
  80. void Resize2d(bool int_mode, int width, int num_features,
  81. NetworkScratch* scratch) {
  82. if (scratch_space_ == nullptr) {
  83. int_mode_ = scratch->int_mode_ && int_mode;
  84. scratch_space_ = scratch;
  85. network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
  86. : scratch_space_->float_stack_.Borrow();
  87. }
  88. network_io_->Resize2d(int_mode, width, num_features);
  89. }
  90. // Resize forcing a float representation with the width of src and the given
  91. // number of features.
  92. void ResizeFloat(const NetworkIO& src, int num_features,
  93. NetworkScratch* scratch) {
  94. if (scratch_space_ == nullptr) {
  95. int_mode_ = false;
  96. scratch_space_ = scratch;
  97. network_io_ = scratch_space_->float_stack_.Borrow();
  98. }
  99. network_io_->ResizeFloat(src, num_features);
  100. }
  101. // Returns a ref to a NetworkIO that enables *this to be treated as if
  102. // it were just a NetworkIO*.
  103. NetworkIO& operator*() {
  104. return *network_io_;
  105. }
  106. NetworkIO* operator->() {
  107. return network_io_;
  108. }
  109. operator NetworkIO*() {
  110. return network_io_;
  111. }
  112. private:
  113. // True if this is from the always-float stack, otherwise the default stack.
  114. bool int_mode_;
  115. // The NetworkIO that we have borrowed from the scratch_space_.
  116. NetworkIO* network_io_;
  117. // The source scratch_space_. Borrowed pointer, used to free the
  118. // NetworkIO. Don't delete!
  119. NetworkScratch* scratch_space_;
  120. }; // class IO.
  121. // Class that acts like a fixed array of float, yet actually uses space
  122. // from a GenericVector<float> in the source NetworkScratch, and knows how
  123. // to unstack the borrowed vector on destruction.
  124. class FloatVec {
  125. public:
  126. // The array will have size elements in it, uninitialized.
  127. FloatVec(int size, NetworkScratch* scratch)
  128. : vec_(nullptr), scratch_space_(scratch) {
  129. Init(size, scratch);
  130. }
  131. // Default constructor is for arrays. Use Init to setup.
  132. FloatVec() : vec_(nullptr), data_(nullptr), scratch_space_(nullptr) {}
  133. ~FloatVec() {
  134. if (scratch_space_ != nullptr) scratch_space_->vec_stack_.Return(vec_);
  135. }
  136. void Init(int size, NetworkScratch* scratch) {
  137. if (scratch_space_ != nullptr && vec_ != nullptr)
  138. scratch_space_->vec_stack_.Return(vec_);
  139. scratch_space_ = scratch;
  140. vec_ = scratch_space_->vec_stack_.Borrow();
  141. vec_->resize_no_init(size);
  142. data_ = &(*vec_)[0];
  143. }
  144. // Use the cast operator instead of operator[] so the FloatVec can be used
  145. // as a double* argument to a function call.
  146. operator double*() const { return data_; }
  147. double* get() { return data_; }
  148. private:
  149. // Vector borrowed from the scratch space. Use Return to free it.
  150. GenericVector<double>* vec_;
  151. // Short-cut pointer to the underlying array.
  152. double* data_;
  153. // The source scratch_space_. Borrowed pointer, used to free the
  154. // vector. Don't delete!
  155. NetworkScratch* scratch_space_;
  156. }; // class FloatVec
  157. // Class that acts like a 2-D array of double, yet actually uses space
  158. // from the source NetworkScratch, and knows how to unstack the borrowed
  159. // array on destruction.
  160. class GradientStore {
  161. public:
  162. // Default constructor is for arrays. Use Init to setup.
  163. GradientStore() : array_(nullptr), scratch_space_(nullptr) {}
  164. ~GradientStore() {
  165. if (scratch_space_ != nullptr) scratch_space_->array_stack_.Return(array_);
  166. }
  167. void Init(int size1, int size2, NetworkScratch* scratch) {
  168. if (scratch_space_ != nullptr && array_ != nullptr)
  169. scratch_space_->array_stack_.Return(array_);
  170. scratch_space_ = scratch;
  171. array_ = scratch_space_->array_stack_.Borrow();
  172. array_->Resize(size1, size2, 0.0);
  173. }
  174. // Accessors to get to the underlying TransposedArray.
  175. TransposedArray* get() const { return array_; }
  176. const TransposedArray& operator*() const { return *array_; }
  177. private:
  178. // Array borrowed from the scratch space. Use Return to free it.
  179. TransposedArray* array_;
  180. // The source scratch_space_. Borrowed pointer, used to free the
  181. // vector. Don't delete!
  182. NetworkScratch* scratch_space_;
  183. }; // class GradientStore
  184. // Class that does the work of holding a stack of objects, a stack pointer
  185. // and a vector of in-use flags, so objects can be returned out of order.
  186. // It is safe to attempt to Borrow/Return in multiple threads.
  187. template<typename T> class Stack {
  188. public:
  189. Stack() : stack_top_(0) {
  190. }
  191. // Lends out the next free item, creating one if none available, sets
  192. // the used flags and increments the stack top.
  193. T* Borrow() {
  194. SVAutoLock lock(&mutex_);
  195. if (stack_top_ == stack_.size()) {
  196. stack_.push_back(new T);
  197. flags_.push_back(false);
  198. }
  199. flags_[stack_top_] = true;
  200. return stack_[stack_top_++];
  201. }
  202. // Takes back the given item, and marks it free. Item does not have to be
  203. // the most recently lent out, but free slots don't get re-used until the
  204. // blocking item is returned. The assumption is that there will only be
  205. // small, temporary variations from true stack use. (Determined by the order
  206. // of destructors within a local scope.)
  207. void Return(T* item) {
  208. SVAutoLock lock(&mutex_);
  209. // Linear search will do.
  210. int index = stack_top_ - 1;
  211. while (index >= 0 && stack_[index] != item) --index;
  212. if (index >= 0) flags_[index] = false;
  213. while (stack_top_ > 0 && !flags_[stack_top_ - 1]) --stack_top_;
  214. }
  215. private:
  216. PointerVector<T> stack_;
  217. GenericVector<bool> flags_;
  218. int stack_top_;
  219. SVMutex mutex_;
  220. }; // class Stack.
  221. private:
  222. // If true, the network weights are int8_t, if false, float.
  223. bool int_mode_;
  224. // Stacks of NetworkIO and GenericVector<float>. Once allocated, they are not
  225. // deleted until the NetworkScratch is deleted.
  226. Stack<NetworkIO> int_stack_;
  227. Stack<NetworkIO> float_stack_;
  228. Stack<GenericVector<double> > vec_stack_;
  229. Stack<TransposedArray> array_stack_;
  230. };
  231. } // namespace tesseract.
  232. #endif // TESSERACT_LSTM_NETWORKSCRATCH_H_