1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 10 #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 11 12 #include <__algorithm/upper_bound.h> 13 #include <__config> 14 #include <__random/is_valid.h> 15 #include <__random/uniform_real_distribution.h> 16 #include <cstddef> 17 #include <iosfwd> 18 #include <numeric> 19 #include <vector> 20 21 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) 22 # pragma GCC system_header 23 #endif 24 25 _LIBCPP_PUSH_MACROS 26 #include <__undef_macros> 27 28 _LIBCPP_BEGIN_NAMESPACE_STD 29 30 template<class _IntType = int> 31 class _LIBCPP_TEMPLATE_VIS discrete_distribution 32 { 33 static_assert(__libcpp_random_is_valid_inttype<_IntType>::value, "IntType must be a supported integer type"); 34 public: 35 // types 36 typedef _IntType result_type; 37 38 class _LIBCPP_TEMPLATE_VIS param_type 39 { 40 vector<double> __p_; 41 public: 42 typedef discrete_distribution distribution_type; 43 44 _LIBCPP_INLINE_VISIBILITY 45 param_type() {} 46 template<class _InputIterator> 47 _LIBCPP_INLINE_VISIBILITY 48 param_type(_InputIterator __f, _InputIterator __l) 49 : __p_(__f, __l) {__init();} 50 #ifndef _LIBCPP_CXX03_LANG 51 _LIBCPP_INLINE_VISIBILITY 52 param_type(initializer_list<double> __wl) 53 : __p_(__wl.begin(), __wl.end()) {__init();} 54 #endif // _LIBCPP_CXX03_LANG 55 template<class _UnaryOperation> 56 _LIBCPP_HIDE_FROM_ABI param_type(size_t __nw, double __xmin, double __xmax, 57 _UnaryOperation __fw); 58 59 _LIBCPP_HIDE_FROM_ABI vector<double> probabilities() const; 60 61 friend _LIBCPP_INLINE_VISIBILITY 62 bool operator==(const param_type& __x, const param_type& __y) 63 {return __x.__p_ == __y.__p_;} 64 friend _LIBCPP_INLINE_VISIBILITY 65 bool operator!=(const param_type& __x, const param_type& __y) 66 {return !(__x == __y);} 67 68 private: 69 _LIBCPP_HIDE_FROM_ABI void __init(); 70 71 friend class discrete_distribution; 72 73 template <class _CharT, class _Traits, class _IT> 74 friend 75 basic_ostream<_CharT, _Traits>& 76 operator<<(basic_ostream<_CharT, _Traits>& __os, 77 const discrete_distribution<_IT>& __x); 78 79 template <class _CharT, class _Traits, class _IT> 80 friend 81 basic_istream<_CharT, _Traits>& 82 operator>>(basic_istream<_CharT, _Traits>& __is, 83 discrete_distribution<_IT>& __x); 84 }; 85 86 private: 87 param_type __p_; 88 89 public: 90 // constructor and reset functions 91 _LIBCPP_INLINE_VISIBILITY 92 discrete_distribution() {} 93 template<class _InputIterator> 94 _LIBCPP_INLINE_VISIBILITY 95 discrete_distribution(_InputIterator __f, _InputIterator __l) 96 : __p_(__f, __l) {} 97 #ifndef _LIBCPP_CXX03_LANG 98 _LIBCPP_INLINE_VISIBILITY 99 discrete_distribution(initializer_list<double> __wl) 100 : __p_(__wl) {} 101 #endif // _LIBCPP_CXX03_LANG 102 template<class _UnaryOperation> 103 _LIBCPP_INLINE_VISIBILITY 104 discrete_distribution(size_t __nw, double __xmin, double __xmax, 105 _UnaryOperation __fw) 106 : __p_(__nw, __xmin, __xmax, __fw) {} 107 _LIBCPP_INLINE_VISIBILITY 108 explicit discrete_distribution(const param_type& __p) 109 : __p_(__p) {} 110 _LIBCPP_INLINE_VISIBILITY 111 void reset() {} 112 113 // generating functions 114 template<class _URNG> 115 _LIBCPP_INLINE_VISIBILITY 116 result_type operator()(_URNG& __g) 117 {return (*this)(__g, __p_);} 118 template<class _URNG> 119 _LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g, const param_type& __p); 120 121 // property functions 122 _LIBCPP_INLINE_VISIBILITY 123 vector<double> probabilities() const {return __p_.probabilities();} 124 125 _LIBCPP_INLINE_VISIBILITY 126 param_type param() const {return __p_;} 127 _LIBCPP_INLINE_VISIBILITY 128 void param(const param_type& __p) {__p_ = __p;} 129 130 _LIBCPP_INLINE_VISIBILITY 131 result_type min() const {return 0;} 132 _LIBCPP_INLINE_VISIBILITY 133 result_type max() const {return __p_.__p_.size();} 134 135 friend _LIBCPP_INLINE_VISIBILITY 136 bool operator==(const discrete_distribution& __x, 137 const discrete_distribution& __y) 138 {return __x.__p_ == __y.__p_;} 139 friend _LIBCPP_INLINE_VISIBILITY 140 bool operator!=(const discrete_distribution& __x, 141 const discrete_distribution& __y) 142 {return !(__x == __y);} 143 144 template <class _CharT, class _Traits, class _IT> 145 friend 146 basic_ostream<_CharT, _Traits>& 147 operator<<(basic_ostream<_CharT, _Traits>& __os, 148 const discrete_distribution<_IT>& __x); 149 150 template <class _CharT, class _Traits, class _IT> 151 friend 152 basic_istream<_CharT, _Traits>& 153 operator>>(basic_istream<_CharT, _Traits>& __is, 154 discrete_distribution<_IT>& __x); 155 }; 156 157 template<class _IntType> 158 template<class _UnaryOperation> 159 discrete_distribution<_IntType>::param_type::param_type(size_t __nw, 160 double __xmin, 161 double __xmax, 162 _UnaryOperation __fw) 163 { 164 if (__nw > 1) 165 { 166 __p_.reserve(__nw - 1); 167 double __d = (__xmax - __xmin) / __nw; 168 double __d2 = __d / 2; 169 for (size_t __k = 0; __k < __nw; ++__k) 170 __p_.push_back(__fw(__xmin + __k * __d + __d2)); 171 __init(); 172 } 173 } 174 175 template<class _IntType> 176 void 177 discrete_distribution<_IntType>::param_type::__init() 178 { 179 if (!__p_.empty()) 180 { 181 if (__p_.size() > 1) 182 { 183 double __s = _VSTD::accumulate(__p_.begin(), __p_.end(), 0.0); 184 for (vector<double>::iterator __i = __p_.begin(), __e = __p_.end(); __i < __e; ++__i) 185 *__i /= __s; 186 vector<double> __t(__p_.size() - 1); 187 _VSTD::partial_sum(__p_.begin(), __p_.end() - 1, __t.begin()); 188 swap(__p_, __t); 189 } 190 else 191 { 192 __p_.clear(); 193 __p_.shrink_to_fit(); 194 } 195 } 196 } 197 198 template<class _IntType> 199 vector<double> 200 discrete_distribution<_IntType>::param_type::probabilities() const 201 { 202 size_t __n = __p_.size(); 203 vector<double> __p(__n+1); 204 _VSTD::adjacent_difference(__p_.begin(), __p_.end(), __p.begin()); 205 if (__n > 0) 206 __p[__n] = 1 - __p_[__n-1]; 207 else 208 __p[0] = 1; 209 return __p; 210 } 211 212 template<class _IntType> 213 template<class _URNG> 214 _IntType 215 discrete_distribution<_IntType>::operator()(_URNG& __g, const param_type& __p) 216 { 217 static_assert(__libcpp_random_is_valid_urng<_URNG>::value, ""); 218 uniform_real_distribution<double> __gen; 219 return static_cast<_IntType>( 220 _VSTD::upper_bound(__p.__p_.begin(), __p.__p_.end(), __gen(__g)) - 221 __p.__p_.begin()); 222 } 223 224 template <class _CharT, class _Traits, class _IT> 225 _LIBCPP_HIDE_FROM_ABI basic_ostream<_CharT, _Traits>& 226 operator<<(basic_ostream<_CharT, _Traits>& __os, 227 const discrete_distribution<_IT>& __x) 228 { 229 __save_flags<_CharT, _Traits> __lx(__os); 230 typedef basic_ostream<_CharT, _Traits> _OStream; 231 __os.flags(_OStream::dec | _OStream::left | _OStream::fixed | 232 _OStream::scientific); 233 _CharT __sp = __os.widen(' '); 234 __os.fill(__sp); 235 size_t __n = __x.__p_.__p_.size(); 236 __os << __n; 237 for (size_t __i = 0; __i < __n; ++__i) 238 __os << __sp << __x.__p_.__p_[__i]; 239 return __os; 240 } 241 242 template <class _CharT, class _Traits, class _IT> 243 _LIBCPP_HIDE_FROM_ABI basic_istream<_CharT, _Traits>& 244 operator>>(basic_istream<_CharT, _Traits>& __is, 245 discrete_distribution<_IT>& __x) 246 { 247 __save_flags<_CharT, _Traits> __lx(__is); 248 typedef basic_istream<_CharT, _Traits> _Istream; 249 __is.flags(_Istream::dec | _Istream::skipws); 250 size_t __n; 251 __is >> __n; 252 vector<double> __p(__n); 253 for (size_t __i = 0; __i < __n; ++__i) 254 __is >> __p[__i]; 255 if (!__is.fail()) 256 swap(__x.__p_.__p_, __p); 257 return __is; 258 } 259 260 _LIBCPP_END_NAMESPACE_STD 261 262 _LIBCPP_POP_MACROS 263 264 #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 265