xref: /freebsd/contrib/arm-optimized-routines/math/aarch64/advsimd/v_math.h (revision 6c05f3a74f30934ee60919cc97e16ec69b542b06)
1 /*
2  * Vector math abstractions.
3  *
4  * Copyright (c) 2019-2024, Arm Limited.
5  * SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception
6  */
7 
8 #ifndef _V_MATH_H
9 #define _V_MATH_H
10 
11 #if !__aarch64__
12 # error "Cannot build without AArch64"
13 #endif
14 
15 #define VPCS_ATTR __attribute__ ((aarch64_vector_pcs))
16 
17 #define V_NAME_F1(fun) _ZGVnN4v_##fun##f
18 #define V_NAME_D1(fun) _ZGVnN2v_##fun
19 #define V_NAME_F2(fun) _ZGVnN4vv_##fun##f
20 #define V_NAME_D2(fun) _ZGVnN2vv_##fun
21 #define V_NAME_F1_L1(fun) _ZGVnN4vl4_##fun##f
22 #define V_NAME_D1_L1(fun) _ZGVnN2vl8_##fun
23 
24 #if USE_GLIBC_ABI
25 
26 # define HALF_WIDTH_ALIAS_F1(fun)                                             \
27     float32x2_t VPCS_ATTR _ZGVnN2v_##fun##f (float32x2_t x)                   \
28     {                                                                         \
29       return vget_low_f32 (_ZGVnN4v_##fun##f (vcombine_f32 (x, x)));          \
30     }
31 
32 # define HALF_WIDTH_ALIAS_F2(fun)                                             \
33     float32x2_t VPCS_ATTR _ZGVnN2vv_##fun##f (float32x2_t x, float32x2_t y)   \
34     {                                                                         \
35       return vget_low_f32 (                                                   \
36 	  _ZGVnN4vv_##fun##f (vcombine_f32 (x, x), vcombine_f32 (y, y)));     \
37     }
38 
39 #else
40 # define HALF_WIDTH_ALIAS_F1(fun)
41 # define HALF_WIDTH_ALIAS_F2(fun)
42 #endif
43 
44 #include <stdint.h>
45 #include "math_config.h"
46 #include <arm_neon.h>
47 
48 /* Shorthand helpers for declaring constants.  */
49 #define V2(X)                                                                 \
50   {                                                                           \
51     X, X                                                                      \
52   }
53 #define V4(X)                                                                 \
54   {                                                                           \
55     X, X, X, X                                                                \
56   }
57 #define V8(X)                                                                 \
58   {                                                                           \
59     X, X, X, X, X, X, X, X                                                    \
60   }
61 
62 static inline int
63 v_any_u16h (uint16x4_t x)
64 {
65   return vget_lane_u64 (vreinterpret_u64_u16 (x), 0) != 0;
66 }
67 
68 static inline int
69 v_lanes32 (void)
70 {
71   return 4;
72 }
73 
74 static inline float32x4_t
75 v_f32 (float x)
76 {
77   return (float32x4_t) V4 (x);
78 }
79 static inline uint32x4_t
80 v_u32 (uint32_t x)
81 {
82   return (uint32x4_t) V4 (x);
83 }
84 static inline int32x4_t
85 v_s32 (int32_t x)
86 {
87   return (int32x4_t) V4 (x);
88 }
89 
90 /* true if any elements of a v_cond result is non-zero.  */
91 static inline int
92 v_any_u32 (uint32x4_t x)
93 {
94   /* assume elements in x are either 0 or -1u.  */
95   return vpaddd_u64 (vreinterpretq_u64_u32 (x)) != 0;
96 }
97 static inline int
98 v_any_u32h (uint32x2_t x)
99 {
100   return vget_lane_u64 (vreinterpret_u64_u32 (x), 0) != 0;
101 }
102 static inline float32x4_t
103 v_lookup_f32 (const float *tab, uint32x4_t idx)
104 {
105   return (float32x4_t){ tab[idx[0]], tab[idx[1]], tab[idx[2]], tab[idx[3]] };
106 }
107 static inline uint32x4_t
108 v_lookup_u32 (const uint32_t *tab, uint32x4_t idx)
109 {
110   return (uint32x4_t){ tab[idx[0]], tab[idx[1]], tab[idx[2]], tab[idx[3]] };
111 }
112 static inline float32x4_t
113 v_call_f32 (float (*f) (float), float32x4_t x, float32x4_t y, uint32x4_t p)
114 {
115   return (float32x4_t){ p[0] ? f (x[0]) : y[0], p[1] ? f (x[1]) : y[1],
116 			p[2] ? f (x[2]) : y[2], p[3] ? f (x[3]) : y[3] };
117 }
118 static inline float32x4_t
119 v_call2_f32 (float (*f) (float, float), float32x4_t x1, float32x4_t x2,
120 	     float32x4_t y, uint32x4_t p)
121 {
122   return (float32x4_t){ p[0] ? f (x1[0], x2[0]) : y[0],
123 			p[1] ? f (x1[1], x2[1]) : y[1],
124 			p[2] ? f (x1[2], x2[2]) : y[2],
125 			p[3] ? f (x1[3], x2[3]) : y[3] };
126 }
127 static inline float32x4_t
128 v_zerofy_f32 (float32x4_t x, uint32x4_t mask)
129 {
130   return vreinterpretq_f32_u32 (vbicq_u32 (vreinterpretq_u32_f32 (x), mask));
131 }
132 
133 static inline int
134 v_lanes64 (void)
135 {
136   return 2;
137 }
138 static inline float64x2_t
139 v_f64 (double x)
140 {
141   return (float64x2_t) V2 (x);
142 }
143 static inline uint64x2_t
144 v_u64 (uint64_t x)
145 {
146   return (uint64x2_t) V2 (x);
147 }
148 static inline int64x2_t
149 v_s64 (int64_t x)
150 {
151   return (int64x2_t) V2 (x);
152 }
153 
154 /* true if any elements of a v_cond result is non-zero.  */
155 static inline int
156 v_any_u64 (uint64x2_t x)
157 {
158   /* assume elements in x are either 0 or -1u.  */
159   return vpaddd_u64 (x) != 0;
160 }
161 static inline float64x2_t
162 v_lookup_f64 (const double *tab, uint64x2_t idx)
163 {
164   return (float64x2_t){ tab[idx[0]], tab[idx[1]] };
165 }
166 static inline uint64x2_t
167 v_lookup_u64 (const uint64_t *tab, uint64x2_t idx)
168 {
169   return (uint64x2_t){ tab[idx[0]], tab[idx[1]] };
170 }
171 static inline float64x2_t
172 v_call_f64 (double (*f) (double), float64x2_t x, float64x2_t y, uint64x2_t p)
173 {
174   double p1 = p[1];
175   double x1 = x[1];
176   if (likely (p[0]))
177     y[0] = f (x[0]);
178   if (likely (p1))
179     y[1] = f (x1);
180   return y;
181 }
182 
183 static inline float64x2_t
184 v_call2_f64 (double (*f) (double, double), float64x2_t x1, float64x2_t x2,
185 	     float64x2_t y, uint64x2_t p)
186 {
187   double p1 = p[1];
188   double x1h = x1[1];
189   double x2h = x2[1];
190   if (likely (p[0]))
191     y[0] = f (x1[0], x2[0]);
192   if (likely (p1))
193     y[1] = f (x1h, x2h);
194   return y;
195 }
196 static inline float64x2_t
197 v_zerofy_f64 (float64x2_t x, uint64x2_t mask)
198 {
199   return vreinterpretq_f64_u64 (vbicq_u64 (vreinterpretq_u64_f64 (x), mask));
200 }
201 
202 #endif
203