1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef _LIBCPP___PSTL_CPU_ALGOS_TRANSFORM_H
10 #define _LIBCPP___PSTL_CPU_ALGOS_TRANSFORM_H
11
12 #include <__algorithm/transform.h>
13 #include <__assert>
14 #include <__config>
15 #include <__iterator/concepts.h>
16 #include <__iterator/iterator_traits.h>
17 #include <__pstl/backend_fwd.h>
18 #include <__pstl/cpu_algos/cpu_traits.h>
19 #include <__type_traits/is_execution_policy.h>
20 #include <__utility/move.h>
21 #include <optional>
22
23 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
24 # pragma GCC system_header
25 #endif
26
27 _LIBCPP_PUSH_MACROS
28 #include <__undef_macros>
29
30 _LIBCPP_BEGIN_NAMESPACE_STD
31 namespace __pstl {
32
33 template <class _Iterator1, class _DifferenceType, class _Iterator2, class _Function>
34 _LIBCPP_HIDE_FROM_ABI _Iterator2
__simd_transform(_Iterator1 __first1,_DifferenceType __n,_Iterator2 __first2,_Function __f)35 __simd_transform(_Iterator1 __first1, _DifferenceType __n, _Iterator2 __first2, _Function __f) noexcept {
36 _PSTL_PRAGMA_SIMD
37 for (_DifferenceType __i = 0; __i < __n; ++__i)
38 __f(__first1[__i], __first2[__i]);
39 return __first2 + __n;
40 }
41
42 template <class _Iterator1, class _DifferenceType, class _Iterator2, class _Iterator3, class _Function>
__simd_transform(_Iterator1 __first1,_DifferenceType __n,_Iterator2 __first2,_Iterator3 __first3,_Function __f)43 _LIBCPP_HIDE_FROM_ABI _Iterator3 __simd_transform(
44 _Iterator1 __first1, _DifferenceType __n, _Iterator2 __first2, _Iterator3 __first3, _Function __f) noexcept {
45 _PSTL_PRAGMA_SIMD
46 for (_DifferenceType __i = 0; __i < __n; ++__i)
47 __f(__first1[__i], __first2[__i], __first3[__i]);
48 return __first3 + __n;
49 }
50
51 template <class _Backend, class _RawExecutionPolicy>
52 struct __cpu_parallel_transform {
53 template <class _Policy, class _ForwardIterator, class _ForwardOutIterator, class _UnaryOperation>
54 _LIBCPP_HIDE_FROM_ABI optional<_ForwardOutIterator>
operator__cpu_parallel_transform55 operator()(_Policy&& __policy,
56 _ForwardIterator __first,
57 _ForwardIterator __last,
58 _ForwardOutIterator __result,
59 _UnaryOperation __op) const noexcept {
60 if constexpr (__is_parallel_execution_policy_v<_RawExecutionPolicy> &&
61 __has_random_access_iterator_category_or_concept<_ForwardIterator>::value &&
62 __has_random_access_iterator_category_or_concept<_ForwardOutIterator>::value) {
63 __cpu_traits<_Backend>::__for_each(
64 __first,
65 __last,
66 [&__policy, __op, __first, __result](_ForwardIterator __brick_first, _ForwardIterator __brick_last) {
67 using _TransformUnseq = __pstl::__transform<_Backend, __remove_parallel_policy_t<_RawExecutionPolicy>>;
68 auto __res = _TransformUnseq()(
69 std::__remove_parallel_policy(__policy),
70 __brick_first,
71 __brick_last,
72 __result + (__brick_first - __first),
73 __op);
74 _LIBCPP_ASSERT_INTERNAL(__res, "unseq/seq should never try to allocate!");
75 return *std::move(__res);
76 });
77 return __result + (__last - __first);
78 } else if constexpr (__is_unsequenced_execution_policy_v<_RawExecutionPolicy> &&
79 __has_random_access_iterator_category_or_concept<_ForwardIterator>::value &&
80 __has_random_access_iterator_category_or_concept<_ForwardOutIterator>::value) {
81 return __pstl::__simd_transform(
82 __first,
83 __last - __first,
84 __result,
85 [&](__iter_reference<_ForwardIterator> __in_value, __iter_reference<_ForwardOutIterator> __out_value) {
86 __out_value = __op(__in_value);
87 });
88 } else {
89 return std::transform(__first, __last, __result, __op);
90 }
91 }
92 };
93
94 template <class _Backend, class _RawExecutionPolicy>
95 struct __cpu_parallel_transform_binary {
96 template <class _Policy,
97 class _ForwardIterator1,
98 class _ForwardIterator2,
99 class _ForwardOutIterator,
100 class _BinaryOperation>
101 _LIBCPP_HIDE_FROM_ABI optional<_ForwardOutIterator>
operator__cpu_parallel_transform_binary102 operator()(_Policy&& __policy,
103 _ForwardIterator1 __first1,
104 _ForwardIterator1 __last1,
105 _ForwardIterator2 __first2,
106 _ForwardOutIterator __result,
107 _BinaryOperation __op) const noexcept {
108 if constexpr (__is_parallel_execution_policy_v<_RawExecutionPolicy> &&
109 __has_random_access_iterator_category_or_concept<_ForwardIterator1>::value &&
110 __has_random_access_iterator_category_or_concept<_ForwardIterator2>::value &&
111 __has_random_access_iterator_category_or_concept<_ForwardOutIterator>::value) {
112 auto __res = __cpu_traits<_Backend>::__for_each(
113 __first1,
114 __last1,
115 [&__policy, __op, __first1, __first2, __result](
116 _ForwardIterator1 __brick_first, _ForwardIterator1 __brick_last) {
117 using _TransformBinaryUnseq =
118 __pstl::__transform_binary<_Backend, __remove_parallel_policy_t<_RawExecutionPolicy>>;
119 return _TransformBinaryUnseq()(
120 std::__remove_parallel_policy(__policy),
121 __brick_first,
122 __brick_last,
123 __first2 + (__brick_first - __first1),
124 __result + (__brick_first - __first1),
125 __op);
126 });
127 if (!__res)
128 return nullopt;
129 return __result + (__last1 - __first1);
130 } else if constexpr (__is_unsequenced_execution_policy_v<_RawExecutionPolicy> &&
131 __has_random_access_iterator_category_or_concept<_ForwardIterator1>::value &&
132 __has_random_access_iterator_category_or_concept<_ForwardIterator2>::value &&
133 __has_random_access_iterator_category_or_concept<_ForwardOutIterator>::value) {
134 return __pstl::__simd_transform(
135 __first1,
136 __last1 - __first1,
137 __first2,
138 __result,
139 [&](__iter_reference<_ForwardIterator1> __in1,
140 __iter_reference<_ForwardIterator2> __in2,
141 __iter_reference<_ForwardOutIterator> __out_value) { __out_value = __op(__in1, __in2); });
142 } else {
143 return std::transform(__first1, __last1, __first2, __result, __op);
144 }
145 }
146 };
147
148 } // namespace __pstl
149 _LIBCPP_END_NAMESPACE_STD
150
151 _LIBCPP_POP_MACROS
152
153 #endif // _LIBCPP___PSTL_CPU_ALGOS_TRANSFORM_H
154