1 /* 2 * Wrapper functions for SVE ACLE. 3 * 4 * Copyright (c) 2019-2023, Arm Limited. 5 * SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception 6 */ 7 8 #ifndef SV_MATH_H 9 #define SV_MATH_H 10 11 #ifndef WANT_VMATH 12 /* Enable the build of vector math code. */ 13 #define WANT_VMATH 1 14 #endif 15 #if WANT_VMATH 16 17 #if WANT_SVE_MATH 18 #define SV_SUPPORTED 1 19 20 #include <arm_sve.h> 21 #include <stdbool.h> 22 23 #include "math_config.h" 24 25 typedef float f32_t; 26 typedef uint32_t u32_t; 27 typedef int32_t s32_t; 28 typedef double f64_t; 29 typedef uint64_t u64_t; 30 typedef int64_t s64_t; 31 32 typedef svfloat64_t sv_f64_t; 33 typedef svuint64_t sv_u64_t; 34 typedef svint64_t sv_s64_t; 35 36 typedef svfloat32_t sv_f32_t; 37 typedef svuint32_t sv_u32_t; 38 typedef svint32_t sv_s32_t; 39 40 /* Double precision. */ 41 static inline sv_s64_t 42 sv_s64 (s64_t x) 43 { 44 return svdup_n_s64 (x); 45 } 46 47 static inline sv_u64_t 48 sv_u64 (u64_t x) 49 { 50 return svdup_n_u64 (x); 51 } 52 53 static inline sv_f64_t 54 sv_f64 (f64_t x) 55 { 56 return svdup_n_f64 (x); 57 } 58 59 static inline sv_f64_t 60 sv_fma_f64_x (svbool_t pg, sv_f64_t x, sv_f64_t y, sv_f64_t z) 61 { 62 return svmla_f64_x (pg, z, x, y); 63 } 64 65 /* res = z + x * y with x scalar. */ 66 static inline sv_f64_t 67 sv_fma_n_f64_x (svbool_t pg, f64_t x, sv_f64_t y, sv_f64_t z) 68 { 69 return svmla_n_f64_x (pg, z, y, x); 70 } 71 72 static inline sv_s64_t 73 sv_as_s64_u64 (sv_u64_t x) 74 { 75 return svreinterpret_s64_u64 (x); 76 } 77 78 static inline sv_u64_t 79 sv_as_u64_f64 (sv_f64_t x) 80 { 81 return svreinterpret_u64_f64 (x); 82 } 83 84 static inline sv_f64_t 85 sv_as_f64_u64 (sv_u64_t x) 86 { 87 return svreinterpret_f64_u64 (x); 88 } 89 90 static inline sv_f64_t 91 sv_to_f64_s64_x (svbool_t pg, sv_s64_t s) 92 { 93 return svcvt_f64_x (pg, s); 94 } 95 96 static inline sv_f64_t 97 sv_call_f64 (f64_t (*f) (f64_t), sv_f64_t x, sv_f64_t y, svbool_t cmp) 98 { 99 svbool_t p = svpfirst (cmp, svpfalse ()); 100 while (svptest_any (cmp, p)) 101 { 102 f64_t elem = svclastb_n_f64 (p, 0, x); 103 elem = (*f) (elem); 104 sv_f64_t y2 = svdup_n_f64 (elem); 105 y = svsel_f64 (p, y2, y); 106 p = svpnext_b64 (cmp, p); 107 } 108 return y; 109 } 110 111 static inline sv_f64_t 112 sv_call2_f64 (f64_t (*f) (f64_t, f64_t), sv_f64_t x1, sv_f64_t x2, sv_f64_t y, 113 svbool_t cmp) 114 { 115 svbool_t p = svpfirst (cmp, svpfalse ()); 116 while (svptest_any (cmp, p)) 117 { 118 f64_t elem1 = svclastb_n_f64 (p, 0, x1); 119 f64_t elem2 = svclastb_n_f64 (p, 0, x2); 120 f64_t ret = (*f) (elem1, elem2); 121 sv_f64_t y2 = svdup_n_f64 (ret); 122 y = svsel_f64 (p, y2, y); 123 p = svpnext_b64 (cmp, p); 124 } 125 return y; 126 } 127 128 /* Load array of uint64_t into svuint64_t. */ 129 static inline sv_u64_t 130 sv_lookup_u64_x (svbool_t pg, const u64_t *tab, sv_u64_t idx) 131 { 132 return svld1_gather_u64index_u64 (pg, tab, idx); 133 } 134 135 /* Load array of double into svfloat64_t. */ 136 static inline sv_f64_t 137 sv_lookup_f64_x (svbool_t pg, const f64_t *tab, sv_u64_t idx) 138 { 139 return svld1_gather_u64index_f64 (pg, tab, idx); 140 } 141 142 static inline sv_u64_t 143 sv_mod_n_u64_x (svbool_t pg, sv_u64_t x, u64_t y) 144 { 145 sv_u64_t q = svdiv_n_u64_x (pg, x, y); 146 return svmls_n_u64_x (pg, x, q, y); 147 } 148 149 /* Single precision. */ 150 static inline sv_s32_t 151 sv_s32 (s32_t x) 152 { 153 return svdup_n_s32 (x); 154 } 155 156 static inline sv_u32_t 157 sv_u32 (u32_t x) 158 { 159 return svdup_n_u32 (x); 160 } 161 162 static inline sv_f32_t 163 sv_f32 (f32_t x) 164 { 165 return svdup_n_f32 (x); 166 } 167 168 static inline sv_f32_t 169 sv_fma_f32_x (svbool_t pg, sv_f32_t x, sv_f32_t y, sv_f32_t z) 170 { 171 return svmla_f32_x (pg, z, x, y); 172 } 173 174 /* res = z + x * y with x scalar. */ 175 static inline sv_f32_t 176 sv_fma_n_f32_x (svbool_t pg, f32_t x, sv_f32_t y, sv_f32_t z) 177 { 178 return svmla_n_f32_x (pg, z, y, x); 179 } 180 181 static inline sv_u32_t 182 sv_as_u32_f32 (sv_f32_t x) 183 { 184 return svreinterpret_u32_f32 (x); 185 } 186 187 static inline sv_f32_t 188 sv_as_f32_u32 (sv_u32_t x) 189 { 190 return svreinterpret_f32_u32 (x); 191 } 192 193 static inline sv_s32_t 194 sv_as_s32_u32 (sv_u32_t x) 195 { 196 return svreinterpret_s32_u32 (x); 197 } 198 199 static inline sv_f32_t 200 sv_to_f32_s32_x (svbool_t pg, sv_s32_t s) 201 { 202 return svcvt_f32_x (pg, s); 203 } 204 205 static inline sv_s32_t 206 sv_to_s32_f32_x (svbool_t pg, sv_f32_t x) 207 { 208 return svcvt_s32_f32_x (pg, x); 209 } 210 211 static inline sv_f32_t 212 sv_call_f32 (f32_t (*f) (f32_t), sv_f32_t x, sv_f32_t y, svbool_t cmp) 213 { 214 svbool_t p = svpfirst (cmp, svpfalse ()); 215 while (svptest_any (cmp, p)) 216 { 217 f32_t elem = svclastb_n_f32 (p, 0, x); 218 elem = (*f) (elem); 219 sv_f32_t y2 = svdup_n_f32 (elem); 220 y = svsel_f32 (p, y2, y); 221 p = svpnext_b32 (cmp, p); 222 } 223 return y; 224 } 225 226 static inline sv_f32_t 227 sv_call2_f32 (f32_t (*f) (f32_t, f32_t), sv_f32_t x1, sv_f32_t x2, sv_f32_t y, 228 svbool_t cmp) 229 { 230 svbool_t p = svpfirst (cmp, svpfalse ()); 231 while (svptest_any (cmp, p)) 232 { 233 f32_t elem1 = svclastb_n_f32 (p, 0, x1); 234 f32_t elem2 = svclastb_n_f32 (p, 0, x2); 235 f32_t ret = (*f) (elem1, elem2); 236 sv_f32_t y2 = svdup_n_f32 (ret); 237 y = svsel_f32 (p, y2, y); 238 p = svpnext_b32 (cmp, p); 239 } 240 return y; 241 } 242 243 #endif 244 #endif 245 #endif 246