1 /*===------------ avx512bf16intrin.h - AVX512_BF16 intrinsics --------------=== 2 * 3 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 * See https://llvm.org/LICENSE.txt for license information. 5 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 * 7 *===-----------------------------------------------------------------------=== 8 */ 9 #ifndef __IMMINTRIN_H 10 #error "Never use <avx512bf16intrin.h> directly; include <immintrin.h> instead." 11 #endif 12 13 #ifdef __SSE2__ 14 15 #ifndef __AVX512BF16INTRIN_H 16 #define __AVX512BF16INTRIN_H 17 18 typedef __bf16 __v32bf __attribute__((__vector_size__(64), __aligned__(64))); 19 typedef __bf16 __m512bh __attribute__((__vector_size__(64), __aligned__(64))); 20 typedef __bf16 __bfloat16 __attribute__((deprecated("use __bf16 instead"))); 21 22 #define __DEFAULT_FN_ATTRS512 \ 23 __attribute__((__always_inline__, __nodebug__, __target__("avx512bf16"), \ 24 __min_vector_width__(512))) 25 #define __DEFAULT_FN_ATTRS \ 26 __attribute__((__always_inline__, __nodebug__, __target__("avx512bf16"))) 27 28 /// Convert One BF16 Data to One Single Float Data. 29 /// 30 /// \headerfile <x86intrin.h> 31 /// 32 /// This intrinsic does not correspond to a specific instruction. 33 /// 34 /// \param __A 35 /// A bfloat data. 36 /// \returns A float data whose sign field and exponent field keep unchanged, 37 /// and fraction field is extended to 23 bits. 38 static __inline__ float __DEFAULT_FN_ATTRS _mm_cvtsbh_ss(__bf16 __A) { 39 return __builtin_ia32_cvtsbf162ss_32(__A); 40 } 41 42 /// Convert Two Packed Single Data to One Packed BF16 Data. 43 /// 44 /// \headerfile <x86intrin.h> 45 /// 46 /// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions. 47 /// 48 /// \param __A 49 /// A 512-bit vector of [16 x float]. 50 /// \param __B 51 /// A 512-bit vector of [16 x float]. 52 /// \returns A 512-bit vector of [32 x bfloat] whose lower 256 bits come from 53 /// conversion of __B, and higher 256 bits come from conversion of __A. 54 static __inline__ __m512bh __DEFAULT_FN_ATTRS512 55 _mm512_cvtne2ps_pbh(__m512 __A, __m512 __B) { 56 return (__m512bh)__builtin_ia32_cvtne2ps2bf16_512((__v16sf) __A, 57 (__v16sf) __B); 58 } 59 60 /// Convert Two Packed Single Data to One Packed BF16 Data. 61 /// 62 /// \headerfile <x86intrin.h> 63 /// 64 /// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions. 65 /// 66 /// \param __A 67 /// A 512-bit vector of [16 x float]. 68 /// \param __B 69 /// A 512-bit vector of [16 x float]. 70 /// \param __W 71 /// A 512-bit vector of [32 x bfloat]. 72 /// \param __U 73 /// A 32-bit mask value specifying what is chosen for each element. 74 /// A 1 means conversion of __A or __B. A 0 means element from __W. 75 /// \returns A 512-bit vector of [32 x bfloat] whose lower 256 bits come from 76 /// conversion of __B, and higher 256 bits come from conversion of __A. 77 static __inline__ __m512bh __DEFAULT_FN_ATTRS512 78 _mm512_mask_cvtne2ps_pbh(__m512bh __W, __mmask32 __U, __m512 __A, __m512 __B) { 79 return (__m512bh)__builtin_ia32_selectpbf_512((__mmask32)__U, 80 (__v32bf)_mm512_cvtne2ps_pbh(__A, __B), 81 (__v32bf)__W); 82 } 83 84 /// Convert Two Packed Single Data to One Packed BF16 Data. 85 /// 86 /// \headerfile <x86intrin.h> 87 /// 88 /// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions. 89 /// 90 /// \param __A 91 /// A 512-bit vector of [16 x float]. 92 /// \param __B 93 /// A 512-bit vector of [16 x float]. 94 /// \param __U 95 /// A 32-bit mask value specifying what is chosen for each element. 96 /// A 1 means conversion of __A or __B. A 0 means element is zero. 97 /// \returns A 512-bit vector of [32 x bfloat] whose lower 256 bits come from 98 /// conversion of __B, and higher 256 bits come from conversion of __A. 99 static __inline__ __m512bh __DEFAULT_FN_ATTRS512 100 _mm512_maskz_cvtne2ps_pbh(__mmask32 __U, __m512 __A, __m512 __B) { 101 return (__m512bh)__builtin_ia32_selectpbf_512((__mmask32)__U, 102 (__v32bf)_mm512_cvtne2ps_pbh(__A, __B), 103 (__v32bf)_mm512_setzero_si512()); 104 } 105 106 /// Convert Packed Single Data to Packed BF16 Data. 107 /// 108 /// \headerfile <x86intrin.h> 109 /// 110 /// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions. 111 /// 112 /// \param __A 113 /// A 512-bit vector of [16 x float]. 114 /// \returns A 256-bit vector of [16 x bfloat] come from conversion of __A. 115 static __inline__ __m256bh __DEFAULT_FN_ATTRS512 116 _mm512_cvtneps_pbh(__m512 __A) { 117 return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A, 118 (__v16bf)_mm256_undefined_si256(), 119 (__mmask16)-1); 120 } 121 122 /// Convert Packed Single Data to Packed BF16 Data. 123 /// 124 /// \headerfile <x86intrin.h> 125 /// 126 /// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions. 127 /// 128 /// \param __A 129 /// A 512-bit vector of [16 x float]. 130 /// \param __W 131 /// A 256-bit vector of [16 x bfloat]. 132 /// \param __U 133 /// A 16-bit mask value specifying what is chosen for each element. 134 /// A 1 means conversion of __A. A 0 means element from __W. 135 /// \returns A 256-bit vector of [16 x bfloat] come from conversion of __A. 136 static __inline__ __m256bh __DEFAULT_FN_ATTRS512 137 _mm512_mask_cvtneps_pbh(__m256bh __W, __mmask16 __U, __m512 __A) { 138 return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A, 139 (__v16bf)__W, 140 (__mmask16)__U); 141 } 142 143 /// Convert Packed Single Data to Packed BF16 Data. 144 /// 145 /// \headerfile <x86intrin.h> 146 /// 147 /// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions. 148 /// 149 /// \param __A 150 /// A 512-bit vector of [16 x float]. 151 /// \param __U 152 /// A 16-bit mask value specifying what is chosen for each element. 153 /// A 1 means conversion of __A. A 0 means element is zero. 154 /// \returns A 256-bit vector of [16 x bfloat] come from conversion of __A. 155 static __inline__ __m256bh __DEFAULT_FN_ATTRS512 156 _mm512_maskz_cvtneps_pbh(__mmask16 __U, __m512 __A) { 157 return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A, 158 (__v16bf)_mm256_setzero_si256(), 159 (__mmask16)__U); 160 } 161 162 /// Dot Product of BF16 Pairs Accumulated into Packed Single Precision. 163 /// 164 /// \headerfile <x86intrin.h> 165 /// 166 /// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions. 167 /// 168 /// \param __A 169 /// A 512-bit vector of [32 x bfloat]. 170 /// \param __B 171 /// A 512-bit vector of [32 x bfloat]. 172 /// \param __D 173 /// A 512-bit vector of [16 x float]. 174 /// \returns A 512-bit vector of [16 x float] comes from Dot Product of 175 /// __A, __B and __D 176 static __inline__ __m512 __DEFAULT_FN_ATTRS512 177 _mm512_dpbf16_ps(__m512 __D, __m512bh __A, __m512bh __B) { 178 return (__m512)__builtin_ia32_dpbf16ps_512((__v16sf) __D, 179 (__v32bf) __A, 180 (__v32bf) __B); 181 } 182 183 /// Dot Product of BF16 Pairs Accumulated into Packed Single Precision. 184 /// 185 /// \headerfile <x86intrin.h> 186 /// 187 /// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions. 188 /// 189 /// \param __A 190 /// A 512-bit vector of [32 x bfloat]. 191 /// \param __B 192 /// A 512-bit vector of [32 x bfloat]. 193 /// \param __D 194 /// A 512-bit vector of [16 x float]. 195 /// \param __U 196 /// A 16-bit mask value specifying what is chosen for each element. 197 /// A 1 means __A and __B's dot product accumulated with __D. A 0 means __D. 198 /// \returns A 512-bit vector of [16 x float] comes from Dot Product of 199 /// __A, __B and __D 200 static __inline__ __m512 __DEFAULT_FN_ATTRS512 201 _mm512_mask_dpbf16_ps(__m512 __D, __mmask16 __U, __m512bh __A, __m512bh __B) { 202 return (__m512)__builtin_ia32_selectps_512((__mmask16)__U, 203 (__v16sf)_mm512_dpbf16_ps(__D, __A, __B), 204 (__v16sf)__D); 205 } 206 207 /// Dot Product of BF16 Pairs Accumulated into Packed Single Precision. 208 /// 209 /// \headerfile <x86intrin.h> 210 /// 211 /// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions. 212 /// 213 /// \param __A 214 /// A 512-bit vector of [32 x bfloat]. 215 /// \param __B 216 /// A 512-bit vector of [32 x bfloat]. 217 /// \param __D 218 /// A 512-bit vector of [16 x float]. 219 /// \param __U 220 /// A 16-bit mask value specifying what is chosen for each element. 221 /// A 1 means __A and __B's dot product accumulated with __D. A 0 means 0. 222 /// \returns A 512-bit vector of [16 x float] comes from Dot Product of 223 /// __A, __B and __D 224 static __inline__ __m512 __DEFAULT_FN_ATTRS512 225 _mm512_maskz_dpbf16_ps(__mmask16 __U, __m512 __D, __m512bh __A, __m512bh __B) { 226 return (__m512)__builtin_ia32_selectps_512((__mmask16)__U, 227 (__v16sf)_mm512_dpbf16_ps(__D, __A, __B), 228 (__v16sf)_mm512_setzero_si512()); 229 } 230 231 /// Convert Packed BF16 Data to Packed float Data. 232 /// 233 /// \headerfile <x86intrin.h> 234 /// 235 /// \param __A 236 /// A 256-bit vector of [16 x bfloat]. 237 /// \returns A 512-bit vector of [16 x float] come from conversion of __A 238 static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) { 239 return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32( 240 (__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16)); 241 } 242 243 /// Convert Packed BF16 Data to Packed float Data using zeroing mask. 244 /// 245 /// \headerfile <x86intrin.h> 246 /// 247 /// \param __U 248 /// A 16-bit mask. Elements are zeroed out when the corresponding mask 249 /// bit is not set. 250 /// \param __A 251 /// A 256-bit vector of [16 x bfloat]. 252 /// \returns A 512-bit vector of [16 x float] come from conversion of __A 253 static __inline__ __m512 __DEFAULT_FN_ATTRS512 254 _mm512_maskz_cvtpbh_ps(__mmask16 __U, __m256bh __A) { 255 return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32( 256 (__m512i)_mm512_maskz_cvtepi16_epi32((__mmask16)__U, (__m256i)__A), 16)); 257 } 258 259 /// Convert Packed BF16 Data to Packed float Data using merging mask. 260 /// 261 /// \headerfile <x86intrin.h> 262 /// 263 /// \param __S 264 /// A 512-bit vector of [16 x float]. Elements are copied from __S when 265 /// the corresponding mask bit is not set. 266 /// \param __U 267 /// A 16-bit mask. 268 /// \param __A 269 /// A 256-bit vector of [16 x bfloat]. 270 /// \returns A 512-bit vector of [16 x float] come from conversion of __A 271 static __inline__ __m512 __DEFAULT_FN_ATTRS512 272 _mm512_mask_cvtpbh_ps(__m512 __S, __mmask16 __U, __m256bh __A) { 273 return _mm512_castsi512_ps((__m512i)_mm512_mask_slli_epi32( 274 (__m512i)__S, (__mmask16)__U, 275 (__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16)); 276 } 277 278 #undef __DEFAULT_FN_ATTRS 279 #undef __DEFAULT_FN_ATTRS512 280 281 #endif 282 #endif 283