1 /*
2 * Wrapper functions for SVE ACLE.
3 *
4 * Copyright (c) 2019-2024, Arm Limited.
5 * SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception
6 */
7
8 #ifndef SV_MATH_H
9 #define SV_MATH_H
10
11 /* Enable SVE in this translation unit. Note, because this is 'pushed' in
12 clang, any file including sv_math.h will have to pop it back off again by
13 ending the source file with CLOSE_SVE_ATTR. It is important that sv_math.h
14 is included first so that all functions have the target attribute. */
15 #ifdef __clang__
16 # pragma clang attribute push(__attribute__((target("sve"))), \
17 apply_to = any(function))
18 # define CLOSE_SVE_ATTR _Pragma("clang attribute pop")
19 #else
20 # pragma GCC target("+sve")
21 # define CLOSE_SVE_ATTR
22 #endif
23
24 #include <arm_sve.h>
25 #include <stdbool.h>
26
27 #include "math_config.h"
28
29 #define SV_NAME_F1(fun) _ZGVsMxv_##fun##f
30 #define SV_NAME_D1(fun) _ZGVsMxv_##fun
31 #define SV_NAME_F2(fun) _ZGVsMxvv_##fun##f
32 #define SV_NAME_D2(fun) _ZGVsMxvv_##fun
33 #define SV_NAME_F1_L1(fun) _ZGVsMxvl4_##fun##f
34 #define SV_NAME_D1_L1(fun) _ZGVsMxvl8_##fun
35 #define SV_NAME_F1_L2(fun) _ZGVsMxvl4l4_##fun##f
36
37 /* Double precision. */
38 static inline svint64_t
sv_s64(int64_t x)39 sv_s64 (int64_t x)
40 {
41 return svdup_s64 (x);
42 }
43
44 static inline svuint64_t
sv_u64(uint64_t x)45 sv_u64 (uint64_t x)
46 {
47 return svdup_u64 (x);
48 }
49
50 static inline svfloat64_t
sv_f64(double x)51 sv_f64 (double x)
52 {
53 return svdup_f64 (x);
54 }
55
56 static inline svfloat64_t
sv_call_f64(double (* f)(double),svfloat64_t x,svfloat64_t y,svbool_t cmp)57 sv_call_f64 (double (*f) (double), svfloat64_t x, svfloat64_t y, svbool_t cmp)
58 {
59 svbool_t p = svpfirst (cmp, svpfalse ());
60 while (svptest_any (cmp, p))
61 {
62 double elem = svclastb (p, 0, x);
63 elem = (*f) (elem);
64 svfloat64_t y2 = sv_f64 (elem);
65 y = svsel (p, y2, y);
66 p = svpnext_b64 (cmp, p);
67 }
68 return y;
69 }
70
71 static inline svfloat64_t
sv_call2_f64(double (* f)(double,double),svfloat64_t x1,svfloat64_t x2,svfloat64_t y,svbool_t cmp)72 sv_call2_f64 (double (*f) (double, double), svfloat64_t x1, svfloat64_t x2,
73 svfloat64_t y, svbool_t cmp)
74 {
75 svbool_t p = svpfirst (cmp, svpfalse ());
76 while (svptest_any (cmp, p))
77 {
78 double elem1 = svclastb (p, 0, x1);
79 double elem2 = svclastb (p, 0, x2);
80 double ret = (*f) (elem1, elem2);
81 svfloat64_t y2 = sv_f64 (ret);
82 y = svsel (p, y2, y);
83 p = svpnext_b64 (cmp, p);
84 }
85 return y;
86 }
87
88 static inline svuint64_t
sv_mod_n_u64_x(svbool_t pg,svuint64_t x,uint64_t y)89 sv_mod_n_u64_x (svbool_t pg, svuint64_t x, uint64_t y)
90 {
91 svuint64_t q = svdiv_x (pg, x, y);
92 return svmls_x (pg, x, q, y);
93 }
94
95 /* Single precision. */
96 static inline svint32_t
sv_s32(int32_t x)97 sv_s32 (int32_t x)
98 {
99 return svdup_s32 (x);
100 }
101
102 static inline svuint32_t
sv_u32(uint32_t x)103 sv_u32 (uint32_t x)
104 {
105 return svdup_u32 (x);
106 }
107
108 static inline svfloat32_t
sv_f32(float x)109 sv_f32 (float x)
110 {
111 return svdup_f32 (x);
112 }
113
114 static inline svfloat32_t
sv_call_f32(float (* f)(float),svfloat32_t x,svfloat32_t y,svbool_t cmp)115 sv_call_f32 (float (*f) (float), svfloat32_t x, svfloat32_t y, svbool_t cmp)
116 {
117 svbool_t p = svpfirst (cmp, svpfalse ());
118 while (svptest_any (cmp, p))
119 {
120 float elem = svclastb (p, 0, x);
121 elem = (*f) (elem);
122 svfloat32_t y2 = sv_f32 (elem);
123 y = svsel (p, y2, y);
124 p = svpnext_b32 (cmp, p);
125 }
126 return y;
127 }
128
129 static inline svfloat32_t
sv_call2_f32(float (* f)(float,float),svfloat32_t x1,svfloat32_t x2,svfloat32_t y,svbool_t cmp)130 sv_call2_f32 (float (*f) (float, float), svfloat32_t x1, svfloat32_t x2,
131 svfloat32_t y, svbool_t cmp)
132 {
133 svbool_t p = svpfirst (cmp, svpfalse ());
134 while (svptest_any (cmp, p))
135 {
136 float elem1 = svclastb (p, 0, x1);
137 float elem2 = svclastb (p, 0, x2);
138 float ret = (*f) (elem1, elem2);
139 svfloat32_t y2 = sv_f32 (ret);
140 y = svsel (p, y2, y);
141 p = svpnext_b32 (cmp, p);
142 }
143 return y;
144 }
145 #endif
146