1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 10 #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 11 12 #include <__algorithm/upper_bound.h> 13 #include <__config> 14 #include <__random/is_valid.h> 15 #include <__random/uniform_real_distribution.h> 16 #include <cstddef> 17 #include <iosfwd> 18 #include <numeric> 19 #include <vector> 20 21 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) 22 # pragma GCC system_header 23 #endif 24 25 _LIBCPP_PUSH_MACROS 26 #include <__undef_macros> 27 28 _LIBCPP_BEGIN_NAMESPACE_STD 29 30 template<class _IntType = int> 31 class _LIBCPP_TEMPLATE_VIS discrete_distribution 32 { 33 static_assert(__libcpp_random_is_valid_inttype<_IntType>::value, "IntType must be a supported integer type"); 34 public: 35 // types 36 typedef _IntType result_type; 37 38 class _LIBCPP_TEMPLATE_VIS param_type 39 { 40 vector<double> __p_; 41 public: 42 typedef discrete_distribution distribution_type; 43 44 _LIBCPP_INLINE_VISIBILITY 45 param_type() {} 46 template<class _InputIterator> 47 _LIBCPP_INLINE_VISIBILITY 48 param_type(_InputIterator __f, _InputIterator __l) 49 : __p_(__f, __l) {__init();} 50 #ifndef _LIBCPP_CXX03_LANG 51 _LIBCPP_INLINE_VISIBILITY 52 param_type(initializer_list<double> __wl) 53 : __p_(__wl.begin(), __wl.end()) {__init();} 54 #endif // _LIBCPP_CXX03_LANG 55 template<class _UnaryOperation> 56 param_type(size_t __nw, double __xmin, double __xmax, 57 _UnaryOperation __fw); 58 59 vector<double> probabilities() const; 60 61 friend _LIBCPP_INLINE_VISIBILITY 62 bool operator==(const param_type& __x, const param_type& __y) 63 {return __x.__p_ == __y.__p_;} 64 friend _LIBCPP_INLINE_VISIBILITY 65 bool operator!=(const param_type& __x, const param_type& __y) 66 {return !(__x == __y);} 67 68 private: 69 void __init(); 70 71 friend class discrete_distribution; 72 73 template <class _CharT, class _Traits, class _IT> 74 friend 75 basic_ostream<_CharT, _Traits>& 76 operator<<(basic_ostream<_CharT, _Traits>& __os, 77 const discrete_distribution<_IT>& __x); 78 79 template <class _CharT, class _Traits, class _IT> 80 friend 81 basic_istream<_CharT, _Traits>& 82 operator>>(basic_istream<_CharT, _Traits>& __is, 83 discrete_distribution<_IT>& __x); 84 }; 85 86 private: 87 param_type __p_; 88 89 public: 90 // constructor and reset functions 91 _LIBCPP_INLINE_VISIBILITY 92 discrete_distribution() {} 93 template<class _InputIterator> 94 _LIBCPP_INLINE_VISIBILITY 95 discrete_distribution(_InputIterator __f, _InputIterator __l) 96 : __p_(__f, __l) {} 97 #ifndef _LIBCPP_CXX03_LANG 98 _LIBCPP_INLINE_VISIBILITY 99 discrete_distribution(initializer_list<double> __wl) 100 : __p_(__wl) {} 101 #endif // _LIBCPP_CXX03_LANG 102 template<class _UnaryOperation> 103 _LIBCPP_INLINE_VISIBILITY 104 discrete_distribution(size_t __nw, double __xmin, double __xmax, 105 _UnaryOperation __fw) 106 : __p_(__nw, __xmin, __xmax, __fw) {} 107 _LIBCPP_INLINE_VISIBILITY 108 explicit discrete_distribution(const param_type& __p) 109 : __p_(__p) {} 110 _LIBCPP_INLINE_VISIBILITY 111 void reset() {} 112 113 // generating functions 114 template<class _URNG> 115 _LIBCPP_INLINE_VISIBILITY 116 result_type operator()(_URNG& __g) 117 {return (*this)(__g, __p_);} 118 template<class _URNG> result_type operator()(_URNG& __g, const param_type& __p); 119 120 // property functions 121 _LIBCPP_INLINE_VISIBILITY 122 vector<double> probabilities() const {return __p_.probabilities();} 123 124 _LIBCPP_INLINE_VISIBILITY 125 param_type param() const {return __p_;} 126 _LIBCPP_INLINE_VISIBILITY 127 void param(const param_type& __p) {__p_ = __p;} 128 129 _LIBCPP_INLINE_VISIBILITY 130 result_type min() const {return 0;} 131 _LIBCPP_INLINE_VISIBILITY 132 result_type max() const {return __p_.__p_.size();} 133 134 friend _LIBCPP_INLINE_VISIBILITY 135 bool operator==(const discrete_distribution& __x, 136 const discrete_distribution& __y) 137 {return __x.__p_ == __y.__p_;} 138 friend _LIBCPP_INLINE_VISIBILITY 139 bool operator!=(const discrete_distribution& __x, 140 const discrete_distribution& __y) 141 {return !(__x == __y);} 142 143 template <class _CharT, class _Traits, class _IT> 144 friend 145 basic_ostream<_CharT, _Traits>& 146 operator<<(basic_ostream<_CharT, _Traits>& __os, 147 const discrete_distribution<_IT>& __x); 148 149 template <class _CharT, class _Traits, class _IT> 150 friend 151 basic_istream<_CharT, _Traits>& 152 operator>>(basic_istream<_CharT, _Traits>& __is, 153 discrete_distribution<_IT>& __x); 154 }; 155 156 template<class _IntType> 157 template<class _UnaryOperation> 158 discrete_distribution<_IntType>::param_type::param_type(size_t __nw, 159 double __xmin, 160 double __xmax, 161 _UnaryOperation __fw) 162 { 163 if (__nw > 1) 164 { 165 __p_.reserve(__nw - 1); 166 double __d = (__xmax - __xmin) / __nw; 167 double __d2 = __d / 2; 168 for (size_t __k = 0; __k < __nw; ++__k) 169 __p_.push_back(__fw(__xmin + __k * __d + __d2)); 170 __init(); 171 } 172 } 173 174 template<class _IntType> 175 void 176 discrete_distribution<_IntType>::param_type::__init() 177 { 178 if (!__p_.empty()) 179 { 180 if (__p_.size() > 1) 181 { 182 double __s = _VSTD::accumulate(__p_.begin(), __p_.end(), 0.0); 183 for (vector<double>::iterator __i = __p_.begin(), __e = __p_.end(); __i < __e; ++__i) 184 *__i /= __s; 185 vector<double> __t(__p_.size() - 1); 186 _VSTD::partial_sum(__p_.begin(), __p_.end() - 1, __t.begin()); 187 swap(__p_, __t); 188 } 189 else 190 { 191 __p_.clear(); 192 __p_.shrink_to_fit(); 193 } 194 } 195 } 196 197 template<class _IntType> 198 vector<double> 199 discrete_distribution<_IntType>::param_type::probabilities() const 200 { 201 size_t __n = __p_.size(); 202 vector<double> __p(__n+1); 203 _VSTD::adjacent_difference(__p_.begin(), __p_.end(), __p.begin()); 204 if (__n > 0) 205 __p[__n] = 1 - __p_[__n-1]; 206 else 207 __p[0] = 1; 208 return __p; 209 } 210 211 template<class _IntType> 212 template<class _URNG> 213 _IntType 214 discrete_distribution<_IntType>::operator()(_URNG& __g, const param_type& __p) 215 { 216 static_assert(__libcpp_random_is_valid_urng<_URNG>::value, ""); 217 uniform_real_distribution<double> __gen; 218 return static_cast<_IntType>( 219 _VSTD::upper_bound(__p.__p_.begin(), __p.__p_.end(), __gen(__g)) - 220 __p.__p_.begin()); 221 } 222 223 template <class _CharT, class _Traits, class _IT> 224 basic_ostream<_CharT, _Traits>& 225 operator<<(basic_ostream<_CharT, _Traits>& __os, 226 const discrete_distribution<_IT>& __x) 227 { 228 __save_flags<_CharT, _Traits> __lx(__os); 229 typedef basic_ostream<_CharT, _Traits> _OStream; 230 __os.flags(_OStream::dec | _OStream::left | _OStream::fixed | 231 _OStream::scientific); 232 _CharT __sp = __os.widen(' '); 233 __os.fill(__sp); 234 size_t __n = __x.__p_.__p_.size(); 235 __os << __n; 236 for (size_t __i = 0; __i < __n; ++__i) 237 __os << __sp << __x.__p_.__p_[__i]; 238 return __os; 239 } 240 241 template <class _CharT, class _Traits, class _IT> 242 basic_istream<_CharT, _Traits>& 243 operator>>(basic_istream<_CharT, _Traits>& __is, 244 discrete_distribution<_IT>& __x) 245 { 246 __save_flags<_CharT, _Traits> __lx(__is); 247 typedef basic_istream<_CharT, _Traits> _Istream; 248 __is.flags(_Istream::dec | _Istream::skipws); 249 size_t __n; 250 __is >> __n; 251 vector<double> __p(__n); 252 for (size_t __i = 0; __i < __n; ++__i) 253 __is >> __p[__i]; 254 if (!__is.fail()) 255 swap(__x.__p_.__p_, __p); 256 return __is; 257 } 258 259 _LIBCPP_END_NAMESPACE_STD 260 261 _LIBCPP_POP_MACROS 262 263 #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 264