hypergeometric_pdf.hpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. // Copyright 2008 Gautam Sewani
  2. // Copyright 2008 John Maddock
  3. //
  4. // Use, modification and distribution are subject to the
  5. // Boost Software License, Version 1.0.
  6. // (See accompanying file LICENSE_1_0.txt
  7. // or copy at http://www.boost.org/LICENSE_1_0.txt)
  8. #ifndef BOOST_MATH_DISTRIBUTIONS_DETAIL_HG_PDF_HPP
  9. #define BOOST_MATH_DISTRIBUTIONS_DETAIL_HG_PDF_HPP
  10. #include <boost/math/constants/constants.hpp>
  11. #include <boost/math/special_functions/lanczos.hpp>
  12. #include <boost/math/special_functions/gamma.hpp>
  13. #include <boost/math/special_functions/pow.hpp>
  14. #include <boost/math/special_functions/prime.hpp>
  15. #include <boost/math/policies/error_handling.hpp>
  16. #include <algorithm>
  17. #include <cstdint>
  18. #ifdef BOOST_MATH_INSTRUMENT
  19. #include <typeinfo>
  20. #endif
  21. namespace boost{ namespace math{ namespace detail{
  22. template <class T, class Func>
  23. void bubble_down_one(T* first, T* last, Func f)
  24. {
  25. using std::swap;
  26. T* next = first;
  27. ++next;
  28. while((next != last) && (!f(*first, *next)))
  29. {
  30. swap(*first, *next);
  31. ++first;
  32. ++next;
  33. }
  34. }
  35. template <class T>
  36. struct sort_functor
  37. {
  38. sort_functor(const T* exponents) : m_exponents(exponents){}
  39. bool operator()(std::size_t i, std::size_t j)
  40. {
  41. return m_exponents[i] > m_exponents[j];
  42. }
  43. private:
  44. const T* m_exponents;
  45. };
  46. template <class T, class Lanczos, class Policy>
  47. T hypergeometric_pdf_lanczos_imp(T /*dummy*/, std::uint64_t x, std::uint64_t r, std::uint64_t n, std::uint64_t N, const Lanczos&, const Policy&)
  48. {
  49. BOOST_MATH_STD_USING
  50. BOOST_MATH_INSTRUMENT_FPU
  51. BOOST_MATH_INSTRUMENT_VARIABLE(x);
  52. BOOST_MATH_INSTRUMENT_VARIABLE(r);
  53. BOOST_MATH_INSTRUMENT_VARIABLE(n);
  54. BOOST_MATH_INSTRUMENT_VARIABLE(N);
  55. BOOST_MATH_INSTRUMENT_VARIABLE(typeid(Lanczos).name());
  56. T bases[9] = {
  57. T(n) + static_cast<T>(Lanczos::g()) + 0.5f,
  58. T(r) + static_cast<T>(Lanczos::g()) + 0.5f,
  59. T(N - n) + static_cast<T>(Lanczos::g()) + 0.5f,
  60. T(N - r) + static_cast<T>(Lanczos::g()) + 0.5f,
  61. 1 / (T(N) + static_cast<T>(Lanczos::g()) + 0.5f),
  62. 1 / (T(x) + static_cast<T>(Lanczos::g()) + 0.5f),
  63. 1 / (T(n - x) + static_cast<T>(Lanczos::g()) + 0.5f),
  64. 1 / (T(r - x) + static_cast<T>(Lanczos::g()) + 0.5f),
  65. 1 / (T(N - n - r + x) + static_cast<T>(Lanczos::g()) + 0.5f)
  66. };
  67. T exponents[9] = {
  68. n + T(0.5f),
  69. r + T(0.5f),
  70. N - n + T(0.5f),
  71. N - r + T(0.5f),
  72. N + T(0.5f),
  73. x + T(0.5f),
  74. n - x + T(0.5f),
  75. r - x + T(0.5f),
  76. N - n - r + x + T(0.5f)
  77. };
  78. int base_e_factors[9] = {
  79. -1, -1, -1, -1, 1, 1, 1, 1, 1
  80. };
  81. int sorted_indexes[9] = {
  82. 0, 1, 2, 3, 4, 5, 6, 7, 8
  83. };
  84. #ifdef BOOST_MATH_INSTRUMENT
  85. BOOST_MATH_INSTRUMENT_FPU
  86. for(unsigned i = 0; i < 9; ++i)
  87. {
  88. BOOST_MATH_INSTRUMENT_VARIABLE(i);
  89. BOOST_MATH_INSTRUMENT_VARIABLE(bases[i]);
  90. BOOST_MATH_INSTRUMENT_VARIABLE(exponents[i]);
  91. BOOST_MATH_INSTRUMENT_VARIABLE(base_e_factors[i]);
  92. BOOST_MATH_INSTRUMENT_VARIABLE(sorted_indexes[i]);
  93. }
  94. #endif
  95. std::sort(sorted_indexes, sorted_indexes + 9, sort_functor<T>(exponents));
  96. #ifdef BOOST_MATH_INSTRUMENT
  97. BOOST_MATH_INSTRUMENT_FPU
  98. for(unsigned i = 0; i < 9; ++i)
  99. {
  100. BOOST_MATH_INSTRUMENT_VARIABLE(i);
  101. BOOST_MATH_INSTRUMENT_VARIABLE(bases[i]);
  102. BOOST_MATH_INSTRUMENT_VARIABLE(exponents[i]);
  103. BOOST_MATH_INSTRUMENT_VARIABLE(base_e_factors[i]);
  104. BOOST_MATH_INSTRUMENT_VARIABLE(sorted_indexes[i]);
  105. }
  106. #endif
  107. do{
  108. exponents[sorted_indexes[0]] -= exponents[sorted_indexes[1]];
  109. bases[sorted_indexes[1]] *= bases[sorted_indexes[0]];
  110. if((bases[sorted_indexes[1]] < tools::min_value<T>()) && (exponents[sorted_indexes[1]] != 0))
  111. {
  112. return 0;
  113. }
  114. base_e_factors[sorted_indexes[1]] += base_e_factors[sorted_indexes[0]];
  115. bubble_down_one(sorted_indexes, sorted_indexes + 9, sort_functor<T>(exponents));
  116. #ifdef BOOST_MATH_INSTRUMENT
  117. for(unsigned i = 0; i < 9; ++i)
  118. {
  119. BOOST_MATH_INSTRUMENT_VARIABLE(i);
  120. BOOST_MATH_INSTRUMENT_VARIABLE(bases[i]);
  121. BOOST_MATH_INSTRUMENT_VARIABLE(exponents[i]);
  122. BOOST_MATH_INSTRUMENT_VARIABLE(base_e_factors[i]);
  123. BOOST_MATH_INSTRUMENT_VARIABLE(sorted_indexes[i]);
  124. }
  125. #endif
  126. }while(exponents[sorted_indexes[1]] > 1);
  127. //
  128. // Combine equal powers:
  129. //
  130. std::size_t j = 8;
  131. while(exponents[sorted_indexes[j]] == 0) --j;
  132. while(j)
  133. {
  134. while(j && (exponents[sorted_indexes[j-1]] == exponents[sorted_indexes[j]]))
  135. {
  136. bases[sorted_indexes[j-1]] *= bases[sorted_indexes[j]];
  137. exponents[sorted_indexes[j]] = 0;
  138. base_e_factors[sorted_indexes[j-1]] += base_e_factors[sorted_indexes[j]];
  139. bubble_down_one(sorted_indexes + j, sorted_indexes + 9, sort_functor<T>(exponents));
  140. --j;
  141. }
  142. --j;
  143. #ifdef BOOST_MATH_INSTRUMENT
  144. BOOST_MATH_INSTRUMENT_VARIABLE(j);
  145. for(unsigned i = 0; i < 9; ++i)
  146. {
  147. BOOST_MATH_INSTRUMENT_VARIABLE(i);
  148. BOOST_MATH_INSTRUMENT_VARIABLE(bases[i]);
  149. BOOST_MATH_INSTRUMENT_VARIABLE(exponents[i]);
  150. BOOST_MATH_INSTRUMENT_VARIABLE(base_e_factors[i]);
  151. BOOST_MATH_INSTRUMENT_VARIABLE(sorted_indexes[i]);
  152. }
  153. #endif
  154. }
  155. #ifdef BOOST_MATH_INSTRUMENT
  156. BOOST_MATH_INSTRUMENT_FPU
  157. for(unsigned i = 0; i < 9; ++i)
  158. {
  159. BOOST_MATH_INSTRUMENT_VARIABLE(i);
  160. BOOST_MATH_INSTRUMENT_VARIABLE(bases[i]);
  161. BOOST_MATH_INSTRUMENT_VARIABLE(exponents[i]);
  162. BOOST_MATH_INSTRUMENT_VARIABLE(base_e_factors[i]);
  163. BOOST_MATH_INSTRUMENT_VARIABLE(sorted_indexes[i]);
  164. }
  165. #endif
  166. T result;
  167. BOOST_MATH_INSTRUMENT_VARIABLE(bases[sorted_indexes[0]] * exp(static_cast<T>(base_e_factors[sorted_indexes[0]])));
  168. BOOST_MATH_INSTRUMENT_VARIABLE(exponents[sorted_indexes[0]]);
  169. {
  170. BOOST_FPU_EXCEPTION_GUARD
  171. result = pow(bases[sorted_indexes[0]] * exp(static_cast<T>(base_e_factors[sorted_indexes[0]])), exponents[sorted_indexes[0]]);
  172. }
  173. BOOST_MATH_INSTRUMENT_VARIABLE(result);
  174. for(std::size_t i = 1; (i < 9) && (exponents[sorted_indexes[i]] > 0); ++i)
  175. {
  176. BOOST_FPU_EXCEPTION_GUARD
  177. if(result < tools::min_value<T>())
  178. return 0; // short circuit further evaluation
  179. if(exponents[sorted_indexes[i]] == 1)
  180. result *= bases[sorted_indexes[i]] * exp(static_cast<T>(base_e_factors[sorted_indexes[i]]));
  181. else if(exponents[sorted_indexes[i]] == 0.5f)
  182. result *= sqrt(bases[sorted_indexes[i]] * exp(static_cast<T>(base_e_factors[sorted_indexes[i]])));
  183. else
  184. result *= pow(bases[sorted_indexes[i]] * exp(static_cast<T>(base_e_factors[sorted_indexes[i]])), exponents[sorted_indexes[i]]);
  185. BOOST_MATH_INSTRUMENT_VARIABLE(result);
  186. }
  187. result *= Lanczos::lanczos_sum_expG_scaled(static_cast<T>(n + 1))
  188. * Lanczos::lanczos_sum_expG_scaled(static_cast<T>(r + 1))
  189. * Lanczos::lanczos_sum_expG_scaled(static_cast<T>(N - n + 1))
  190. * Lanczos::lanczos_sum_expG_scaled(static_cast<T>(N - r + 1))
  191. /
  192. ( Lanczos::lanczos_sum_expG_scaled(static_cast<T>(N + 1))
  193. * Lanczos::lanczos_sum_expG_scaled(static_cast<T>(x + 1))
  194. * Lanczos::lanczos_sum_expG_scaled(static_cast<T>(n - x + 1))
  195. * Lanczos::lanczos_sum_expG_scaled(static_cast<T>(r - x + 1))
  196. * Lanczos::lanczos_sum_expG_scaled(static_cast<T>(N - n - r + x + 1)));
  197. BOOST_MATH_INSTRUMENT_VARIABLE(result);
  198. return result;
  199. }
  200. template <class T, class Policy>
  201. T hypergeometric_pdf_lanczos_imp(T /*dummy*/, std::uint64_t x, std::uint64_t r, std::uint64_t n, std::uint64_t N, const boost::math::lanczos::undefined_lanczos&, const Policy& pol)
  202. {
  203. BOOST_MATH_STD_USING
  204. return exp(
  205. boost::math::lgamma(T(n + 1), pol)
  206. + boost::math::lgamma(T(r + 1), pol)
  207. + boost::math::lgamma(T(N - n + 1), pol)
  208. + boost::math::lgamma(T(N - r + 1), pol)
  209. - boost::math::lgamma(T(N + 1), pol)
  210. - boost::math::lgamma(T(x + 1), pol)
  211. - boost::math::lgamma(T(n - x + 1), pol)
  212. - boost::math::lgamma(T(r - x + 1), pol)
  213. - boost::math::lgamma(T(N - n - r + x + 1), pol));
  214. }
  215. template <class T>
  216. inline T integer_power(const T& x, int ex)
  217. {
  218. if(ex < 0)
  219. return 1 / integer_power(x, -ex);
  220. switch(ex)
  221. {
  222. case 0:
  223. return 1;
  224. case 1:
  225. return x;
  226. case 2:
  227. return x * x;
  228. case 3:
  229. return x * x * x;
  230. case 4:
  231. return boost::math::pow<4>(x);
  232. case 5:
  233. return boost::math::pow<5>(x);
  234. case 6:
  235. return boost::math::pow<6>(x);
  236. case 7:
  237. return boost::math::pow<7>(x);
  238. case 8:
  239. return boost::math::pow<8>(x);
  240. }
  241. BOOST_MATH_STD_USING
  242. #ifdef __SUNPRO_CC
  243. return pow(x, T(ex));
  244. #else
  245. return static_cast<T>(pow(x, ex));
  246. #endif
  247. }
  248. template <class T>
  249. struct hypergeometric_pdf_prime_loop_result_entry
  250. {
  251. T value;
  252. const hypergeometric_pdf_prime_loop_result_entry* next;
  253. };
  254. #ifdef _MSC_VER
  255. #pragma warning(push)
  256. #pragma warning(disable:4510 4512 4610)
  257. #endif
  258. struct hypergeometric_pdf_prime_loop_data
  259. {
  260. const std::uint64_t x;
  261. const std::uint64_t r;
  262. const std::uint64_t n;
  263. const std::uint64_t N;
  264. std::size_t prime_index;
  265. std::uint64_t current_prime;
  266. };
  267. #ifdef _MSC_VER
  268. #pragma warning(pop)
  269. #endif
  270. template <class T>
  271. T hypergeometric_pdf_prime_loop_imp(hypergeometric_pdf_prime_loop_data& data, hypergeometric_pdf_prime_loop_result_entry<T>& result)
  272. {
  273. while(data.current_prime <= data.N)
  274. {
  275. std::uint64_t base = data.current_prime;
  276. std::uint64_t prime_powers = 0;
  277. while(base <= data.N)
  278. {
  279. prime_powers += data.n / base;
  280. prime_powers += data.r / base;
  281. prime_powers += (data.N - data.n) / base;
  282. prime_powers += (data.N - data.r) / base;
  283. prime_powers -= data.N / base;
  284. prime_powers -= data.x / base;
  285. prime_powers -= (data.n - data.x) / base;
  286. prime_powers -= (data.r - data.x) / base;
  287. prime_powers -= (data.N - data.n - data.r + data.x) / base;
  288. base *= data.current_prime;
  289. }
  290. if(prime_powers)
  291. {
  292. T p = integer_power<T>(static_cast<T>(data.current_prime), static_cast<int>(prime_powers));
  293. if((p > 1) && (tools::max_value<T>() / p < result.value))
  294. {
  295. //
  296. // The next calculation would overflow, use recursion
  297. // to sidestep the issue:
  298. //
  299. hypergeometric_pdf_prime_loop_result_entry<T> t = { p, &result };
  300. data.current_prime = prime(static_cast<unsigned>(++data.prime_index));
  301. return hypergeometric_pdf_prime_loop_imp<T>(data, t);
  302. }
  303. if((p < 1) && (tools::min_value<T>() / p > result.value))
  304. {
  305. //
  306. // The next calculation would underflow, use recursion
  307. // to sidestep the issue:
  308. //
  309. hypergeometric_pdf_prime_loop_result_entry<T> t = { p, &result };
  310. data.current_prime = prime(static_cast<unsigned>(++data.prime_index));
  311. return hypergeometric_pdf_prime_loop_imp<T>(data, t);
  312. }
  313. result.value *= p;
  314. }
  315. data.current_prime = prime(static_cast<unsigned>(++data.prime_index));
  316. }
  317. //
  318. // When we get to here we have run out of prime factors,
  319. // the overall result is the product of all the partial
  320. // results we have accumulated on the stack so far, these
  321. // are in a linked list starting with "data.head" and ending
  322. // with "result".
  323. //
  324. // All that remains is to multiply them together, taking
  325. // care not to overflow or underflow.
  326. //
  327. // Enumerate partial results >= 1 in variable i
  328. // and partial results < 1 in variable j:
  329. //
  330. hypergeometric_pdf_prime_loop_result_entry<T> const *i, *j;
  331. i = &result;
  332. while(i && i->value < 1)
  333. i = i->next;
  334. j = &result;
  335. while(j && j->value >= 1)
  336. j = j->next;
  337. T prod = 1;
  338. while(i || j)
  339. {
  340. while(i && ((prod <= 1) || (j == 0)))
  341. {
  342. prod *= i->value;
  343. i = i->next;
  344. while(i && i->value < 1)
  345. i = i->next;
  346. }
  347. while(j && ((prod >= 1) || (i == 0)))
  348. {
  349. prod *= j->value;
  350. j = j->next;
  351. while(j && j->value >= 1)
  352. j = j->next;
  353. }
  354. }
  355. return prod;
  356. }
  357. template <class T, class Policy>
  358. inline T hypergeometric_pdf_prime_imp(std::uint64_t x, std::uint64_t r, std::uint64_t n, std::uint64_t N, const Policy&)
  359. {
  360. hypergeometric_pdf_prime_loop_result_entry<T> result = { 1, 0 };
  361. hypergeometric_pdf_prime_loop_data data = { x, r, n, N, 0, prime(0) };
  362. return hypergeometric_pdf_prime_loop_imp<T>(data, result);
  363. }
  364. template <class T, class Policy>
  365. T hypergeometric_pdf_factorial_imp(std::uint64_t x, std::uint64_t r, std::uint64_t n, std::uint64_t N, const Policy&)
  366. {
  367. BOOST_MATH_STD_USING
  368. BOOST_MATH_ASSERT(N <= boost::math::max_factorial<T>::value);
  369. T result = boost::math::unchecked_factorial<T>(static_cast<unsigned>(n));
  370. T num[3] = {
  371. boost::math::unchecked_factorial<T>(static_cast<unsigned>(r)),
  372. boost::math::unchecked_factorial<T>(static_cast<unsigned>(N - n)),
  373. boost::math::unchecked_factorial<T>(static_cast<unsigned>(N - r))
  374. };
  375. T denom[5] = {
  376. boost::math::unchecked_factorial<T>(static_cast<unsigned>(N)),
  377. boost::math::unchecked_factorial<T>(static_cast<unsigned>(x)),
  378. boost::math::unchecked_factorial<T>(static_cast<unsigned>(n - x)),
  379. boost::math::unchecked_factorial<T>(static_cast<unsigned>(r - x)),
  380. boost::math::unchecked_factorial<T>(static_cast<unsigned>(N - n - r + x))
  381. };
  382. std::size_t i = 0;
  383. std::size_t j = 0;
  384. while((i < 3) || (j < 5))
  385. {
  386. while((j < 5) && ((result >= 1) || (i >= 3)))
  387. {
  388. result /= denom[j];
  389. ++j;
  390. }
  391. while((i < 3) && ((result <= 1) || (j >= 5)))
  392. {
  393. result *= num[i];
  394. ++i;
  395. }
  396. }
  397. return result;
  398. }
  399. template <class T, class Policy>
  400. inline typename tools::promote_args<T>::type
  401. hypergeometric_pdf(std::uint64_t x, std::uint64_t r, std::uint64_t n, std::uint64_t N, const Policy&)
  402. {
  403. BOOST_FPU_EXCEPTION_GUARD
  404. typedef typename tools::promote_args<T>::type result_type;
  405. typedef typename policies::evaluation<result_type, Policy>::type value_type;
  406. typedef typename lanczos::lanczos<value_type, Policy>::type evaluation_type;
  407. typedef typename policies::normalise<
  408. Policy,
  409. policies::promote_float<false>,
  410. policies::promote_double<false>,
  411. policies::discrete_quantile<>,
  412. policies::assert_undefined<> >::type forwarding_policy;
  413. value_type result;
  414. if(N <= boost::math::max_factorial<value_type>::value)
  415. {
  416. //
  417. // If N is small enough then we can evaluate the PDF via the factorials
  418. // directly: table lookup of the factorials gives the best performance
  419. // of the methods available:
  420. //
  421. result = detail::hypergeometric_pdf_factorial_imp<value_type>(x, r, n, N, forwarding_policy());
  422. }
  423. else if(N <= boost::math::prime(boost::math::max_prime - 1))
  424. {
  425. //
  426. // If N is no larger than the largest prime number in our lookup table
  427. // (104729) then we can use prime factorisation to evaluate the PDF,
  428. // this is slow but accurate:
  429. //
  430. result = detail::hypergeometric_pdf_prime_imp<value_type>(x, r, n, N, forwarding_policy());
  431. }
  432. else
  433. {
  434. //
  435. // Catch all case - use the lanczos approximation - where available -
  436. // to evaluate the ratio of factorials. This is reasonably fast
  437. // (almost as quick as using logarithmic evaluation in terms of lgamma)
  438. // but only a few digits better in accuracy than using lgamma:
  439. //
  440. result = detail::hypergeometric_pdf_lanczos_imp(value_type(), x, r, n, N, evaluation_type(), forwarding_policy());
  441. }
  442. if(result > 1)
  443. {
  444. result = 1;
  445. }
  446. if(result < 0)
  447. {
  448. result = 0;
  449. }
  450. return policies::checked_narrowing_cast<result_type, forwarding_policy>(result, "boost::math::hypergeometric_pdf<%1%>(%1%,%1%,%1%,%1%)");
  451. }
  452. }}} // namespaces
  453. #endif