//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H #include <__algorithm/upper_bound.h> #include <__config> #include <__random/is_valid.h> #include <__random/uniform_real_distribution.h> #include #include #include #include #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) # pragma GCC system_header #endif _LIBCPP_PUSH_MACROS #include <__undef_macros> _LIBCPP_BEGIN_NAMESPACE_STD template class _LIBCPP_TEMPLATE_VIS discrete_distribution { static_assert(__libcpp_random_is_valid_inttype<_IntType>::value, "IntType must be a supported integer type"); public: // types typedef _IntType result_type; class _LIBCPP_TEMPLATE_VIS param_type { vector __p_; public: typedef discrete_distribution distribution_type; _LIBCPP_INLINE_VISIBILITY param_type() {} template _LIBCPP_INLINE_VISIBILITY param_type(_InputIterator __f, _InputIterator __l) : __p_(__f, __l) {__init();} #ifndef _LIBCPP_CXX03_LANG _LIBCPP_INLINE_VISIBILITY param_type(initializer_list __wl) : __p_(__wl.begin(), __wl.end()) {__init();} #endif // _LIBCPP_CXX03_LANG template param_type(size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw); vector probabilities() const; friend _LIBCPP_INLINE_VISIBILITY bool operator==(const param_type& __x, const param_type& __y) {return __x.__p_ == __y.__p_;} friend _LIBCPP_INLINE_VISIBILITY bool operator!=(const param_type& __x, const param_type& __y) {return !(__x == __y);} private: void __init(); friend class discrete_distribution; template friend basic_ostream<_CharT, _Traits>& operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x); template friend basic_istream<_CharT, _Traits>& operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x); }; private: param_type __p_; public: // constructor and reset functions _LIBCPP_INLINE_VISIBILITY discrete_distribution() {} template _LIBCPP_INLINE_VISIBILITY discrete_distribution(_InputIterator __f, _InputIterator __l) : __p_(__f, __l) {} #ifndef _LIBCPP_CXX03_LANG _LIBCPP_INLINE_VISIBILITY discrete_distribution(initializer_list __wl) : __p_(__wl) {} #endif // _LIBCPP_CXX03_LANG template _LIBCPP_INLINE_VISIBILITY discrete_distribution(size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw) : __p_(__nw, __xmin, __xmax, __fw) {} _LIBCPP_INLINE_VISIBILITY explicit discrete_distribution(const param_type& __p) : __p_(__p) {} _LIBCPP_INLINE_VISIBILITY void reset() {} // generating functions template _LIBCPP_INLINE_VISIBILITY result_type operator()(_URNG& __g) {return (*this)(__g, __p_);} template result_type operator()(_URNG& __g, const param_type& __p); // property functions _LIBCPP_INLINE_VISIBILITY vector probabilities() const {return __p_.probabilities();} _LIBCPP_INLINE_VISIBILITY param_type param() const {return __p_;} _LIBCPP_INLINE_VISIBILITY void param(const param_type& __p) {__p_ = __p;} _LIBCPP_INLINE_VISIBILITY result_type min() const {return 0;} _LIBCPP_INLINE_VISIBILITY result_type max() const {return __p_.__p_.size();} friend _LIBCPP_INLINE_VISIBILITY bool operator==(const discrete_distribution& __x, const discrete_distribution& __y) {return __x.__p_ == __y.__p_;} friend _LIBCPP_INLINE_VISIBILITY bool operator!=(const discrete_distribution& __x, const discrete_distribution& __y) {return !(__x == __y);} template friend basic_ostream<_CharT, _Traits>& operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x); template friend basic_istream<_CharT, _Traits>& operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x); }; template template discrete_distribution<_IntType>::param_type::param_type(size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw) { if (__nw > 1) { __p_.reserve(__nw - 1); double __d = (__xmax - __xmin) / __nw; double __d2 = __d / 2; for (size_t __k = 0; __k < __nw; ++__k) __p_.push_back(__fw(__xmin + __k * __d + __d2)); __init(); } } template void discrete_distribution<_IntType>::param_type::__init() { if (!__p_.empty()) { if (__p_.size() > 1) { double __s = _VSTD::accumulate(__p_.begin(), __p_.end(), 0.0); for (vector::iterator __i = __p_.begin(), __e = __p_.end(); __i < __e; ++__i) *__i /= __s; vector __t(__p_.size() - 1); _VSTD::partial_sum(__p_.begin(), __p_.end() - 1, __t.begin()); swap(__p_, __t); } else { __p_.clear(); __p_.shrink_to_fit(); } } } template vector discrete_distribution<_IntType>::param_type::probabilities() const { size_t __n = __p_.size(); vector __p(__n+1); _VSTD::adjacent_difference(__p_.begin(), __p_.end(), __p.begin()); if (__n > 0) __p[__n] = 1 - __p_[__n-1]; else __p[0] = 1; return __p; } template template _IntType discrete_distribution<_IntType>::operator()(_URNG& __g, const param_type& __p) { static_assert(__libcpp_random_is_valid_urng<_URNG>::value, ""); uniform_real_distribution __gen; return static_cast<_IntType>( _VSTD::upper_bound(__p.__p_.begin(), __p.__p_.end(), __gen(__g)) - __p.__p_.begin()); } template _LIBCPP_HIDE_FROM_ABI basic_ostream<_CharT, _Traits>& operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x) { __save_flags<_CharT, _Traits> __lx(__os); typedef basic_ostream<_CharT, _Traits> _OStream; __os.flags(_OStream::dec | _OStream::left | _OStream::fixed | _OStream::scientific); _CharT __sp = __os.widen(' '); __os.fill(__sp); size_t __n = __x.__p_.__p_.size(); __os << __n; for (size_t __i = 0; __i < __n; ++__i) __os << __sp << __x.__p_.__p_[__i]; return __os; } template _LIBCPP_HIDE_FROM_ABI basic_istream<_CharT, _Traits>& operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x) { __save_flags<_CharT, _Traits> __lx(__is); typedef basic_istream<_CharT, _Traits> _Istream; __is.flags(_Istream::dec | _Istream::skipws); size_t __n; __is >> __n; vector __p(__n); for (size_t __i = 0; __i < __n; ++__i) __is >> __p[__i]; if (!__is.fail()) swap(__x.__p_.__p_, __p); return __is; } _LIBCPP_END_NAMESPACE_STD _LIBCPP_POP_MACROS #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H