input.h 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: input.h
  3. // Description: Input layer class for neural network implementations.
  4. // Author: Ray Smith
  5. //
  6. // (C) Copyright 2014, Google Inc.
  7. // Licensed under the Apache License, Version 2.0 (the "License");
  8. // you may not use this file except in compliance with the License.
  9. // You may obtain a copy of the License at
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. ///////////////////////////////////////////////////////////////////////
  17. #ifndef TESSERACT_LSTM_INPUT_H_
  18. #define TESSERACT_LSTM_INPUT_H_
  19. #include "network.h"
  20. class ScrollView;
  21. namespace tesseract {
  22. class Input : public Network {
  23. public:
  24. Input(const STRING& name, int ni, int no);
  25. Input(const STRING& name, const StaticShape& shape);
  26. ~Input() override = default;
  27. STRING spec() const override {
  28. STRING spec;
  29. spec.add_str_int("", shape_.batch());
  30. spec.add_str_int(",", shape_.height());
  31. spec.add_str_int(",", shape_.width());
  32. spec.add_str_int(",", shape_.depth());
  33. return spec;
  34. }
  35. // Returns the required shape input to the network.
  36. StaticShape InputShape() const override { return shape_; }
  37. // Returns the shape output from the network given an input shape (which may
  38. // be partially unknown ie zero).
  39. StaticShape OutputShape(const StaticShape& input_shape) const override {
  40. return shape_;
  41. }
  42. // Writes to the given file. Returns false in case of error.
  43. // Should be overridden by subclasses, but called by their Serialize.
  44. bool Serialize(TFile* fp) const override;
  45. // Reads from the given file. Returns false in case of error.
  46. bool DeSerialize(TFile* fp) override;
  47. // Returns an integer reduction factor that the network applies to the
  48. // time sequence. Assumes that any 2-d is already eliminated. Used for
  49. // scaling bounding boxes of truth data.
  50. // WARNING: if GlobalMinimax is used to vary the scale, this will return
  51. // the last used scale factor. Call it before any forward, and it will return
  52. // the minimum scale factor of the paths through the GlobalMinimax.
  53. int XScaleFactor() const override;
  54. // Provides the (minimum) x scale factor to the network (of interest only to
  55. // input units) so they can determine how to scale bounding boxes.
  56. void CacheXScaleFactor(int factor) override;
  57. // Runs forward propagation of activations on the input line.
  58. // See Network for a detailed discussion of the arguments.
  59. void Forward(bool debug, const NetworkIO& input,
  60. const TransposedArray* input_transpose,
  61. NetworkScratch* scratch, NetworkIO* output) override;
  62. // Runs backward propagation of errors on the deltas line.
  63. // See Network for a detailed discussion of the arguments.
  64. bool Backward(bool debug, const NetworkIO& fwd_deltas,
  65. NetworkScratch* scratch,
  66. NetworkIO* back_deltas) override;
  67. // Creates and returns a Pix of appropriate size for the network from the
  68. // image_data. If non-null, *image_scale returns the image scale factor used.
  69. // Returns nullptr on error.
  70. /* static */
  71. static Pix* PrepareLSTMInputs(const ImageData& image_data,
  72. const Network* network, int min_width,
  73. TRand* randomizer, float* image_scale);
  74. // Converts the given pix to a NetworkIO of height and depth appropriate to
  75. // the given StaticShape:
  76. // If depth == 3, convert to 24 bit color, otherwise normalized grey.
  77. // Scale to target height, if the shape's height is > 1, or its depth if the
  78. // height == 1. If height == 0 then no scaling.
  79. // NOTE: It isn't safe for multiple threads to call this on the same pix.
  80. static void PreparePixInput(const StaticShape& shape, const Pix* pix,
  81. TRand* randomizer, NetworkIO* input);
  82. private:
  83. void DebugWeights() override {
  84. tprintf("Must override Network::DebugWeights for type %d\n", type_);
  85. }
  86. // Input shape determines how images are dealt with.
  87. StaticShape shape_;
  88. // Cached total network x scale factor for scaling bounding boxes.
  89. int cached_x_scale_;
  90. };
  91. } // namespace tesseract.
  92. #endif // TESSERACT_LSTM_INPUT_H_