14824e7fdSDimitry Andric //===----------------------------------------------------------------------===// 24824e7fdSDimitry Andric // 34824e7fdSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44824e7fdSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 54824e7fdSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64824e7fdSDimitry Andric // 74824e7fdSDimitry Andric //===----------------------------------------------------------------------===// 84824e7fdSDimitry Andric 94824e7fdSDimitry Andric #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 104824e7fdSDimitry Andric #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 114824e7fdSDimitry Andric 124824e7fdSDimitry Andric #include <__algorithm/upper_bound.h> 134824e7fdSDimitry Andric #include <__config> 1481ad6265SDimitry Andric #include <__random/is_valid.h> 154824e7fdSDimitry Andric #include <__random/uniform_real_distribution.h> 164824e7fdSDimitry Andric #include <cstddef> 174824e7fdSDimitry Andric #include <iosfwd> 184824e7fdSDimitry Andric #include <numeric> 194824e7fdSDimitry Andric #include <vector> 204824e7fdSDimitry Andric 214824e7fdSDimitry Andric #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) 224824e7fdSDimitry Andric # pragma GCC system_header 234824e7fdSDimitry Andric #endif 244824e7fdSDimitry Andric 254824e7fdSDimitry Andric _LIBCPP_PUSH_MACROS 264824e7fdSDimitry Andric #include <__undef_macros> 274824e7fdSDimitry Andric 284824e7fdSDimitry Andric _LIBCPP_BEGIN_NAMESPACE_STD 294824e7fdSDimitry Andric 304824e7fdSDimitry Andric template <class _IntType = int> 31*cb14a3feSDimitry Andric class _LIBCPP_TEMPLATE_VIS discrete_distribution { 32fcaf7f86SDimitry Andric static_assert(__libcpp_random_is_valid_inttype<_IntType>::value, "IntType must be a supported integer type"); 33*cb14a3feSDimitry Andric 344824e7fdSDimitry Andric public: 354824e7fdSDimitry Andric // types 364824e7fdSDimitry Andric typedef _IntType result_type; 374824e7fdSDimitry Andric 38*cb14a3feSDimitry Andric class _LIBCPP_TEMPLATE_VIS param_type { 394824e7fdSDimitry Andric vector<double> __p_; 40*cb14a3feSDimitry Andric 414824e7fdSDimitry Andric public: 424824e7fdSDimitry Andric typedef discrete_distribution distribution_type; 434824e7fdSDimitry Andric 44*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI param_type() {} 454824e7fdSDimitry Andric template <class _InputIterator> 46*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI param_type(_InputIterator __f, _InputIterator __l) : __p_(__f, __l) { 47*cb14a3feSDimitry Andric __init(); 48*cb14a3feSDimitry Andric } 494824e7fdSDimitry Andric #ifndef _LIBCPP_CXX03_LANG 50*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI param_type(initializer_list<double> __wl) : __p_(__wl.begin(), __wl.end()) { __init(); } 514824e7fdSDimitry Andric #endif // _LIBCPP_CXX03_LANG 524824e7fdSDimitry Andric template <class _UnaryOperation> 53*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI param_type(size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw); 544824e7fdSDimitry Andric 5506c3fb27SDimitry Andric _LIBCPP_HIDE_FROM_ABI vector<double> probabilities() const; 564824e7fdSDimitry Andric 57*cb14a3feSDimitry Andric friend _LIBCPP_HIDE_FROM_ABI bool operator==(const param_type& __x, const param_type& __y) { 58*cb14a3feSDimitry Andric return __x.__p_ == __y.__p_; 59*cb14a3feSDimitry Andric } 60*cb14a3feSDimitry Andric friend _LIBCPP_HIDE_FROM_ABI bool operator!=(const param_type& __x, const param_type& __y) { return !(__x == __y); } 614824e7fdSDimitry Andric 624824e7fdSDimitry Andric private: 6306c3fb27SDimitry Andric _LIBCPP_HIDE_FROM_ABI void __init(); 644824e7fdSDimitry Andric 654824e7fdSDimitry Andric friend class discrete_distribution; 664824e7fdSDimitry Andric 674824e7fdSDimitry Andric template <class _CharT, class _Traits, class _IT> 68*cb14a3feSDimitry Andric friend basic_ostream<_CharT, _Traits>& 69*cb14a3feSDimitry Andric operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x); 704824e7fdSDimitry Andric 714824e7fdSDimitry Andric template <class _CharT, class _Traits, class _IT> 72*cb14a3feSDimitry Andric friend basic_istream<_CharT, _Traits>& 73*cb14a3feSDimitry Andric operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x); 744824e7fdSDimitry Andric }; 754824e7fdSDimitry Andric 764824e7fdSDimitry Andric private: 774824e7fdSDimitry Andric param_type __p_; 784824e7fdSDimitry Andric 794824e7fdSDimitry Andric public: 804824e7fdSDimitry Andric // constructor and reset functions 81*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI discrete_distribution() {} 824824e7fdSDimitry Andric template <class _InputIterator> 83*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI discrete_distribution(_InputIterator __f, _InputIterator __l) : __p_(__f, __l) {} 844824e7fdSDimitry Andric #ifndef _LIBCPP_CXX03_LANG 85*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI discrete_distribution(initializer_list<double> __wl) : __p_(__wl) {} 864824e7fdSDimitry Andric #endif // _LIBCPP_CXX03_LANG 874824e7fdSDimitry Andric template <class _UnaryOperation> 88*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI discrete_distribution(size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw) 894824e7fdSDimitry Andric : __p_(__nw, __xmin, __xmax, __fw) {} 90*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI explicit discrete_distribution(const param_type& __p) : __p_(__p) {} 91*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI void reset() {} 924824e7fdSDimitry Andric 934824e7fdSDimitry Andric // generating functions 944824e7fdSDimitry Andric template <class _URNG> 95*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g) { 96*cb14a3feSDimitry Andric return (*this)(__g, __p_); 97*cb14a3feSDimitry Andric } 9806c3fb27SDimitry Andric template <class _URNG> 9906c3fb27SDimitry Andric _LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g, const param_type& __p); 1004824e7fdSDimitry Andric 1014824e7fdSDimitry Andric // property functions 102*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI vector<double> probabilities() const { return __p_.probabilities(); } 1034824e7fdSDimitry Andric 104*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI param_type param() const { return __p_; } 105*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI void param(const param_type& __p) { __p_ = __p; } 1064824e7fdSDimitry Andric 107*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI result_type min() const { return 0; } 108*cb14a3feSDimitry Andric _LIBCPP_HIDE_FROM_ABI result_type max() const { return __p_.__p_.size(); } 1094824e7fdSDimitry Andric 110*cb14a3feSDimitry Andric friend _LIBCPP_HIDE_FROM_ABI bool operator==(const discrete_distribution& __x, const discrete_distribution& __y) { 111*cb14a3feSDimitry Andric return __x.__p_ == __y.__p_; 112*cb14a3feSDimitry Andric } 113*cb14a3feSDimitry Andric friend _LIBCPP_HIDE_FROM_ABI bool operator!=(const discrete_distribution& __x, const discrete_distribution& __y) { 114*cb14a3feSDimitry Andric return !(__x == __y); 115*cb14a3feSDimitry Andric } 1164824e7fdSDimitry Andric 1174824e7fdSDimitry Andric template <class _CharT, class _Traits, class _IT> 118*cb14a3feSDimitry Andric friend basic_ostream<_CharT, _Traits>& 119*cb14a3feSDimitry Andric operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x); 1204824e7fdSDimitry Andric 1214824e7fdSDimitry Andric template <class _CharT, class _Traits, class _IT> 122*cb14a3feSDimitry Andric friend basic_istream<_CharT, _Traits>& 123*cb14a3feSDimitry Andric operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x); 1244824e7fdSDimitry Andric }; 1254824e7fdSDimitry Andric 1264824e7fdSDimitry Andric template <class _IntType> 1274824e7fdSDimitry Andric template <class _UnaryOperation> 128*cb14a3feSDimitry Andric discrete_distribution<_IntType>::param_type::param_type( 129*cb14a3feSDimitry Andric size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw) { 130*cb14a3feSDimitry Andric if (__nw > 1) { 1314824e7fdSDimitry Andric __p_.reserve(__nw - 1); 1324824e7fdSDimitry Andric double __d = (__xmax - __xmin) / __nw; 1334824e7fdSDimitry Andric double __d2 = __d / 2; 1344824e7fdSDimitry Andric for (size_t __k = 0; __k < __nw; ++__k) 1354824e7fdSDimitry Andric __p_.push_back(__fw(__xmin + __k * __d + __d2)); 1364824e7fdSDimitry Andric __init(); 1374824e7fdSDimitry Andric } 1384824e7fdSDimitry Andric } 1394824e7fdSDimitry Andric 1404824e7fdSDimitry Andric template <class _IntType> 141*cb14a3feSDimitry Andric void discrete_distribution<_IntType>::param_type::__init() { 142*cb14a3feSDimitry Andric if (!__p_.empty()) { 143*cb14a3feSDimitry Andric if (__p_.size() > 1) { 1445f757f3fSDimitry Andric double __s = std::accumulate(__p_.begin(), __p_.end(), 0.0); 1454824e7fdSDimitry Andric for (vector<double>::iterator __i = __p_.begin(), __e = __p_.end(); __i < __e; ++__i) 1464824e7fdSDimitry Andric *__i /= __s; 1474824e7fdSDimitry Andric vector<double> __t(__p_.size() - 1); 1485f757f3fSDimitry Andric std::partial_sum(__p_.begin(), __p_.end() - 1, __t.begin()); 1494824e7fdSDimitry Andric swap(__p_, __t); 150*cb14a3feSDimitry Andric } else { 1514824e7fdSDimitry Andric __p_.clear(); 1524824e7fdSDimitry Andric __p_.shrink_to_fit(); 1534824e7fdSDimitry Andric } 1544824e7fdSDimitry Andric } 1554824e7fdSDimitry Andric } 1564824e7fdSDimitry Andric 1574824e7fdSDimitry Andric template <class _IntType> 158*cb14a3feSDimitry Andric vector<double> discrete_distribution<_IntType>::param_type::probabilities() const { 1594824e7fdSDimitry Andric size_t __n = __p_.size(); 1604824e7fdSDimitry Andric vector<double> __p(__n + 1); 1615f757f3fSDimitry Andric std::adjacent_difference(__p_.begin(), __p_.end(), __p.begin()); 1624824e7fdSDimitry Andric if (__n > 0) 1634824e7fdSDimitry Andric __p[__n] = 1 - __p_[__n - 1]; 1644824e7fdSDimitry Andric else 1654824e7fdSDimitry Andric __p[0] = 1; 1664824e7fdSDimitry Andric return __p; 1674824e7fdSDimitry Andric } 1684824e7fdSDimitry Andric 1694824e7fdSDimitry Andric template <class _IntType> 1704824e7fdSDimitry Andric template <class _URNG> 171*cb14a3feSDimitry Andric _IntType discrete_distribution<_IntType>::operator()(_URNG& __g, const param_type& __p) { 17281ad6265SDimitry Andric static_assert(__libcpp_random_is_valid_urng<_URNG>::value, ""); 1734824e7fdSDimitry Andric uniform_real_distribution<double> __gen; 174*cb14a3feSDimitry Andric return static_cast<_IntType>(std::upper_bound(__p.__p_.begin(), __p.__p_.end(), __gen(__g)) - __p.__p_.begin()); 1754824e7fdSDimitry Andric } 1764824e7fdSDimitry Andric 1774824e7fdSDimitry Andric template <class _CharT, class _Traits, class _IT> 178bdd1243dSDimitry Andric _LIBCPP_HIDE_FROM_ABI basic_ostream<_CharT, _Traits>& 179*cb14a3feSDimitry Andric operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x) { 1804824e7fdSDimitry Andric __save_flags<_CharT, _Traits> __lx(__os); 1814824e7fdSDimitry Andric typedef basic_ostream<_CharT, _Traits> _OStream; 182*cb14a3feSDimitry Andric __os.flags(_OStream::dec | _OStream::left | _OStream::fixed | _OStream::scientific); 1834824e7fdSDimitry Andric _CharT __sp = __os.widen(' '); 1844824e7fdSDimitry Andric __os.fill(__sp); 1854824e7fdSDimitry Andric size_t __n = __x.__p_.__p_.size(); 1864824e7fdSDimitry Andric __os << __n; 1874824e7fdSDimitry Andric for (size_t __i = 0; __i < __n; ++__i) 1884824e7fdSDimitry Andric __os << __sp << __x.__p_.__p_[__i]; 1894824e7fdSDimitry Andric return __os; 1904824e7fdSDimitry Andric } 1914824e7fdSDimitry Andric 1924824e7fdSDimitry Andric template <class _CharT, class _Traits, class _IT> 193bdd1243dSDimitry Andric _LIBCPP_HIDE_FROM_ABI basic_istream<_CharT, _Traits>& 194*cb14a3feSDimitry Andric operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x) { 1954824e7fdSDimitry Andric __save_flags<_CharT, _Traits> __lx(__is); 1964824e7fdSDimitry Andric typedef basic_istream<_CharT, _Traits> _Istream; 1974824e7fdSDimitry Andric __is.flags(_Istream::dec | _Istream::skipws); 1984824e7fdSDimitry Andric size_t __n; 1994824e7fdSDimitry Andric __is >> __n; 2004824e7fdSDimitry Andric vector<double> __p(__n); 2014824e7fdSDimitry Andric for (size_t __i = 0; __i < __n; ++__i) 2024824e7fdSDimitry Andric __is >> __p[__i]; 2034824e7fdSDimitry Andric if (!__is.fail()) 2044824e7fdSDimitry Andric swap(__x.__p_.__p_, __p); 2054824e7fdSDimitry Andric return __is; 2064824e7fdSDimitry Andric } 2074824e7fdSDimitry Andric 2084824e7fdSDimitry Andric _LIBCPP_END_NAMESPACE_STD 2094824e7fdSDimitry Andric 2104824e7fdSDimitry Andric _LIBCPP_POP_MACROS 2114824e7fdSDimitry Andric 2124824e7fdSDimitry Andric #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 213