1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 10 #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 11 12 #include <__algorithm/upper_bound.h> 13 #include <__config> 14 #include <__random/is_valid.h> 15 #include <__random/uniform_real_distribution.h> 16 #include <cstddef> 17 #include <iosfwd> 18 #include <numeric> 19 #include <vector> 20 21 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) 22 # pragma GCC system_header 23 #endif 24 25 _LIBCPP_PUSH_MACROS 26 #include <__undef_macros> 27 28 _LIBCPP_BEGIN_NAMESPACE_STD 29 30 template <class _IntType = int> 31 class _LIBCPP_TEMPLATE_VIS discrete_distribution { 32 static_assert(__libcpp_random_is_valid_inttype<_IntType>::value, "IntType must be a supported integer type"); 33 34 public: 35 // types 36 typedef _IntType result_type; 37 38 class _LIBCPP_TEMPLATE_VIS param_type { 39 vector<double> __p_; 40 41 public: 42 typedef discrete_distribution distribution_type; 43 44 _LIBCPP_HIDE_FROM_ABI param_type() {} 45 template <class _InputIterator> 46 _LIBCPP_HIDE_FROM_ABI param_type(_InputIterator __f, _InputIterator __l) : __p_(__f, __l) { 47 __init(); 48 } 49 #ifndef _LIBCPP_CXX03_LANG 50 _LIBCPP_HIDE_FROM_ABI param_type(initializer_list<double> __wl) : __p_(__wl.begin(), __wl.end()) { __init(); } 51 #endif // _LIBCPP_CXX03_LANG 52 template <class _UnaryOperation> 53 _LIBCPP_HIDE_FROM_ABI param_type(size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw); 54 55 _LIBCPP_HIDE_FROM_ABI vector<double> probabilities() const; 56 57 friend _LIBCPP_HIDE_FROM_ABI bool operator==(const param_type& __x, const param_type& __y) { 58 return __x.__p_ == __y.__p_; 59 } 60 friend _LIBCPP_HIDE_FROM_ABI bool operator!=(const param_type& __x, const param_type& __y) { return !(__x == __y); } 61 62 private: 63 _LIBCPP_HIDE_FROM_ABI void __init(); 64 65 friend class discrete_distribution; 66 67 template <class _CharT, class _Traits, class _IT> 68 friend basic_ostream<_CharT, _Traits>& 69 operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x); 70 71 template <class _CharT, class _Traits, class _IT> 72 friend basic_istream<_CharT, _Traits>& 73 operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x); 74 }; 75 76 private: 77 param_type __p_; 78 79 public: 80 // constructor and reset functions 81 _LIBCPP_HIDE_FROM_ABI discrete_distribution() {} 82 template <class _InputIterator> 83 _LIBCPP_HIDE_FROM_ABI discrete_distribution(_InputIterator __f, _InputIterator __l) : __p_(__f, __l) {} 84 #ifndef _LIBCPP_CXX03_LANG 85 _LIBCPP_HIDE_FROM_ABI discrete_distribution(initializer_list<double> __wl) : __p_(__wl) {} 86 #endif // _LIBCPP_CXX03_LANG 87 template <class _UnaryOperation> 88 _LIBCPP_HIDE_FROM_ABI discrete_distribution(size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw) 89 : __p_(__nw, __xmin, __xmax, __fw) {} 90 _LIBCPP_HIDE_FROM_ABI explicit discrete_distribution(const param_type& __p) : __p_(__p) {} 91 _LIBCPP_HIDE_FROM_ABI void reset() {} 92 93 // generating functions 94 template <class _URNG> 95 _LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g) { 96 return (*this)(__g, __p_); 97 } 98 template <class _URNG> 99 _LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g, const param_type& __p); 100 101 // property functions 102 _LIBCPP_HIDE_FROM_ABI vector<double> probabilities() const { return __p_.probabilities(); } 103 104 _LIBCPP_HIDE_FROM_ABI param_type param() const { return __p_; } 105 _LIBCPP_HIDE_FROM_ABI void param(const param_type& __p) { __p_ = __p; } 106 107 _LIBCPP_HIDE_FROM_ABI result_type min() const { return 0; } 108 _LIBCPP_HIDE_FROM_ABI result_type max() const { return __p_.__p_.size(); } 109 110 friend _LIBCPP_HIDE_FROM_ABI bool operator==(const discrete_distribution& __x, const discrete_distribution& __y) { 111 return __x.__p_ == __y.__p_; 112 } 113 friend _LIBCPP_HIDE_FROM_ABI bool operator!=(const discrete_distribution& __x, const discrete_distribution& __y) { 114 return !(__x == __y); 115 } 116 117 template <class _CharT, class _Traits, class _IT> 118 friend basic_ostream<_CharT, _Traits>& 119 operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x); 120 121 template <class _CharT, class _Traits, class _IT> 122 friend basic_istream<_CharT, _Traits>& 123 operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x); 124 }; 125 126 template <class _IntType> 127 template <class _UnaryOperation> 128 discrete_distribution<_IntType>::param_type::param_type( 129 size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw) { 130 if (__nw > 1) { 131 __p_.reserve(__nw - 1); 132 double __d = (__xmax - __xmin) / __nw; 133 double __d2 = __d / 2; 134 for (size_t __k = 0; __k < __nw; ++__k) 135 __p_.push_back(__fw(__xmin + __k * __d + __d2)); 136 __init(); 137 } 138 } 139 140 template <class _IntType> 141 void discrete_distribution<_IntType>::param_type::__init() { 142 if (!__p_.empty()) { 143 if (__p_.size() > 1) { 144 double __s = std::accumulate(__p_.begin(), __p_.end(), 0.0); 145 for (vector<double>::iterator __i = __p_.begin(), __e = __p_.end(); __i < __e; ++__i) 146 *__i /= __s; 147 vector<double> __t(__p_.size() - 1); 148 std::partial_sum(__p_.begin(), __p_.end() - 1, __t.begin()); 149 swap(__p_, __t); 150 } else { 151 __p_.clear(); 152 __p_.shrink_to_fit(); 153 } 154 } 155 } 156 157 template <class _IntType> 158 vector<double> discrete_distribution<_IntType>::param_type::probabilities() const { 159 size_t __n = __p_.size(); 160 vector<double> __p(__n + 1); 161 std::adjacent_difference(__p_.begin(), __p_.end(), __p.begin()); 162 if (__n > 0) 163 __p[__n] = 1 - __p_[__n - 1]; 164 else 165 __p[0] = 1; 166 return __p; 167 } 168 169 template <class _IntType> 170 template <class _URNG> 171 _IntType discrete_distribution<_IntType>::operator()(_URNG& __g, const param_type& __p) { 172 static_assert(__libcpp_random_is_valid_urng<_URNG>::value, ""); 173 uniform_real_distribution<double> __gen; 174 return static_cast<_IntType>(std::upper_bound(__p.__p_.begin(), __p.__p_.end(), __gen(__g)) - __p.__p_.begin()); 175 } 176 177 template <class _CharT, class _Traits, class _IT> 178 _LIBCPP_HIDE_FROM_ABI basic_ostream<_CharT, _Traits>& 179 operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x) { 180 __save_flags<_CharT, _Traits> __lx(__os); 181 typedef basic_ostream<_CharT, _Traits> _OStream; 182 __os.flags(_OStream::dec | _OStream::left | _OStream::fixed | _OStream::scientific); 183 _CharT __sp = __os.widen(' '); 184 __os.fill(__sp); 185 size_t __n = __x.__p_.__p_.size(); 186 __os << __n; 187 for (size_t __i = 0; __i < __n; ++__i) 188 __os << __sp << __x.__p_.__p_[__i]; 189 return __os; 190 } 191 192 template <class _CharT, class _Traits, class _IT> 193 _LIBCPP_HIDE_FROM_ABI basic_istream<_CharT, _Traits>& 194 operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x) { 195 __save_flags<_CharT, _Traits> __lx(__is); 196 typedef basic_istream<_CharT, _Traits> _Istream; 197 __is.flags(_Istream::dec | _Istream::skipws); 198 size_t __n; 199 __is >> __n; 200 vector<double> __p(__n); 201 for (size_t __i = 0; __i < __n; ++__i) 202 __is >> __p[__i]; 203 if (!__is.fail()) 204 swap(__x.__p_.__p_, __p); 205 return __is; 206 } 207 208 _LIBCPP_END_NAMESPACE_STD 209 210 _LIBCPP_POP_MACROS 211 212 #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 213