diff --git a/stan/math/prim/fun/hypergeometric_2F1.hpp b/stan/math/prim/fun/hypergeometric_2F1.hpp index ae327e033f1..730be40fc45 100644 --- a/stan/math/prim/fun/hypergeometric_2F1.hpp +++ b/stan/math/prim/fun/hypergeometric_2F1.hpp @@ -16,7 +16,7 @@ #include #include #include -#include +#include #include namespace stan { @@ -187,7 +187,7 @@ inline return_type_t hypergeometric_2F1(const Ta1& a1, a_args << a1, a2; b_args << b; - return hypergeometric_pFq(a_args, b_args, z); + return internal::hypergeometric_pFq_helper(a_args, b_args, z); } catch (const std::exception& e) { // Apply Euler's hypergeometric transformation if function // will not converge with current arguments @@ -200,7 +200,8 @@ inline return_type_t hypergeometric_2F1(const Ta1& a1, a_args << a1_t, a2_t; b_args << b_t; - return hypergeometric_pFq(a_args, b_args, z_t) / pow(1 - z, a2); + return internal::hypergeometric_pFq_helper(a_args, b_args, z_t) + / pow(1 - z, a2); } } } // namespace math diff --git a/stan/math/prim/fun/hypergeometric_pFq.hpp b/stan/math/prim/fun/hypergeometric_pFq.hpp index c2b30610314..65b387b3f3c 100644 --- a/stan/math/prim/fun/hypergeometric_pFq.hpp +++ b/stan/math/prim/fun/hypergeometric_pFq.hpp @@ -4,12 +4,14 @@ #include #include #include +#include #include -#include +#include +#include +#include namespace stan { namespace math { - /** * Returns the generalized hypergeometric function applied to the * input arguments: @@ -29,6 +31,13 @@ return_type_t hypergeometric_pFq(const Ta& a, const Tb& b, const Tz& z) { plain_type_t a_ref = a; plain_type_t b_ref = b; + + if (a_ref.size() == 1 && b_ref.size() == 0) { + return hypergeometric_1F0(a_ref[0], z); + } else if (a_ref.size() == 2 && b_ref.size() == 1) { + return hypergeometric_2F1(a_ref[0], a_ref[1], b_ref[0], z); + } + check_finite("hypergeometric_pFq", "a", a_ref); check_finite("hypergeometric_pFq", "b", b_ref); check_finite("hypergeometric_pFq", "z", z); @@ -50,9 +59,7 @@ return_type_t hypergeometric_pFq(const Ta& a, const Tb& b, throw std::domain_error(msg.str()); } - return boost::math::hypergeometric_pFq( - std::vector(a_ref.data(), a_ref.data() + a_ref.size()), - std::vector(b_ref.data(), b_ref.data() + b_ref.size()), z); + return internal::hypergeometric_pFq_helper(a_ref, b_ref, z); } } // namespace math } // namespace stan diff --git a/stan/math/prim/fun/hypergeometric_pFq_helper.hpp b/stan/math/prim/fun/hypergeometric_pFq_helper.hpp new file mode 100644 index 00000000000..8867f316c3c --- /dev/null +++ b/stan/math/prim/fun/hypergeometric_pFq_helper.hpp @@ -0,0 +1,32 @@ +#ifndef STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_PFQ_HELPER_HPP +#define STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_PFQ_HELPER_HPP + +#include +#include +#include + +namespace stan { +namespace math { +namespace internal { +/** + * Implementation for calculating the generalized hypergeometric function + * \f$_pF_q(a_1,...,a_p;b_1,...,b_q;z)\f$. + * + * This is declared separatel to avoid circular dependencies between the + * various hypergeometric functions. + * + * @param[in] a Vector of 'a' arguments to function + * @param[in] b Vector of 'b' arguments to function + * @param[in] z Scalar z argument + * @return Generalized hypergeometric function + */ +template * = nullptr, + require_arithmetic_t* = nullptr> +inline double hypergeometric_pFq_helper(const Ta& a, const Tb& b, const Tz& z) { + return boost::math::hypergeometric_pFq(to_array_1d(a), to_array_1d(b), z); +} +} // namespace internal +} // namespace math +} // namespace stan +#endif