weighted_mean.hpp 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. // Copyright 2018 Hans Dembinski
  2. //
  3. // Distributed under the Boost Software License, version 1.0.
  4. // (See accompanying file LICENSE_1_0.txt
  5. // or copy at http://www.boost.org/LICENSE_1_0.txt)
  6. #ifndef BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_MEAN_HPP
  7. #define BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_MEAN_HPP
  8. #include <boost/histogram/fwd.hpp>
  9. #include <type_traits>
  10. namespace boost {
  11. namespace histogram {
  12. namespace accumulators {
  13. /**
  14. Calculates mean and variance of weighted sample.
  15. Uses West's incremental algorithm to improve numerical stability
  16. of mean and variance computation.
  17. */
  18. template <typename RealType>
  19. class weighted_mean {
  20. public:
  21. weighted_mean() = default;
  22. weighted_mean(const RealType& wsum, const RealType& wsum2, const RealType& mean,
  23. const RealType& variance)
  24. : sum_of_weights_(wsum)
  25. , sum_of_weights_squared_(wsum2)
  26. , weighted_mean_(mean)
  27. , sum_of_weighted_deltas_squared_(
  28. variance * (sum_of_weights_ - sum_of_weights_squared_ / sum_of_weights_)) {}
  29. void operator()(const RealType& x) { operator()(1, x); }
  30. void operator()(const RealType& w, const RealType& x) {
  31. sum_of_weights_ += w;
  32. sum_of_weights_squared_ += w * w;
  33. const auto delta = x - weighted_mean_;
  34. weighted_mean_ += w * delta / sum_of_weights_;
  35. sum_of_weighted_deltas_squared_ += w * delta * (x - weighted_mean_);
  36. }
  37. template <typename T>
  38. weighted_mean& operator+=(const weighted_mean<T>& rhs) {
  39. const auto tmp = weighted_mean_ * sum_of_weights_ +
  40. static_cast<RealType>(rhs.weighted_mean_ * rhs.sum_of_weights_);
  41. sum_of_weights_ += static_cast<RealType>(rhs.sum_of_weights_);
  42. sum_of_weights_squared_ += static_cast<RealType>(rhs.sum_of_weights_squared_);
  43. weighted_mean_ = tmp / sum_of_weights_;
  44. sum_of_weighted_deltas_squared_ +=
  45. static_cast<RealType>(rhs.sum_of_weighted_deltas_squared_);
  46. return *this;
  47. }
  48. weighted_mean& operator*=(const RealType& s) {
  49. weighted_mean_ *= s;
  50. sum_of_weighted_deltas_squared_ *= s * s;
  51. return *this;
  52. }
  53. template <typename T>
  54. bool operator==(const weighted_mean<T>& rhs) const noexcept {
  55. return sum_of_weights_ == rhs.sum_of_weights_ &&
  56. sum_of_weights_squared_ == rhs.sum_of_weights_squared_ &&
  57. weighted_mean_ == rhs.weighted_mean_ &&
  58. sum_of_weighted_deltas_squared_ == rhs.sum_of_weighted_deltas_squared_;
  59. }
  60. template <typename T>
  61. bool operator!=(const T& rhs) const noexcept {
  62. return !operator==(rhs);
  63. }
  64. const RealType& sum_of_weights() const noexcept { return sum_of_weights_; }
  65. const RealType& value() const noexcept { return weighted_mean_; }
  66. RealType variance() const {
  67. return sum_of_weighted_deltas_squared_ /
  68. (sum_of_weights_ - sum_of_weights_squared_ / sum_of_weights_);
  69. }
  70. template <class Archive>
  71. void serialize(Archive&, unsigned /* version */);
  72. private:
  73. RealType sum_of_weights_ = RealType(), sum_of_weights_squared_ = RealType(),
  74. weighted_mean_ = RealType(), sum_of_weighted_deltas_squared_ = RealType();
  75. };
  76. } // namespace accumulators
  77. } // namespace histogram
  78. } // namespace boost
  79. #ifndef BOOST_HISTOGRAM_DOXYGEN_INVOKED
  80. namespace std {
  81. template <class T, class U>
  82. /// Specialization for boost::histogram::accumulators::weighted_mean.
  83. struct common_type<boost::histogram::accumulators::weighted_mean<T>,
  84. boost::histogram::accumulators::weighted_mean<U>> {
  85. using type = boost::histogram::accumulators::weighted_mean<common_type_t<T, U>>;
  86. };
  87. } // namespace std
  88. #endif
  89. #endif