reversed.h 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: reversed.h
  3. // Description: Runs a single network on time-reversed input, reversing output.
  4. // Author: Ray Smith
  5. // Created: Thu May 02 08:38:06 PST 2013
  6. //
  7. // (C) Copyright 2013, Google Inc.
  8. // Licensed under the Apache License, Version 2.0 (the "License");
  9. // you may not use this file except in compliance with the License.
  10. // You may obtain a copy of the License at
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. // Unless required by applicable law or agreed to in writing, software
  13. // distributed under the License is distributed on an "AS IS" BASIS,
  14. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. // See the License for the specific language governing permissions and
  16. // limitations under the License.
  17. ///////////////////////////////////////////////////////////////////////
  18. #ifndef TESSERACT_LSTM_REVERSED_H_
  19. #define TESSERACT_LSTM_REVERSED_H_
  20. #include "matrix.h"
  21. #include "plumbing.h"
  22. namespace tesseract {
  23. // C++ Implementation of the Reversed class from lstm.py.
  24. class Reversed : public Plumbing {
  25. public:
  26. explicit Reversed(const STRING& name, NetworkType type);
  27. ~Reversed() override = default;
  28. // Returns the shape output from the network given an input shape (which may
  29. // be partially unknown ie zero).
  30. StaticShape OutputShape(const StaticShape& input_shape) const override;
  31. STRING spec() const override {
  32. STRING spec(type_ == NT_XREVERSED ? "Rx"
  33. : (type_ == NT_YREVERSED ? "Ry" : "Txy"));
  34. // For most simple cases, we will output Rx<net> or Ry<net> where <net> is
  35. // the network in stack_[0], but in the special case that <net> is an
  36. // LSTM, we will just output the LSTM's spec modified to take the reversal
  37. // into account. This is because when the user specified Lfy64, we actually
  38. // generated TxyLfx64, and if the user specified Lrx64 we actually
  39. // generated RxLfx64, and we want to display what the user asked for.
  40. STRING net_spec = stack_[0]->spec();
  41. if (net_spec[0] == 'L') {
  42. // Setup a from and to character according to the type of the reversal
  43. // such that the LSTM spec gets modified to the spec that the user
  44. // asked for
  45. char from = 'f';
  46. char to = 'r';
  47. if (type_ == NT_XYTRANSPOSE) {
  48. from = 'x';
  49. to = 'y';
  50. }
  51. // Change the from char to the to char.
  52. for (int i = 0; i < net_spec.length(); ++i) {
  53. if (net_spec[i] == from) net_spec[i] = to;
  54. }
  55. return net_spec;
  56. }
  57. spec += net_spec;
  58. return spec;
  59. }
  60. // Takes ownership of the given network to make it the reversed one.
  61. void SetNetwork(Network* network);
  62. // Runs forward propagation of activations on the input line.
  63. // See Network for a detailed discussion of the arguments.
  64. void Forward(bool debug, const NetworkIO& input,
  65. const TransposedArray* input_transpose,
  66. NetworkScratch* scratch, NetworkIO* output) override;
  67. // Runs backward propagation of errors on the deltas line.
  68. // See Network for a detailed discussion of the arguments.
  69. bool Backward(bool debug, const NetworkIO& fwd_deltas,
  70. NetworkScratch* scratch,
  71. NetworkIO* back_deltas) override;
  72. private:
  73. // Copies src to *dest with the reversal according to type_.
  74. void ReverseData(const NetworkIO& src, NetworkIO* dest) const;
  75. };
  76. } // namespace tesseract.
  77. #endif // TESSERACT_LSTM_REVERSED_H_