discrete_distribution.hpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. /* boost random/discrete_distribution.hpp header file
  2. *
  3. * Copyright Steven Watanabe 2009-2011
  4. * Distributed under the Boost Software License, Version 1.0. (See
  5. * accompanying file LICENSE_1_0.txt or copy at
  6. * http://www.boost.org/LICENSE_1_0.txt)
  7. *
  8. * See http://www.boost.org for most recent version including documentation.
  9. *
  10. * $Id$
  11. */
  12. #ifndef BOOST_RANDOM_DISCRETE_DISTRIBUTION_HPP_INCLUDED
  13. #define BOOST_RANDOM_DISCRETE_DISTRIBUTION_HPP_INCLUDED
  14. #include <vector>
  15. #include <limits>
  16. #include <numeric>
  17. #include <utility>
  18. #include <iterator>
  19. #include <boost/assert.hpp>
  20. #include <boost/random/uniform_01.hpp>
  21. #include <boost/random/uniform_int_distribution.hpp>
  22. #include <boost/random/detail/config.hpp>
  23. #include <boost/random/detail/operators.hpp>
  24. #include <boost/random/detail/vector_io.hpp>
  25. #ifndef BOOST_NO_CXX11_HDR_INITIALIZER_LIST
  26. #include <initializer_list>
  27. #endif
  28. #include <boost/random/detail/disable_warnings.hpp>
  29. namespace boost {
  30. namespace random {
  31. namespace detail {
  32. template<class IntType, class WeightType>
  33. struct integer_alias_table {
  34. WeightType get_weight(IntType bin) const {
  35. WeightType result = _average;
  36. if(bin < _excess) ++result;
  37. return result;
  38. }
  39. template<class Iter>
  40. WeightType init_average(Iter begin, Iter end) {
  41. WeightType weight_average = 0;
  42. IntType excess = 0;
  43. IntType n = 0;
  44. // weight_average * n + excess == current partial sum
  45. // This is a bit messy, but it's guaranteed not to overflow
  46. for(Iter iter = begin; iter != end; ++iter) {
  47. ++n;
  48. if(*iter < weight_average) {
  49. WeightType diff = weight_average - *iter;
  50. weight_average -= diff / n;
  51. if(diff % n > excess) {
  52. --weight_average;
  53. excess += n - diff % n;
  54. } else {
  55. excess -= diff % n;
  56. }
  57. } else {
  58. WeightType diff = *iter - weight_average;
  59. weight_average += diff / n;
  60. if(diff % n < n - excess) {
  61. excess += diff % n;
  62. } else {
  63. ++weight_average;
  64. excess -= n - diff % n;
  65. }
  66. }
  67. }
  68. _alias_table.resize(static_cast<std::size_t>(n));
  69. _average = weight_average;
  70. _excess = excess;
  71. return weight_average;
  72. }
  73. void init_empty()
  74. {
  75. _alias_table.clear();
  76. _alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
  77. static_cast<IntType>(0)));
  78. _average = static_cast<WeightType>(1);
  79. _excess = static_cast<IntType>(0);
  80. }
  81. bool operator==(const integer_alias_table& other) const
  82. {
  83. return _alias_table == other._alias_table &&
  84. _average == other._average && _excess == other._excess;
  85. }
  86. static WeightType normalize(WeightType val, WeightType /* average */)
  87. {
  88. return val;
  89. }
  90. static void normalize(std::vector<WeightType>&) {}
  91. template<class URNG>
  92. WeightType test(URNG &urng) const
  93. {
  94. return uniform_int_distribution<WeightType>(0, _average)(urng);
  95. }
  96. bool accept(IntType result, WeightType val) const
  97. {
  98. return result < _excess || val < _average;
  99. }
  100. static WeightType try_get_sum(const std::vector<WeightType>& weights)
  101. {
  102. WeightType result = static_cast<WeightType>(0);
  103. for(typename std::vector<WeightType>::const_iterator
  104. iter = weights.begin(), end = weights.end();
  105. iter != end; ++iter)
  106. {
  107. if((std::numeric_limits<WeightType>::max)() - result > *iter) {
  108. return static_cast<WeightType>(0);
  109. }
  110. result += *iter;
  111. }
  112. return result;
  113. }
  114. template<class URNG>
  115. static WeightType generate_in_range(URNG &urng, WeightType max)
  116. {
  117. return uniform_int_distribution<WeightType>(
  118. static_cast<WeightType>(0), max-1)(urng);
  119. }
  120. typedef std::vector<std::pair<WeightType, IntType> > alias_table_t;
  121. alias_table_t _alias_table;
  122. WeightType _average;
  123. IntType _excess;
  124. };
  125. template<class IntType, class WeightType>
  126. struct real_alias_table {
  127. WeightType get_weight(IntType) const
  128. {
  129. return WeightType(1.0);
  130. }
  131. template<class Iter>
  132. WeightType init_average(Iter first, Iter last)
  133. {
  134. std::size_t size = std::distance(first, last);
  135. WeightType weight_sum =
  136. std::accumulate(first, last, static_cast<WeightType>(0));
  137. _alias_table.resize(size);
  138. return weight_sum / size;
  139. }
  140. void init_empty()
  141. {
  142. _alias_table.clear();
  143. _alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
  144. static_cast<IntType>(0)));
  145. }
  146. bool operator==(const real_alias_table& other) const
  147. {
  148. return _alias_table == other._alias_table;
  149. }
  150. static WeightType normalize(WeightType val, WeightType average)
  151. {
  152. return val / average;
  153. }
  154. static void normalize(std::vector<WeightType>& weights)
  155. {
  156. WeightType sum =
  157. std::accumulate(weights.begin(), weights.end(),
  158. static_cast<WeightType>(0));
  159. for(typename std::vector<WeightType>::iterator
  160. iter = weights.begin(),
  161. end = weights.end();
  162. iter != end; ++iter)
  163. {
  164. *iter /= sum;
  165. }
  166. }
  167. template<class URNG>
  168. WeightType test(URNG &urng) const
  169. {
  170. return uniform_01<WeightType>()(urng);
  171. }
  172. bool accept(IntType, WeightType) const
  173. {
  174. return true;
  175. }
  176. static WeightType try_get_sum(const std::vector<WeightType>& /* weights */)
  177. {
  178. return static_cast<WeightType>(1);
  179. }
  180. template<class URNG>
  181. static WeightType generate_in_range(URNG &urng, WeightType)
  182. {
  183. return uniform_01<WeightType>()(urng);
  184. }
  185. typedef std::vector<std::pair<WeightType, IntType> > alias_table_t;
  186. alias_table_t _alias_table;
  187. };
  188. template<bool IsIntegral>
  189. struct select_alias_table;
  190. template<>
  191. struct select_alias_table<true> {
  192. template<class IntType, class WeightType>
  193. struct apply {
  194. typedef integer_alias_table<IntType, WeightType> type;
  195. };
  196. };
  197. template<>
  198. struct select_alias_table<false> {
  199. template<class IntType, class WeightType>
  200. struct apply {
  201. typedef real_alias_table<IntType, WeightType> type;
  202. };
  203. };
  204. }
  205. /**
  206. * The class @c discrete_distribution models a \random_distribution.
  207. * It produces integers in the range [0, n) with the probability
  208. * of producing each value is specified by the parameters of the
  209. * distribution.
  210. */
  211. template<class IntType = int, class WeightType = double>
  212. class discrete_distribution {
  213. public:
  214. typedef WeightType input_type;
  215. typedef IntType result_type;
  216. class param_type {
  217. public:
  218. typedef discrete_distribution distribution_type;
  219. /**
  220. * Constructs a @c param_type object, representing a distribution
  221. * with \f$p(0) = 1\f$ and \f$p(k|k>0) = 0\f$.
  222. */
  223. param_type() : _probabilities(1, static_cast<WeightType>(1)) {}
  224. /**
  225. * If @c first == @c last, equivalent to the default constructor.
  226. * Otherwise, the values of the range represent weights for the
  227. * possible values of the distribution.
  228. */
  229. template<class Iter>
  230. param_type(Iter first, Iter last) : _probabilities(first, last)
  231. {
  232. normalize();
  233. }
  234. #ifndef BOOST_NO_CXX11_HDR_INITIALIZER_LIST
  235. /**
  236. * If wl.size() == 0, equivalent to the default constructor.
  237. * Otherwise, the values of the @c initializer_list represent
  238. * weights for the possible values of the distribution.
  239. */
  240. param_type(const std::initializer_list<WeightType>& wl)
  241. : _probabilities(wl)
  242. {
  243. normalize();
  244. }
  245. #endif
  246. /**
  247. * If the range is empty, equivalent to the default constructor.
  248. * Otherwise, the elements of the range represent
  249. * weights for the possible values of the distribution.
  250. */
  251. template<class Range>
  252. explicit param_type(const Range& range)
  253. : _probabilities(std::begin(range), std::end(range))
  254. {
  255. normalize();
  256. }
  257. /**
  258. * If nw is zero, equivalent to the default constructor.
  259. * Otherwise, the range of the distribution is [0, nw),
  260. * and the weights are found by calling fw with values
  261. * evenly distributed between \f$\mbox{xmin} + \delta/2\f$ and
  262. * \f$\mbox{xmax} - \delta/2\f$, where
  263. * \f$\delta = (\mbox{xmax} - \mbox{xmin})/\mbox{nw}\f$.
  264. */
  265. template<class Func>
  266. param_type(std::size_t nw, double xmin, double xmax, Func fw)
  267. {
  268. std::size_t n = (nw == 0) ? 1 : nw;
  269. double delta = (xmax - xmin) / n;
  270. BOOST_ASSERT(delta > 0);
  271. for(std::size_t k = 0; k < n; ++k) {
  272. _probabilities.push_back(fw(xmin + k*delta + delta/2));
  273. }
  274. normalize();
  275. }
  276. /**
  277. * Returns a vector containing the probabilities of each possible
  278. * value of the distribution.
  279. */
  280. std::vector<WeightType> probabilities() const
  281. {
  282. return _probabilities;
  283. }
  284. /** Writes the parameters to a @c std::ostream. */
  285. BOOST_RANDOM_DETAIL_OSTREAM_OPERATOR(os, param_type, parm)
  286. {
  287. detail::print_vector(os, parm._probabilities);
  288. return os;
  289. }
  290. /** Reads the parameters from a @c std::istream. */
  291. BOOST_RANDOM_DETAIL_ISTREAM_OPERATOR(is, param_type, parm)
  292. {
  293. std::vector<WeightType> temp;
  294. detail::read_vector(is, temp);
  295. if(is) {
  296. parm._probabilities.swap(temp);
  297. }
  298. return is;
  299. }
  300. /** Returns true if the two sets of parameters are the same. */
  301. BOOST_RANDOM_DETAIL_EQUALITY_OPERATOR(param_type, lhs, rhs)
  302. {
  303. return lhs._probabilities == rhs._probabilities;
  304. }
  305. /** Returns true if the two sets of parameters are different. */
  306. BOOST_RANDOM_DETAIL_INEQUALITY_OPERATOR(param_type)
  307. private:
  308. /// @cond show_private
  309. friend class discrete_distribution;
  310. explicit param_type(const discrete_distribution& dist)
  311. : _probabilities(dist.probabilities())
  312. {}
  313. void normalize()
  314. {
  315. impl_type::normalize(_probabilities);
  316. }
  317. std::vector<WeightType> _probabilities;
  318. /// @endcond
  319. };
  320. /**
  321. * Creates a new @c discrete_distribution object that has
  322. * \f$p(0) = 1\f$ and \f$p(i|i>0) = 0\f$.
  323. */
  324. discrete_distribution()
  325. {
  326. _impl.init_empty();
  327. }
  328. /**
  329. * Constructs a discrete_distribution from an iterator range.
  330. * If @c first == @c last, equivalent to the default constructor.
  331. * Otherwise, the values of the range represent weights for the
  332. * possible values of the distribution.
  333. */
  334. template<class Iter>
  335. discrete_distribution(Iter first, Iter last)
  336. {
  337. init(first, last);
  338. }
  339. #ifndef BOOST_NO_CXX11_HDR_INITIALIZER_LIST
  340. /**
  341. * Constructs a @c discrete_distribution from a @c std::initializer_list.
  342. * If the @c initializer_list is empty, equivalent to the default
  343. * constructor. Otherwise, the values of the @c initializer_list
  344. * represent weights for the possible values of the distribution.
  345. * For example, given the distribution
  346. *
  347. * @code
  348. * discrete_distribution<> dist{1, 4, 5};
  349. * @endcode
  350. *
  351. * The probability of a 0 is 1/10, the probability of a 1 is 2/5,
  352. * the probability of a 2 is 1/2, and no other values are possible.
  353. */
  354. discrete_distribution(std::initializer_list<WeightType> wl)
  355. {
  356. init(wl.begin(), wl.end());
  357. }
  358. #endif
  359. /**
  360. * Constructs a discrete_distribution from a Boost.Range range.
  361. * If the range is empty, equivalent to the default constructor.
  362. * Otherwise, the values of the range represent weights for the
  363. * possible values of the distribution.
  364. */
  365. template<class Range>
  366. explicit discrete_distribution(const Range& range)
  367. {
  368. init(std::begin(range), std::end(range));
  369. }
  370. /**
  371. * Constructs a discrete_distribution that approximates a function.
  372. * If nw is zero, equivalent to the default constructor.
  373. * Otherwise, the range of the distribution is [0, nw),
  374. * and the weights are found by calling fw with values
  375. * evenly distributed between \f$\mbox{xmin} + \delta/2\f$ and
  376. * \f$\mbox{xmax} - \delta/2\f$, where
  377. * \f$\delta = (\mbox{xmax} - \mbox{xmin})/\mbox{nw}\f$.
  378. */
  379. template<class Func>
  380. discrete_distribution(std::size_t nw, double xmin, double xmax, Func fw)
  381. {
  382. std::size_t n = (nw == 0) ? 1 : nw;
  383. double delta = (xmax - xmin) / n;
  384. BOOST_ASSERT(delta > 0);
  385. std::vector<WeightType> weights;
  386. for(std::size_t k = 0; k < n; ++k) {
  387. weights.push_back(fw(xmin + k*delta + delta/2));
  388. }
  389. init(weights.begin(), weights.end());
  390. }
  391. /**
  392. * Constructs a discrete_distribution from its parameters.
  393. */
  394. explicit discrete_distribution(const param_type& parm)
  395. {
  396. param(parm);
  397. }
  398. /**
  399. * Returns a value distributed according to the parameters of the
  400. * discrete_distribution.
  401. */
  402. template<class URNG>
  403. IntType operator()(URNG& urng) const
  404. {
  405. BOOST_ASSERT(!_impl._alias_table.empty());
  406. IntType result;
  407. WeightType test;
  408. do {
  409. result = uniform_int_distribution<IntType>((min)(), (max)())(urng);
  410. test = _impl.test(urng);
  411. } while(!_impl.accept(result, test));
  412. if(test < _impl._alias_table[static_cast<std::size_t>(result)].first) {
  413. return result;
  414. } else {
  415. return(_impl._alias_table[static_cast<std::size_t>(result)].second);
  416. }
  417. }
  418. /**
  419. * Returns a value distributed according to the parameters
  420. * specified by param.
  421. */
  422. template<class URNG>
  423. IntType operator()(URNG& urng, const param_type& parm) const
  424. {
  425. if(WeightType limit = impl_type::try_get_sum(parm._probabilities)) {
  426. WeightType val = impl_type::generate_in_range(urng, limit);
  427. WeightType sum = 0;
  428. std::size_t result = 0;
  429. for(typename std::vector<WeightType>::const_iterator
  430. iter = parm._probabilities.begin(),
  431. end = parm._probabilities.end();
  432. iter != end; ++iter, ++result)
  433. {
  434. sum += *iter;
  435. if(sum > val) {
  436. return result;
  437. }
  438. }
  439. // This shouldn't be reachable, but round-off error
  440. // can prevent any match from being found when val is
  441. // very close to 1.
  442. return static_cast<IntType>(parm._probabilities.size() - 1);
  443. } else {
  444. // WeightType is integral and sum(parm._probabilities)
  445. // would overflow. Just use the easy solution.
  446. return discrete_distribution(parm)(urng);
  447. }
  448. }
  449. /** Returns the smallest value that the distribution can produce. */
  450. result_type min BOOST_PREVENT_MACRO_SUBSTITUTION () const { return 0; }
  451. /** Returns the largest value that the distribution can produce. */
  452. result_type max BOOST_PREVENT_MACRO_SUBSTITUTION () const
  453. { return static_cast<result_type>(_impl._alias_table.size() - 1); }
  454. /**
  455. * Returns a vector containing the probabilities of each
  456. * value of the distribution. For example, given
  457. *
  458. * @code
  459. * discrete_distribution<> dist = { 1, 4, 5 };
  460. * std::vector<double> p = dist.param();
  461. * @endcode
  462. *
  463. * the vector, p will contain {0.1, 0.4, 0.5}.
  464. *
  465. * If @c WeightType is integral, then the weights
  466. * will be returned unchanged.
  467. */
  468. std::vector<WeightType> probabilities() const
  469. {
  470. std::vector<WeightType> result(_impl._alias_table.size(), static_cast<WeightType>(0));
  471. std::size_t i = 0;
  472. for(typename impl_type::alias_table_t::const_iterator
  473. iter = _impl._alias_table.begin(),
  474. end = _impl._alias_table.end();
  475. iter != end; ++iter, ++i)
  476. {
  477. WeightType val = iter->first;
  478. result[i] += val;
  479. result[static_cast<std::size_t>(iter->second)] += _impl.get_weight(i) - val;
  480. }
  481. impl_type::normalize(result);
  482. return(result);
  483. }
  484. /** Returns the parameters of the distribution. */
  485. param_type param() const
  486. {
  487. return param_type(*this);
  488. }
  489. /** Sets the parameters of the distribution. */
  490. void param(const param_type& parm)
  491. {
  492. init(parm._probabilities.begin(), parm._probabilities.end());
  493. }
  494. /**
  495. * Effects: Subsequent uses of the distribution do not depend
  496. * on values produced by any engine prior to invoking reset.
  497. */
  498. void reset() {}
  499. /** Writes a distribution to a @c std::ostream. */
  500. BOOST_RANDOM_DETAIL_OSTREAM_OPERATOR(os, discrete_distribution, dd)
  501. {
  502. os << dd.param();
  503. return os;
  504. }
  505. /** Reads a distribution from a @c std::istream */
  506. BOOST_RANDOM_DETAIL_ISTREAM_OPERATOR(is, discrete_distribution, dd)
  507. {
  508. param_type parm;
  509. if(is >> parm) {
  510. dd.param(parm);
  511. }
  512. return is;
  513. }
  514. /**
  515. * Returns true if the two distributions will return the
  516. * same sequence of values, when passed equal generators.
  517. */
  518. BOOST_RANDOM_DETAIL_EQUALITY_OPERATOR(discrete_distribution, lhs, rhs)
  519. {
  520. return lhs._impl == rhs._impl;
  521. }
  522. /**
  523. * Returns true if the two distributions may return different
  524. * sequences of values, when passed equal generators.
  525. */
  526. BOOST_RANDOM_DETAIL_INEQUALITY_OPERATOR(discrete_distribution)
  527. private:
  528. /// @cond show_private
  529. template<class Iter>
  530. void init(Iter first, Iter last, std::input_iterator_tag)
  531. {
  532. std::vector<WeightType> temp(first, last);
  533. init(temp.begin(), temp.end());
  534. }
  535. template<class Iter>
  536. void init(Iter first, Iter last, std::forward_iterator_tag)
  537. {
  538. size_t input_size = std::distance(first, last);
  539. std::vector<std::pair<WeightType, IntType> > below_average;
  540. std::vector<std::pair<WeightType, IntType> > above_average;
  541. below_average.reserve(input_size);
  542. above_average.reserve(input_size);
  543. WeightType weight_average = _impl.init_average(first, last);
  544. WeightType normalized_average = _impl.get_weight(0);
  545. std::size_t i = 0;
  546. for(; first != last; ++first, ++i) {
  547. WeightType val = impl_type::normalize(*first, weight_average);
  548. std::pair<WeightType, IntType> elem(val, static_cast<IntType>(i));
  549. if(val < normalized_average) {
  550. below_average.push_back(elem);
  551. } else {
  552. above_average.push_back(elem);
  553. }
  554. }
  555. typename impl_type::alias_table_t::iterator
  556. b_iter = below_average.begin(),
  557. b_end = below_average.end(),
  558. a_iter = above_average.begin(),
  559. a_end = above_average.end()
  560. ;
  561. while(b_iter != b_end && a_iter != a_end) {
  562. _impl._alias_table[static_cast<std::size_t>(b_iter->second)] =
  563. std::make_pair(b_iter->first, a_iter->second);
  564. a_iter->first -= (_impl.get_weight(b_iter->second) - b_iter->first);
  565. if(a_iter->first < normalized_average) {
  566. *b_iter = *a_iter++;
  567. } else {
  568. ++b_iter;
  569. }
  570. }
  571. for(; b_iter != b_end; ++b_iter) {
  572. _impl._alias_table[static_cast<std::size_t>(b_iter->second)].first =
  573. _impl.get_weight(b_iter->second);
  574. }
  575. for(; a_iter != a_end; ++a_iter) {
  576. _impl._alias_table[static_cast<std::size_t>(a_iter->second)].first =
  577. _impl.get_weight(a_iter->second);
  578. }
  579. }
  580. template<class Iter>
  581. void init(Iter first, Iter last)
  582. {
  583. if(first == last) {
  584. _impl.init_empty();
  585. } else {
  586. typename std::iterator_traits<Iter>::iterator_category category;
  587. init(first, last, category);
  588. }
  589. }
  590. typedef typename detail::select_alias_table<
  591. (::boost::is_integral<WeightType>::value)
  592. >::template apply<IntType, WeightType>::type impl_type;
  593. impl_type _impl;
  594. /// @endcond
  595. };
  596. }
  597. }
  598. #include <boost/random/detail/enable_warnings.hpp>
  599. #endif