linearize.hpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. // Copyright 2015-2018 Hans Dembinski
  2. //
  3. // Distributed under the Boost Software License, Version 1.0.
  4. // (See accompanying file LICENSE_1_0.txt
  5. // or copy at http://www.boost.org/LICENSE_1_0.txt)
  6. #ifndef BOOST_HISTOGRAM_DETAIL_LINEARIZE_HPP
  7. #define BOOST_HISTOGRAM_DETAIL_LINEARIZE_HPP
  8. #include <algorithm>
  9. #include <boost/assert.hpp>
  10. #include <boost/histogram/axis/traits.hpp>
  11. #include <boost/histogram/axis/variant.hpp>
  12. #include <boost/histogram/detail/axes.hpp>
  13. #include <boost/histogram/detail/meta.hpp>
  14. #include <boost/histogram/fwd.hpp>
  15. #include <boost/histogram/unsafe_access.hpp>
  16. #include <boost/mp11/algorithm.hpp>
  17. #include <boost/mp11/function.hpp>
  18. #include <boost/mp11/integral.hpp>
  19. #include <boost/mp11/list.hpp>
  20. #include <boost/mp11/tuple.hpp>
  21. #include <boost/throw_exception.hpp>
  22. #include <stdexcept>
  23. #include <tuple>
  24. #include <type_traits>
  25. namespace boost {
  26. namespace histogram {
  27. namespace detail {
  28. template <class T>
  29. struct is_accumulator_set : std::false_type {};
  30. template <class T>
  31. using has_underflow =
  32. decltype(axis::traits::static_options<T>::test(axis::option::underflow));
  33. template <class T>
  34. struct is_growing
  35. : decltype(axis::traits::static_options<T>::test(axis::option::growth)) {};
  36. template <class... Ts>
  37. struct is_growing<std::tuple<Ts...>> : mp11::mp_or<is_growing<Ts>...> {};
  38. template <class... Ts>
  39. struct is_growing<axis::variant<Ts...>> : mp11::mp_or<is_growing<Ts>...> {};
  40. template <class T>
  41. using has_growing_axis =
  42. mp11::mp_if<is_vector_like<T>, is_growing<mp11::mp_first<T>>, is_growing<T>>;
  43. /// Index with an invalid state
  44. struct optional_index {
  45. std::size_t idx = 0;
  46. std::size_t stride = 1;
  47. operator bool() const { return stride > 0; }
  48. std::size_t operator*() const { return idx; }
  49. };
  50. inline void linearize(optional_index& out, const axis::index_type extent,
  51. const axis::index_type j) noexcept {
  52. // j is internal index shifted by +1 if axis has underflow bin
  53. out.idx += j * out.stride;
  54. // set stride to 0, if j is invalid
  55. out.stride *= (0 <= j && j < extent) * extent;
  56. }
  57. // for non-growing axis
  58. template <class Axis, class Value>
  59. void linearize_value(optional_index& o, const Axis& a, const Value& v) {
  60. using B = decltype(axis::traits::static_options<Axis>::test(axis::option::underflow));
  61. const auto j = axis::traits::index(a, v) + B::value;
  62. linearize(o, axis::traits::extent(a), j);
  63. }
  64. // for variant that does not contain any growing axis
  65. template <class... Ts, class Value>
  66. void linearize_value(optional_index& o, const axis::variant<Ts...>& a, const Value& v) {
  67. axis::visit([&o, &v](const auto& a) { linearize_value(o, a, v); }, a);
  68. }
  69. // for growing axis
  70. template <class Axis, class Value>
  71. void linearize_value(optional_index& o, axis::index_type& s, Axis& a, const Value& v) {
  72. axis::index_type j;
  73. std::tie(j, s) = axis::traits::update(a, v);
  74. j += has_underflow<Axis>::value;
  75. linearize(o, axis::traits::extent(a), j);
  76. }
  77. // for variant which contains at least one growing axis
  78. template <class... Ts, class Value>
  79. void linearize_value(optional_index& o, axis::index_type& s, axis::variant<Ts...>& a,
  80. const Value& v) {
  81. axis::visit([&o, &s, &v](auto&& a) { linearize_value(o, s, a, v); }, a);
  82. }
  83. template <class A>
  84. void linearize_index(optional_index& out, const A& axis, const axis::index_type j) {
  85. // A may be axis or variant, cannot use static option detection here
  86. const auto opt = axis::traits::options(axis);
  87. const auto shift = opt & axis::option::underflow ? 1 : 0;
  88. const auto n = axis.size() + (opt & axis::option::overflow ? 1 : 0);
  89. linearize(out, n + shift, j + shift);
  90. }
  91. template <class S, class A, class T>
  92. void maybe_replace_storage(S& storage, const A& axes, const T& shifts) {
  93. bool update_needed = false;
  94. auto sit = shifts;
  95. for_each_axis(axes, [&](const auto&) { update_needed |= (*sit++ != 0); });
  96. if (!update_needed) return;
  97. struct item {
  98. axis::index_type idx, old_extent;
  99. std::size_t new_stride;
  100. } data[buffer_size<A>::value];
  101. sit = shifts;
  102. auto dit = data;
  103. std::size_t s = 1;
  104. for_each_axis(axes, [&](const auto& a) {
  105. const auto n = axis::traits::extent(a);
  106. *dit++ = {0, n - std::abs(*sit++), s};
  107. s *= n;
  108. });
  109. auto new_storage = make_default(storage);
  110. new_storage.reset(detail::bincount(axes));
  111. const auto dlast = data + get_size(axes) - 1;
  112. for (const auto& x : storage) {
  113. auto ns = new_storage.begin();
  114. sit = shifts;
  115. dit = data;
  116. for_each_axis(axes, [&](const auto& a) {
  117. using opt = axis::traits::static_options<decltype(a)>;
  118. if (opt::test(axis::option::underflow)) {
  119. if (dit->idx == 0) {
  120. // axis has underflow and we are in the underflow bin:
  121. // keep storage pointer unchanged
  122. ++dit;
  123. ++sit;
  124. return;
  125. }
  126. }
  127. if (opt::test(axis::option::overflow)) {
  128. if (dit->idx == dit->old_extent - 1) {
  129. // axis has overflow and we are in the overflow bin:
  130. // move storage pointer to corresponding overflow bin position
  131. ns += (axis::traits::extent(a) - 1) * dit->new_stride;
  132. ++dit;
  133. ++sit;
  134. return;
  135. }
  136. }
  137. // we are in a normal bin:
  138. // move storage pointer to index position, apply positive shifts
  139. ns += (dit->idx + std::max(*sit, 0)) * dit->new_stride;
  140. ++dit;
  141. ++sit;
  142. });
  143. // assign old value to new location
  144. *ns = x;
  145. // advance multi-dimensional index
  146. dit = data;
  147. ++dit->idx;
  148. while (dit != dlast && dit->idx == dit->old_extent) {
  149. dit->idx = 0;
  150. ++(++dit)->idx;
  151. }
  152. }
  153. storage = std::move(new_storage);
  154. }
  155. // special case: if histogram::operator()(tuple(1, 2)) is called on 1d histogram
  156. // with axis that accepts 2d tuple, this should not fail
  157. // - solution is to forward tuples of size > 1 directly to axis for 1d
  158. // histograms
  159. // - has nice side-effect of making histogram::operator(1, 2) work as well
  160. // - cannot detect call signature of axis at compile-time in all configurations
  161. // (axis::variant provides generic call interface and hides concrete
  162. // interface), so we throw at runtime if incompatible argument is passed (e.g.
  163. // 3d tuple)
  164. // histogram has only non-growing axes
  165. template <unsigned I, unsigned N, class S, class T, class U>
  166. optional_index args_to_index(std::false_type, S&, const T& axes, const U& args) {
  167. optional_index idx;
  168. const auto rank = get_size(axes);
  169. if (rank == 1 && N > 1)
  170. linearize_value(idx, axis_get<0>(axes), tuple_slice<I, N>(args));
  171. else {
  172. if (rank != N)
  173. BOOST_THROW_EXCEPTION(
  174. std::invalid_argument("number of arguments != histogram rank"));
  175. constexpr unsigned M = buffer_size<remove_cvref_t<decltype(axes)>>::value;
  176. mp11::mp_for_each<mp11::mp_iota_c<(N < M ? N : M)>>([&](auto J) {
  177. linearize_value(idx, axis_get<J>(axes), std::get<(J + I)>(args));
  178. });
  179. }
  180. return idx;
  181. }
  182. // histogram has growing axes
  183. template <unsigned I, unsigned N, class S, class T, class U>
  184. optional_index args_to_index(std::true_type, S& storage, T& axes, const U& args) {
  185. optional_index idx;
  186. axis::index_type shifts[buffer_size<T>::value];
  187. const auto rank = get_size(axes);
  188. if (rank == 1 && N > 1)
  189. linearize_value(idx, shifts[0], axis_get<0>(axes), tuple_slice<I, N>(args));
  190. else {
  191. if (rank != N)
  192. BOOST_THROW_EXCEPTION(
  193. std::invalid_argument("number of arguments != histogram rank"));
  194. constexpr unsigned M = buffer_size<remove_cvref_t<decltype(axes)>>::value;
  195. mp11::mp_for_each<mp11::mp_iota_c<(N < M ? N : M)>>([&](auto J) {
  196. linearize_value(idx, shifts[J], axis_get<J>(axes), std::get<(J + I)>(args));
  197. });
  198. }
  199. maybe_replace_storage(storage, axes, shifts);
  200. return idx;
  201. }
  202. template <typename U>
  203. constexpr auto weight_sample_indices() {
  204. if (is_weight<U>::value) return std::make_pair(0, -1);
  205. if (is_sample<U>::value) return std::make_pair(-1, 0);
  206. return std::make_pair(-1, -1);
  207. }
  208. template <typename U0, typename U1, typename... Us>
  209. constexpr auto weight_sample_indices() {
  210. using L = mp11::mp_list<U0, U1, Us...>;
  211. const int n = sizeof...(Us) + 1;
  212. if (is_weight<mp11::mp_at_c<L, 0>>::value) {
  213. if (is_sample<mp11::mp_at_c<L, 1>>::value) return std::make_pair(0, 1);
  214. if (is_sample<mp11::mp_at_c<L, n>>::value) return std::make_pair(0, n);
  215. return std::make_pair(0, -1);
  216. }
  217. if (is_sample<mp11::mp_at_c<L, 0>>::value) {
  218. if (is_weight<mp11::mp_at_c<L, 1>>::value) return std::make_pair(1, 0);
  219. if (is_weight<mp11::mp_at_c<L, n>>::value) return std::make_pair(n, 0);
  220. return std::make_pair(-1, 0);
  221. }
  222. if (is_weight<mp11::mp_at_c<L, n>>::value) {
  223. // 0, n already covered
  224. if (is_sample<mp11::mp_at_c<L, (n - 1)>>::value) return std::make_pair(n, n - 1);
  225. return std::make_pair(n, -1);
  226. }
  227. if (is_sample<mp11::mp_at_c<L, n>>::value) {
  228. // n, 0 already covered
  229. if (is_weight<mp11::mp_at_c<L, (n - 1)>>::value) return std::make_pair(n - 1, n);
  230. return std::make_pair(-1, n);
  231. }
  232. return std::make_pair(-1, -1);
  233. }
  234. template <class T, class U>
  235. void fill_storage(mp11::mp_int<-1>, mp11::mp_int<-1>, T&& t, U&&) {
  236. static_if<is_incrementable<remove_cvref_t<T>>>(
  237. [](auto&& t) { ++t; }, [](auto&& t) { t(); }, std::forward<T>(t));
  238. }
  239. template <class IW, class T, class U>
  240. void fill_storage(IW, mp11::mp_int<-1>, T&& t, U&& args) {
  241. static_if<is_incrementable<remove_cvref_t<T>>>(
  242. [](auto&& t, const auto& w) { t += w; },
  243. [](auto&& t, const auto& w) {
  244. #ifdef BOOST_HISTOGRAM_WITH_ACCUMULATORS_SUPPORT
  245. static_if<is_accumulator_set<remove_cvref_t<T>>>(
  246. [w](auto&& t) { t(::boost::accumulators::weight = w); },
  247. [w](auto&& t) { t(w); }, t);
  248. #else
  249. t(w);
  250. #endif
  251. },
  252. std::forward<T>(t), std::get<IW::value>(args).value);
  253. }
  254. template <class IS, class T, class U>
  255. void fill_storage(mp11::mp_int<-1>, IS, T&& t, U&& args) {
  256. mp11::tuple_apply([&t](auto&&... args) { t(args...); },
  257. std::get<IS::value>(args).value);
  258. }
  259. template <class IW, class IS, class T, class U>
  260. void fill_storage(IW, IS, T&& t, U&& args) {
  261. mp11::tuple_apply(
  262. [&](auto&&... args2) { t(std::get<IW::value>(args).value, args2...); },
  263. std::get<IS::value>(args).value);
  264. }
  265. template <class S, class A, class... Us>
  266. auto fill(S& storage, A& axes, const std::tuple<Us...>& args) {
  267. constexpr auto iws = weight_sample_indices<Us...>();
  268. constexpr unsigned n = sizeof...(Us) - (iws.first > -1) - (iws.second > -1);
  269. constexpr unsigned i = (iws.first == 0 || iws.second == 0)
  270. ? (iws.first == 1 || iws.second == 1 ? 2 : 1)
  271. : 0;
  272. optional_index idx = args_to_index<i, n>(has_growing_axis<A>(), storage, axes, args);
  273. if (idx) {
  274. fill_storage(mp11::mp_int<iws.first>(), mp11::mp_int<iws.second>(), storage[*idx],
  275. args);
  276. return storage.begin() + *idx;
  277. }
  278. return storage.end();
  279. }
  280. template <typename A, typename... Us>
  281. optional_index at(const A& axes, const std::tuple<Us...>& args) {
  282. if (get_size(axes) != sizeof...(Us))
  283. BOOST_THROW_EXCEPTION(std::invalid_argument("number of arguments != histogram rank"));
  284. optional_index idx;
  285. mp11::mp_for_each<mp11::mp_iota_c<sizeof...(Us)>>([&](auto I) {
  286. // axes_get works with static and dynamic axes
  287. linearize_index(idx, axis_get<I>(axes),
  288. static_cast<axis::index_type>(std::get<I>(args)));
  289. });
  290. return idx;
  291. }
  292. template <typename A, typename U>
  293. optional_index at(const A& axes, const U& args) {
  294. if (get_size(axes) != get_size(args))
  295. BOOST_THROW_EXCEPTION(std::invalid_argument("number of arguments != histogram rank"));
  296. optional_index idx;
  297. using std::begin;
  298. auto it = begin(args);
  299. for_each_axis(axes, [&](const auto& a) {
  300. linearize_index(idx, a, static_cast<axis::index_type>(*it++));
  301. });
  302. return idx;
  303. }
  304. } // namespace detail
  305. } // namespace histogram
  306. } // namespace boost
  307. #endif