xref: /freebsd/contrib/llvm-project/libcxx/include/__random/discrete_distribution.h (revision 924226fba12cc9a228c73b956e1b7fa24c60b055)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
10 #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
11 
12 #include <__algorithm/upper_bound.h>
13 #include <__config>
14 #include <__random/uniform_real_distribution.h>
15 #include <cstddef>
16 #include <iosfwd>
17 #include <numeric>
18 #include <vector>
19 
20 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
21 #pragma GCC system_header
22 #endif
23 
24 _LIBCPP_PUSH_MACROS
25 #include <__undef_macros>
26 
27 _LIBCPP_BEGIN_NAMESPACE_STD
28 
29 template<class _IntType = int>
30 class _LIBCPP_TEMPLATE_VIS discrete_distribution
31 {
32 public:
33     // types
34     typedef _IntType result_type;
35 
36     class _LIBCPP_TEMPLATE_VIS param_type
37     {
38         vector<double> __p_;
39     public:
40         typedef discrete_distribution distribution_type;
41 
42         _LIBCPP_INLINE_VISIBILITY
43         param_type() {}
44         template<class _InputIterator>
45             _LIBCPP_INLINE_VISIBILITY
46             param_type(_InputIterator __f, _InputIterator __l)
47             : __p_(__f, __l) {__init();}
48 #ifndef _LIBCPP_CXX03_LANG
49         _LIBCPP_INLINE_VISIBILITY
50         param_type(initializer_list<double> __wl)
51             : __p_(__wl.begin(), __wl.end()) {__init();}
52 #endif // _LIBCPP_CXX03_LANG
53         template<class _UnaryOperation>
54             param_type(size_t __nw, double __xmin, double __xmax,
55                        _UnaryOperation __fw);
56 
57         vector<double> probabilities() const;
58 
59         friend _LIBCPP_INLINE_VISIBILITY
60             bool operator==(const param_type& __x, const param_type& __y)
61             {return __x.__p_ == __y.__p_;}
62         friend _LIBCPP_INLINE_VISIBILITY
63             bool operator!=(const param_type& __x, const param_type& __y)
64             {return !(__x == __y);}
65 
66     private:
67         void __init();
68 
69         friend class discrete_distribution;
70 
71         template <class _CharT, class _Traits, class _IT>
72         friend
73         basic_ostream<_CharT, _Traits>&
74         operator<<(basic_ostream<_CharT, _Traits>& __os,
75                    const discrete_distribution<_IT>& __x);
76 
77         template <class _CharT, class _Traits, class _IT>
78         friend
79         basic_istream<_CharT, _Traits>&
80         operator>>(basic_istream<_CharT, _Traits>& __is,
81                    discrete_distribution<_IT>& __x);
82     };
83 
84 private:
85     param_type __p_;
86 
87 public:
88     // constructor and reset functions
89     _LIBCPP_INLINE_VISIBILITY
90     discrete_distribution() {}
91     template<class _InputIterator>
92         _LIBCPP_INLINE_VISIBILITY
93         discrete_distribution(_InputIterator __f, _InputIterator __l)
94             : __p_(__f, __l) {}
95 #ifndef _LIBCPP_CXX03_LANG
96     _LIBCPP_INLINE_VISIBILITY
97     discrete_distribution(initializer_list<double> __wl)
98         : __p_(__wl) {}
99 #endif // _LIBCPP_CXX03_LANG
100     template<class _UnaryOperation>
101         _LIBCPP_INLINE_VISIBILITY
102         discrete_distribution(size_t __nw, double __xmin, double __xmax,
103                               _UnaryOperation __fw)
104         : __p_(__nw, __xmin, __xmax, __fw) {}
105     _LIBCPP_INLINE_VISIBILITY
106     explicit discrete_distribution(const param_type& __p)
107         : __p_(__p) {}
108     _LIBCPP_INLINE_VISIBILITY
109     void reset() {}
110 
111     // generating functions
112     template<class _URNG>
113         _LIBCPP_INLINE_VISIBILITY
114         result_type operator()(_URNG& __g)
115         {return (*this)(__g, __p_);}
116     template<class _URNG> result_type operator()(_URNG& __g, const param_type& __p);
117 
118     // property functions
119     _LIBCPP_INLINE_VISIBILITY
120     vector<double> probabilities() const {return __p_.probabilities();}
121 
122     _LIBCPP_INLINE_VISIBILITY
123     param_type param() const {return __p_;}
124     _LIBCPP_INLINE_VISIBILITY
125     void param(const param_type& __p) {__p_ = __p;}
126 
127     _LIBCPP_INLINE_VISIBILITY
128     result_type min() const {return 0;}
129     _LIBCPP_INLINE_VISIBILITY
130     result_type max() const {return __p_.__p_.size();}
131 
132     friend _LIBCPP_INLINE_VISIBILITY
133         bool operator==(const discrete_distribution& __x,
134                         const discrete_distribution& __y)
135         {return __x.__p_ == __y.__p_;}
136     friend _LIBCPP_INLINE_VISIBILITY
137         bool operator!=(const discrete_distribution& __x,
138                         const discrete_distribution& __y)
139         {return !(__x == __y);}
140 
141     template <class _CharT, class _Traits, class _IT>
142     friend
143     basic_ostream<_CharT, _Traits>&
144     operator<<(basic_ostream<_CharT, _Traits>& __os,
145                const discrete_distribution<_IT>& __x);
146 
147     template <class _CharT, class _Traits, class _IT>
148     friend
149     basic_istream<_CharT, _Traits>&
150     operator>>(basic_istream<_CharT, _Traits>& __is,
151                discrete_distribution<_IT>& __x);
152 };
153 
154 template<class _IntType>
155 template<class _UnaryOperation>
156 discrete_distribution<_IntType>::param_type::param_type(size_t __nw,
157                                                         double __xmin,
158                                                         double __xmax,
159                                                         _UnaryOperation __fw)
160 {
161     if (__nw > 1)
162     {
163         __p_.reserve(__nw - 1);
164         double __d = (__xmax - __xmin) / __nw;
165         double __d2 = __d / 2;
166         for (size_t __k = 0; __k < __nw; ++__k)
167             __p_.push_back(__fw(__xmin + __k * __d + __d2));
168         __init();
169     }
170 }
171 
172 template<class _IntType>
173 void
174 discrete_distribution<_IntType>::param_type::__init()
175 {
176     if (!__p_.empty())
177     {
178         if (__p_.size() > 1)
179         {
180             double __s = _VSTD::accumulate(__p_.begin(), __p_.end(), 0.0);
181             for (vector<double>::iterator __i = __p_.begin(), __e = __p_.end(); __i < __e; ++__i)
182                 *__i /= __s;
183             vector<double> __t(__p_.size() - 1);
184             _VSTD::partial_sum(__p_.begin(), __p_.end() - 1, __t.begin());
185             swap(__p_, __t);
186         }
187         else
188         {
189             __p_.clear();
190             __p_.shrink_to_fit();
191         }
192     }
193 }
194 
195 template<class _IntType>
196 vector<double>
197 discrete_distribution<_IntType>::param_type::probabilities() const
198 {
199     size_t __n = __p_.size();
200     vector<double> __p(__n+1);
201     _VSTD::adjacent_difference(__p_.begin(), __p_.end(), __p.begin());
202     if (__n > 0)
203         __p[__n] = 1 - __p_[__n-1];
204     else
205         __p[0] = 1;
206     return __p;
207 }
208 
209 template<class _IntType>
210 template<class _URNG>
211 _IntType
212 discrete_distribution<_IntType>::operator()(_URNG& __g, const param_type& __p)
213 {
214     uniform_real_distribution<double> __gen;
215     return static_cast<_IntType>(
216            _VSTD::upper_bound(__p.__p_.begin(), __p.__p_.end(), __gen(__g)) -
217                                                               __p.__p_.begin());
218 }
219 
220 template <class _CharT, class _Traits, class _IT>
221 basic_ostream<_CharT, _Traits>&
222 operator<<(basic_ostream<_CharT, _Traits>& __os,
223            const discrete_distribution<_IT>& __x)
224 {
225     __save_flags<_CharT, _Traits> __lx(__os);
226     typedef basic_ostream<_CharT, _Traits> _OStream;
227     __os.flags(_OStream::dec | _OStream::left | _OStream::fixed |
228                _OStream::scientific);
229     _CharT __sp = __os.widen(' ');
230     __os.fill(__sp);
231     size_t __n = __x.__p_.__p_.size();
232     __os << __n;
233     for (size_t __i = 0; __i < __n; ++__i)
234         __os << __sp << __x.__p_.__p_[__i];
235     return __os;
236 }
237 
238 template <class _CharT, class _Traits, class _IT>
239 basic_istream<_CharT, _Traits>&
240 operator>>(basic_istream<_CharT, _Traits>& __is,
241            discrete_distribution<_IT>& __x)
242 {
243     __save_flags<_CharT, _Traits> __lx(__is);
244     typedef basic_istream<_CharT, _Traits> _Istream;
245     __is.flags(_Istream::dec | _Istream::skipws);
246     size_t __n;
247     __is >> __n;
248     vector<double> __p(__n);
249     for (size_t __i = 0; __i < __n; ++__i)
250         __is >> __p[__i];
251     if (!__is.fail())
252         swap(__x.__p_.__p_, __p);
253     return __is;
254 }
255 
256 _LIBCPP_END_NAMESPACE_STD
257 
258 _LIBCPP_POP_MACROS
259 
260 #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
261