xref: /freebsd/contrib/llvm-project/libcxx/include/__random/discrete_distribution.h (revision f126890ac5386406dadf7c4cfa9566cbb56537c5)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
10 #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
11 
12 #include <__algorithm/upper_bound.h>
13 #include <__config>
14 #include <__random/is_valid.h>
15 #include <__random/uniform_real_distribution.h>
16 #include <cstddef>
17 #include <iosfwd>
18 #include <numeric>
19 #include <vector>
20 
21 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
22 #  pragma GCC system_header
23 #endif
24 
25 _LIBCPP_PUSH_MACROS
26 #include <__undef_macros>
27 
28 _LIBCPP_BEGIN_NAMESPACE_STD
29 
30 template<class _IntType = int>
31 class _LIBCPP_TEMPLATE_VIS discrete_distribution
32 {
33     static_assert(__libcpp_random_is_valid_inttype<_IntType>::value, "IntType must be a supported integer type");
34 public:
35     // types
36     typedef _IntType result_type;
37 
38     class _LIBCPP_TEMPLATE_VIS param_type
39     {
40         vector<double> __p_;
41     public:
42         typedef discrete_distribution distribution_type;
43 
44         _LIBCPP_INLINE_VISIBILITY
45         param_type() {}
46         template<class _InputIterator>
47             _LIBCPP_INLINE_VISIBILITY
48             param_type(_InputIterator __f, _InputIterator __l)
49             : __p_(__f, __l) {__init();}
50 #ifndef _LIBCPP_CXX03_LANG
51         _LIBCPP_INLINE_VISIBILITY
52         param_type(initializer_list<double> __wl)
53             : __p_(__wl.begin(), __wl.end()) {__init();}
54 #endif // _LIBCPP_CXX03_LANG
55         template<class _UnaryOperation>
56         _LIBCPP_HIDE_FROM_ABI param_type(size_t __nw, double __xmin, double __xmax,
57                        _UnaryOperation __fw);
58 
59         _LIBCPP_HIDE_FROM_ABI vector<double> probabilities() const;
60 
61         friend _LIBCPP_INLINE_VISIBILITY
62             bool operator==(const param_type& __x, const param_type& __y)
63             {return __x.__p_ == __y.__p_;}
64         friend _LIBCPP_INLINE_VISIBILITY
65             bool operator!=(const param_type& __x, const param_type& __y)
66             {return !(__x == __y);}
67 
68     private:
69         _LIBCPP_HIDE_FROM_ABI void __init();
70 
71         friend class discrete_distribution;
72 
73         template <class _CharT, class _Traits, class _IT>
74         friend
75         basic_ostream<_CharT, _Traits>&
76         operator<<(basic_ostream<_CharT, _Traits>& __os,
77                    const discrete_distribution<_IT>& __x);
78 
79         template <class _CharT, class _Traits, class _IT>
80         friend
81         basic_istream<_CharT, _Traits>&
82         operator>>(basic_istream<_CharT, _Traits>& __is,
83                    discrete_distribution<_IT>& __x);
84     };
85 
86 private:
87     param_type __p_;
88 
89 public:
90     // constructor and reset functions
91     _LIBCPP_INLINE_VISIBILITY
92     discrete_distribution() {}
93     template<class _InputIterator>
94         _LIBCPP_INLINE_VISIBILITY
95         discrete_distribution(_InputIterator __f, _InputIterator __l)
96             : __p_(__f, __l) {}
97 #ifndef _LIBCPP_CXX03_LANG
98     _LIBCPP_INLINE_VISIBILITY
99     discrete_distribution(initializer_list<double> __wl)
100         : __p_(__wl) {}
101 #endif // _LIBCPP_CXX03_LANG
102     template<class _UnaryOperation>
103         _LIBCPP_INLINE_VISIBILITY
104         discrete_distribution(size_t __nw, double __xmin, double __xmax,
105                               _UnaryOperation __fw)
106         : __p_(__nw, __xmin, __xmax, __fw) {}
107     _LIBCPP_INLINE_VISIBILITY
108     explicit discrete_distribution(const param_type& __p)
109         : __p_(__p) {}
110     _LIBCPP_INLINE_VISIBILITY
111     void reset() {}
112 
113     // generating functions
114     template<class _URNG>
115         _LIBCPP_INLINE_VISIBILITY
116         result_type operator()(_URNG& __g)
117         {return (*this)(__g, __p_);}
118     template<class _URNG>
119     _LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g, const param_type& __p);
120 
121     // property functions
122     _LIBCPP_INLINE_VISIBILITY
123     vector<double> probabilities() const {return __p_.probabilities();}
124 
125     _LIBCPP_INLINE_VISIBILITY
126     param_type param() const {return __p_;}
127     _LIBCPP_INLINE_VISIBILITY
128     void param(const param_type& __p) {__p_ = __p;}
129 
130     _LIBCPP_INLINE_VISIBILITY
131     result_type min() const {return 0;}
132     _LIBCPP_INLINE_VISIBILITY
133     result_type max() const {return __p_.__p_.size();}
134 
135     friend _LIBCPP_INLINE_VISIBILITY
136         bool operator==(const discrete_distribution& __x,
137                         const discrete_distribution& __y)
138         {return __x.__p_ == __y.__p_;}
139     friend _LIBCPP_INLINE_VISIBILITY
140         bool operator!=(const discrete_distribution& __x,
141                         const discrete_distribution& __y)
142         {return !(__x == __y);}
143 
144     template <class _CharT, class _Traits, class _IT>
145     friend
146     basic_ostream<_CharT, _Traits>&
147     operator<<(basic_ostream<_CharT, _Traits>& __os,
148                const discrete_distribution<_IT>& __x);
149 
150     template <class _CharT, class _Traits, class _IT>
151     friend
152     basic_istream<_CharT, _Traits>&
153     operator>>(basic_istream<_CharT, _Traits>& __is,
154                discrete_distribution<_IT>& __x);
155 };
156 
157 template<class _IntType>
158 template<class _UnaryOperation>
159 discrete_distribution<_IntType>::param_type::param_type(size_t __nw,
160                                                         double __xmin,
161                                                         double __xmax,
162                                                         _UnaryOperation __fw)
163 {
164     if (__nw > 1)
165     {
166         __p_.reserve(__nw - 1);
167         double __d = (__xmax - __xmin) / __nw;
168         double __d2 = __d / 2;
169         for (size_t __k = 0; __k < __nw; ++__k)
170             __p_.push_back(__fw(__xmin + __k * __d + __d2));
171         __init();
172     }
173 }
174 
175 template<class _IntType>
176 void
177 discrete_distribution<_IntType>::param_type::__init()
178 {
179     if (!__p_.empty())
180     {
181         if (__p_.size() > 1)
182         {
183             double __s = _VSTD::accumulate(__p_.begin(), __p_.end(), 0.0);
184             for (vector<double>::iterator __i = __p_.begin(), __e = __p_.end(); __i < __e; ++__i)
185                 *__i /= __s;
186             vector<double> __t(__p_.size() - 1);
187             _VSTD::partial_sum(__p_.begin(), __p_.end() - 1, __t.begin());
188             swap(__p_, __t);
189         }
190         else
191         {
192             __p_.clear();
193             __p_.shrink_to_fit();
194         }
195     }
196 }
197 
198 template<class _IntType>
199 vector<double>
200 discrete_distribution<_IntType>::param_type::probabilities() const
201 {
202     size_t __n = __p_.size();
203     vector<double> __p(__n+1);
204     _VSTD::adjacent_difference(__p_.begin(), __p_.end(), __p.begin());
205     if (__n > 0)
206         __p[__n] = 1 - __p_[__n-1];
207     else
208         __p[0] = 1;
209     return __p;
210 }
211 
212 template<class _IntType>
213 template<class _URNG>
214 _IntType
215 discrete_distribution<_IntType>::operator()(_URNG& __g, const param_type& __p)
216 {
217     static_assert(__libcpp_random_is_valid_urng<_URNG>::value, "");
218     uniform_real_distribution<double> __gen;
219     return static_cast<_IntType>(
220            _VSTD::upper_bound(__p.__p_.begin(), __p.__p_.end(), __gen(__g)) -
221                                                               __p.__p_.begin());
222 }
223 
224 template <class _CharT, class _Traits, class _IT>
225 _LIBCPP_HIDE_FROM_ABI basic_ostream<_CharT, _Traits>&
226 operator<<(basic_ostream<_CharT, _Traits>& __os,
227            const discrete_distribution<_IT>& __x)
228 {
229     __save_flags<_CharT, _Traits> __lx(__os);
230     typedef basic_ostream<_CharT, _Traits> _OStream;
231     __os.flags(_OStream::dec | _OStream::left | _OStream::fixed |
232                _OStream::scientific);
233     _CharT __sp = __os.widen(' ');
234     __os.fill(__sp);
235     size_t __n = __x.__p_.__p_.size();
236     __os << __n;
237     for (size_t __i = 0; __i < __n; ++__i)
238         __os << __sp << __x.__p_.__p_[__i];
239     return __os;
240 }
241 
242 template <class _CharT, class _Traits, class _IT>
243 _LIBCPP_HIDE_FROM_ABI basic_istream<_CharT, _Traits>&
244 operator>>(basic_istream<_CharT, _Traits>& __is,
245            discrete_distribution<_IT>& __x)
246 {
247     __save_flags<_CharT, _Traits> __lx(__is);
248     typedef basic_istream<_CharT, _Traits> _Istream;
249     __is.flags(_Istream::dec | _Istream::skipws);
250     size_t __n;
251     __is >> __n;
252     vector<double> __p(__n);
253     for (size_t __i = 0; __i < __n; ++__i)
254         __is >> __p[__i];
255     if (!__is.fail())
256         swap(__x.__p_.__p_, __p);
257     return __is;
258 }
259 
260 _LIBCPP_END_NAMESPACE_STD
261 
262 _LIBCPP_POP_MACROS
263 
264 #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
265