chatterjee_correlation.hpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. // (C) Copyright Matt Borland 2022.
  2. // Use, modification and distribution are subject to the
  3. // Boost Software License, Version 1.0. (See accompanying file
  4. // LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  5. #ifndef BOOST_MATH_STATISTICS_CHATTERJEE_CORRELATION_HPP
  6. #define BOOST_MATH_STATISTICS_CHATTERJEE_CORRELATION_HPP
  7. #include <cstdint>
  8. #include <cmath>
  9. #include <algorithm>
  10. #include <iterator>
  11. #include <vector>
  12. #include <limits>
  13. #include <utility>
  14. #include <type_traits>
  15. #include <boost/math/tools/assert.hpp>
  16. #include <boost/math/tools/config.hpp>
  17. #include <boost/math/statistics/detail/rank.hpp>
  18. #ifdef BOOST_MATH_EXEC_COMPATIBLE
  19. #include <execution>
  20. #include <future>
  21. #include <thread>
  22. #endif
  23. namespace boost { namespace math { namespace statistics {
  24. namespace detail {
  25. template <typename BDIter>
  26. std::size_t chatterjee_transform(BDIter begin, BDIter end)
  27. {
  28. std::size_t sum = 0;
  29. while(++begin != end)
  30. {
  31. if(*begin > *std::prev(begin))
  32. {
  33. sum += *begin - *std::prev(begin);
  34. }
  35. else
  36. {
  37. sum += *std::prev(begin) - *begin;
  38. }
  39. }
  40. return sum;
  41. }
  42. template <typename ReturnType, typename ForwardIterator>
  43. ReturnType chatterjee_correlation_seq_impl(ForwardIterator u_begin, ForwardIterator u_end, ForwardIterator v_begin, ForwardIterator v_end)
  44. {
  45. using std::abs;
  46. BOOST_MATH_ASSERT_MSG(std::is_sorted(u_begin, u_end), "The x values must be sorted in order to use this functionality");
  47. const std::vector<std::size_t> rank_vector = rank(v_begin, v_end);
  48. std::size_t sum = chatterjee_transform(rank_vector.begin(), rank_vector.end());
  49. ReturnType result = static_cast<ReturnType>(1) - (static_cast<ReturnType>(3 * sum) / static_cast<ReturnType>(rank_vector.size() * rank_vector.size() - 1));
  50. // If the result is 1 then Y is constant and all the elements must be ties
  51. if (abs(result - static_cast<ReturnType>(1)) < std::numeric_limits<ReturnType>::epsilon())
  52. {
  53. return std::numeric_limits<ReturnType>::quiet_NaN();
  54. }
  55. return result;
  56. }
  57. } // Namespace detail
  58. template <typename Container, typename Real = typename Container::value_type,
  59. typename ReturnType = typename std::conditional<std::is_integral<Real>::value, double, Real>::type>
  60. inline ReturnType chatterjee_correlation(const Container& u, const Container& v)
  61. {
  62. return detail::chatterjee_correlation_seq_impl<ReturnType>(std::begin(u), std::end(u), std::begin(v), std::end(v));
  63. }
  64. }}} // Namespace boost::math::statistics
  65. #ifdef BOOST_MATH_EXEC_COMPATIBLE
  66. namespace boost::math::statistics {
  67. namespace detail {
  68. template <typename ReturnType, typename ExecutionPolicy, typename ForwardIterator>
  69. ReturnType chatterjee_correlation_par_impl(ExecutionPolicy&& exec, ForwardIterator u_begin, ForwardIterator u_end,
  70. ForwardIterator v_begin, ForwardIterator v_end)
  71. {
  72. using std::abs;
  73. BOOST_MATH_ASSERT_MSG(std::is_sorted(std::forward<ExecutionPolicy>(exec), u_begin, u_end), "The x values must be sorted in order to use this functionality");
  74. auto rank_vector = rank(std::forward<ExecutionPolicy>(exec), v_begin, v_end);
  75. const auto num_threads = std::thread::hardware_concurrency() == 0 ? 2u : std::thread::hardware_concurrency();
  76. std::vector<std::future<std::size_t>> future_manager {};
  77. const auto elements_per_thread = std::ceil(static_cast<double>(rank_vector.size()) / num_threads);
  78. auto it = rank_vector.begin();
  79. auto end = rank_vector.end();
  80. for(std::size_t i {}; i < num_threads - 1; ++i)
  81. {
  82. future_manager.emplace_back(std::async(std::launch::async | std::launch::deferred, [it, elements_per_thread]() -> std::size_t
  83. {
  84. return chatterjee_transform(it, std::next(it, elements_per_thread));
  85. }));
  86. it = std::next(it, elements_per_thread - 1);
  87. }
  88. future_manager.emplace_back(std::async(std::launch::async | std::launch::deferred, [it, end]() -> std::size_t
  89. {
  90. return chatterjee_transform(it, end);
  91. }));
  92. std::size_t sum {};
  93. for(std::size_t i {}; i < future_manager.size(); ++i)
  94. {
  95. sum += future_manager[i].get();
  96. }
  97. ReturnType result = static_cast<ReturnType>(1) - (static_cast<ReturnType>(3 * sum) / static_cast<ReturnType>(rank_vector.size() * rank_vector.size() - 1));
  98. // If the result is 1 then Y is constant and all the elements must be ties
  99. if (abs(result - static_cast<ReturnType>(1)) < std::numeric_limits<ReturnType>::epsilon())
  100. {
  101. return std::numeric_limits<ReturnType>::quiet_NaN();
  102. }
  103. return result;
  104. }
  105. } // Namespace detail
  106. template <typename ExecutionPolicy, typename Container, typename Real = typename Container::value_type,
  107. typename ReturnType = std::conditional_t<std::is_integral_v<Real>, double, Real>>
  108. inline ReturnType chatterjee_correlation(ExecutionPolicy&& exec, const Container& u, const Container& v)
  109. {
  110. if constexpr (std::is_same_v<std::remove_reference_t<decltype(exec)>, decltype(std::execution::seq)>)
  111. {
  112. return detail::chatterjee_correlation_seq_impl<ReturnType>(std::cbegin(u), std::cend(u),
  113. std::cbegin(v), std::cend(v));
  114. }
  115. else
  116. {
  117. return detail::chatterjee_correlation_par_impl<ReturnType>(std::forward<ExecutionPolicy>(exec),
  118. std::cbegin(u), std::cend(u),
  119. std::cbegin(v), std::cend(v));
  120. }
  121. }
  122. } // Namespace boost::math::statistics
  123. #endif
  124. #endif // BOOST_MATH_STATISTICS_CHATTERJEE_CORRELATION_HPP