static_shape.h 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. ///////////////////////////////////////////////////////////////////////
  2. // File: static_shape.h
  3. // Description: Defines the size of the 4-d tensor input/output from a network.
  4. // Author: Ray Smith
  5. // Created: Fri Oct 14 09:07:31 PST 2016
  6. //
  7. // (C) Copyright 2016, Google Inc.
  8. // Licensed under the Apache License, Version 2.0 (the "License");
  9. // you may not use this file except in compliance with the License.
  10. // You may obtain a copy of the License at
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. // Unless required by applicable law or agreed to in writing, software
  13. // distributed under the License is distributed on an "AS IS" BASIS,
  14. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. // See the License for the specific language governing permissions and
  16. // limitations under the License.
  17. ///////////////////////////////////////////////////////////////////////
  18. #ifndef TESSERACT_LSTM_STATIC_SHAPE_H_
  19. #define TESSERACT_LSTM_STATIC_SHAPE_H_
  20. #include "serialis.h" // for TFile
  21. #include "tprintf.h" // for tprintf
  22. namespace tesseract {
  23. // Enum describing the loss function to apply during training and/or the
  24. // decoding method to apply at runtime.
  25. enum LossType {
  26. LT_NONE, // Undefined.
  27. LT_CTC, // Softmax with standard CTC for training/decoding.
  28. LT_SOFTMAX, // Outputs sum to 1 in fixed positions.
  29. LT_LOGISTIC, // Logistic outputs with independent values.
  30. };
  31. // Simple class to hold the tensor shape that is known at network build time
  32. // and the LossType of the loss function.
  33. class StaticShape {
  34. public:
  35. StaticShape()
  36. : batch_(0), height_(0), width_(0), depth_(0), loss_type_(LT_NONE) {}
  37. int batch() const { return batch_; }
  38. void set_batch(int value) { batch_ = value; }
  39. int height() const { return height_; }
  40. void set_height(int value) { height_ = value; }
  41. int width() const { return width_; }
  42. void set_width(int value) { width_ = value; }
  43. int depth() const { return depth_; }
  44. void set_depth(int value) { depth_ = value; }
  45. LossType loss_type() const { return loss_type_; }
  46. void set_loss_type(LossType value) { loss_type_ = value; }
  47. void SetShape(int batch, int height, int width, int depth) {
  48. batch_ = batch;
  49. height_ = height;
  50. width_ = width;
  51. depth_ = depth;
  52. }
  53. void Print() const {
  54. tprintf("Batch=%d, Height=%d, Width=%d, Depth=%d, loss=%d\n", batch_,
  55. height_, width_, depth_, loss_type_);
  56. }
  57. bool DeSerialize(TFile *fp) {
  58. int32_t tmp = LT_NONE;
  59. bool result =
  60. fp->DeSerialize(&batch_) &&
  61. fp->DeSerialize(&height_) &&
  62. fp->DeSerialize(&width_) &&
  63. fp->DeSerialize(&depth_) &&
  64. fp->DeSerialize(&tmp);
  65. loss_type_ = static_cast<LossType>(tmp);
  66. return result;
  67. }
  68. bool Serialize(TFile *fp) const {
  69. int32_t tmp = loss_type_;
  70. return
  71. fp->Serialize(&batch_) &&
  72. fp->Serialize(&height_) &&
  73. fp->Serialize(&width_) &&
  74. fp->Serialize(&depth_) &&
  75. fp->Serialize(&tmp);
  76. }
  77. private:
  78. // Size of the 4-D tensor input/output to a network. A value of zero is
  79. // allowed for all except depth_ and means to be determined at runtime, and
  80. // regarded as variable.
  81. // Number of elements in a batch, or number of frames in a video stream.
  82. int32_t batch_;
  83. // Height of the image.
  84. int32_t height_;
  85. // Width of the image.
  86. int32_t width_;
  87. // Depth of the image. (Number of "nodes").
  88. int32_t depth_;
  89. // How to train/interpret the output.
  90. LossType loss_type_;
  91. };
  92. } // namespace tesseract
  93. #endif // TESSERACT_LSTM_STATIC_SHAPE_H_