1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "clang/Support/RISCVVIntrinsicUtils.h" 10 #include "llvm/ADT/ArrayRef.h" 11 #include "llvm/ADT/SmallSet.h" 12 #include "llvm/ADT/StringExtras.h" 13 #include "llvm/ADT/StringMap.h" 14 #include "llvm/ADT/StringSet.h" 15 #include "llvm/ADT/Twine.h" 16 #include "llvm/Support/ErrorHandling.h" 17 #include "llvm/Support/raw_ostream.h" 18 #include <numeric> 19 #include <optional> 20 21 using namespace llvm; 22 23 namespace clang { 24 namespace RISCV { 25 26 const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor( 27 BaseTypeModifier::Vector, VectorTypeModifier::MaskVector); 28 const PrototypeDescriptor PrototypeDescriptor::VL = 29 PrototypeDescriptor(BaseTypeModifier::SizeT); 30 const PrototypeDescriptor PrototypeDescriptor::Vector = 31 PrototypeDescriptor(BaseTypeModifier::Vector); 32 33 //===----------------------------------------------------------------------===// 34 // Type implementation 35 //===----------------------------------------------------------------------===// 36 37 LMULType::LMULType(int NewLog2LMUL) { 38 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3 39 assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!"); 40 Log2LMUL = NewLog2LMUL; 41 } 42 43 std::string LMULType::str() const { 44 if (Log2LMUL < 0) 45 return "mf" + utostr(1ULL << (-Log2LMUL)); 46 return "m" + utostr(1ULL << Log2LMUL); 47 } 48 49 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { 50 int Log2ScaleResult = 0; 51 switch (ElementBitwidth) { 52 default: 53 break; 54 case 8: 55 Log2ScaleResult = Log2LMUL + 3; 56 break; 57 case 16: 58 Log2ScaleResult = Log2LMUL + 2; 59 break; 60 case 32: 61 Log2ScaleResult = Log2LMUL + 1; 62 break; 63 case 64: 64 Log2ScaleResult = Log2LMUL; 65 break; 66 } 67 // Illegal vscale result would be less than 1 68 if (Log2ScaleResult < 0) 69 return std::nullopt; 70 return 1 << Log2ScaleResult; 71 } 72 73 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; } 74 75 RVVType::RVVType(BasicType BT, int Log2LMUL, 76 const PrototypeDescriptor &prototype) 77 : BT(BT), LMUL(LMULType(Log2LMUL)) { 78 applyBasicType(); 79 applyModifier(prototype); 80 Valid = verifyType(); 81 if (Valid) { 82 initBuiltinStr(); 83 initTypeStr(); 84 if (isVector()) { 85 initClangBuiltinStr(); 86 } 87 } 88 } 89 90 // clang-format off 91 // boolean type are encoded the ratio of n (SEW/LMUL) 92 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64 93 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t 94 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1 95 96 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 97 // -------- |------ | -------- | ------- | ------- | -------- | -------- | -------- 98 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64 99 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32 100 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16 101 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8 102 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 103 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 104 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 105 // clang-format on 106 107 bool RVVType::verifyType() const { 108 if (ScalarType == Invalid) 109 return false; 110 if (isScalar()) 111 return true; 112 if (!Scale) 113 return false; 114 if (isFloat() && ElementBitwidth == 8) 115 return false; 116 if (IsTuple && (NF == 1 || NF > 8)) 117 return false; 118 if (IsTuple && (1 << std::max(0, LMUL.Log2LMUL)) * NF > 8) 119 return false; 120 unsigned V = *Scale; 121 switch (ElementBitwidth) { 122 case 1: 123 case 8: 124 // Check Scale is 1,2,4,8,16,32,64 125 return (V <= 64 && isPowerOf2_32(V)); 126 case 16: 127 // Check Scale is 1,2,4,8,16,32 128 return (V <= 32 && isPowerOf2_32(V)); 129 case 32: 130 // Check Scale is 1,2,4,8,16 131 return (V <= 16 && isPowerOf2_32(V)); 132 case 64: 133 // Check Scale is 1,2,4,8 134 return (V <= 8 && isPowerOf2_32(V)); 135 } 136 return false; 137 } 138 139 void RVVType::initBuiltinStr() { 140 assert(isValid() && "RVVType is invalid"); 141 switch (ScalarType) { 142 case ScalarTypeKind::Void: 143 BuiltinStr = "v"; 144 return; 145 case ScalarTypeKind::Size_t: 146 BuiltinStr = "z"; 147 if (IsImmediate) 148 BuiltinStr = "I" + BuiltinStr; 149 if (IsPointer) 150 BuiltinStr += "*"; 151 return; 152 case ScalarTypeKind::Ptrdiff_t: 153 BuiltinStr = "Y"; 154 return; 155 case ScalarTypeKind::UnsignedLong: 156 BuiltinStr = "ULi"; 157 return; 158 case ScalarTypeKind::SignedLong: 159 BuiltinStr = "Li"; 160 return; 161 case ScalarTypeKind::Boolean: 162 assert(ElementBitwidth == 1); 163 BuiltinStr += "b"; 164 break; 165 case ScalarTypeKind::SignedInteger: 166 case ScalarTypeKind::UnsignedInteger: 167 switch (ElementBitwidth) { 168 case 8: 169 BuiltinStr += "c"; 170 break; 171 case 16: 172 BuiltinStr += "s"; 173 break; 174 case 32: 175 BuiltinStr += "i"; 176 break; 177 case 64: 178 BuiltinStr += "Wi"; 179 break; 180 default: 181 llvm_unreachable("Unhandled ElementBitwidth!"); 182 } 183 if (isSignedInteger()) 184 BuiltinStr = "S" + BuiltinStr; 185 else 186 BuiltinStr = "U" + BuiltinStr; 187 break; 188 case ScalarTypeKind::Float: 189 switch (ElementBitwidth) { 190 case 16: 191 BuiltinStr += "x"; 192 break; 193 case 32: 194 BuiltinStr += "f"; 195 break; 196 case 64: 197 BuiltinStr += "d"; 198 break; 199 default: 200 llvm_unreachable("Unhandled ElementBitwidth!"); 201 } 202 break; 203 default: 204 llvm_unreachable("ScalarType is invalid!"); 205 } 206 if (IsImmediate) 207 BuiltinStr = "I" + BuiltinStr; 208 if (isScalar()) { 209 if (IsConstant) 210 BuiltinStr += "C"; 211 if (IsPointer) 212 BuiltinStr += "*"; 213 return; 214 } 215 BuiltinStr = "q" + utostr(*Scale) + BuiltinStr; 216 // Pointer to vector types. Defined for segment load intrinsics. 217 // segment load intrinsics have pointer type arguments to store the loaded 218 // vector values. 219 if (IsPointer) 220 BuiltinStr += "*"; 221 222 if (IsTuple) 223 BuiltinStr = "T" + utostr(NF) + BuiltinStr; 224 } 225 226 void RVVType::initClangBuiltinStr() { 227 assert(isValid() && "RVVType is invalid"); 228 assert(isVector() && "Handle Vector type only"); 229 230 ClangBuiltinStr = "__rvv_"; 231 switch (ScalarType) { 232 case ScalarTypeKind::Boolean: 233 ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t"; 234 return; 235 case ScalarTypeKind::Float: 236 ClangBuiltinStr += "float"; 237 break; 238 case ScalarTypeKind::SignedInteger: 239 ClangBuiltinStr += "int"; 240 break; 241 case ScalarTypeKind::UnsignedInteger: 242 ClangBuiltinStr += "uint"; 243 break; 244 default: 245 llvm_unreachable("ScalarTypeKind is invalid"); 246 } 247 ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + 248 (IsTuple ? "x" + utostr(NF) : "") + "_t"; 249 } 250 251 void RVVType::initTypeStr() { 252 assert(isValid() && "RVVType is invalid"); 253 254 if (IsConstant) 255 Str += "const "; 256 257 auto getTypeString = [&](StringRef TypeStr) { 258 if (isScalar()) 259 return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str(); 260 return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + 261 (IsTuple ? "x" + utostr(NF) : "") + "_t") 262 .str(); 263 }; 264 265 switch (ScalarType) { 266 case ScalarTypeKind::Void: 267 Str = "void"; 268 return; 269 case ScalarTypeKind::Size_t: 270 Str = "size_t"; 271 if (IsPointer) 272 Str += " *"; 273 return; 274 case ScalarTypeKind::Ptrdiff_t: 275 Str = "ptrdiff_t"; 276 return; 277 case ScalarTypeKind::UnsignedLong: 278 Str = "unsigned long"; 279 return; 280 case ScalarTypeKind::SignedLong: 281 Str = "long"; 282 return; 283 case ScalarTypeKind::Boolean: 284 if (isScalar()) 285 Str += "bool"; 286 else 287 // Vector bool is special case, the formulate is 288 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1 289 Str += "vbool" + utostr(64 / *Scale) + "_t"; 290 break; 291 case ScalarTypeKind::Float: 292 if (isScalar()) { 293 if (ElementBitwidth == 64) 294 Str += "double"; 295 else if (ElementBitwidth == 32) 296 Str += "float"; 297 else if (ElementBitwidth == 16) 298 Str += "_Float16"; 299 else 300 llvm_unreachable("Unhandled floating type."); 301 } else 302 Str += getTypeString("float"); 303 break; 304 case ScalarTypeKind::SignedInteger: 305 Str += getTypeString("int"); 306 break; 307 case ScalarTypeKind::UnsignedInteger: 308 Str += getTypeString("uint"); 309 break; 310 default: 311 llvm_unreachable("ScalarType is invalid!"); 312 } 313 if (IsPointer) 314 Str += " *"; 315 } 316 317 void RVVType::initShortStr() { 318 switch (ScalarType) { 319 case ScalarTypeKind::Boolean: 320 assert(isVector()); 321 ShortStr = "b" + utostr(64 / *Scale); 322 return; 323 case ScalarTypeKind::Float: 324 ShortStr = "f" + utostr(ElementBitwidth); 325 break; 326 case ScalarTypeKind::SignedInteger: 327 ShortStr = "i" + utostr(ElementBitwidth); 328 break; 329 case ScalarTypeKind::UnsignedInteger: 330 ShortStr = "u" + utostr(ElementBitwidth); 331 break; 332 default: 333 llvm_unreachable("Unhandled case!"); 334 } 335 if (isVector()) 336 ShortStr += LMUL.str(); 337 if (isTuple()) 338 ShortStr += "x" + utostr(NF); 339 } 340 341 static VectorTypeModifier getTupleVTM(unsigned NF) { 342 assert(2 <= NF && NF <= 8 && "2 <= NF <= 8"); 343 return static_cast<VectorTypeModifier>( 344 static_cast<uint8_t>(VectorTypeModifier::Tuple2) + (NF - 2)); 345 } 346 347 void RVVType::applyBasicType() { 348 switch (BT) { 349 case BasicType::Int8: 350 ElementBitwidth = 8; 351 ScalarType = ScalarTypeKind::SignedInteger; 352 break; 353 case BasicType::Int16: 354 ElementBitwidth = 16; 355 ScalarType = ScalarTypeKind::SignedInteger; 356 break; 357 case BasicType::Int32: 358 ElementBitwidth = 32; 359 ScalarType = ScalarTypeKind::SignedInteger; 360 break; 361 case BasicType::Int64: 362 ElementBitwidth = 64; 363 ScalarType = ScalarTypeKind::SignedInteger; 364 break; 365 case BasicType::Float16: 366 ElementBitwidth = 16; 367 ScalarType = ScalarTypeKind::Float; 368 break; 369 case BasicType::Float32: 370 ElementBitwidth = 32; 371 ScalarType = ScalarTypeKind::Float; 372 break; 373 case BasicType::Float64: 374 ElementBitwidth = 64; 375 ScalarType = ScalarTypeKind::Float; 376 break; 377 default: 378 llvm_unreachable("Unhandled type code!"); 379 } 380 assert(ElementBitwidth != 0 && "Bad element bitwidth!"); 381 } 382 383 std::optional<PrototypeDescriptor> 384 PrototypeDescriptor::parsePrototypeDescriptor( 385 llvm::StringRef PrototypeDescriptorStr) { 386 PrototypeDescriptor PD; 387 BaseTypeModifier PT = BaseTypeModifier::Invalid; 388 VectorTypeModifier VTM = VectorTypeModifier::NoModifier; 389 390 if (PrototypeDescriptorStr.empty()) 391 return PD; 392 393 // Handle base type modifier 394 auto PType = PrototypeDescriptorStr.back(); 395 switch (PType) { 396 case 'e': 397 PT = BaseTypeModifier::Scalar; 398 break; 399 case 'v': 400 PT = BaseTypeModifier::Vector; 401 break; 402 case 'w': 403 PT = BaseTypeModifier::Vector; 404 VTM = VectorTypeModifier::Widening2XVector; 405 break; 406 case 'q': 407 PT = BaseTypeModifier::Vector; 408 VTM = VectorTypeModifier::Widening4XVector; 409 break; 410 case 'o': 411 PT = BaseTypeModifier::Vector; 412 VTM = VectorTypeModifier::Widening8XVector; 413 break; 414 case 'm': 415 PT = BaseTypeModifier::Vector; 416 VTM = VectorTypeModifier::MaskVector; 417 break; 418 case '0': 419 PT = BaseTypeModifier::Void; 420 break; 421 case 'z': 422 PT = BaseTypeModifier::SizeT; 423 break; 424 case 't': 425 PT = BaseTypeModifier::Ptrdiff; 426 break; 427 case 'u': 428 PT = BaseTypeModifier::UnsignedLong; 429 break; 430 case 'l': 431 PT = BaseTypeModifier::SignedLong; 432 break; 433 default: 434 llvm_unreachable("Illegal primitive type transformers!"); 435 } 436 PD.PT = static_cast<uint8_t>(PT); 437 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back(); 438 439 // Compute the vector type transformers, it can only appear one time. 440 if (PrototypeDescriptorStr.startswith("(")) { 441 assert(VTM == VectorTypeModifier::NoModifier && 442 "VectorTypeModifier should only have one modifier"); 443 size_t Idx = PrototypeDescriptorStr.find(')'); 444 assert(Idx != StringRef::npos); 445 StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx); 446 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1); 447 assert(!PrototypeDescriptorStr.contains('(') && 448 "Only allow one vector type modifier"); 449 450 auto ComplexTT = ComplexType.split(":"); 451 if (ComplexTT.first == "Log2EEW") { 452 uint32_t Log2EEW; 453 if (ComplexTT.second.getAsInteger(10, Log2EEW)) { 454 llvm_unreachable("Invalid Log2EEW value!"); 455 return std::nullopt; 456 } 457 switch (Log2EEW) { 458 case 3: 459 VTM = VectorTypeModifier::Log2EEW3; 460 break; 461 case 4: 462 VTM = VectorTypeModifier::Log2EEW4; 463 break; 464 case 5: 465 VTM = VectorTypeModifier::Log2EEW5; 466 break; 467 case 6: 468 VTM = VectorTypeModifier::Log2EEW6; 469 break; 470 default: 471 llvm_unreachable("Invalid Log2EEW value, should be [3-6]"); 472 return std::nullopt; 473 } 474 } else if (ComplexTT.first == "FixedSEW") { 475 uint32_t NewSEW; 476 if (ComplexTT.second.getAsInteger(10, NewSEW)) { 477 llvm_unreachable("Invalid FixedSEW value!"); 478 return std::nullopt; 479 } 480 switch (NewSEW) { 481 case 8: 482 VTM = VectorTypeModifier::FixedSEW8; 483 break; 484 case 16: 485 VTM = VectorTypeModifier::FixedSEW16; 486 break; 487 case 32: 488 VTM = VectorTypeModifier::FixedSEW32; 489 break; 490 case 64: 491 VTM = VectorTypeModifier::FixedSEW64; 492 break; 493 default: 494 llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64"); 495 return std::nullopt; 496 } 497 } else if (ComplexTT.first == "LFixedLog2LMUL") { 498 int32_t Log2LMUL; 499 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { 500 llvm_unreachable("Invalid LFixedLog2LMUL value!"); 501 return std::nullopt; 502 } 503 switch (Log2LMUL) { 504 case -3: 505 VTM = VectorTypeModifier::LFixedLog2LMULN3; 506 break; 507 case -2: 508 VTM = VectorTypeModifier::LFixedLog2LMULN2; 509 break; 510 case -1: 511 VTM = VectorTypeModifier::LFixedLog2LMULN1; 512 break; 513 case 0: 514 VTM = VectorTypeModifier::LFixedLog2LMUL0; 515 break; 516 case 1: 517 VTM = VectorTypeModifier::LFixedLog2LMUL1; 518 break; 519 case 2: 520 VTM = VectorTypeModifier::LFixedLog2LMUL2; 521 break; 522 case 3: 523 VTM = VectorTypeModifier::LFixedLog2LMUL3; 524 break; 525 default: 526 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); 527 return std::nullopt; 528 } 529 } else if (ComplexTT.first == "SFixedLog2LMUL") { 530 int32_t Log2LMUL; 531 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { 532 llvm_unreachable("Invalid SFixedLog2LMUL value!"); 533 return std::nullopt; 534 } 535 switch (Log2LMUL) { 536 case -3: 537 VTM = VectorTypeModifier::SFixedLog2LMULN3; 538 break; 539 case -2: 540 VTM = VectorTypeModifier::SFixedLog2LMULN2; 541 break; 542 case -1: 543 VTM = VectorTypeModifier::SFixedLog2LMULN1; 544 break; 545 case 0: 546 VTM = VectorTypeModifier::SFixedLog2LMUL0; 547 break; 548 case 1: 549 VTM = VectorTypeModifier::SFixedLog2LMUL1; 550 break; 551 case 2: 552 VTM = VectorTypeModifier::SFixedLog2LMUL2; 553 break; 554 case 3: 555 VTM = VectorTypeModifier::SFixedLog2LMUL3; 556 break; 557 default: 558 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); 559 return std::nullopt; 560 } 561 562 } else if (ComplexTT.first == "Tuple") { 563 unsigned NF = 0; 564 if (ComplexTT.second.getAsInteger(10, NF)) { 565 llvm_unreachable("Invalid NF value!"); 566 return std::nullopt; 567 } 568 VTM = getTupleVTM(NF); 569 } else { 570 llvm_unreachable("Illegal complex type transformers!"); 571 } 572 } 573 PD.VTM = static_cast<uint8_t>(VTM); 574 575 // Compute the remain type transformers 576 TypeModifier TM = TypeModifier::NoModifier; 577 for (char I : PrototypeDescriptorStr) { 578 switch (I) { 579 case 'P': 580 if ((TM & TypeModifier::Const) == TypeModifier::Const) 581 llvm_unreachable("'P' transformer cannot be used after 'C'"); 582 if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer) 583 llvm_unreachable("'P' transformer cannot be used twice"); 584 TM |= TypeModifier::Pointer; 585 break; 586 case 'C': 587 TM |= TypeModifier::Const; 588 break; 589 case 'K': 590 TM |= TypeModifier::Immediate; 591 break; 592 case 'U': 593 TM |= TypeModifier::UnsignedInteger; 594 break; 595 case 'I': 596 TM |= TypeModifier::SignedInteger; 597 break; 598 case 'F': 599 TM |= TypeModifier::Float; 600 break; 601 case 'S': 602 TM |= TypeModifier::LMUL1; 603 break; 604 default: 605 llvm_unreachable("Illegal non-primitive type transformer!"); 606 } 607 } 608 PD.TM = static_cast<uint8_t>(TM); 609 610 return PD; 611 } 612 613 void RVVType::applyModifier(const PrototypeDescriptor &Transformer) { 614 // Handle primitive type transformer 615 switch (static_cast<BaseTypeModifier>(Transformer.PT)) { 616 case BaseTypeModifier::Scalar: 617 Scale = 0; 618 break; 619 case BaseTypeModifier::Vector: 620 Scale = LMUL.getScale(ElementBitwidth); 621 break; 622 case BaseTypeModifier::Void: 623 ScalarType = ScalarTypeKind::Void; 624 break; 625 case BaseTypeModifier::SizeT: 626 ScalarType = ScalarTypeKind::Size_t; 627 break; 628 case BaseTypeModifier::Ptrdiff: 629 ScalarType = ScalarTypeKind::Ptrdiff_t; 630 break; 631 case BaseTypeModifier::UnsignedLong: 632 ScalarType = ScalarTypeKind::UnsignedLong; 633 break; 634 case BaseTypeModifier::SignedLong: 635 ScalarType = ScalarTypeKind::SignedLong; 636 break; 637 case BaseTypeModifier::Invalid: 638 ScalarType = ScalarTypeKind::Invalid; 639 return; 640 } 641 642 switch (static_cast<VectorTypeModifier>(Transformer.VTM)) { 643 case VectorTypeModifier::Widening2XVector: 644 ElementBitwidth *= 2; 645 LMUL.MulLog2LMUL(1); 646 Scale = LMUL.getScale(ElementBitwidth); 647 break; 648 case VectorTypeModifier::Widening4XVector: 649 ElementBitwidth *= 4; 650 LMUL.MulLog2LMUL(2); 651 Scale = LMUL.getScale(ElementBitwidth); 652 break; 653 case VectorTypeModifier::Widening8XVector: 654 ElementBitwidth *= 8; 655 LMUL.MulLog2LMUL(3); 656 Scale = LMUL.getScale(ElementBitwidth); 657 break; 658 case VectorTypeModifier::MaskVector: 659 ScalarType = ScalarTypeKind::Boolean; 660 Scale = LMUL.getScale(ElementBitwidth); 661 ElementBitwidth = 1; 662 break; 663 case VectorTypeModifier::Log2EEW3: 664 applyLog2EEW(3); 665 break; 666 case VectorTypeModifier::Log2EEW4: 667 applyLog2EEW(4); 668 break; 669 case VectorTypeModifier::Log2EEW5: 670 applyLog2EEW(5); 671 break; 672 case VectorTypeModifier::Log2EEW6: 673 applyLog2EEW(6); 674 break; 675 case VectorTypeModifier::FixedSEW8: 676 applyFixedSEW(8); 677 break; 678 case VectorTypeModifier::FixedSEW16: 679 applyFixedSEW(16); 680 break; 681 case VectorTypeModifier::FixedSEW32: 682 applyFixedSEW(32); 683 break; 684 case VectorTypeModifier::FixedSEW64: 685 applyFixedSEW(64); 686 break; 687 case VectorTypeModifier::LFixedLog2LMULN3: 688 applyFixedLog2LMUL(-3, FixedLMULType::LargerThan); 689 break; 690 case VectorTypeModifier::LFixedLog2LMULN2: 691 applyFixedLog2LMUL(-2, FixedLMULType::LargerThan); 692 break; 693 case VectorTypeModifier::LFixedLog2LMULN1: 694 applyFixedLog2LMUL(-1, FixedLMULType::LargerThan); 695 break; 696 case VectorTypeModifier::LFixedLog2LMUL0: 697 applyFixedLog2LMUL(0, FixedLMULType::LargerThan); 698 break; 699 case VectorTypeModifier::LFixedLog2LMUL1: 700 applyFixedLog2LMUL(1, FixedLMULType::LargerThan); 701 break; 702 case VectorTypeModifier::LFixedLog2LMUL2: 703 applyFixedLog2LMUL(2, FixedLMULType::LargerThan); 704 break; 705 case VectorTypeModifier::LFixedLog2LMUL3: 706 applyFixedLog2LMUL(3, FixedLMULType::LargerThan); 707 break; 708 case VectorTypeModifier::SFixedLog2LMULN3: 709 applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan); 710 break; 711 case VectorTypeModifier::SFixedLog2LMULN2: 712 applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan); 713 break; 714 case VectorTypeModifier::SFixedLog2LMULN1: 715 applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan); 716 break; 717 case VectorTypeModifier::SFixedLog2LMUL0: 718 applyFixedLog2LMUL(0, FixedLMULType::SmallerThan); 719 break; 720 case VectorTypeModifier::SFixedLog2LMUL1: 721 applyFixedLog2LMUL(1, FixedLMULType::SmallerThan); 722 break; 723 case VectorTypeModifier::SFixedLog2LMUL2: 724 applyFixedLog2LMUL(2, FixedLMULType::SmallerThan); 725 break; 726 case VectorTypeModifier::SFixedLog2LMUL3: 727 applyFixedLog2LMUL(3, FixedLMULType::SmallerThan); 728 break; 729 case VectorTypeModifier::Tuple2: 730 case VectorTypeModifier::Tuple3: 731 case VectorTypeModifier::Tuple4: 732 case VectorTypeModifier::Tuple5: 733 case VectorTypeModifier::Tuple6: 734 case VectorTypeModifier::Tuple7: 735 case VectorTypeModifier::Tuple8: { 736 IsTuple = true; 737 NF = 2 + static_cast<uint8_t>(Transformer.VTM) - 738 static_cast<uint8_t>(VectorTypeModifier::Tuple2); 739 break; 740 } 741 case VectorTypeModifier::NoModifier: 742 break; 743 } 744 745 for (unsigned TypeModifierMaskShift = 0; 746 TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset); 747 ++TypeModifierMaskShift) { 748 unsigned TypeModifierMask = 1 << TypeModifierMaskShift; 749 if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) != 750 TypeModifierMask) 751 continue; 752 switch (static_cast<TypeModifier>(TypeModifierMask)) { 753 case TypeModifier::Pointer: 754 IsPointer = true; 755 break; 756 case TypeModifier::Const: 757 IsConstant = true; 758 break; 759 case TypeModifier::Immediate: 760 IsImmediate = true; 761 IsConstant = true; 762 break; 763 case TypeModifier::UnsignedInteger: 764 ScalarType = ScalarTypeKind::UnsignedInteger; 765 break; 766 case TypeModifier::SignedInteger: 767 ScalarType = ScalarTypeKind::SignedInteger; 768 break; 769 case TypeModifier::Float: 770 ScalarType = ScalarTypeKind::Float; 771 break; 772 case TypeModifier::LMUL1: 773 LMUL = LMULType(0); 774 // Update ElementBitwidth need to update Scale too. 775 Scale = LMUL.getScale(ElementBitwidth); 776 break; 777 default: 778 llvm_unreachable("Unknown type modifier mask!"); 779 } 780 } 781 } 782 783 void RVVType::applyLog2EEW(unsigned Log2EEW) { 784 // update new elmul = (eew/sew) * lmul 785 LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); 786 // update new eew 787 ElementBitwidth = 1 << Log2EEW; 788 ScalarType = ScalarTypeKind::SignedInteger; 789 Scale = LMUL.getScale(ElementBitwidth); 790 } 791 792 void RVVType::applyFixedSEW(unsigned NewSEW) { 793 // Set invalid type if src and dst SEW are same. 794 if (ElementBitwidth == NewSEW) { 795 ScalarType = ScalarTypeKind::Invalid; 796 return; 797 } 798 // Update new SEW 799 ElementBitwidth = NewSEW; 800 Scale = LMUL.getScale(ElementBitwidth); 801 } 802 803 void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) { 804 switch (Type) { 805 case FixedLMULType::LargerThan: 806 if (Log2LMUL < LMUL.Log2LMUL) { 807 ScalarType = ScalarTypeKind::Invalid; 808 return; 809 } 810 break; 811 case FixedLMULType::SmallerThan: 812 if (Log2LMUL > LMUL.Log2LMUL) { 813 ScalarType = ScalarTypeKind::Invalid; 814 return; 815 } 816 break; 817 } 818 819 // Update new LMUL 820 LMUL = LMULType(Log2LMUL); 821 Scale = LMUL.getScale(ElementBitwidth); 822 } 823 824 std::optional<RVVTypes> 825 RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, 826 ArrayRef<PrototypeDescriptor> Prototype) { 827 RVVTypes Types; 828 for (const PrototypeDescriptor &Proto : Prototype) { 829 auto T = computeType(BT, Log2LMUL, Proto); 830 if (!T) 831 return std::nullopt; 832 // Record legal type index 833 Types.push_back(*T); 834 } 835 return Types; 836 } 837 838 // Compute the hash value of RVVType, used for cache the result of computeType. 839 static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL, 840 PrototypeDescriptor Proto) { 841 // Layout of hash value: 842 // 0 8 16 24 32 40 843 // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM | 844 assert(Log2LMUL >= -3 && Log2LMUL <= 3); 845 return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 | 846 ((uint64_t)(Proto.PT & 0xff) << 16) | 847 ((uint64_t)(Proto.TM & 0xff) << 24) | 848 ((uint64_t)(Proto.VTM & 0xff) << 32); 849 } 850 851 std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL, 852 PrototypeDescriptor Proto) { 853 uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto); 854 // Search first 855 auto It = LegalTypes.find(Idx); 856 if (It != LegalTypes.end()) 857 return &(It->second); 858 859 if (IllegalTypes.count(Idx)) 860 return std::nullopt; 861 862 // Compute type and record the result. 863 RVVType T(BT, Log2LMUL, Proto); 864 if (T.isValid()) { 865 // Record legal type index and value. 866 std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool> 867 InsertResult = LegalTypes.insert({Idx, T}); 868 return &(InsertResult.first->second); 869 } 870 // Record illegal type index. 871 IllegalTypes.insert(Idx); 872 return std::nullopt; 873 } 874 875 //===----------------------------------------------------------------------===// 876 // RVVIntrinsic implementation 877 //===----------------------------------------------------------------------===// 878 RVVIntrinsic::RVVIntrinsic( 879 StringRef NewName, StringRef Suffix, StringRef NewOverloadedName, 880 StringRef OverloadedSuffix, StringRef IRName, bool IsMasked, 881 bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme, 882 bool SupportOverloading, bool HasBuiltinAlias, StringRef ManualCodegen, 883 const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes, 884 const std::vector<StringRef> &RequiredFeatures, unsigned NF, 885 Policy NewPolicyAttrs, bool HasFRMRoundModeOp) 886 : IRName(IRName), IsMasked(IsMasked), 887 HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme), 888 SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias), 889 ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) { 890 891 // Init BuiltinName, Name and OverloadedName 892 BuiltinName = NewName.str(); 893 Name = BuiltinName; 894 if (NewOverloadedName.empty()) 895 OverloadedName = NewName.split("_").first.str(); 896 else 897 OverloadedName = NewOverloadedName.str(); 898 if (!Suffix.empty()) 899 Name += "_" + Suffix.str(); 900 if (!OverloadedSuffix.empty()) 901 OverloadedName += "_" + OverloadedSuffix.str(); 902 903 updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName, 904 PolicyAttrs, HasFRMRoundModeOp); 905 906 // Init OutputType and InputTypes 907 OutputType = OutInTypes[0]; 908 InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end()); 909 910 // IntrinsicTypes is unmasked TA version index. Need to update it 911 // if there is merge operand (It is always in first operand). 912 IntrinsicTypes = NewIntrinsicTypes; 913 if ((IsMasked && hasMaskedOffOperand()) || 914 (!IsMasked && hasPassthruOperand())) { 915 for (auto &I : IntrinsicTypes) { 916 if (I >= 0) 917 I += NF; 918 } 919 } 920 } 921 922 std::string RVVIntrinsic::getBuiltinTypeStr() const { 923 std::string S; 924 S += OutputType->getBuiltinStr(); 925 for (const auto &T : InputTypes) { 926 S += T->getBuiltinStr(); 927 } 928 return S; 929 } 930 931 std::string RVVIntrinsic::getSuffixStr( 932 RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL, 933 llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) { 934 SmallVector<std::string> SuffixStrs; 935 for (auto PD : PrototypeDescriptors) { 936 auto T = TypeCache.computeType(Type, Log2LMUL, PD); 937 SuffixStrs.push_back((*T)->getShortStr()); 938 } 939 return join(SuffixStrs, "_"); 940 } 941 942 llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes( 943 llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked, 944 bool HasMaskedOffOperand, bool HasVL, unsigned NF, 945 PolicyScheme DefaultScheme, Policy PolicyAttrs, bool IsTuple) { 946 SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(), 947 Prototype.end()); 948 bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand; 949 if (IsMasked) { 950 // If HasMaskedOffOperand, insert result type as first input operand if 951 // need. 952 if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) { 953 if (NF == 1) { 954 NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]); 955 } else if (NF > 1) { 956 if (IsTuple) { 957 PrototypeDescriptor BasePtrOperand = Prototype[1]; 958 PrototypeDescriptor MaskoffType = PrototypeDescriptor( 959 static_cast<uint8_t>(BaseTypeModifier::Vector), 960 static_cast<uint8_t>(getTupleVTM(NF)), 961 BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer)); 962 NewPrototype.insert(NewPrototype.begin() + 1, MaskoffType); 963 } else { 964 // Convert 965 // (void, op0 address, op1 address, ...) 966 // to 967 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) 968 PrototypeDescriptor MaskoffType = NewPrototype[1]; 969 MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer); 970 NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType); 971 } 972 } 973 } 974 if (HasMaskedOffOperand && NF > 1) { 975 // Convert 976 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) 977 // to 978 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1, 979 // ...) 980 if (IsTuple) 981 NewPrototype.insert(NewPrototype.begin() + 1, 982 PrototypeDescriptor::Mask); 983 else 984 NewPrototype.insert(NewPrototype.begin() + NF + 1, 985 PrototypeDescriptor::Mask); 986 } else { 987 // If IsMasked, insert PrototypeDescriptor:Mask as first input operand. 988 NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask); 989 } 990 } else { 991 if (NF == 1) { 992 if (PolicyAttrs.isTUPolicy() && HasPassthruOp) 993 NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]); 994 } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) { 995 if (IsTuple) { 996 PrototypeDescriptor BasePtrOperand = Prototype[0]; 997 PrototypeDescriptor MaskoffType = PrototypeDescriptor( 998 static_cast<uint8_t>(BaseTypeModifier::Vector), 999 static_cast<uint8_t>(getTupleVTM(NF)), 1000 BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer)); 1001 NewPrototype.insert(NewPrototype.begin(), MaskoffType); 1002 } else { 1003 // NF > 1 cases for segment load operations. 1004 // Convert 1005 // (void, op0 address, op1 address, ...) 1006 // to 1007 // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...) 1008 PrototypeDescriptor MaskoffType = Prototype[1]; 1009 MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer); 1010 NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType); 1011 } 1012 } 1013 } 1014 1015 // If HasVL, append PrototypeDescriptor:VL to last operand 1016 if (HasVL) 1017 NewPrototype.push_back(PrototypeDescriptor::VL); 1018 1019 return NewPrototype; 1020 } 1021 1022 llvm::SmallVector<Policy> RVVIntrinsic::getSupportedUnMaskedPolicies() { 1023 return {Policy(Policy::PolicyType::Undisturbed)}; // TU 1024 } 1025 1026 llvm::SmallVector<Policy> 1027 RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy, 1028 bool HasMaskPolicy) { 1029 if (HasTailPolicy && HasMaskPolicy) 1030 return {Policy(Policy::PolicyType::Undisturbed, 1031 Policy::PolicyType::Agnostic), // TUM 1032 Policy(Policy::PolicyType::Undisturbed, 1033 Policy::PolicyType::Undisturbed), // TUMU 1034 Policy(Policy::PolicyType::Agnostic, 1035 Policy::PolicyType::Undisturbed)}; // MU 1036 if (HasTailPolicy && !HasMaskPolicy) 1037 return {Policy(Policy::PolicyType::Undisturbed, 1038 Policy::PolicyType::Agnostic)}; // TU 1039 if (!HasTailPolicy && HasMaskPolicy) 1040 return {Policy(Policy::PolicyType::Agnostic, 1041 Policy::PolicyType::Undisturbed)}; // MU 1042 llvm_unreachable("An RVV instruction should not be without both tail policy " 1043 "and mask policy"); 1044 } 1045 1046 void RVVIntrinsic::updateNamesAndPolicy( 1047 bool IsMasked, bool HasPolicy, std::string &Name, std::string &BuiltinName, 1048 std::string &OverloadedName, Policy &PolicyAttrs, bool HasFRMRoundModeOp) { 1049 1050 auto appendPolicySuffix = [&](const std::string &suffix) { 1051 Name += suffix; 1052 BuiltinName += suffix; 1053 OverloadedName += suffix; 1054 }; 1055 1056 // This follows the naming guideline under riscv-c-api-doc to add the 1057 // `__riscv_` suffix for all RVV intrinsics. 1058 Name = "__riscv_" + Name; 1059 OverloadedName = "__riscv_" + OverloadedName; 1060 1061 if (HasFRMRoundModeOp) { 1062 Name += "_rm"; 1063 BuiltinName += "_rm"; 1064 } 1065 1066 if (IsMasked) { 1067 if (PolicyAttrs.isTUMUPolicy()) 1068 appendPolicySuffix("_tumu"); 1069 else if (PolicyAttrs.isTUMAPolicy()) 1070 appendPolicySuffix("_tum"); 1071 else if (PolicyAttrs.isTAMUPolicy()) 1072 appendPolicySuffix("_mu"); 1073 else if (PolicyAttrs.isTAMAPolicy()) { 1074 Name += "_m"; 1075 BuiltinName += "_m"; 1076 } else 1077 llvm_unreachable("Unhandled policy condition"); 1078 } else { 1079 if (PolicyAttrs.isTUPolicy()) 1080 appendPolicySuffix("_tu"); 1081 else if (PolicyAttrs.isTAPolicy()) // no suffix needed 1082 return; 1083 else 1084 llvm_unreachable("Unhandled policy condition"); 1085 } 1086 } 1087 1088 SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) { 1089 SmallVector<PrototypeDescriptor> PrototypeDescriptors; 1090 const StringRef Primaries("evwqom0ztul"); 1091 while (!Prototypes.empty()) { 1092 size_t Idx = 0; 1093 // Skip over complex prototype because it could contain primitive type 1094 // character. 1095 if (Prototypes[0] == '(') 1096 Idx = Prototypes.find_first_of(')'); 1097 Idx = Prototypes.find_first_of(Primaries, Idx); 1098 assert(Idx != StringRef::npos); 1099 auto PD = PrototypeDescriptor::parsePrototypeDescriptor( 1100 Prototypes.slice(0, Idx + 1)); 1101 if (!PD) 1102 llvm_unreachable("Error during parsing prototype."); 1103 PrototypeDescriptors.push_back(*PD); 1104 Prototypes = Prototypes.drop_front(Idx + 1); 1105 } 1106 return PrototypeDescriptors; 1107 } 1108 1109 raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) { 1110 OS << "{"; 1111 OS << "\"" << Record.Name << "\","; 1112 if (Record.OverloadedName == nullptr || 1113 StringRef(Record.OverloadedName).empty()) 1114 OS << "nullptr,"; 1115 else 1116 OS << "\"" << Record.OverloadedName << "\","; 1117 OS << Record.PrototypeIndex << ","; 1118 OS << Record.SuffixIndex << ","; 1119 OS << Record.OverloadedSuffixIndex << ","; 1120 OS << (int)Record.PrototypeLength << ","; 1121 OS << (int)Record.SuffixLength << ","; 1122 OS << (int)Record.OverloadedSuffixSize << ","; 1123 OS << (int)Record.RequiredExtensions << ","; 1124 OS << (int)Record.TypeRangeMask << ","; 1125 OS << (int)Record.Log2LMULMask << ","; 1126 OS << (int)Record.NF << ","; 1127 OS << (int)Record.HasMasked << ","; 1128 OS << (int)Record.HasVL << ","; 1129 OS << (int)Record.HasMaskedOffOperand << ","; 1130 OS << (int)Record.HasTailPolicy << ","; 1131 OS << (int)Record.HasMaskPolicy << ","; 1132 OS << (int)Record.HasFRMRoundModeOp << ","; 1133 OS << (int)Record.IsTuple << ","; 1134 OS << (int)Record.UnMaskedPolicyScheme << ","; 1135 OS << (int)Record.MaskedPolicyScheme << ","; 1136 OS << "},\n"; 1137 return OS; 1138 } 1139 1140 } // end namespace RISCV 1141 } // end namespace clang 1142