plumbing.h 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: plumbing.h
  3. // Description: Base class for networks that organize other networks
  4. // eg series or parallel.
  5. // Author: Ray Smith
  6. // Created: Mon May 12 08:11:36 PST 2014
  7. //
  8. // (C) Copyright 2014, Google Inc.
  9. // Licensed under the Apache License, Version 2.0 (the "License");
  10. // you may not use this file except in compliance with the License.
  11. // You may obtain a copy of the License at
  12. // http://www.apache.org/licenses/LICENSE-2.0
  13. // Unless required by applicable law or agreed to in writing, software
  14. // distributed under the License is distributed on an "AS IS" BASIS,
  15. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. // See the License for the specific language governing permissions and
  17. // limitations under the License.
  18. ///////////////////////////////////////////////////////////////////////
  19. #ifndef TESSERACT_LSTM_PLUMBING_H_
  20. #define TESSERACT_LSTM_PLUMBING_H_
  21. #include "genericvector.h"
  22. #include "matrix.h"
  23. #include "network.h"
  24. namespace tesseract {
  25. // Holds a collection of other networks and forwards calls to each of them.
  26. class Plumbing : public Network {
  27. public:
  28. // ni_ and no_ will be set by AddToStack.
  29. explicit Plumbing(const STRING& name);
  30. ~Plumbing() override = default;
  31. // Returns the required shape input to the network.
  32. StaticShape InputShape() const override { return stack_[0]->InputShape(); }
  33. STRING spec() const override {
  34. return "Sub-classes of Plumbing must implement spec()!";
  35. }
  36. // Returns true if the given type is derived from Plumbing, and thus contains
  37. // multiple sub-networks that can have their own learning rate.
  38. bool IsPlumbingType() const override { return true; }
  39. // Suspends/Enables training by setting the training_ flag. Serialize and
  40. // DeSerialize only operate on the run-time data if state is false.
  41. void SetEnableTraining(TrainingState state) override;
  42. // Sets flags that control the action of the network. See NetworkFlags enum
  43. // for bit values.
  44. void SetNetworkFlags(uint32_t flags) override;
  45. // Sets up the network for training. Initializes weights using weights of
  46. // scale `range` picked according to the random number generator `randomizer`.
  47. // Note that randomizer is a borrowed pointer that should outlive the network
  48. // and should not be deleted by any of the networks.
  49. // Returns the number of weights initialized.
  50. int InitWeights(float range, TRand* randomizer) override;
  51. // Recursively searches the network for softmaxes with old_no outputs,
  52. // and remaps their outputs according to code_map. See network.h for details.
  53. int RemapOutputs(int old_no, const std::vector<int>& code_map) override;
  54. // Converts a float network to an int network.
  55. void ConvertToInt() override;
  56. // Provides a pointer to a TRand for any networks that care to use it.
  57. // Note that randomizer is a borrowed pointer that should outlive the network
  58. // and should not be deleted by any of the networks.
  59. void SetRandomizer(TRand* randomizer) override;
  60. // Adds the given network to the stack.
  61. virtual void AddToStack(Network* network);
  62. // Sets needs_to_backprop_ to needs_backprop and returns true if
  63. // needs_backprop || any weights in this network so the next layer forward
  64. // can be told to produce backprop for this layer if needed.
  65. bool SetupNeedsBackprop(bool needs_backprop) override;
  66. // Returns an integer reduction factor that the network applies to the
  67. // time sequence. Assumes that any 2-d is already eliminated. Used for
  68. // scaling bounding boxes of truth data.
  69. // WARNING: if GlobalMinimax is used to vary the scale, this will return
  70. // the last used scale factor. Call it before any forward, and it will return
  71. // the minimum scale factor of the paths through the GlobalMinimax.
  72. int XScaleFactor() const override;
  73. // Provides the (minimum) x scale factor to the network (of interest only to
  74. // input units) so they can determine how to scale bounding boxes.
  75. void CacheXScaleFactor(int factor) override;
  76. // Provides debug output on the weights.
  77. void DebugWeights() override;
  78. // Returns the current stack.
  79. const PointerVector<Network>& stack() const {
  80. return stack_;
  81. }
  82. // Returns a set of strings representing the layer-ids of all layers below.
  83. void EnumerateLayers(const STRING* prefix,
  84. GenericVector<STRING>* layers) const;
  85. // Returns a pointer to the network layer corresponding to the given id.
  86. Network* GetLayer(const char* id) const;
  87. // Returns the learning rate for a specific layer of the stack.
  88. float LayerLearningRate(const char* id) const {
  89. const float* lr_ptr = LayerLearningRatePtr(id);
  90. ASSERT_HOST(lr_ptr != nullptr);
  91. return *lr_ptr;
  92. }
  93. // Scales the learning rate for a specific layer of the stack.
  94. void ScaleLayerLearningRate(const char* id, double factor) {
  95. float* lr_ptr = LayerLearningRatePtr(id);
  96. ASSERT_HOST(lr_ptr != nullptr);
  97. *lr_ptr *= factor;
  98. }
  99. // Returns a pointer to the learning rate for the given layer id.
  100. float* LayerLearningRatePtr(const char* id) const;
  101. // Writes to the given file. Returns false in case of error.
  102. bool Serialize(TFile* fp) const override;
  103. // Reads from the given file. Returns false in case of error.
  104. bool DeSerialize(TFile* fp) override;
  105. // Updates the weights using the given learning rate, momentum and adam_beta.
  106. // num_samples is used in the adam computation iff use_adam_ is true.
  107. void Update(float learning_rate, float momentum, float adam_beta,
  108. int num_samples) override;
  109. // Sums the products of weight updates in *this and other, splitting into
  110. // positive (same direction) in *same and negative (different direction) in
  111. // *changed.
  112. void CountAlternators(const Network& other, double* same,
  113. double* changed) const override;
  114. protected:
  115. // The networks.
  116. PointerVector<Network> stack_;
  117. // Layer-specific learning rate iff network_flags_ & NF_LAYER_SPECIFIC_LR.
  118. // One element for each element of stack_.
  119. GenericVector<float> learning_rates_;
  120. };
  121. } // namespace tesseract.
  122. #endif // TESSERACT_LSTM_PLUMBING_H_