1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 10 #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 11 12 #include <__algorithm/upper_bound.h> 13 #include <__config> 14 #include <__random/uniform_real_distribution.h> 15 #include <cstddef> 16 #include <iosfwd> 17 #include <numeric> 18 #include <vector> 19 20 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) 21 #pragma GCC system_header 22 #endif 23 24 _LIBCPP_PUSH_MACROS 25 #include <__undef_macros> 26 27 _LIBCPP_BEGIN_NAMESPACE_STD 28 29 template<class _IntType = int> 30 class _LIBCPP_TEMPLATE_VIS discrete_distribution 31 { 32 public: 33 // types 34 typedef _IntType result_type; 35 36 class _LIBCPP_TEMPLATE_VIS param_type 37 { 38 vector<double> __p_; 39 public: 40 typedef discrete_distribution distribution_type; 41 42 _LIBCPP_INLINE_VISIBILITY 43 param_type() {} 44 template<class _InputIterator> 45 _LIBCPP_INLINE_VISIBILITY 46 param_type(_InputIterator __f, _InputIterator __l) 47 : __p_(__f, __l) {__init();} 48 #ifndef _LIBCPP_CXX03_LANG 49 _LIBCPP_INLINE_VISIBILITY 50 param_type(initializer_list<double> __wl) 51 : __p_(__wl.begin(), __wl.end()) {__init();} 52 #endif // _LIBCPP_CXX03_LANG 53 template<class _UnaryOperation> 54 param_type(size_t __nw, double __xmin, double __xmax, 55 _UnaryOperation __fw); 56 57 vector<double> probabilities() const; 58 59 friend _LIBCPP_INLINE_VISIBILITY 60 bool operator==(const param_type& __x, const param_type& __y) 61 {return __x.__p_ == __y.__p_;} 62 friend _LIBCPP_INLINE_VISIBILITY 63 bool operator!=(const param_type& __x, const param_type& __y) 64 {return !(__x == __y);} 65 66 private: 67 void __init(); 68 69 friend class discrete_distribution; 70 71 template <class _CharT, class _Traits, class _IT> 72 friend 73 basic_ostream<_CharT, _Traits>& 74 operator<<(basic_ostream<_CharT, _Traits>& __os, 75 const discrete_distribution<_IT>& __x); 76 77 template <class _CharT, class _Traits, class _IT> 78 friend 79 basic_istream<_CharT, _Traits>& 80 operator>>(basic_istream<_CharT, _Traits>& __is, 81 discrete_distribution<_IT>& __x); 82 }; 83 84 private: 85 param_type __p_; 86 87 public: 88 // constructor and reset functions 89 _LIBCPP_INLINE_VISIBILITY 90 discrete_distribution() {} 91 template<class _InputIterator> 92 _LIBCPP_INLINE_VISIBILITY 93 discrete_distribution(_InputIterator __f, _InputIterator __l) 94 : __p_(__f, __l) {} 95 #ifndef _LIBCPP_CXX03_LANG 96 _LIBCPP_INLINE_VISIBILITY 97 discrete_distribution(initializer_list<double> __wl) 98 : __p_(__wl) {} 99 #endif // _LIBCPP_CXX03_LANG 100 template<class _UnaryOperation> 101 _LIBCPP_INLINE_VISIBILITY 102 discrete_distribution(size_t __nw, double __xmin, double __xmax, 103 _UnaryOperation __fw) 104 : __p_(__nw, __xmin, __xmax, __fw) {} 105 _LIBCPP_INLINE_VISIBILITY 106 explicit discrete_distribution(const param_type& __p) 107 : __p_(__p) {} 108 _LIBCPP_INLINE_VISIBILITY 109 void reset() {} 110 111 // generating functions 112 template<class _URNG> 113 _LIBCPP_INLINE_VISIBILITY 114 result_type operator()(_URNG& __g) 115 {return (*this)(__g, __p_);} 116 template<class _URNG> result_type operator()(_URNG& __g, const param_type& __p); 117 118 // property functions 119 _LIBCPP_INLINE_VISIBILITY 120 vector<double> probabilities() const {return __p_.probabilities();} 121 122 _LIBCPP_INLINE_VISIBILITY 123 param_type param() const {return __p_;} 124 _LIBCPP_INLINE_VISIBILITY 125 void param(const param_type& __p) {__p_ = __p;} 126 127 _LIBCPP_INLINE_VISIBILITY 128 result_type min() const {return 0;} 129 _LIBCPP_INLINE_VISIBILITY 130 result_type max() const {return __p_.__p_.size();} 131 132 friend _LIBCPP_INLINE_VISIBILITY 133 bool operator==(const discrete_distribution& __x, 134 const discrete_distribution& __y) 135 {return __x.__p_ == __y.__p_;} 136 friend _LIBCPP_INLINE_VISIBILITY 137 bool operator!=(const discrete_distribution& __x, 138 const discrete_distribution& __y) 139 {return !(__x == __y);} 140 141 template <class _CharT, class _Traits, class _IT> 142 friend 143 basic_ostream<_CharT, _Traits>& 144 operator<<(basic_ostream<_CharT, _Traits>& __os, 145 const discrete_distribution<_IT>& __x); 146 147 template <class _CharT, class _Traits, class _IT> 148 friend 149 basic_istream<_CharT, _Traits>& 150 operator>>(basic_istream<_CharT, _Traits>& __is, 151 discrete_distribution<_IT>& __x); 152 }; 153 154 template<class _IntType> 155 template<class _UnaryOperation> 156 discrete_distribution<_IntType>::param_type::param_type(size_t __nw, 157 double __xmin, 158 double __xmax, 159 _UnaryOperation __fw) 160 { 161 if (__nw > 1) 162 { 163 __p_.reserve(__nw - 1); 164 double __d = (__xmax - __xmin) / __nw; 165 double __d2 = __d / 2; 166 for (size_t __k = 0; __k < __nw; ++__k) 167 __p_.push_back(__fw(__xmin + __k * __d + __d2)); 168 __init(); 169 } 170 } 171 172 template<class _IntType> 173 void 174 discrete_distribution<_IntType>::param_type::__init() 175 { 176 if (!__p_.empty()) 177 { 178 if (__p_.size() > 1) 179 { 180 double __s = _VSTD::accumulate(__p_.begin(), __p_.end(), 0.0); 181 for (vector<double>::iterator __i = __p_.begin(), __e = __p_.end(); __i < __e; ++__i) 182 *__i /= __s; 183 vector<double> __t(__p_.size() - 1); 184 _VSTD::partial_sum(__p_.begin(), __p_.end() - 1, __t.begin()); 185 swap(__p_, __t); 186 } 187 else 188 { 189 __p_.clear(); 190 __p_.shrink_to_fit(); 191 } 192 } 193 } 194 195 template<class _IntType> 196 vector<double> 197 discrete_distribution<_IntType>::param_type::probabilities() const 198 { 199 size_t __n = __p_.size(); 200 vector<double> __p(__n+1); 201 _VSTD::adjacent_difference(__p_.begin(), __p_.end(), __p.begin()); 202 if (__n > 0) 203 __p[__n] = 1 - __p_[__n-1]; 204 else 205 __p[0] = 1; 206 return __p; 207 } 208 209 template<class _IntType> 210 template<class _URNG> 211 _IntType 212 discrete_distribution<_IntType>::operator()(_URNG& __g, const param_type& __p) 213 { 214 uniform_real_distribution<double> __gen; 215 return static_cast<_IntType>( 216 _VSTD::upper_bound(__p.__p_.begin(), __p.__p_.end(), __gen(__g)) - 217 __p.__p_.begin()); 218 } 219 220 template <class _CharT, class _Traits, class _IT> 221 basic_ostream<_CharT, _Traits>& 222 operator<<(basic_ostream<_CharT, _Traits>& __os, 223 const discrete_distribution<_IT>& __x) 224 { 225 __save_flags<_CharT, _Traits> __lx(__os); 226 typedef basic_ostream<_CharT, _Traits> _OStream; 227 __os.flags(_OStream::dec | _OStream::left | _OStream::fixed | 228 _OStream::scientific); 229 _CharT __sp = __os.widen(' '); 230 __os.fill(__sp); 231 size_t __n = __x.__p_.__p_.size(); 232 __os << __n; 233 for (size_t __i = 0; __i < __n; ++__i) 234 __os << __sp << __x.__p_.__p_[__i]; 235 return __os; 236 } 237 238 template <class _CharT, class _Traits, class _IT> 239 basic_istream<_CharT, _Traits>& 240 operator>>(basic_istream<_CharT, _Traits>& __is, 241 discrete_distribution<_IT>& __x) 242 { 243 __save_flags<_CharT, _Traits> __lx(__is); 244 typedef basic_istream<_CharT, _Traits> _Istream; 245 __is.flags(_Istream::dec | _Istream::skipws); 246 size_t __n; 247 __is >> __n; 248 vector<double> __p(__n); 249 for (size_t __i = 0; __i < __n; ++__i) 250 __is >> __p[__i]; 251 if (!__is.fail()) 252 swap(__x.__p_.__p_, __p); 253 return __is; 254 } 255 256 _LIBCPP_END_NAMESPACE_STD 257 258 _LIBCPP_POP_MACROS 259 260 #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H 261