functions.h 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: functions.h
  3. // Description: Collection of function-objects used by the network layers.
  4. // Author: Ray Smith
  5. //
  6. // (C) Copyright 2014, Google Inc.
  7. // Licensed under the Apache License, Version 2.0 (the "License");
  8. // you may not use this file except in compliance with the License.
  9. // You may obtain a copy of the License at
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. ///////////////////////////////////////////////////////////////////////
  17. #ifndef TESSERACT_LSTM_FUNCTIONS_H_
  18. #define TESSERACT_LSTM_FUNCTIONS_H_
  19. #include "helpers.h"
  20. // Setting this to 1 or more causes massive dumps of debug data: weights,
  21. // updates, internal calculations etc, and reduces the number of test iterations
  22. // to a small number, so outputs can be diffed.
  23. #define DEBUG_DETAIL 0
  24. #if DEBUG_DETAIL > 0
  25. #undef _OPENMP // Disable open mp to get the outputs in sync.
  26. #endif
  27. namespace tesseract {
  28. // Size of static tables.
  29. constexpr int kTableSize = 4096;
  30. // Scale factor for float arg to int index.
  31. constexpr double kScaleFactor = 256.0;
  32. #if __cplusplus < 201402 || defined(__clang__) // C++11
  33. extern double TanhTable[];
  34. extern double LogisticTable[];
  35. #else // C++14 or newer
  36. typedef double (*LUT_FUNCTION)(int i);
  37. constexpr double LUTFuncTanh(int i) {
  38. return std::tanh(i / kScaleFactor);
  39. }
  40. constexpr double LUTFuncLog(int i) {
  41. return 1 / (1 + std::exp(-i / kScaleFactor));
  42. }
  43. template<int n, LUT_FUNCTION f>
  44. struct LUTTempl {
  45. constexpr LUTTempl() : table_() {
  46. for (auto i = 0; i < n; ++i) {
  47. table_[i] = f(i);
  48. }
  49. }
  50. const double& operator[](size_t i) const {
  51. return table_[i];
  52. }
  53. double table_[n];
  54. };
  55. extern const LUTTempl<kTableSize, LUTFuncTanh> TanhTable;
  56. extern const LUTTempl<kTableSize, LUTFuncLog> LogisticTable;
  57. #endif
  58. // Non-linearity (sigmoid) functions with cache tables and clipping.
  59. inline double Tanh(double x) {
  60. if (x < 0.0) return -Tanh(-x);
  61. x *= kScaleFactor;
  62. int index = static_cast<int>(x);
  63. if (index >= (kTableSize - 1)) return 1.0;
  64. double tanh_i0 = TanhTable[index];
  65. double tanh_i1 = TanhTable[index + 1];
  66. // Linear interpolation.
  67. return tanh_i0 + (tanh_i1 - tanh_i0) * (x - index);
  68. }
  69. inline double Logistic(double x) {
  70. if (x < 0.0) return 1.0 - Logistic(-x);
  71. x *= kScaleFactor;
  72. int index = static_cast<int>(x);
  73. if (index >= (kTableSize - 1)) return 1.0;
  74. double l0 = LogisticTable[index];
  75. double l1 = LogisticTable[index + 1];
  76. // Linear interpolation.
  77. return l0 + (l1 - l0) * (x - index);
  78. }
  79. // Non-linearity (sigmoid) functions and their derivatives.
  80. struct FFunc {
  81. inline double operator()(double x) const { return Logistic(x); }
  82. };
  83. struct FPrime {
  84. inline double operator()(double y) const { return y * (1.0 - y); }
  85. };
  86. struct ClipFFunc {
  87. inline double operator()(double x) const {
  88. if (x <= 0.0) return 0.0;
  89. if (x >= 1.0) return 1.0;
  90. return x;
  91. }
  92. };
  93. struct ClipFPrime {
  94. inline double operator()(double y) const {
  95. return 0.0 < y && y < 1.0 ? 1.0 : 0.0;
  96. }
  97. };
  98. struct Relu {
  99. inline double operator()(double x) const {
  100. if (x <= 0.0) return 0.0;
  101. return x;
  102. }
  103. };
  104. struct ReluPrime {
  105. inline double operator()(double y) const { return 0.0 < y ? 1.0 : 0.0; }
  106. };
  107. struct GFunc {
  108. inline double operator()(double x) const { return Tanh(x); }
  109. };
  110. struct GPrime {
  111. inline double operator()(double y) const { return 1.0 - y * y; }
  112. };
  113. struct ClipGFunc {
  114. inline double operator()(double x) const {
  115. if (x <= -1.0) return -1.0;
  116. if (x >= 1.0) return 1.0;
  117. return x;
  118. }
  119. };
  120. struct ClipGPrime {
  121. inline double operator()(double y) const {
  122. return -1.0 < y && y < 1.0 ? 1.0 : 0.0;
  123. }
  124. };
  125. struct HFunc {
  126. inline double operator()(double x) const { return Tanh(x); }
  127. };
  128. struct HPrime {
  129. inline double operator()(double y) const {
  130. double u = Tanh(y);
  131. return 1.0 - u * u;
  132. }
  133. };
  134. struct UnityFunc {
  135. inline double operator()(double /*x*/) const { return 1.0; }
  136. };
  137. struct IdentityFunc {
  138. inline double operator()(double x) const { return x; }
  139. };
  140. // Applies Func in-place to inout, of size n.
  141. template <class Func>
  142. inline void FuncInplace(int n, double* inout) {
  143. Func f;
  144. for (int i = 0; i < n; ++i) {
  145. inout[i] = f(inout[i]);
  146. }
  147. }
  148. // Applies Func to u and multiplies the result by v component-wise,
  149. // putting the product in out, all of size n.
  150. template <class Func>
  151. inline void FuncMultiply(const double* u, const double* v, int n, double* out) {
  152. Func f;
  153. for (int i = 0; i < n; ++i) {
  154. out[i] = f(u[i]) * v[i];
  155. }
  156. }
  157. // Applies the Softmax function in-place to inout, of size n.
  158. template <typename T>
  159. inline void SoftmaxInPlace(int n, T* inout) {
  160. if (n <= 0) return;
  161. // A limit on the negative range input to exp to guarantee non-zero output.
  162. const T kMaxSoftmaxActivation = 86.0f;
  163. T max_output = inout[0];
  164. for (int i = 1; i < n; i++) {
  165. T output = inout[i];
  166. if (output > max_output) max_output = output;
  167. }
  168. T prob_total = 0.0;
  169. for (int i = 0; i < n; i++) {
  170. T prob = inout[i] - max_output;
  171. prob = exp(ClipToRange(prob, -kMaxSoftmaxActivation, static_cast<T>(0)));
  172. prob_total += prob;
  173. inout[i] = prob;
  174. }
  175. if (prob_total > 0.0) {
  176. for (int i = 0; i < n; i++) inout[i] /= prob_total;
  177. }
  178. }
  179. // Copies n values of the given src vector to dest.
  180. inline void CopyVector(int n, const double* src, double* dest) {
  181. memcpy(dest, src, n * sizeof(dest[0]));
  182. }
  183. // Adds n values of the given src vector to dest.
  184. inline void AccumulateVector(int n, const double* src, double* dest) {
  185. for (int i = 0; i < n; ++i) dest[i] += src[i];
  186. }
  187. // Multiplies n values of inout in-place element-wise by the given src vector.
  188. inline void MultiplyVectorsInPlace(int n, const double* src, double* inout) {
  189. for (int i = 0; i < n; ++i) inout[i] *= src[i];
  190. }
  191. // Multiplies n values of u by v, element-wise, accumulating to out.
  192. inline void MultiplyAccumulate(int n, const double* u, const double* v,
  193. double* out) {
  194. for (int i = 0; i < n; i++) {
  195. out[i] += u[i] * v[i];
  196. }
  197. }
  198. // Sums the given 5 n-vectors putting the result into sum.
  199. inline void SumVectors(int n, const double* v1, const double* v2,
  200. const double* v3, const double* v4, const double* v5,
  201. double* sum) {
  202. for (int i = 0; i < n; ++i) {
  203. sum[i] = v1[i] + v2[i] + v3[i] + v4[i] + v5[i];
  204. }
  205. }
  206. // Sets the given n-vector vec to 0.
  207. template <typename T>
  208. inline void ZeroVector(int n, T* vec) {
  209. memset(vec, 0, n * sizeof(*vec));
  210. }
  211. // Clips the given vector vec, of size n to [lower, upper].
  212. template <typename T>
  213. inline void ClipVector(int n, T lower, T upper, T* vec) {
  214. for (int i = 0; i < n; ++i) vec[i] = ClipToRange(vec[i], lower, upper);
  215. }
  216. // Converts the given n-vector to a binary encoding of the maximum value,
  217. // encoded as vector of nf binary values.
  218. inline void CodeInBinary(int n, int nf, double* vec) {
  219. if (nf <= 0 || n < nf) return;
  220. int index = 0;
  221. double best_score = vec[0];
  222. for (int i = 1; i < n; ++i) {
  223. if (vec[i] > best_score) {
  224. best_score = vec[i];
  225. index = i;
  226. }
  227. }
  228. int mask = 1;
  229. for (int i = 0; i < nf; ++i, mask *= 2) {
  230. vec[i] = (index & mask) ? 1.0 : 0.0;
  231. }
  232. }
  233. } // namespace tesseract.
  234. #endif // TESSERACT_LSTM_FUNCTIONS_H_