xref: /freebsd/contrib/arm-optimized-routines/math/aarch64/sve/sv_math.h (revision f3087bef11543b42e0d69b708f367097a4118d24)
1 /*
2  * Wrapper functions for SVE ACLE.
3  *
4  * Copyright (c) 2019-2024, Arm Limited.
5  * SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception
6  */
7 
8 #ifndef SV_MATH_H
9 #define SV_MATH_H
10 
11 /* Enable SVE in this translation unit. Note, because this is 'pushed' in
12    clang, any file including sv_math.h will have to pop it back off again by
13    ending the source file with CLOSE_SVE_ATTR. It is important that sv_math.h
14    is included first so that all functions have the target attribute.  */
15 #ifdef __clang__
16 # pragma clang attribute push(__attribute__((target("sve"))),                \
17 			       apply_to = any(function))
18 # define CLOSE_SVE_ATTR _Pragma("clang attribute pop")
19 #else
20 # pragma GCC target("+sve")
21 # define CLOSE_SVE_ATTR
22 #endif
23 
24 #include <arm_sve.h>
25 #include <stdbool.h>
26 
27 #include "math_config.h"
28 
29 #define SV_NAME_F1(fun) _ZGVsMxv_##fun##f
30 #define SV_NAME_D1(fun) _ZGVsMxv_##fun
31 #define SV_NAME_F2(fun) _ZGVsMxvv_##fun##f
32 #define SV_NAME_D2(fun) _ZGVsMxvv_##fun
33 #define SV_NAME_F1_L1(fun) _ZGVsMxvl4_##fun##f
34 #define SV_NAME_D1_L1(fun) _ZGVsMxvl8_##fun
35 #define SV_NAME_F1_L2(fun) _ZGVsMxvl4l4_##fun##f
36 
37 /* Double precision.  */
38 static inline svint64_t
sv_s64(int64_t x)39 sv_s64 (int64_t x)
40 {
41   return svdup_s64 (x);
42 }
43 
44 static inline svuint64_t
sv_u64(uint64_t x)45 sv_u64 (uint64_t x)
46 {
47   return svdup_u64 (x);
48 }
49 
50 static inline svfloat64_t
sv_f64(double x)51 sv_f64 (double x)
52 {
53   return svdup_f64 (x);
54 }
55 
56 static inline svfloat64_t
sv_call_f64(double (* f)(double),svfloat64_t x,svfloat64_t y,svbool_t cmp)57 sv_call_f64 (double (*f) (double), svfloat64_t x, svfloat64_t y, svbool_t cmp)
58 {
59   svbool_t p = svpfirst (cmp, svpfalse ());
60   while (svptest_any (cmp, p))
61     {
62       double elem = svclastb (p, 0, x);
63       elem = (*f) (elem);
64       svfloat64_t y2 = sv_f64 (elem);
65       y = svsel (p, y2, y);
66       p = svpnext_b64 (cmp, p);
67     }
68   return y;
69 }
70 
71 static inline svfloat64_t
sv_call2_f64(double (* f)(double,double),svfloat64_t x1,svfloat64_t x2,svfloat64_t y,svbool_t cmp)72 sv_call2_f64 (double (*f) (double, double), svfloat64_t x1, svfloat64_t x2,
73 	      svfloat64_t y, svbool_t cmp)
74 {
75   svbool_t p = svpfirst (cmp, svpfalse ());
76   while (svptest_any (cmp, p))
77     {
78       double elem1 = svclastb (p, 0, x1);
79       double elem2 = svclastb (p, 0, x2);
80       double ret = (*f) (elem1, elem2);
81       svfloat64_t y2 = sv_f64 (ret);
82       y = svsel (p, y2, y);
83       p = svpnext_b64 (cmp, p);
84     }
85   return y;
86 }
87 
88 static inline svuint64_t
sv_mod_n_u64_x(svbool_t pg,svuint64_t x,uint64_t y)89 sv_mod_n_u64_x (svbool_t pg, svuint64_t x, uint64_t y)
90 {
91   svuint64_t q = svdiv_x (pg, x, y);
92   return svmls_x (pg, x, q, y);
93 }
94 
95 /* Single precision.  */
96 static inline svint32_t
sv_s32(int32_t x)97 sv_s32 (int32_t x)
98 {
99   return svdup_s32 (x);
100 }
101 
102 static inline svuint32_t
sv_u32(uint32_t x)103 sv_u32 (uint32_t x)
104 {
105   return svdup_u32 (x);
106 }
107 
108 static inline svfloat32_t
sv_f32(float x)109 sv_f32 (float x)
110 {
111   return svdup_f32 (x);
112 }
113 
114 static inline svfloat32_t
sv_call_f32(float (* f)(float),svfloat32_t x,svfloat32_t y,svbool_t cmp)115 sv_call_f32 (float (*f) (float), svfloat32_t x, svfloat32_t y, svbool_t cmp)
116 {
117   svbool_t p = svpfirst (cmp, svpfalse ());
118   while (svptest_any (cmp, p))
119     {
120       float elem = svclastb (p, 0, x);
121       elem = (*f) (elem);
122       svfloat32_t y2 = sv_f32 (elem);
123       y = svsel (p, y2, y);
124       p = svpnext_b32 (cmp, p);
125     }
126   return y;
127 }
128 
129 static inline svfloat32_t
sv_call2_f32(float (* f)(float,float),svfloat32_t x1,svfloat32_t x2,svfloat32_t y,svbool_t cmp)130 sv_call2_f32 (float (*f) (float, float), svfloat32_t x1, svfloat32_t x2,
131 	      svfloat32_t y, svbool_t cmp)
132 {
133   svbool_t p = svpfirst (cmp, svpfalse ());
134   while (svptest_any (cmp, p))
135     {
136       float elem1 = svclastb (p, 0, x1);
137       float elem2 = svclastb (p, 0, x2);
138       float ret = (*f) (elem1, elem2);
139       svfloat32_t y2 = sv_f32 (ret);
140       y = svsel (p, y2, y);
141       p = svpnext_b32 (cmp, p);
142     }
143   return y;
144 }
145 #endif
146