1 /*===------------ avx512bf16intrin.h - AVX512_BF16 intrinsics --------------=== 2 * 3 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 * See https://llvm.org/LICENSE.txt for license information. 5 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 * 7 *===-----------------------------------------------------------------------=== 8 */ 9 #ifndef __IMMINTRIN_H 10 #error "Never use <avx512bf16intrin.h> directly; include <immintrin.h> instead." 11 #endif 12 13 #ifndef __AVX512BF16INTRIN_H 14 #define __AVX512BF16INTRIN_H 15 16 typedef short __m512bh __attribute__((__vector_size__(64), __aligned__(64))); 17 typedef short __m256bh __attribute__((__vector_size__(32), __aligned__(32))); 18 typedef unsigned short __bfloat16; 19 20 #define __DEFAULT_FN_ATTRS512 \ 21 __attribute__((__always_inline__, __nodebug__, __target__("avx512bf16"), \ 22 __min_vector_width__(512))) 23 #define __DEFAULT_FN_ATTRS \ 24 __attribute__((__always_inline__, __nodebug__, __target__("avx512bf16"))) 25 26 /// Convert One BF16 Data to One Single Float Data. 27 /// 28 /// \headerfile <x86intrin.h> 29 /// 30 /// This intrinsic does not correspond to a specific instruction. 31 /// 32 /// \param __A 33 /// A bfloat data. 34 /// \returns A float data whose sign field and exponent field keep unchanged, 35 /// and fraction field is extended to 23 bits. 36 static __inline__ float __DEFAULT_FN_ATTRS _mm_cvtsbh_ss(__bfloat16 __A) { 37 return __builtin_ia32_cvtsbf162ss_32(__A); 38 } 39 40 /// Convert Two Packed Single Data to One Packed BF16 Data. 41 /// 42 /// \headerfile <x86intrin.h> 43 /// 44 /// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions. 45 /// 46 /// \param __A 47 /// A 512-bit vector of [16 x float]. 48 /// \param __B 49 /// A 512-bit vector of [16 x float]. 50 /// \returns A 512-bit vector of [32 x bfloat] whose lower 256 bits come from 51 /// conversion of __B, and higher 256 bits come from conversion of __A. 52 static __inline__ __m512bh __DEFAULT_FN_ATTRS512 53 _mm512_cvtne2ps_pbh(__m512 __A, __m512 __B) { 54 return (__m512bh)__builtin_ia32_cvtne2ps2bf16_512((__v16sf) __A, 55 (__v16sf) __B); 56 } 57 58 /// Convert Two Packed Single Data to One Packed BF16 Data. 59 /// 60 /// \headerfile <x86intrin.h> 61 /// 62 /// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions. 63 /// 64 /// \param __A 65 /// A 512-bit vector of [16 x float]. 66 /// \param __B 67 /// A 512-bit vector of [16 x float]. 68 /// \param __W 69 /// A 512-bit vector of [32 x bfloat]. 70 /// \param __U 71 /// A 32-bit mask value specifying what is chosen for each element. 72 /// A 1 means conversion of __A or __B. A 0 means element from __W. 73 /// \returns A 512-bit vector of [32 x bfloat] whose lower 256 bits come from 74 /// conversion of __B, and higher 256 bits come from conversion of __A. 75 static __inline__ __m512bh __DEFAULT_FN_ATTRS512 76 _mm512_mask_cvtne2ps_pbh(__m512bh __W, __mmask32 __U, __m512 __A, __m512 __B) { 77 return (__m512bh)__builtin_ia32_selectw_512((__mmask32)__U, 78 (__v32hi)_mm512_cvtne2ps_pbh(__A, __B), 79 (__v32hi)__W); 80 } 81 82 /// Convert Two Packed Single Data to One Packed BF16 Data. 83 /// 84 /// \headerfile <x86intrin.h> 85 /// 86 /// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions. 87 /// 88 /// \param __A 89 /// A 512-bit vector of [16 x float]. 90 /// \param __B 91 /// A 512-bit vector of [16 x float]. 92 /// \param __U 93 /// A 32-bit mask value specifying what is chosen for each element. 94 /// A 1 means conversion of __A or __B. A 0 means element is zero. 95 /// \returns A 512-bit vector of [32 x bfloat] whose lower 256 bits come from 96 /// conversion of __B, and higher 256 bits come from conversion of __A. 97 static __inline__ __m512bh __DEFAULT_FN_ATTRS512 98 _mm512_maskz_cvtne2ps_pbh(__mmask32 __U, __m512 __A, __m512 __B) { 99 return (__m512bh)__builtin_ia32_selectw_512((__mmask32)__U, 100 (__v32hi)_mm512_cvtne2ps_pbh(__A, __B), 101 (__v32hi)_mm512_setzero_si512()); 102 } 103 104 /// Convert Packed Single Data to Packed BF16 Data. 105 /// 106 /// \headerfile <x86intrin.h> 107 /// 108 /// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions. 109 /// 110 /// \param __A 111 /// A 512-bit vector of [16 x float]. 112 /// \returns A 256-bit vector of [16 x bfloat] come from conversion of __A. 113 static __inline__ __m256bh __DEFAULT_FN_ATTRS512 114 _mm512_cvtneps_pbh(__m512 __A) { 115 return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A, 116 (__v16hi)_mm256_undefined_si256(), 117 (__mmask16)-1); 118 } 119 120 /// Convert Packed Single Data to Packed BF16 Data. 121 /// 122 /// \headerfile <x86intrin.h> 123 /// 124 /// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions. 125 /// 126 /// \param __A 127 /// A 512-bit vector of [16 x float]. 128 /// \param __W 129 /// A 256-bit vector of [16 x bfloat]. 130 /// \param __U 131 /// A 16-bit mask value specifying what is chosen for each element. 132 /// A 1 means conversion of __A. A 0 means element from __W. 133 /// \returns A 256-bit vector of [16 x bfloat] come from conversion of __A. 134 static __inline__ __m256bh __DEFAULT_FN_ATTRS512 135 _mm512_mask_cvtneps_pbh(__m256bh __W, __mmask16 __U, __m512 __A) { 136 return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A, 137 (__v16hi)__W, 138 (__mmask16)__U); 139 } 140 141 /// Convert Packed Single Data to Packed BF16 Data. 142 /// 143 /// \headerfile <x86intrin.h> 144 /// 145 /// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions. 146 /// 147 /// \param __A 148 /// A 512-bit vector of [16 x float]. 149 /// \param __U 150 /// A 16-bit mask value specifying what is chosen for each element. 151 /// A 1 means conversion of __A. A 0 means element is zero. 152 /// \returns A 256-bit vector of [16 x bfloat] come from conversion of __A. 153 static __inline__ __m256bh __DEFAULT_FN_ATTRS512 154 _mm512_maskz_cvtneps_pbh(__mmask16 __U, __m512 __A) { 155 return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A, 156 (__v16hi)_mm256_setzero_si256(), 157 (__mmask16)__U); 158 } 159 160 /// Dot Product of BF16 Pairs Accumulated into Packed Single Precision. 161 /// 162 /// \headerfile <x86intrin.h> 163 /// 164 /// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions. 165 /// 166 /// \param __A 167 /// A 512-bit vector of [32 x bfloat]. 168 /// \param __B 169 /// A 512-bit vector of [32 x bfloat]. 170 /// \param __D 171 /// A 512-bit vector of [16 x float]. 172 /// \returns A 512-bit vector of [16 x float] comes from Dot Product of 173 /// __A, __B and __D 174 static __inline__ __m512 __DEFAULT_FN_ATTRS512 175 _mm512_dpbf16_ps(__m512 __D, __m512bh __A, __m512bh __B) { 176 return (__m512)__builtin_ia32_dpbf16ps_512((__v16sf) __D, 177 (__v16si) __A, 178 (__v16si) __B); 179 } 180 181 /// Dot Product of BF16 Pairs Accumulated into Packed Single Precision. 182 /// 183 /// \headerfile <x86intrin.h> 184 /// 185 /// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions. 186 /// 187 /// \param __A 188 /// A 512-bit vector of [32 x bfloat]. 189 /// \param __B 190 /// A 512-bit vector of [32 x bfloat]. 191 /// \param __D 192 /// A 512-bit vector of [16 x float]. 193 /// \param __U 194 /// A 16-bit mask value specifying what is chosen for each element. 195 /// A 1 means __A and __B's dot product accumulated with __D. A 0 means __D. 196 /// \returns A 512-bit vector of [16 x float] comes from Dot Product of 197 /// __A, __B and __D 198 static __inline__ __m512 __DEFAULT_FN_ATTRS512 199 _mm512_mask_dpbf16_ps(__m512 __D, __mmask16 __U, __m512bh __A, __m512bh __B) { 200 return (__m512)__builtin_ia32_selectps_512((__mmask16)__U, 201 (__v16sf)_mm512_dpbf16_ps(__D, __A, __B), 202 (__v16sf)__D); 203 } 204 205 /// Dot Product of BF16 Pairs Accumulated into Packed Single Precision. 206 /// 207 /// \headerfile <x86intrin.h> 208 /// 209 /// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions. 210 /// 211 /// \param __A 212 /// A 512-bit vector of [32 x bfloat]. 213 /// \param __B 214 /// A 512-bit vector of [32 x bfloat]. 215 /// \param __D 216 /// A 512-bit vector of [16 x float]. 217 /// \param __U 218 /// A 16-bit mask value specifying what is chosen for each element. 219 /// A 1 means __A and __B's dot product accumulated with __D. A 0 means 0. 220 /// \returns A 512-bit vector of [16 x float] comes from Dot Product of 221 /// __A, __B and __D 222 static __inline__ __m512 __DEFAULT_FN_ATTRS512 223 _mm512_maskz_dpbf16_ps(__mmask16 __U, __m512 __D, __m512bh __A, __m512bh __B) { 224 return (__m512)__builtin_ia32_selectps_512((__mmask16)__U, 225 (__v16sf)_mm512_dpbf16_ps(__D, __A, __B), 226 (__v16sf)_mm512_setzero_si512()); 227 } 228 229 /// Convert Packed BF16 Data to Packed float Data. 230 /// 231 /// \headerfile <x86intrin.h> 232 /// 233 /// \param __A 234 /// A 256-bit vector of [16 x bfloat]. 235 /// \returns A 512-bit vector of [16 x float] come from conversion of __A 236 static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) { 237 return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32( 238 (__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16)); 239 } 240 241 /// Convert Packed BF16 Data to Packed float Data using zeroing mask. 242 /// 243 /// \headerfile <x86intrin.h> 244 /// 245 /// \param __U 246 /// A 16-bit mask. Elements are zeroed out when the corresponding mask 247 /// bit is not set. 248 /// \param __A 249 /// A 256-bit vector of [16 x bfloat]. 250 /// \returns A 512-bit vector of [16 x float] come from conversion of __A 251 static __inline__ __m512 __DEFAULT_FN_ATTRS512 252 _mm512_maskz_cvtpbh_ps(__mmask16 __U, __m256bh __A) { 253 return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32( 254 (__m512i)_mm512_maskz_cvtepi16_epi32((__mmask16)__U, (__m256i)__A), 16)); 255 } 256 257 /// Convert Packed BF16 Data to Packed float Data using merging mask. 258 /// 259 /// \headerfile <x86intrin.h> 260 /// 261 /// \param __S 262 /// A 512-bit vector of [16 x float]. Elements are copied from __S when 263 /// the corresponding mask bit is not set. 264 /// \param __U 265 /// A 16-bit mask. 266 /// \param __A 267 /// A 256-bit vector of [16 x bfloat]. 268 /// \returns A 512-bit vector of [16 x float] come from conversion of __A 269 static __inline__ __m512 __DEFAULT_FN_ATTRS512 270 _mm512_mask_cvtpbh_ps(__m512 __S, __mmask16 __U, __m256bh __A) { 271 return _mm512_castsi512_ps((__m512i)_mm512_mask_slli_epi32( 272 (__m512i)__S, (__mmask16)__U, 273 (__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16)); 274 } 275 276 #undef __DEFAULT_FN_ATTRS 277 #undef __DEFAULT_FN_ATTRS512 278 279 #endif 280