| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810 |
- // Copyright Maksym Zhelyenzyakov 2025-2026.
- // Distributed under the Boost Software License, Version 1.0.
- // (See accompanying file LICENSE_1_0.txt or copy at
- // https://www.boost.org/LICENSE_1_0.txt)
- #ifndef BOOST_MATH_DIFFERENTIATION_AUTODIFF_HPP
- #define BOOST_MATH_DIFFERENTIATION_AUTODIFF_HPP
- #include <boost/math/constants/constants.hpp>
- #if defined(BOOST_MATH_REVERSE_MODE_ET_OFF) && defined(BOOST_MATH_REVERSE_MODE_ET_ON)
- #error "Cannot define both BOOST_MATH_REVERSE_MODE_ET_OFF and BOOST_MATH_REVERSE_MODE_ET_ON"
- #endif
- #if !defined(BOOST_MATH_REVERSE_MODE_ET_OFF) && !defined(BOOST_MATH_REVERSE_MODE_ET_ON)
- #define BOOST_MATH_REVERSE_MODE_ET_ON
- #endif
- #ifdef BOOST_MATH_REVERSE_MODE_ET_ON
- #include <boost/math/differentiation/detail/reverse_mode_autodiff_basic_ops_et.hpp>
- #include <boost/math/differentiation/detail/reverse_mode_autodiff_stl_et.hpp>
- #else
- #include <boost/math/differentiation/detail/reverse_mode_autodiff_basic_ops_no_et.hpp>
- #include <boost/math/differentiation/detail/reverse_mode_autodiff_stl_no_et.hpp>
- #endif
- #include <boost/math/differentiation/detail/reverse_mode_autodiff_comparison_operator_overloads.hpp>
- #include <boost/math/differentiation/detail/reverse_mode_autodiff_erf_overloads.hpp>
- #include <boost/math/differentiation/detail/reverse_mode_autodiff_expression_template_base.hpp>
- #include <boost/math/differentiation/detail/reverse_mode_autodiff_memory_management.hpp>
- #include <boost/math/special_functions/acosh.hpp>
- #include <boost/math/special_functions/asinh.hpp>
- #include <boost/math/special_functions/atanh.hpp>
- #include <boost/math/special_functions/digamma.hpp>
- #include <boost/math/special_functions/erf.hpp>
- #include <boost/math/special_functions/lambert_w.hpp>
- #include <boost/math/special_functions/polygamma.hpp>
- #include <boost/math/special_functions/round.hpp>
- #include <boost/math/special_functions/trunc.hpp>
- #include <boost/math/tools/config.hpp>
- #include <boost/math/tools/promotion.hpp>
- #include <cstddef>
- #include <iostream>
- #include <type_traits>
- #include <vector>
- #define BOOST_MATH_BUFFER_SIZE 65536
- namespace boost {
- namespace math {
- namespace differentiation {
- namespace reverse_mode {
- /* forward declarations for utitlity functions */
- template<typename RealType, size_t DerivativeOrder, class DerivedExpression>
- struct expression;
- template<typename RealType, size_t DerivativeOrder>
- class rvar;
- template<typename RealType,
- size_t DerivativeOrder,
- typename LHS,
- typename RHS,
- typename ConcreteBinaryOperation>
- struct abstract_binary_expression;
- template<typename RealType, size_t DerivativeOrder, typename ARG, typename ConcreteBinaryOperation>
- struct abstract_unary_expression;
- template<typename RealType, size_t DerivativeOrder>
- class gradient_node; // forward declaration for tape
- // manages nodes in computational graph
- template<typename RealType, size_t DerivativeOrder, size_t buffer_size = BOOST_MATH_BUFFER_SIZE>
- class gradient_tape
- {
- /** @brief tape (graph) management class for autodiff
- * holds all the data structures for autodiff */
- private:
- /* type decays to order - 1 to support higher order derivatives */
- using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
- /* adjoints are the overall derivative, and derivatives are the "local"
- * derivative */
- detail::flat_linear_allocator<inner_t, buffer_size> adjoints_;
- detail::flat_linear_allocator<inner_t, buffer_size> derivatives_;
- detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder>, buffer_size>
- gradient_nodes_;
- detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder> *, buffer_size>
- argument_nodes_;
- // compile time check if emplace_back calls on zero
- template<size_t n>
- gradient_node<RealType, DerivativeOrder> *fill_node_at_compile_time(
- std::true_type, gradient_node<RealType, DerivativeOrder> *node_ptr)
- {
- node_ptr->derivatives_ = derivatives_.template emplace_back_n<n>();
- node_ptr->argument_nodes_ = argument_nodes_.template emplace_back_n<n>();
- return node_ptr;
- }
- template<size_t n>
- gradient_node<RealType, DerivativeOrder> *fill_node_at_compile_time(
- std::false_type, gradient_node<RealType, DerivativeOrder> *node_ptr)
- {
- node_ptr->derivatives_ = nullptr;
- node_ptr->argument_adjoints_ = nullptr;
- node_ptr->argument_nodes_ = nullptr;
- return node_ptr;
- }
- public:
- /* gradient node stores iterators to its data memebers
- * (adjoint/derivative/arguments) so that in case flat linear allocator
- * reaches its block boundary and needs more memory for that node, the
- * iterator can be invoked to access it */
- using adjoint_iterator = typename detail::flat_linear_allocator<inner_t, buffer_size>::iterator;
- using derivatives_iterator =
- typename detail::flat_linear_allocator<inner_t, buffer_size>::iterator;
- using gradient_nodes_iterator =
- typename detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder>,
- buffer_size>::iterator;
- using argument_nodes_iterator =
- typename detail::flat_linear_allocator<gradient_node<RealType, DerivativeOrder> *,
- buffer_size>::iterator;
- gradient_tape() { clear(); };
- gradient_tape(const gradient_tape &) = delete;
- gradient_tape &operator=(const gradient_tape &) = delete;
- gradient_tape(gradient_tape &&other) = delete;
- gradient_tape operator=(gradient_tape &&other) = delete;
- ~gradient_tape() noexcept { clear(); }
- void clear()
- {
- adjoints_.clear();
- derivatives_.clear();
- gradient_nodes_.clear();
- argument_nodes_.clear();
- }
- // no derivatives or arguments
- gradient_node<RealType, DerivativeOrder> *emplace_leaf_node()
- {
- gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
- node->adjoint_ = adjoints_.emplace_back();
- node->derivatives_ = derivatives_iterator(); // nullptr;
- node->argument_nodes_ = argument_nodes_iterator(); // nullptr;
- return node;
- };
- // single argument, single derivative
- gradient_node<RealType, DerivativeOrder> *emplace_active_unary_node()
- {
- gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
- node->n_ = 1;
- node->adjoint_ = adjoints_.emplace_back();
- node->derivatives_ = derivatives_.emplace_back();
- return node;
- };
- // arbitrary number of arguments/derivatives (compile time)
- template<size_t n>
- gradient_node<RealType, DerivativeOrder> *emplace_active_multi_node()
- {
- gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
- node->n_ = n;
- node->adjoint_ = adjoints_.emplace_back();
- // emulate if constexpr
- return fill_node_at_compile_time<n>(std::integral_constant<bool, (n > 0)>{}, node);
- }
- // same as above at runtime
- gradient_node<RealType, DerivativeOrder> *emplace_active_multi_node(size_t n)
- {
- gradient_node<RealType, DerivativeOrder> *node = &*gradient_nodes_.emplace_back();
- node->n_ = n;
- node->adjoint_ = adjoints_.emplace_back();
- if (n > 0) {
- node->derivatives_ = derivatives_.emplace_back_n(n);
- node->argument_nodes_ = argument_nodes_.emplace_back_n(n);
- }
- return node;
- };
- /* manual reset button for all adjoints */
- void zero_grad()
- {
- const RealType zero = RealType(0.0);
- adjoints_.fill(zero);
- }
- // return type is an iterator
- auto begin() { return gradient_nodes_.begin(); }
- auto end() { return gradient_nodes_.end(); }
- auto find(gradient_node<RealType, DerivativeOrder> *node)
- {
- return gradient_nodes_.find(node);
- };
- void add_checkpoint()
- {
- gradient_nodes_.add_checkpoint();
- adjoints_.add_checkpoint();
- derivatives_.add_checkpoint();
- argument_nodes_.add_checkpoint();
- };
- auto last_checkpoint() { return gradient_nodes_.last_checkpoint(); };
- auto first_checkpoint() { return gradient_nodes_.last_checkpoint(); };
- auto checkpoint_at(size_t index) { return gradient_nodes_.get_checkpoint_at(index); };
- void rewind_to_last_checkpoint()
- {
- gradient_nodes_.rewind_to_last_checkpoint();
- adjoints_.rewind_to_last_checkpoint();
- derivatives_.rewind_to_last_checkpoint();
- argument_nodes_.rewind_to_last_checkpoint();
- };
- void rewind_to_checkpoint_at(size_t index) // index is "checkpoint" index. so
- // order which checkpoint was set
- {
- gradient_nodes_.rewind_to_checkpoint_at(index);
- adjoints_.rewind_to_checkpoint_at(index);
- derivatives_.rewind_to_checkpoint_at(index);
- argument_nodes_.rewind_to_checkpoint_at(index);
- }
- // rewind to beginning of computational graph
- void rewind()
- {
- gradient_nodes_.rewind();
- adjoints_.rewind();
- derivatives_.rewind();
- argument_nodes_.rewind();
- }
- // random acces
- gradient_node<RealType, DerivativeOrder> &operator[](size_t i) { return gradient_nodes_[i]; }
- const gradient_node<RealType, DerivativeOrder> &operator[](size_t i) const
- {
- return gradient_nodes_[i];
- }
- };
- // class rvar;
- template<typename RealType, size_t DerivativeOrder> // no CRTP, just storage
- class gradient_node
- {
- /*
- * @brief manages adjoints, derivatives, and stores points to argument
- * adjoints pointers to arguments aren't needed here
- * */
- public:
- using adjoint_iterator = typename gradient_tape<RealType, DerivativeOrder>::adjoint_iterator;
- using derivatives_iterator =
- typename gradient_tape<RealType, DerivativeOrder>::derivatives_iterator;
- using argument_nodes_iterator =
- typename gradient_tape<RealType, DerivativeOrder>::argument_nodes_iterator;
- private:
- size_t n_;
- using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
- /* these are iterators in case
- * flat linear allocator is at capacity, and needs to allocate a new block of
- * memory. */
- adjoint_iterator adjoint_;
- derivatives_iterator derivatives_;
- argument_nodes_iterator argument_nodes_;
- public:
- friend class gradient_tape<RealType, DerivativeOrder>;
- friend class rvar<RealType, DerivativeOrder>;
- gradient_node() = default;
- explicit gradient_node(const size_t n)
- : n_(n)
- , adjoint_(nullptr)
- , derivatives_(nullptr)
- {}
- explicit gradient_node(const size_t n,
- RealType *adjoint,
- RealType *derivatives,
- rvar<RealType, DerivativeOrder> **arguments)
- : n_(n)
- , adjoint_(adjoint)
- , derivatives_(derivatives)
- { static_cast<void>(arguments); }
- inner_t get_adjoint_v() const { return *adjoint_; }
- inner_t get_derivative_v(size_t arg_id) const { return derivatives_[static_cast<ptrdiff_t>(arg_id)]; };
- inner_t get_argument_adjoint_v(size_t arg_id) const
- {
- return *argument_nodes_[static_cast<ptrdiff_t>(arg_id)]->adjoint_;
- }
- adjoint_iterator get_adjoint_ptr() { return adjoint_; }
- adjoint_iterator get_adjoint_ptr() const { return adjoint_; };
- void update_adjoint_v(inner_t value) { *adjoint_ = value; };
- void update_derivative_v(size_t arg_id, inner_t value) { derivatives_[static_cast<ptrdiff_t>(arg_id)] = value; };
- void update_argument_adj_v(size_t arg_id, inner_t value)
- {
- argument_nodes_[static_cast<ptrdiff_t>(arg_id)]->update_adjoint_v(value);
- };
- void update_argument_ptr_at(size_t arg_id, gradient_node<RealType, DerivativeOrder> *node_ptr)
- {
- argument_nodes_[static_cast<ptrdiff_t>(arg_id)] = node_ptr;
- }
- void backward()
- {
- if (!n_) // leaf node
- return;
- using boost::math::differentiation::reverse_mode::fabs;
- using std::fabs;
- if (!adjoint_ || fabs(*adjoint_) < 2 * std::numeric_limits<RealType>::epsilon())
- return;
- if (!argument_nodes_) // no arguments
- return;
- if (!derivatives_) // no derivatives
- return;
- for (size_t i = 0; i < n_; ++i) {
- auto adjoint = get_adjoint_v();
- auto derivative = get_derivative_v(i);
- auto argument_adjoint = get_argument_adjoint_v(i);
- update_argument_adj_v(i, argument_adjoint + derivative * adjoint);
- }
- }
- };
- /****************************************************************************************************************/
- template<typename RealType, size_t DerivativeOrder>
- inline gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &get_active_tape()
- {
- static BOOST_MATH_THREAD_LOCAL gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE>
- tape;
- return tape;
- }
- template<typename RealType, size_t DerivativeOrder = 1>
- class rvar : public expression<RealType, DerivativeOrder, rvar<RealType, DerivativeOrder>>
- {
- private:
- using inner_t = rvar_t<RealType, DerivativeOrder - 1>;
- friend class gradient_node<RealType, DerivativeOrder>;
- inner_t value_;
- gradient_node<RealType, DerivativeOrder> *node_ = nullptr;
- template<typename, size_t>
- friend class rvar;
- /*****************************************************************************************/
- /**
- * @brief implementation helpers for get_value_at
- */
- template<size_t target_order, size_t current_order>
- struct get_value_at_impl
- {
- static_assert(target_order <= current_order, "Requested depth exceeds variable order.");
- /** @return value_ at rvar_t<T,current_order - 1>
- */
- static auto &get(rvar<RealType, current_order> &v)
- {
- return get_value_at_impl<target_order, current_order - 1>::get(v.get_value());
- }
- /** @return const value_ at rvar_t<T,current_order - 1>
- */
- static const auto &get(const rvar<RealType, current_order> &v)
- {
- return get_value_at_impl<target_order, current_order - 1>::get(v.get_value());
- }
- };
- /** @brief base case specialization for target_order == current order
- */
- template<size_t target_order>
- struct get_value_at_impl<target_order, target_order>
- {
- /** @return value_ at rvar_t<T,target_order>
- */
- static auto &get(rvar<RealType, target_order> &v) { return v; }
- /** @return const value_ at rvar_t<T,target_order>
- */
- static const auto &get(const rvar<RealType, target_order> &v) { return v; }
- };
- /*****************************************************************************************/
- void make_leaf_node()
- {
- gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
- = get_active_tape<RealType, DerivativeOrder>();
- node_ = tape.emplace_leaf_node();
- }
- void make_unary_node()
- {
- gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
- = get_active_tape<RealType, DerivativeOrder>();
- node_ = tape.emplace_active_unary_node();
- }
- void make_multi_node(size_t n)
- {
- gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
- = get_active_tape<RealType, DerivativeOrder>();
- node_ = tape.emplace_active_multi_node(n);
- }
- template<size_t n>
- void make_multi_node()
- {
- gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
- = get_active_tape<RealType, DerivativeOrder>();
- node_ = tape.template emplace_active_multi_node<n>();
- }
- template<typename E>
- void make_rvar_from_expr(const expression<RealType, DerivativeOrder, E> &expr)
- {
- make_multi_node<detail::count_rvars<E, DerivativeOrder>>();
- expr.template propagatex<0>(node_, inner_t(static_cast<RealType>(1.0)));
- }
- RealType get_item_impl(std::true_type) const
- {
- return value_.get_item_impl(std::integral_constant<bool, (DerivativeOrder - 1 > 1)>{});
- }
- RealType get_item_impl(std::false_type) const { return value_; }
- public:
- using value_type = RealType;
- static constexpr size_t DerivativeOrder_v = DerivativeOrder;
- rvar()
- : value_()
- {
- make_leaf_node();
- }
- rvar(const RealType value)
- : value_(inner_t{static_cast<RealType>(value)})
- {
- make_leaf_node();
- }
- rvar &operator=(RealType v)
- {
- value_ = inner_t(v);
- if (node_ == nullptr) {
- make_leaf_node();
- }
- return *this;
- }
- rvar(const rvar<RealType, DerivativeOrder> &other) = default;
- rvar &operator=(const rvar<RealType, DerivativeOrder> &other) = default;
- template<size_t arg_index>
- void propagatex(gradient_node<RealType, DerivativeOrder> *node, inner_t adj) const
- {
- node->update_derivative_v(arg_index, adj);
- node->update_argument_ptr_at(arg_index, node_);
- }
- template<class E>
- rvar(const expression<RealType, DerivativeOrder, E> &expr)
- {
- value_ = expr.evaluate();
- make_rvar_from_expr(expr);
- }
- template<typename T,
- typename = std::enable_if_t<is_floating_point_v<T> && !is_same_v<T, RealType>>>
- rvar(T v)
- : value_(inner_t{static_cast<RealType>(v)})
- {
- make_leaf_node();
- }
- template<class E>
- rvar &operator=(const expression<RealType, DerivativeOrder, E> &expr)
- {
- value_ = expr.evaluate();
- make_rvar_from_expr(expr);
- return *this;
- }
- /***************************************************************************************************/
- template<class E>
- rvar<RealType, DerivativeOrder> &operator+=(const expression<RealType, DerivativeOrder, E> &expr)
- {
- *this = *this + expr;
- return *this;
- }
- template<class E>
- rvar<RealType, DerivativeOrder> &operator*=(const expression<RealType, DerivativeOrder, E> &expr)
- {
- *this = *this * expr;
- return *this;
- }
- template<class E>
- rvar<RealType, DerivativeOrder> &operator-=(const expression<RealType, DerivativeOrder, E> &expr)
- {
- *this = *this - expr;
- return *this;
- }
- template<class E>
- rvar<RealType, DerivativeOrder> &operator/=(const expression<RealType, DerivativeOrder, E> &expr)
- {
- *this = *this / expr;
- return *this;
- }
- /***************************************************************************************************/
- rvar<RealType, DerivativeOrder> &operator+=(const RealType &v)
- {
- *this = *this + v;
- return *this;
- }
- rvar<RealType, DerivativeOrder> &operator*=(const RealType &v)
- {
- *this = *this * v;
- return *this;
- }
- rvar<RealType, DerivativeOrder> &operator-=(const RealType &v)
- {
- *this = *this - v;
- return *this;
- }
- rvar<RealType, DerivativeOrder> &operator/=(const RealType &v)
- {
- *this = *this / v;
- return *this;
- }
- /***************************************************************************************************/
- const inner_t &adjoint() const { return *node_->get_adjoint_ptr(); }
- inner_t &adjoint() { return *node_->get_adjoint_ptr(); }
- const inner_t &evaluate() const { return value_; };
- inner_t &get_value() { return value_; };
- explicit operator RealType() const { return item(); }
- explicit operator int() const { return static_cast<int>(item()); }
- explicit operator long() const { return static_cast<long>(item()); }
- explicit operator long long() const { return static_cast<long long>(item()); }
- /**
- * @brief same as evaluate but returns proper depth for higher order derivatives
- * @return value_ at depth N
- */
- template<size_t N>
- auto &get_value_at()
- {
- static_assert(N <= DerivativeOrder, "Requested depth exceeds variable order.");
- return get_value_at_impl<N, DerivativeOrder>::get(*this);
- }
- /** @brief same as above but const
- */
- template<size_t N>
- const auto &get_value_at() const
- {
- static_assert(N <= DerivativeOrder, "Requested depth exceeds variable order.");
- return get_value_at_impl<N, DerivativeOrder>::get(*this);
- }
- RealType item() const
- {
- return get_item_impl(std::integral_constant<bool, (DerivativeOrder > 1)>{});
- }
- void backward()
- {
- gradient_tape<RealType, DerivativeOrder, BOOST_MATH_BUFFER_SIZE> &tape
- = get_active_tape<RealType, DerivativeOrder>();
- auto it = tape.find(node_);
- it->update_adjoint_v(inner_t(static_cast<RealType>(1.0)));
- while (it != tape.begin()) {
- it->backward();
- --it;
- }
- it->backward();
- }
- };
- template<typename RealType, size_t DerivativeOrder>
- std::ostream &operator<<(std::ostream &os, const rvar<RealType, DerivativeOrder> var)
- {
- os << "rvar<" << DerivativeOrder << ">(" << var.item() << "," << var.adjoint() << ")";
- return os;
- }
- template<typename RealType, size_t DerivativeOrder, typename E>
- std::ostream &operator<<(std::ostream &os, const expression<RealType, DerivativeOrder, E> &expr)
- {
- rvar<RealType, DerivativeOrder> tmp = expr;
- os << "rvar<" << DerivativeOrder << ">(" << tmp.item() << "," << tmp.adjoint() << ")";
- return os;
- }
- template<typename RealType, size_t DerivativeOrder>
- rvar<RealType, DerivativeOrder> make_rvar(const RealType v)
- {
- static_assert(DerivativeOrder > 0, "rvar order must be >= 1");
- return rvar<RealType, DerivativeOrder>(v);
- }
- template<typename RealType, size_t DerivativeOrder, typename E>
- rvar<RealType, DerivativeOrder> make_rvar(const expression<RealType, DerivativeOrder, E> &expr)
- {
- static_assert(DerivativeOrder > 0, "rvar order must be >= 1");
- return rvar<RealType, DerivativeOrder>(expr);
- }
- namespace detail {
- /** @brief helper overload for grad implementation.
- * @return vector<rvar<T,order-1> of gradients of the autodiff graph.
- * specialization for autodiffing through autodiff. i.e. being able to
- * compute higher order grads
- */
- template<typename RealType, size_t DerivativeOrder>
- struct grad_op_impl
- {
- std::vector<rvar<RealType, DerivativeOrder - 1>> operator()(
- rvar<RealType, DerivativeOrder> &f, std::vector<rvar<RealType, DerivativeOrder> *> &x)
- {
- auto &tape = get_active_tape<RealType, DerivativeOrder>();
- tape.zero_grad();
- f.backward();
- std::vector<rvar<RealType, DerivativeOrder - 1>> gradient_vector;
- gradient_vector.reserve(x.size());
- for (auto &xi : x) {
- gradient_vector.emplace_back(xi->adjoint());
- }
- return gradient_vector;
- }
- };
- /** @brief helper overload for grad implementation.
- * @return vector<T> of gradients of the autodiff graph.
- * base specialization for order 1 autodiff
- */
- template<typename T>
- struct grad_op_impl<T, 1>
- {
- std::vector<T> operator()(rvar<T, 1> &f, std::vector<rvar<T, 1> *> &x)
- {
- gradient_tape<T, 1, BOOST_MATH_BUFFER_SIZE> &tape = get_active_tape<T, 1>();
- tape.zero_grad();
- f.backward();
- std::vector<T> gradient_vector;
- gradient_vector.reserve(x.size());
- for (auto &xi : x) {
- gradient_vector.push_back(xi->adjoint());
- }
- return gradient_vector;
- }
- };
- /** @brief helper overload for higher order autodiff
- * @return nested vector representing N-d tensor of
- * higher order derivatives
- */
- template<size_t N,
- typename RealType,
- size_t DerivativeOrder_1,
- size_t DerivativeOrder_2,
- typename Enable = void>
- struct grad_nd_impl
- {
- auto operator()(rvar<RealType, DerivativeOrder_1> &f,
- std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
- {
- static_assert(N > 1, "N must be greater than 1 for this template");
- auto current_grad = grad(f, x); // vector<rvar<T,DerivativeOrder_1-1>> or vector<T>
- std::vector<decltype(grad_nd_impl<N - 1, RealType, DerivativeOrder_1 - 1, DerivativeOrder_2>()(
- current_grad[0], x))>
- result;
- result.reserve(current_grad.size());
- for (auto &g_i : current_grad) {
- result.push_back(
- grad_nd_impl<N - 1, RealType, DerivativeOrder_1 - 1, DerivativeOrder_2>()(g_i, x));
- }
- return result;
- }
- };
- /** @brief spcialization for order = 1,
- * @return vector<rvar<T,DerivativeOrder_1-1>> gradients */
- template<typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
- struct grad_nd_impl<1, RealType, DerivativeOrder_1, DerivativeOrder_2>
- {
- auto operator()(rvar<RealType, DerivativeOrder_1> &f,
- std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
- {
- return grad(f, x);
- }
- };
- template<typename ptr>
- struct rvar_order;
- template<typename RealType, size_t DerivativeOrder>
- struct rvar_order<rvar<RealType, DerivativeOrder> *>
- {
- static constexpr size_t value = DerivativeOrder;
- };
- } // namespace detail
- /**
- * @brief grad computes gradient with respect to vector of pointers x
- * @param f -> computational graph
- * @param x -> variables gradients to record. Note ALL gradients of the graph
- * are computed simultaneously, only the ones w.r.t. x are returned
- * @return vector<rvar<T,DerivativeOrder_1 - 1> of gradients. in the case of DerivativeOrder_1 = 1
- * rvar<T,DerivativeOrder_1-1> decays to T
- *
- * safe to call recursively with grad(grad(grad...
- */
- template<typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
- auto grad(rvar<RealType, DerivativeOrder_1> &f, std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
- {
- static_assert(DerivativeOrder_1 <= DerivativeOrder_2,
- "variable differentiating w.r.t. must have order >= function order");
- std::vector<rvar<RealType, DerivativeOrder_1> *> xx;
- xx.reserve(x.size());
- for (auto &xi : x)
- xx.push_back(&(xi->template get_value_at<DerivativeOrder_1>()));
- return detail::grad_op_impl<RealType, DerivativeOrder_1>{}(f, xx);
- }
- /** @brief variadic overload of above
- */
- template<typename RealType, size_t DerivativeOrder_1, typename First, typename... Other>
- auto grad(rvar<RealType, DerivativeOrder_1> &f, First first, Other... other)
- {
- constexpr size_t DerivativeOrder_2 = detail::rvar_order<First>::value;
- static_assert(DerivativeOrder_1 <= DerivativeOrder_2,
- "variable differentiating w.r.t. must have order >= function order");
- std::vector<rvar<RealType, DerivativeOrder_2> *> x_vec = {first, other...};
- return grad(f, x_vec);
- }
- /** @brief computes hessian matrix of computational graph w.r.t.
- * vector of variables x.
- * @return std::vector<std::vector<rvar<T,DerivativeOrder_1-2>> hessian matrix
- * rvar<T,2> decays to T
- *
- * NOT recursion safe, cannot do hess(hess(
- */
- template<typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
- auto hess(rvar<RealType, DerivativeOrder_1> &f, std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
- {
- return detail::grad_nd_impl<2, RealType, DerivativeOrder_1, DerivativeOrder_2>{}(f, x);
- }
- /** @brief variadic overload of above
- */
- template<typename RealType, size_t DerivativeOrder_1, typename First, typename... Other>
- auto hess(rvar<RealType, DerivativeOrder_1> &f, First first, Other... other)
- {
- constexpr size_t DerivativeOrder_2 = detail::rvar_order<First>::value;
- std::vector<rvar<RealType, DerivativeOrder_2> *> x_vec = {first, other...};
- return hess(f, x_vec);
- }
- /** @brief comput N'th gradient of computational graph w.r.t. x
- * @return vector<vector<.... up N nestings representing tensor
- * of gradients of order N
- *
- * NOT recursively safe, cannot do grad_nd(grad_nd(... etc...
- */
- template<size_t N, typename RealType, size_t DerivativeOrder_1, size_t DerivativeOrder_2>
- auto grad_nd(rvar<RealType, DerivativeOrder_1> &f,
- std::vector<rvar<RealType, DerivativeOrder_2> *> &x)
- {
- static_assert(DerivativeOrder_1 >= N, "Function order must be at least N");
- static_assert(DerivativeOrder_2 >= DerivativeOrder_1,
- "Variable order must be at least function order");
- return detail::grad_nd_impl<N, RealType, DerivativeOrder_1, DerivativeOrder_2>()(f, x);
- }
- /** @brief variadic overload of above
- */
- template<size_t N, typename ftype, typename First, typename... Other>
- auto grad_nd(ftype &f, First first, Other... other)
- {
- using RealType = typename ftype::value_type;
- constexpr size_t DerivativeOrder_1 = detail::rvar_order<ftype *>::value;
- constexpr size_t DerivativeOrder_2 = detail::rvar_order<First>::value;
- std::vector<rvar<RealType, DerivativeOrder_2> *> x_vec = {first, other...};
- return detail::grad_nd_impl<N, RealType, DerivativeOrder_1, DerivativeOrder_1>{}(f, x_vec);
- }
- } // namespace reverse_mode
- } // namespace differentiation
- } // namespace math
- } // namespace boost
- namespace std {
- // copied from forward mode
- template<typename RealType, size_t DerivativeOrder>
- class numeric_limits<boost::math::differentiation::reverse_mode::rvar<RealType, DerivativeOrder>>
- : public numeric_limits<typename boost::math::differentiation::reverse_mode::
- rvar<RealType, DerivativeOrder>::value_type>
- {};
- } // namespace std
- #endif
|