xref: /freebsd/contrib/llvm-project/libcxx/include/__pstl/cpu_algos/transform.h (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef _LIBCPP___PSTL_CPU_ALGOS_TRANSFORM_H
10 #define _LIBCPP___PSTL_CPU_ALGOS_TRANSFORM_H
11 
12 #include <__algorithm/transform.h>
13 #include <__assert>
14 #include <__config>
15 #include <__iterator/concepts.h>
16 #include <__iterator/iterator_traits.h>
17 #include <__pstl/backend_fwd.h>
18 #include <__pstl/cpu_algos/cpu_traits.h>
19 #include <__type_traits/is_execution_policy.h>
20 #include <__utility/move.h>
21 #include <optional>
22 
23 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
24 #  pragma GCC system_header
25 #endif
26 
27 _LIBCPP_PUSH_MACROS
28 #include <__undef_macros>
29 
30 _LIBCPP_BEGIN_NAMESPACE_STD
31 namespace __pstl {
32 
33 template <class _Iterator1, class _DifferenceType, class _Iterator2, class _Function>
34 _LIBCPP_HIDE_FROM_ABI _Iterator2
__simd_transform(_Iterator1 __first1,_DifferenceType __n,_Iterator2 __first2,_Function __f)35 __simd_transform(_Iterator1 __first1, _DifferenceType __n, _Iterator2 __first2, _Function __f) noexcept {
36   _PSTL_PRAGMA_SIMD
37   for (_DifferenceType __i = 0; __i < __n; ++__i)
38     __f(__first1[__i], __first2[__i]);
39   return __first2 + __n;
40 }
41 
42 template <class _Iterator1, class _DifferenceType, class _Iterator2, class _Iterator3, class _Function>
__simd_transform(_Iterator1 __first1,_DifferenceType __n,_Iterator2 __first2,_Iterator3 __first3,_Function __f)43 _LIBCPP_HIDE_FROM_ABI _Iterator3 __simd_transform(
44     _Iterator1 __first1, _DifferenceType __n, _Iterator2 __first2, _Iterator3 __first3, _Function __f) noexcept {
45   _PSTL_PRAGMA_SIMD
46   for (_DifferenceType __i = 0; __i < __n; ++__i)
47     __f(__first1[__i], __first2[__i], __first3[__i]);
48   return __first3 + __n;
49 }
50 
51 template <class _Backend, class _RawExecutionPolicy>
52 struct __cpu_parallel_transform {
53   template <class _Policy, class _ForwardIterator, class _ForwardOutIterator, class _UnaryOperation>
54   _LIBCPP_HIDE_FROM_ABI optional<_ForwardOutIterator>
operator__cpu_parallel_transform55   operator()(_Policy&& __policy,
56              _ForwardIterator __first,
57              _ForwardIterator __last,
58              _ForwardOutIterator __result,
59              _UnaryOperation __op) const noexcept {
60     if constexpr (__is_parallel_execution_policy_v<_RawExecutionPolicy> &&
61                   __has_random_access_iterator_category_or_concept<_ForwardIterator>::value &&
62                   __has_random_access_iterator_category_or_concept<_ForwardOutIterator>::value) {
63       __cpu_traits<_Backend>::__for_each(
64           __first,
65           __last,
66           [&__policy, __op, __first, __result](_ForwardIterator __brick_first, _ForwardIterator __brick_last) {
67             using _TransformUnseq = __pstl::__transform<_Backend, __remove_parallel_policy_t<_RawExecutionPolicy>>;
68             auto __res            = _TransformUnseq()(
69                 std::__remove_parallel_policy(__policy),
70                 __brick_first,
71                 __brick_last,
72                 __result + (__brick_first - __first),
73                 __op);
74             _LIBCPP_ASSERT_INTERNAL(__res, "unseq/seq should never try to allocate!");
75             return *std::move(__res);
76           });
77       return __result + (__last - __first);
78     } else if constexpr (__is_unsequenced_execution_policy_v<_RawExecutionPolicy> &&
79                          __has_random_access_iterator_category_or_concept<_ForwardIterator>::value &&
80                          __has_random_access_iterator_category_or_concept<_ForwardOutIterator>::value) {
81       return __pstl::__simd_transform(
82           __first,
83           __last - __first,
84           __result,
85           [&](__iter_reference<_ForwardIterator> __in_value, __iter_reference<_ForwardOutIterator> __out_value) {
86             __out_value = __op(__in_value);
87           });
88     } else {
89       return std::transform(__first, __last, __result, __op);
90     }
91   }
92 };
93 
94 template <class _Backend, class _RawExecutionPolicy>
95 struct __cpu_parallel_transform_binary {
96   template <class _Policy,
97             class _ForwardIterator1,
98             class _ForwardIterator2,
99             class _ForwardOutIterator,
100             class _BinaryOperation>
101   _LIBCPP_HIDE_FROM_ABI optional<_ForwardOutIterator>
operator__cpu_parallel_transform_binary102   operator()(_Policy&& __policy,
103              _ForwardIterator1 __first1,
104              _ForwardIterator1 __last1,
105              _ForwardIterator2 __first2,
106              _ForwardOutIterator __result,
107              _BinaryOperation __op) const noexcept {
108     if constexpr (__is_parallel_execution_policy_v<_RawExecutionPolicy> &&
109                   __has_random_access_iterator_category_or_concept<_ForwardIterator1>::value &&
110                   __has_random_access_iterator_category_or_concept<_ForwardIterator2>::value &&
111                   __has_random_access_iterator_category_or_concept<_ForwardOutIterator>::value) {
112       auto __res = __cpu_traits<_Backend>::__for_each(
113           __first1,
114           __last1,
115           [&__policy, __op, __first1, __first2, __result](
116               _ForwardIterator1 __brick_first, _ForwardIterator1 __brick_last) {
117             using _TransformBinaryUnseq =
118                 __pstl::__transform_binary<_Backend, __remove_parallel_policy_t<_RawExecutionPolicy>>;
119             return _TransformBinaryUnseq()(
120                 std::__remove_parallel_policy(__policy),
121                 __brick_first,
122                 __brick_last,
123                 __first2 + (__brick_first - __first1),
124                 __result + (__brick_first - __first1),
125                 __op);
126           });
127       if (!__res)
128         return nullopt;
129       return __result + (__last1 - __first1);
130     } else if constexpr (__is_unsequenced_execution_policy_v<_RawExecutionPolicy> &&
131                          __has_random_access_iterator_category_or_concept<_ForwardIterator1>::value &&
132                          __has_random_access_iterator_category_or_concept<_ForwardIterator2>::value &&
133                          __has_random_access_iterator_category_or_concept<_ForwardOutIterator>::value) {
134       return __pstl::__simd_transform(
135           __first1,
136           __last1 - __first1,
137           __first2,
138           __result,
139           [&](__iter_reference<_ForwardIterator1> __in1,
140               __iter_reference<_ForwardIterator2> __in2,
141               __iter_reference<_ForwardOutIterator> __out_value) { __out_value = __op(__in1, __in2); });
142     } else {
143       return std::transform(__first1, __last1, __first2, __result, __op);
144     }
145   }
146 };
147 
148 } // namespace __pstl
149 _LIBCPP_END_NAMESPACE_STD
150 
151 _LIBCPP_POP_MACROS
152 
153 #endif // _LIBCPP___PSTL_CPU_ALGOS_TRANSFORM_H
154