diff --git a/stan/math/fwd/fun.hpp b/stan/math/fwd/fun.hpp index b56232a0121..a100117751c 100644 --- a/stan/math/fwd/fun.hpp +++ b/stan/math/fwd/fun.hpp @@ -80,6 +80,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/fwd/fun/log_add_exp.hpp b/stan/math/fwd/fun/log_add_exp.hpp new file mode 100644 index 00000000000..2a446a301cb --- /dev/null +++ b/stan/math/fwd/fun/log_add_exp.hpp @@ -0,0 +1,162 @@ +#ifndef STAN_MATH_FWD_FUN_LOG_ADD_EXP_HPP +#define STAN_MATH_FWD_FUN_LOG_ADD_EXP_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +// Overload for fvar and fvar +template +inline fvar log_add_exp(const fvar& x1, const fvar& x2) { + auto val = stan::math::log_add_exp(x1.val_, x2.val_); + + auto exp_x1 = stan::math::exp(x1.val_); + auto exp_x2 = stan::math::exp(x2.val_); + auto sum_exp = exp_x1 + exp_x2; + + auto grad1 = exp_x1 / sum_exp; + auto grad2 = exp_x2 / sum_exp; + + return fvar(val, x1.d_ * grad1 + x2.d_ * grad2); +} + +template +inline fvar log_add_exp(const fvar& x1, double x2) { + if (x1.val_ == NEGATIVE_INFTY) { + return fvar(x2, 0.0); // log_add_exp(-∞, b) = b + } + return log_add_exp(x2, x1); +} + +template +inline fvar log_add_exp(double x1, const fvar& x2) { + if (x2.val_ == NEGATIVE_INFTY) { + return fvar(x1, 0.0); // log_add_exp(a, -∞) = a + } + auto val = stan::math::log_add_exp(x1, x2.val_); + auto exp_x2 = stan::math::exp(x2.val_); + auto grad = exp_x2 / (stan::math::exp(x1) + exp_x2); + return fvar(val, x2.d_ * grad); +} + +// Overload for matrices of fvar +template +inline Eigen::Matrix, -1, -1> log_add_exp( + const Eigen::Matrix, -1, -1>& a, + const Eigen::Matrix, -1, -1>& b) { + using fvar_mat_type = Eigen::Matrix, -1, -1>; + fvar_mat_type result(a.rows(), a.cols()); + + // Check for empty inputs + if (a.size() == 0 || b.size() == 0) { + throw std::invalid_argument("Input containers must not be empty."); + } + + // Check for NaN + if (a.array().isNaN().any() || b.array().isNaN().any()) { + result.setConstant(fvar(std::numeric_limits::quiet_NaN())); + return result; + } + + // Check for infinity + if (a.array().isInf().any() || b.array().isInf().any()) { + result.setConstant(fvar(std::numeric_limits::quiet_NaN())); + return result; + } + + // Apply the log_add_exp operation directly + for (int i = 0; i < a.rows(); ++i) { + for (int j = 0; j < a.cols(); ++j) { + result(i, j) = stan::math::log_add_exp(a(i, j), b(i, j)); + } + } + + return result; // Return the result matrix +} + +// Overload for Eigen vectors +template +inline Eigen::Matrix, -1, 1> log_add_exp( + const Eigen::Matrix, -1, 1>& a, + const Eigen::Matrix, -1, 1>& b) { + using fvar_vec_type = Eigen::Matrix, -1, 1>; + fvar_vec_type result(a.rows()); + + // Check for empty inputs + if (a.size() == 0 || b.size() == 0) { + throw std::invalid_argument("Input containers must not be empty."); + } + + // Check for NaN + if (a.array().isNaN().any() || b.array().isNaN().any()) { + result.setConstant(fvar(std::numeric_limits::quiet_NaN())); + return result; + } + + // Check for infinity + if (a.array().isInf().any() || b.array().isInf().any()) { + result.setConstant(fvar(std::numeric_limits::quiet_NaN())); + return result; + } + + // Apply the log_add_exp operation directly + for (int i = 0; i < a.rows(); ++i) { + result(i) = stan::math::log_add_exp(a(i), b(i)); + } + + return result; // Return the result vector +} + +// Specialization for nested fvar types +template +inline auto log_add_exp( + const Eigen::Matrix>, -1, -1>& a, + const Eigen::Matrix>, -1, -1>& + b) { + using nested_fvar_mat_type + = Eigen::Matrix>, -1, -1>; + nested_fvar_mat_type result(a.rows(), a.cols()); + + // Check for empty inputs + if (a.size() == 0 || b.size() == 0) { + throw std::invalid_argument("Input containers must not be empty."); + } + + // Check for NaN + if (a.array().isNaN().any() || b.array().isNaN().any()) { + result.setConstant(stan::math::fvar>( + std::numeric_limits::quiet_NaN())); + return result; + } + + // Check for infinity + if (a.array().isInf().any() || b.array().isInf().any()) { + result.setConstant(stan::math::fvar>( + std::numeric_limits::quiet_NaN())); + return result; + } + + // Implement the logic for log_add_exp for nested fvar types + for (int i = 0; i < a.rows(); ++i) { + for (int j = 0; j < a.cols(); ++j) { + auto inner_a = a(i, j); + auto inner_b = b(i, j); + result(i, j) = stan::math::log_add_exp(inner_a, inner_b); + } + } + + return result; // Return the result matrix +} + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/prim/fun.hpp b/stan/math/prim/fun.hpp index 21403200b41..1885331725e 100644 --- a/stan/math/prim/fun.hpp +++ b/stan/math/prim/fun.hpp @@ -188,6 +188,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/fun/log_add_exp.hpp b/stan/math/prim/fun/log_add_exp.hpp new file mode 100644 index 00000000000..d41aa4b485c --- /dev/null +++ b/stan/math/prim/fun/log_add_exp.hpp @@ -0,0 +1,159 @@ +#ifndef STAN_MATH_PRIM_FUN_LOG_ADD_EXP_HPP +#define STAN_MATH_PRIM_FUN_LOG_ADD_EXP_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** + * Calculates the elementwise sum of exponentials without overflow. + * + * \f$\log (\exp(a) + \exp(b)) = m + \log(\exp(a-m) + \exp(b-m))\f$, + * + * where \f$m = max(a, b)\f$. + * + * @tparam T1 type of the first variable + * @tparam T2 type of the second variable + * @param a the first variable + * @param b the second variable + */ + +template * = nullptr, + require_all_stan_scalar_t* = nullptr> +inline return_type_t log_add_exp(const T2& a, const T1& b) { + if (a == NEGATIVE_INFTY) { + return b; + } + if (b == NEGATIVE_INFTY) { + return a; + } + if (a == INFTY || b == INFTY) { + return INFTY; + } + + const double max_val = std::max(a, b); + return max_val + std::log(std::exp(a - max_val) + std::exp(b - max_val)); +} + +/** + * Calculates the element-wise log sum of exponentials for two containers. + * For vectors a and b, computes log(exp(a[i]) + exp(b[i])) for each element i. + * If sizes don't match, uses the smaller size. + * + * @tparam T1 type of first container + * @tparam T2 type of second container + * @param a First input container + * @param b Second input container + * @return Container with element-wise log_add_exp results + */ +template * = nullptr> +inline auto log_add_exp(const T& a, const T& b) { + // Check if sizes are compatible + if constexpr (stan::is_eigen::value) { + // Check if both matrices/vectors have the same dimensions + stan::math::check_matching_dims("log_add_exp", "a", a, "b", b); + + // Determine the number of rows and columns for the result + size_t rows = a.rows(); + size_t cols = b.cols(); + using return_t = return_type_t; + + Eigen::Matrix result(rows, cols); + + // Iterate over each element + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + double a_val = (a.cols() == 1) + ? a(i, 0) + : a(i, j); // Handle column vector or matrix + double b_val = (b.rows() == 1) + ? b(0, j) + : b(i, j); // Handle row vector or matrix + + if (a_val == NEGATIVE_INFTY) { + result(i, j) = b_val; + } else if (b_val == NEGATIVE_INFTY) { + result(i, j) = a_val; + } else if (a_val == INFTY || b_val == INFTY) { + result(i, j) = INFTY; + } else { + result(i, j) = log_sum_exp(a_val, b_val); + } + } + } + + return result; + } else if constexpr (std::is_same_v>) { + // Handle std::vector + if (a.size() != b.size()) { + throw std::invalid_argument("Sizes of x and y must match."); + } + + using return_t = return_type_t; + std::vector result(a.size()); + + for (size_t i = 0; i < a.size(); ++i) { + double a_val = a[i]; + double b_val = b[i]; + + if (a_val == NEGATIVE_INFTY) { + result[i] = b_val; + } else if (b_val == NEGATIVE_INFTY) { + result[i] = a_val; + } else if (a_val == INFTY || b_val == INFTY) { + result[i] = INFTY; + } else { + result[i] = log_sum_exp(a_val, b_val); + } + } + + return result; + } else { + throw std::invalid_argument("Unsupported container type."); + } +} + +/** + * Enables the vectorized application of the log_add_exp function, + * when the first and/or second arguments are containers. + * + * @tparam T1 + * @tparam T2 + * @param a + * @param b + * @return auto + */ +template * = nullptr> +inline auto log_add_exp(const T1& a, const T2& b) { + // Check if both are Eigen/vectors + if constexpr (stan::is_eigen::value && stan::is_eigen::value) { + // Check if both matrices/vectors have the same dimensions + stan::math::check_matching_dims("log_add_exp", "a", a, "b", b); + } else { + // Check if sizes are compatible for other types + if (a.size() != b.size()) { + throw std::invalid_argument( + "Sizes of x and y must match or be compatible."); + } + } + + // If dimensions are verified to match, apply the operation + return apply_scalar_binary( + a, b, [](const auto& c, const auto& d) { return log_add_exp(c, d); }); +} + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/rev/fun.hpp b/stan/math/rev/fun.hpp index 929e88aa4c3..edc1d829a16 100644 --- a/stan/math/rev/fun.hpp +++ b/stan/math/rev/fun.hpp @@ -116,6 +116,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/rev/fun/log_add_exp.hpp b/stan/math/rev/fun/log_add_exp.hpp new file mode 100644 index 00000000000..cfe7711b796 --- /dev/null +++ b/stan/math/rev/fun/log_add_exp.hpp @@ -0,0 +1,88 @@ +#ifndef STAN_MATH_REV_FUN_LOG_ADD_EXP_HPP +#define STAN_MATH_REV_FUN_LOG_ADD_EXP_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace stan { +namespace math { +namespace internal { + +class log_add_exp_vv_vari : public op_vv_vari { + public: + log_add_exp_vv_vari(vari* avi, vari* bvi) + : op_vv_vari(log_add_exp(avi->val_, bvi->val_), avi, bvi) {} + void chain() { + double exp_a = std::exp(avi_->val_); + double exp_b = std::exp(bvi_->val_); + double sum_exp = exp_a + exp_b; + + avi_->adj_ += adj_ * (exp_a / sum_exp); + bvi_->adj_ += adj_ * (exp_b / sum_exp); + } +}; + +class log_add_exp_vd_vari : public op_vd_vari { + public: + log_add_exp_vd_vari(vari* avi, double b) + : op_vd_vari(log_add_exp(avi->val_, b), avi, b) {} + void chain() { + if (val_ == NEGATIVE_INFTY) { + avi_->adj_ += adj_; + } else { + double exp_a = std::exp(avi_->val_); + avi_->adj_ += adj_ * (exp_a / (exp_a + std::exp(bd_))); + } + } +}; + +} // namespace internal + +/** + * Returns the element-wise log sum of exponentials. + */ +inline var log_add_exp(const var& a, const var& b) { + return var(new internal::log_add_exp_vv_vari(a.vi_, b.vi_)); +} + +/** + * Returns the log sum of exponentials. + */ +inline var log_add_exp(const var& a, double b) { + return var(new internal::log_add_exp_vd_vari(a.vi_, b)); +} + +/** + * Returns the element-wise log sum of exponentials. + */ +inline var log_add_exp(double a, const var& b) { + return var(new internal::log_add_exp_vd_vari(b.vi_, a)); +} + +/** + * Returns element-wise log sum of exponentials for Eigen types. + * + * @tparam T A type inheriting from EigenBase with var scalar type + * @param x First input + * @param y Second input + */ +template * = nullptr> +inline T log_add_exp(const T& x, const T& y) { + return apply_scalar_binary( + x, y, [](const auto& a, const auto& b) { return log_add_exp(a, b); }); +} + +} // namespace math +} // namespace stan + +#endif diff --git a/test/unit/math/mix/fun/log_add_exp_test.cpp b/test/unit/math/mix/fun/log_add_exp_test.cpp new file mode 100644 index 00000000000..cd58b217b83 --- /dev/null +++ b/test/unit/math/mix/fun/log_add_exp_test.cpp @@ -0,0 +1,86 @@ +#include +#include +#include + +TEST(MathMixMatFun, logAddExp) { + auto f = [](const auto& x, const auto& y) { + return stan::math::log_add_exp(x, y); + }; + // Test with finite values + Eigen::VectorXd x1(2); + x1 << 2.0, 1.0; + Eigen::VectorXd y1(2); + y1 << 3.0, 2.0; + stan::test::expect_ad(f, x1, y1); + + // Test with negative infinity + + stan::test::expect_ad(f, stan::math::NEGATIVE_INFTY, 1.0); + stan::test::expect_ad(f, 1.0, stan::math::NEGATIVE_INFTY); + + // Test with infinity + stan::test::expect_ad(f, stan::math::INFTY, stan::math::INFTY); +} + +TEST(MathMixMatFun, log_add_exp_elementwise_values) { + auto f = [](const auto& x, const auto& y) { + return stan::math::log_add_exp(x, y); + }; + + Eigen::VectorXd x1(2); + x1 << 2.0, 1.0; + Eigen::VectorXd y1(2); + y1 << 3.0, 2.0; + stan::test::expect_ad(f, x1, y1); + + Eigen::VectorXd x2(2); + x2 << 0.5, -1.0; + Eigen::VectorXd y2(2); + y2 << 1.0, 2.0; + stan::test::expect_ad(f, x2, y2); + + // Test with infinity + Eigen::VectorXd x3(2); + x3 << std::numeric_limits::infinity(), 1.0; + Eigen::VectorXd y3(2); + y3 << 2.0, std::numeric_limits::infinity(); + + Eigen::VectorXd result = f(x3, y3); + EXPECT_TRUE(std::isinf(result[0])); // Expect infinity for the first element + EXPECT_TRUE(std::isinf(result[1])); +} + +TEST(MathMixMatFun, log_add_exp_mismatched_sizes) { + auto f = [](const auto& x, const auto& y) { + return stan::math::log_add_exp(x, y); + }; + + std::vector x{1.0, 2.0}; + std::vector y{1.0, 2.0, 3.0}; + + stan::test::expect_ad(f, x, y); + stan::test::expect_ad(f, y, x); +} + +TEST(MathMixMatFun, log_add_exp_container_tests) { + auto f = [](const auto& x, const auto& y) { + return stan::math::log_add_exp(x, y); + }; + + // Test with Eigen::MatrixXd + Eigen::MatrixXd x_row(1, 2); + x_row << 2.0, 1.0; + Eigen::MatrixXd y_row(1, 2); + y_row << 3.0, 2.0; + + stan::test::expect_ad(f, x_row, y_row); + + // Additional tests with mismatched sizes + Eigen::MatrixXd x_mismatch(2, 1); + x_mismatch << 0.5, -1.0; + Eigen::MatrixXd y_mismatch(1, 3); + y_mismatch << 1.0, 2.0, 3.0; + + EXPECT_THROW(stan::math::log_add_exp(x_mismatch, y_mismatch), + std::invalid_argument); +}