autodiff_reverse.hpp 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810
  1. // Copyright Maksym Zhelyenzyakov 2025-2026.
  2. // Distributed under the Boost Software License, Version 1.0.
  3. // (See accompanying file LICENSE_1_0.txt or copy at
  4. // https://www.boost.org/LICENSE_1_0.txt)
  5. #ifndef BOOST_MATH_DIFFERENTIATION_AUTODIFF_HPP
  6. #define BOOST_MATH_DIFFERENTIATION_AUTODIFF_HPP
  7. #include <boost/math/constants/constants.hpp>
  8. #if defined(BOOST_MATH_REVERSE_MODE_ET_OFF) && defined(BOOST_MATH_REVERSE_MODE_ET_ON)
  9. #error "Cannot define both BOOST_MATH_REVERSE_MODE_ET_OFF and BOOST_MATH_REVERSE_MODE_ET_ON"
  10. #endif
  11. #if !defined(BOOST_MATH_REVERSE_MODE_ET_OFF) && !defined(BOOST_MATH_REVERSE_MODE_ET_ON)
  12. #define BOOST_MATH_REVERSE_MODE_ET_ON
  13. #endif
  14. #ifdef BOOST_MATH_REVERSE_MODE_ET_ON
  15. #include <boost/math/differentiation/detail/reverse_mode_autodiff_basic_ops_et.hpp>
  16. #include <boost/math/differentiation/detail/reverse_mode_autodiff_stl_et.hpp>
  17. #else
  18. #include <boost/math/differentiation/detail/reverse_mode_autodiff_basic_ops_no_et.hpp>
  19. #include <boost/math/differentiation/detail/reverse_mode_autodiff_stl_no_et.hpp>
  20. #endif
  21. #include <boost/math/differentiation/detail/reverse_mode_autodiff_comparison_operator_overloads.hpp>
  22. #include <boost/math/differentiation/detail/reverse_mode_autodiff_erf_overloads.hpp>
  23. #include <boost/math/differentiation/detail/reverse_mode_autodiff_expression_template_base.hpp>
  24. #include <boost/math/differentiation/detail/reverse_mode_autodiff_memory_management.hpp>
  25. #include <boost/math/special_functions/acosh.hpp>
  26. #include <boost/math/special_functions/asinh.hpp>
  27. #include <boost/math/special_functions/atanh.hpp>
  28. #include <boost/math/special_functions/digamma.hpp>
  29. #include <boost/math/special_functions/erf.hpp>
  30. #include <boost/math/special_functions/lambert_w.hpp>
  31. #include <boost/math/special_functions/polygamma.hpp>
  32. #include <boost/math/special_functions/round.hpp>
  33. #include <boost/math/special_functions/trunc.hpp>
  34. #include <boost/math/tools/config.hpp>
  35. #include <boost/math/tools/promotion.hpp>
  36. #include <cstddef>
  37. #include <iostream>
  38. #include <type_traits>
  39. #include <vector>
  40. #define BOOST_MATH_BUFFER_SIZE 65536
  41. namespace boost {
  42. namespace math {
  43. namespace differentiation {
  44. namespace reverse_mode {
  45. /* forward declarations for utitlity functions */
  46. template<typename RealType, size_t DerivativeOrder, class DerivedExpression>
  47. struct expression;
  48. template<typename RealType, size_t DerivativeOrder>
  49. class rvar;
  50. template<typename RealType,
  51. size_t DerivativeOrder,
  52. typename LHS,
  53. typename RHS,
  54. typename ConcreteBinaryOperation>
  55. struct abstract_binary_expression;
  56. template<typename RealType, size_t DerivativeOrder, typename ARG, typename ConcreteBinaryOperation>
  57. struct abstract_unary_expression;
  58. template<typename RealType, size_t DerivativeOrder>
  59. class gradient_node; // forward declaration for tape
  60. // manages nodes in computational graph
  61. template<typename RealType, size_t DerivativeOrder, size_t buffer_size = BOOST_MATH_BUFFER_SIZE>
  62. class gradient_tape
  63. {
  64. /** @brief tape (graph) management class for autodiff
  65. * holds all the data structures for autodiff */
  66. private:
  67. /* type decays to order - 1 to support higher order derivatives */
  68. using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
  69. /* adjoints are the overall derivative, and derivatives are the "local"
  70. * derivative */
  71. detail::flat_linear_allocator<inner_t, buffer_size> adjoints_;
  72. detail::flat_linear_allocator<inner_t, buffer_size> derivatives_;
  73. detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder>, buffer_size>
  74. gradient_nodes_;
  75. detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder> *, buffer_size>
  76. argument_nodes_;
  77. // compile time check if emplace_back calls on zero
  78. template<size_t n>
  79. gradient_node<RealType, DerivativeOrder> *fill_node_at_compile_time(
  80. std::true_type, gradient_node<RealType, DerivativeOrder> *node_ptr)
  81. {
  82. node_ptr->derivatives_ = derivatives_.template emplace_back_n<n>();
  83. node_ptr->argument_nodes_ = argument_nodes_.template emplace_back_n<n>();
  84. return node_ptr;
  85. }
  86. template<size_t n>
  87. gradient_node<RealType, DerivativeOrder> *fill_node_at_compile_time(
  88. std::false_type, gradient_node<RealType, DerivativeOrder> *node_ptr)
  89. {
  90. node_ptr->derivatives_ = nullptr;
  91. node_ptr->argument_adjoints_ = nullptr;
  92. node_ptr->argument_nodes_ = nullptr;
  93. return node_ptr;
  94. }
  95. public:
  96. /* gradient node stores iterators to its data memebers
  97. * (adjoint/derivative/arguments) so that in case flat linear allocator
  98. * reaches its block boundary and needs more memory for that node, the
  99. * iterator can be invoked to access it */
  100. using adjoint_iterator = typename detail::flat_linear_allocator<inner_t, buffer_size>::iterator;
  101. using derivatives_iterator =
  102. typename detail::flat_linear_allocator<inner_t, buffer_size>::iterator;
  103. using gradient_nodes_iterator =
  104. typename detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder>,
  105. buffer_size>::iterator;
  106. using argument_nodes_iterator =
  107. typename detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder> *,
  108. buffer_size>::iterator;
  109. gradient_tape() { clear(); };
  110. gradient_tape(const gradient_tape &) = delete;
  111. gradient_tape &operator=(const gradient_tape &) = delete;
  112. gradient_tape(gradient_tape &&other) = delete;
  113. gradient_tape operator=(gradient_tape &&other) = delete;
  114. ~gradient_tape() noexcept { clear(); }
  115. void clear()
  116. {
  117. adjoints_.clear();
  118. derivatives_.clear();
  119. gradient_nodes_.clear();
  120. argument_nodes_.clear();
  121. }
  122. // no derivatives or arguments
  123. gradient_node<RealType, DerivativeOrder> *emplace_leaf_node()
  124. {
  125. gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
  126. node->adjoint_ = adjoints_.emplace_back();
  127. node->derivatives_ = derivatives_iterator(); // nullptr;
  128. node->argument_nodes_ = argument_nodes_iterator(); // nullptr;
  129. return node;
  130. };
  131. // single argument, single derivative
  132. gradient_node<RealType, DerivativeOrder> *emplace_active_unary_node()
  133. {
  134. gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
  135. node->n_ = 1;
  136. node->adjoint_ = adjoints_.emplace_back();
  137. node->derivatives_ = derivatives_.emplace_back();
  138. return node;
  139. };
  140. // arbitrary number of arguments/derivatives (compile time)
  141. template<size_t n>
  142. gradient_node<RealType, DerivativeOrder> *emplace_active_multi_node()
  143. {
  144. gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
  145. node->n_ = n;
  146. node->adjoint_ = adjoints_.emplace_back();
  147. // emulate if constexpr
  148. return fill_node_at_compile_time<n>(std::integral_constant<bool, (n > 0)>{}, node);
  149. }
  150. // same as above at runtime
  151. gradient_node<RealType, DerivativeOrder> *emplace_active_multi_node(size_t n)
  152. {
  153. gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
  154. node->n_ = n;
  155. node->adjoint_ = adjoints_.emplace_back();
  156. if (n > 0) {
  157. node->derivatives_ = derivatives_.emplace_back_n(n);
  158. node->argument_nodes_ = argument_nodes_.emplace_back_n(n);
  159. }
  160. return node;
  161. };
  162. /* manual reset button for all adjoints */
  163. void zero_grad()
  164. {
  165. const RealType zero = RealType(0.0);
  166. adjoints_.fill(zero);
  167. }
  168. // return type is an iterator
  169. auto begin() { return gradient_nodes_.begin(); }
  170. auto end() { return gradient_nodes_.end(); }
  171. auto find(gradient_node<RealType, DerivativeOrder> *node)
  172. {
  173. return gradient_nodes_.find(node);
  174. };
  175. void add_checkpoint()
  176. {
  177. gradient_nodes_.add_checkpoint();
  178. adjoints_.add_checkpoint();
  179. derivatives_.add_checkpoint();
  180. argument_nodes_.add_checkpoint();
  181. };
  182. auto last_checkpoint() { return gradient_nodes_.last_checkpoint(); };
  183. auto first_checkpoint() { return gradient_nodes_.last_checkpoint(); };
  184. auto checkpoint_at(size_t index) { return gradient_nodes_.get_checkpoint_at(index); };
  185. void rewind_to_last_checkpoint()
  186. {
  187. gradient_nodes_.rewind_to_last_checkpoint();
  188. adjoints_.rewind_to_last_checkpoint();
  189. derivatives_.rewind_to_last_checkpoint();
  190. argument_nodes_.rewind_to_last_checkpoint();
  191. };
  192. void rewind_to_checkpoint_at(size_t index) // index is "checkpoint" index. so
  193. // order which checkpoint was set
  194. {
  195. gradient_nodes_.rewind_to_checkpoint_at(index);
  196. adjoints_.rewind_to_checkpoint_at(index);
  197. derivatives_.rewind_to_checkpoint_at(index);
  198. argument_nodes_.rewind_to_checkpoint_at(index);
  199. }
  200. // rewind to beginning of computational graph
  201. void rewind()
  202. {
  203. gradient_nodes_.rewind();
  204. adjoints_.rewind();
  205. derivatives_.rewind();
  206. argument_nodes_.rewind();
  207. }
  208. // random acces
  209. gradient_node<RealType, DerivativeOrder> &operator[](size_t i) { return gradient_nodes_[i]; }
  210. const gradient_node<RealType, DerivativeOrder> &operator[](size_t i) const
  211. {
  212. return gradient_nodes_[i];
  213. }
  214. };
  215. // class rvar;
  216. template<typename RealType, size_t DerivativeOrder> // no CRTP, just storage
  217. class gradient_node
  218. {
  219. /*
  220. * @brief manages adjoints, derivatives, and stores points to argument
  221. * adjoints pointers to arguments aren't needed here
  222. * */
  223. public:
  224. using adjoint_iterator = typename gradient_tape<RealType, DerivativeOrder>::adjoint_iterator;
  225. using derivatives_iterator =
  226. typename gradient_tape<RealType, DerivativeOrder>::derivatives_iterator;
  227. using argument_nodes_iterator =
  228. typename gradient_tape<RealType, DerivativeOrder>::argument_nodes_iterator;
  229. private:
  230. size_t n_;
  231. using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
  232. /* these are iterators in case
  233. * flat linear allocator is at capacity, and needs to allocate a new block of
  234. * memory. */
  235. adjoint_iterator adjoint_;
  236. derivatives_iterator derivatives_;
  237. argument_nodes_iterator argument_nodes_;
  238. public:
  239. friend class gradient_tape<RealType, DerivativeOrder>;
  240. friend class rvar<RealType, DerivativeOrder>;
  241. gradient_node() = default;
  242. explicit gradient_node(const size_t n)
  243. : n_(n)
  244. , adjoint_(nullptr)
  245. , derivatives_(nullptr)
  246. {}
  247. explicit gradient_node(const size_t n,
  248. RealType *adjoint,
  249. RealType *derivatives,
  250. rvar<RealType, DerivativeOrder> **arguments)
  251. : n_(n)
  252. , adjoint_(adjoint)
  253. , derivatives_(derivatives)
  254. { static_cast<void>(arguments); }
  255. inner_t get_adjoint_v() const { return *adjoint_; }
  256. inner_t get_derivative_v(size_t arg_id) const { return derivatives_[static_cast<ptrdiff_t>(arg_id)]; };
  257. inner_t get_argument_adjoint_v(size_t arg_id) const
  258. {
  259. return *argument_nodes_[static_cast<ptrdiff_t>(arg_id)]->adjoint_;
  260. }
  261. adjoint_iterator get_adjoint_ptr() { return adjoint_; }
  262. adjoint_iterator get_adjoint_ptr() const { return adjoint_; };
  263. void update_adjoint_v(inner_t value) { *adjoint_ = value; };
  264. void update_derivative_v(size_t arg_id, inner_t value) { derivatives_[static_cast<ptrdiff_t>(arg_id)] = value; };
  265. void update_argument_adj_v(size_t arg_id, inner_t value)
  266. {
  267. argument_nodes_[static_cast<ptrdiff_t>(arg_id)]->update_adjoint_v(value);
  268. };
  269. void update_argument_ptr_at(size_t arg_id, gradient_node<RealType, DerivativeOrder> *node_ptr)
  270. {
  271. argument_nodes_[static_cast<ptrdiff_t>(arg_id)] = node_ptr;
  272. }
  273. void backward()
  274. {
  275. if (!n_) // leaf node
  276. return;
  277. using boost::math::differentiation::reverse_mode::fabs;
  278. using std::fabs;
  279. if (!adjoint_ || fabs(*adjoint_) < 2 * std::numeric_limits<RealType>::epsilon())
  280. return;
  281. if (!argument_nodes_) // no arguments
  282. return;
  283. if (!derivatives_) // no derivatives
  284. return;
  285. for (size_t i = 0; i < n_; ++i) {
  286. auto adjoint = get_adjoint_v();
  287. auto derivative = get_derivative_v(i);
  288. auto argument_adjoint = get_argument_adjoint_v(i);
  289. update_argument_adj_v(i, argument_adjoint + derivative * adjoint);
  290. }
  291. }
  292. };
  293. /****************************************************************************************************************/
  294. template<typename RealType, size_t DerivativeOrder>
  295. inline gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &get_active_tape()
  296. {
  297. static BOOST_MATH_THREAD_LOCAL gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE>
  298. tape;
  299. return tape;
  300. }
  301. template<typename RealType, size_t DerivativeOrder = 1>
  302. class rvar : public expression<RealType, DerivativeOrder, rvar<RealType, DerivativeOrder>>
  303. {
  304. private:
  305. using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
  306. friend class gradient_node<RealType, DerivativeOrder>;
  307. inner_t value_;
  308. gradient_node<RealType, DerivativeOrder> *node_ = nullptr;
  309. template<typename, size_t>
  310. friend class rvar;
  311. /*****************************************************************************************/
  312. /**
  313. * @brief implementation helpers for get_value_at
  314. */
  315. template<size_t target_order, size_t current_order>
  316. struct get_value_at_impl
  317. {
  318. static_assert(target_order <= current_order, "Requested depth exceeds variable order.");
  319. /** @return value_ at rvar_t<T,current_order - 1>
  320. */
  321. static auto &get(rvar<RealType, current_order> &v)
  322. {
  323. return get_value_at_impl<target_order, current_order - 1>::get(v.get_value());
  324. }
  325. /** @return const value_ at rvar_t<T,current_order - 1>
  326. */
  327. static const auto &get(const rvar<RealType, current_order> &v)
  328. {
  329. return get_value_at_impl<target_order, current_order - 1>::get(v.get_value());
  330. }
  331. };
  332. /** @brief base case specialization for target_order == current order
  333. */
  334. template<size_t target_order>
  335. struct get_value_at_impl<target_order, target_order>
  336. {
  337. /** @return value_ at rvar_t<T,target_order>
  338. */
  339. static auto &get(rvar<RealType, target_order> &v) { return v; }
  340. /** @return const value_ at rvar_t<T,target_order>
  341. */
  342. static const auto &get(const rvar<RealType, target_order> &v) { return v; }
  343. };
  344. /*****************************************************************************************/
  345. void make_leaf_node()
  346. {
  347. gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
  348. = get_active_tape<RealType, DerivativeOrder>();
  349. node_ = tape.emplace_leaf_node();
  350. }
  351. void make_unary_node()
  352. {
  353. gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
  354. = get_active_tape<RealType, DerivativeOrder>();
  355. node_ = tape.emplace_active_unary_node();
  356. }
  357. void make_multi_node(size_t n)
  358. {
  359. gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
  360. = get_active_tape<RealType, DerivativeOrder>();
  361. node_ = tape.emplace_active_multi_node(n);
  362. }
  363. template<size_t n>
  364. void make_multi_node()
  365. {
  366. gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
  367. = get_active_tape<RealType, DerivativeOrder>();
  368. node_ = tape.template emplace_active_multi_node<n>();
  369. }
  370. template<typename E>
  371. void make_rvar_from_expr(const expression<RealType, DerivativeOrder, E> &expr)
  372. {
  373. make_multi_node<detail::count_rvars<E, DerivativeOrder>>();
  374. expr.template propagatex<0>(node_, inner_t(static_cast<RealType>(1.0)));
  375. }
  376. RealType get_item_impl(std::true_type) const
  377. {
  378. return value_.get_item_impl(std::integral_constant<bool, (DerivativeOrder - 1 > 1)>{});
  379. }
  380. RealType get_item_impl(std::false_type) const { return value_; }
  381. public:
  382. using value_type = RealType;
  383. static constexpr size_t DerivativeOrder_v = DerivativeOrder;
  384. rvar()
  385. : value_()
  386. {
  387. make_leaf_node();
  388. }
  389. rvar(const RealType value)
  390. : value_(inner_t{static_cast<RealType>(value)})
  391. {
  392. make_leaf_node();
  393. }
  394. rvar &operator=(RealType v)
  395. {
  396. value_ = inner_t(v);
  397. if (node_ == nullptr) {
  398. make_leaf_node();
  399. }
  400. return *this;
  401. }
  402. rvar(const rvar<RealType, DerivativeOrder> &other) = default;
  403. rvar &operator=(const rvar<RealType, DerivativeOrder> &other) = default;
  404. template<size_t arg_index>
  405. void propagatex(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
  406. {
  407. node->update_derivative_v(arg_index, adj);
  408. node->update_argument_ptr_at(arg_index, node_);
  409. }
  410. template<class E>
  411. rvar(const expression<RealType, DerivativeOrder, E> &expr)
  412. {
  413. value_ = expr.evaluate();
  414. make_rvar_from_expr(expr);
  415. }
  416. template<typename T,
  417. typename = std::enable_if_t<is_floating_point_v<T> && !is_same_v<T, RealType>>>
  418. rvar(T v)
  419. : value_(inner_t{static_cast<RealType>(v)})
  420. {
  421. make_leaf_node();
  422. }
  423. template<class E>
  424. rvar &operator=(const expression<RealType, DerivativeOrder, E> &expr)
  425. {
  426. value_ = expr.evaluate();
  427. make_rvar_from_expr(expr);
  428. return *this;
  429. }
  430. /***************************************************************************************************/
  431. template<class E>
  432. rvar<RealType, DerivativeOrder> &operator+=(const expression<RealType, DerivativeOrder, E> &expr)
  433. {
  434. *this = *this + expr;
  435. return *this;
  436. }
  437. template<class E>
  438. rvar<RealType, DerivativeOrder> &operator*=(const expression<RealType, DerivativeOrder, E> &expr)
  439. {
  440. *this = *this * expr;
  441. return *this;
  442. }
  443. template<class E>
  444. rvar<RealType, DerivativeOrder> &operator-=(const expression<RealType, DerivativeOrder, E> &expr)
  445. {
  446. *this = *this - expr;
  447. return *this;
  448. }
  449. template<class E>
  450. rvar<RealType, DerivativeOrder> &operator/=(const expression<RealType, DerivativeOrder, E> &expr)
  451. {
  452. *this = *this / expr;
  453. return *this;
  454. }
  455. /***************************************************************************************************/
  456. rvar<RealType, DerivativeOrder> &operator+=(const RealType &v)
  457. {
  458. *this = *this + v;
  459. return *this;
  460. }
  461. rvar<RealType, DerivativeOrder> &operator*=(const RealType &v)
  462. {
  463. *this = *this * v;
  464. return *this;
  465. }
  466. rvar<RealType, DerivativeOrder> &operator-=(const RealType &v)
  467. {
  468. *this = *this - v;
  469. return *this;
  470. }
  471. rvar<RealType, DerivativeOrder> &operator/=(const RealType &v)
  472. {
  473. *this = *this / v;
  474. return *this;
  475. }
  476. /***************************************************************************************************/
  477. const inner_t &adjoint() const { return *node_->get_adjoint_ptr(); }
  478. inner_t &adjoint() { return *node_->get_adjoint_ptr(); }
  479. const inner_t &evaluate() const { return value_; };
  480. inner_t &get_value() { return value_; };
  481. explicit operator RealType() const { return item(); }
  482. explicit operator int() const { return static_cast<int>(item()); }
  483. explicit operator long() const { return static_cast<long>(item()); }
  484. explicit operator long long() const { return static_cast<long long>(item()); }
  485. /**
  486. * @brief same as evaluate but returns proper depth for higher order derivatives
  487. * @return value_ at depth N
  488. */
  489. template<size_t N>
  490. auto &get_value_at()
  491. {
  492. static_assert(N <= DerivativeOrder, "Requested depth exceeds variable order.");
  493. return get_value_at_impl<N, DerivativeOrder>::get(*this);
  494. }
  495. /** @brief same as above but const
  496. */
  497. template<size_t N>
  498. const auto &get_value_at() const
  499. {
  500. static_assert(N <= DerivativeOrder, "Requested depth exceeds variable order.");
  501. return get_value_at_impl<N, DerivativeOrder>::get(*this);
  502. }
  503. RealType item() const
  504. {
  505. return get_item_impl(std::integral_constant<bool, (DerivativeOrder > 1)>{});
  506. }
  507. void backward()
  508. {
  509. gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
  510. = get_active_tape<RealType, DerivativeOrder>();
  511. auto it = tape.find(node_);
  512. it->update_adjoint_v(inner_t(static_cast<RealType>(1.0)));
  513. while (it != tape.begin()) {
  514. it->backward();
  515. --it;
  516. }
  517. it->backward();
  518. }
  519. };
  520. template<typename RealType, size_t DerivativeOrder>
  521. std::ostream &operator<<(std::ostream &os, const rvar<RealType, DerivativeOrder> var)
  522. {
  523. os << "rvar<" << DerivativeOrder << ">(" << var.item() << "," << var.adjoint() << ")";
  524. return os;
  525. }
  526. template<typename RealType, size_t DerivativeOrder, typename E>
  527. std::ostream &operator<<(std::ostream &os, const expression<RealType, DerivativeOrder, E> &expr)
  528. {
  529. rvar<RealType, DerivativeOrder> tmp = expr;
  530. os << "rvar<" << DerivativeOrder << ">(" << tmp.item() << "," << tmp.adjoint() << ")";
  531. return os;
  532. }
  533. template<typename RealType, size_t DerivativeOrder>
  534. rvar<RealType, DerivativeOrder> make_rvar(const RealType v)
  535. {
  536. static_assert(DerivativeOrder > 0, "rvar order must be >= 1");
  537. return rvar<RealType, DerivativeOrder>(v);
  538. }
  539. template<typename RealType, size_t DerivativeOrder, typename E>
  540. rvar<RealType, DerivativeOrder> make_rvar(const expression<RealType, DerivativeOrder, E> &expr)
  541. {
  542. static_assert(DerivativeOrder > 0, "rvar order must be >= 1");
  543. return rvar<RealType, DerivativeOrder>(expr);
  544. }
  545. namespace detail {
  546. /** @brief helper overload for grad implementation.
  547. * @return vector<rvar<T,order-1> of gradients of the autodiff graph.
  548. * specialization for autodiffing through autodiff. i.e. being able to
  549. * compute higher order grads
  550. */
  551. template<typename RealType, size_t DerivativeOrder>
  552. struct grad_op_impl
  553. {
  554. std::vector<rvar<RealType, DerivativeOrder - 1>> operator()(
  555. rvar<RealType, DerivativeOrder> &f, std::vector<rvar<RealType, DerivativeOrder> *> &x)
  556. {
  557. auto &tape = get_active_tape<RealType, DerivativeOrder>();
  558. tape.zero_grad();
  559. f.backward();
  560. std::vector<rvar<RealType, DerivativeOrder - 1>> gradient_vector;
  561. gradient_vector.reserve(x.size());
  562. for (auto &xi : x) {
  563. gradient_vector.emplace_back(xi->adjoint());
  564. }
  565. return gradient_vector;
  566. }
  567. };
  568. /** @brief helper overload for grad implementation.
  569. * @return vector<T> of gradients of the autodiff graph.
  570. * base specialization for order 1 autodiff
  571. */
  572. template<typename T>
  573. struct grad_op_impl<T, 1>
  574. {
  575. std::vector<T> operator()(rvar<T, 1> &f, std::vector<rvar<T, 1> *> &x)
  576. {
  577. gradient_tape<T, 1, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, 1>();
  578. tape.zero_grad();
  579. f.backward();
  580. std::vector<T> gradient_vector;
  581. gradient_vector.reserve(x.size());
  582. for (auto &xi : x) {
  583. gradient_vector.push_back(xi->adjoint());
  584. }
  585. return gradient_vector;
  586. }
  587. };
  588. /** @brief helper overload for higher order autodiff
  589. * @return nested vector representing N-d tensor of
  590. * higher order derivatives
  591. */
  592. template<size_t N,
  593. typename RealType,
  594. size_t DerivativeOrder_1,
  595. size_t DerivativeOrder_2,
  596. typename Enable = void>
  597. struct grad_nd_impl
  598. {
  599. auto operator()(rvar<RealType, DerivativeOrder_1> &f,
  600. std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
  601. {
  602. static_assert(N > 1, "N must be greater than 1 for this template");
  603. auto current_grad = grad(f, x); // vector<rvar<T,DerivativeOrder_1-1>> or vector<T>
  604. std::vector<decltype(grad_nd_impl<N - 1, RealType, DerivativeOrder_1 - 1, DerivativeOrder_2>()(
  605. current_grad[0], x))>
  606. result;
  607. result.reserve(current_grad.size());
  608. for (auto &g_i : current_grad) {
  609. result.push_back(
  610. grad_nd_impl<N - 1, RealType, DerivativeOrder_1 - 1, DerivativeOrder_2>()(g_i, x));
  611. }
  612. return result;
  613. }
  614. };
  615. /** @brief spcialization for order = 1,
  616. * @return vector<rvar<T,DerivativeOrder_1-1>> gradients */
  617. template<typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
  618. struct grad_nd_impl<1, RealType, DerivativeOrder_1, DerivativeOrder_2>
  619. {
  620. auto operator()(rvar<RealType, DerivativeOrder_1> &f,
  621. std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
  622. {
  623. return grad(f, x);
  624. }
  625. };
  626. template<typename ptr>
  627. struct rvar_order;
  628. template<typename RealType, size_t DerivativeOrder>
  629. struct rvar_order<rvar<RealType, DerivativeOrder> *>
  630. {
  631. static constexpr size_t value = DerivativeOrder;
  632. };
  633. } // namespace detail
  634. /**
  635. * @brief grad computes gradient with respect to vector of pointers x
  636. * @param f -> computational graph
  637. * @param x -> variables gradients to record. Note ALL gradients of the graph
  638. * are computed simultaneously, only the ones w.r.t. x are returned
  639. * @return vector<rvar<T,DerivativeOrder_1 - 1> of gradients. in the case of DerivativeOrder_1 = 1
  640. * rvar<T,DerivativeOrder_1-1> decays to T
  641. *
  642. * safe to call recursively with grad(grad(grad...
  643. */
  644. template<typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
  645. auto grad(rvar<RealType, DerivativeOrder_1> &f, std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
  646. {
  647. static_assert(DerivativeOrder_1 <= DerivativeOrder_2,
  648. "variable differentiating w.r.t. must have order >= function order");
  649. std::vector<rvar<RealType, DerivativeOrder_1> *> xx;
  650. xx.reserve(x.size());
  651. for (auto &xi : x)
  652. xx.push_back(&(xi->template get_value_at<DerivativeOrder_1>()));
  653. return detail::grad_op_impl<RealType, DerivativeOrder_1>{}(f, xx);
  654. }
  655. /** @brief variadic overload of above
  656. */
  657. template<typename RealType, size_t DerivativeOrder_1, typename First, typename... Other>
  658. auto grad(rvar<RealType, DerivativeOrder_1> &f, First first, Other... other)
  659. {
  660. constexpr size_t DerivativeOrder_2 = detail::rvar_order<First>::value;
  661. static_assert(DerivativeOrder_1 <= DerivativeOrder_2,
  662. "variable differentiating w.r.t. must have order >= function order");
  663. std::vector<rvar<RealType, DerivativeOrder_2> *> x_vec = {first, other...};
  664. return grad(f, x_vec);
  665. }
  666. /** @brief computes hessian matrix of computational graph w.r.t.
  667. * vector of variables x.
  668. * @return std::vector<std::vector<rvar<T,DerivativeOrder_1-2>> hessian matrix
  669. * rvar<T,2> decays to T
  670. *
  671. * NOT recursion safe, cannot do hess(hess(
  672. */
  673. template<typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
  674. auto hess(rvar<RealType, DerivativeOrder_1> &f, std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
  675. {
  676. return detail::grad_nd_impl<2, RealType, DerivativeOrder_1, DerivativeOrder_2>{}(f, x);
  677. }
  678. /** @brief variadic overload of above
  679. */
  680. template<typename RealType, size_t DerivativeOrder_1, typename First, typename... Other>
  681. auto hess(rvar<RealType, DerivativeOrder_1> &f, First first, Other... other)
  682. {
  683. constexpr size_t DerivativeOrder_2 = detail::rvar_order<First>::value;
  684. std::vector<rvar<RealType, DerivativeOrder_2> *> x_vec = {first, other...};
  685. return hess(f, x_vec);
  686. }
  687. /** @brief comput N'th gradient of computational graph w.r.t. x
  688. * @return vector<vector<.... up N nestings representing tensor
  689. * of gradients of order N
  690. *
  691. * NOT recursively safe, cannot do grad_nd(grad_nd(... etc...
  692. */
  693. template<size_t N, typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
  694. auto grad_nd(rvar<RealType, DerivativeOrder_1> &f,
  695. std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
  696. {
  697. static_assert(DerivativeOrder_1 >= N, "Function order must be at least N");
  698. static_assert(DerivativeOrder_2 >= DerivativeOrder_1,
  699. "Variable order must be at least function order");
  700. return detail::grad_nd_impl<N, RealType, DerivativeOrder_1, DerivativeOrder_2>()(f, x);
  701. }
  702. /** @brief variadic overload of above
  703. */
  704. template<size_t N, typename ftype, typename First, typename... Other>
  705. auto grad_nd(ftype &f, First first, Other... other)
  706. {
  707. using RealType = typename ftype::value_type;
  708. constexpr size_t DerivativeOrder_1 = detail::rvar_order<ftype *>::value;
  709. constexpr size_t DerivativeOrder_2 = detail::rvar_order<First>::value;
  710. std::vector<rvar<RealType, DerivativeOrder_2> *> x_vec = {first, other...};
  711. return detail::grad_nd_impl<N, RealType, DerivativeOrder_1, DerivativeOrder_1>{}(f, x_vec);
  712. }
  713. } // namespace reverse_mode
  714. } // namespace differentiation
  715. } // namespace math
  716. } // namespace boost
  717. namespace std {
  718. // copied from forward mode
  719. template<typename RealType, size_t DerivativeOrder>
  720. class numeric_limits<boost::math::differentiation::reverse_mode::rvar<RealType, DerivativeOrder>>
  721. : public numeric_limits<typename boost::math::differentiation::reverse_mode::
  722. rvar<RealType, DerivativeOrder>::value_type>
  723. {};
  724. } // namespace std
  725. #endif