1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "clang/Support/RISCVVIntrinsicUtils.h" 10 #include "llvm/ADT/ArrayRef.h" 11 #include "llvm/ADT/SmallSet.h" 12 #include "llvm/ADT/StringExtras.h" 13 #include "llvm/ADT/StringMap.h" 14 #include "llvm/ADT/StringSet.h" 15 #include "llvm/ADT/Twine.h" 16 #include "llvm/Support/ErrorHandling.h" 17 #include "llvm/Support/raw_ostream.h" 18 #include <numeric> 19 #include <optional> 20 21 using namespace llvm; 22 23 namespace clang { 24 namespace RISCV { 25 26 const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor( 27 BaseTypeModifier::Vector, VectorTypeModifier::MaskVector); 28 const PrototypeDescriptor PrototypeDescriptor::VL = 29 PrototypeDescriptor(BaseTypeModifier::SizeT); 30 const PrototypeDescriptor PrototypeDescriptor::Vector = 31 PrototypeDescriptor(BaseTypeModifier::Vector); 32 33 //===----------------------------------------------------------------------===// 34 // Type implementation 35 //===----------------------------------------------------------------------===// 36 37 LMULType::LMULType(int NewLog2LMUL) { 38 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3 39 assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!"); 40 Log2LMUL = NewLog2LMUL; 41 } 42 43 std::string LMULType::str() const { 44 if (Log2LMUL < 0) 45 return "mf" + utostr(1ULL << (-Log2LMUL)); 46 return "m" + utostr(1ULL << Log2LMUL); 47 } 48 49 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { 50 int Log2ScaleResult = 0; 51 switch (ElementBitwidth) { 52 default: 53 break; 54 case 8: 55 Log2ScaleResult = Log2LMUL + 3; 56 break; 57 case 16: 58 Log2ScaleResult = Log2LMUL + 2; 59 break; 60 case 32: 61 Log2ScaleResult = Log2LMUL + 1; 62 break; 63 case 64: 64 Log2ScaleResult = Log2LMUL; 65 break; 66 } 67 // Illegal vscale result would be less than 1 68 if (Log2ScaleResult < 0) 69 return std::nullopt; 70 return 1 << Log2ScaleResult; 71 } 72 73 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; } 74 75 RVVType::RVVType(BasicType BT, int Log2LMUL, 76 const PrototypeDescriptor &prototype) 77 : BT(BT), LMUL(LMULType(Log2LMUL)) { 78 applyBasicType(); 79 applyModifier(prototype); 80 Valid = verifyType(); 81 if (Valid) { 82 initBuiltinStr(); 83 initTypeStr(); 84 if (isVector()) { 85 initClangBuiltinStr(); 86 } 87 } 88 } 89 90 // clang-format off 91 // boolean type are encoded the ratio of n (SEW/LMUL) 92 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64 93 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t 94 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1 95 96 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 97 // -------- |------ | -------- | ------- | ------- | -------- | -------- | -------- 98 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64 99 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32 100 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16 101 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8 102 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 103 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 104 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 105 // clang-format on 106 107 bool RVVType::verifyType() const { 108 if (ScalarType == Invalid) 109 return false; 110 if (isScalar()) 111 return true; 112 if (!Scale) 113 return false; 114 if (isFloat() && ElementBitwidth == 8) 115 return false; 116 unsigned V = *Scale; 117 switch (ElementBitwidth) { 118 case 1: 119 case 8: 120 // Check Scale is 1,2,4,8,16,32,64 121 return (V <= 64 && isPowerOf2_32(V)); 122 case 16: 123 // Check Scale is 1,2,4,8,16,32 124 return (V <= 32 && isPowerOf2_32(V)); 125 case 32: 126 // Check Scale is 1,2,4,8,16 127 return (V <= 16 && isPowerOf2_32(V)); 128 case 64: 129 // Check Scale is 1,2,4,8 130 return (V <= 8 && isPowerOf2_32(V)); 131 } 132 return false; 133 } 134 135 void RVVType::initBuiltinStr() { 136 assert(isValid() && "RVVType is invalid"); 137 switch (ScalarType) { 138 case ScalarTypeKind::Void: 139 BuiltinStr = "v"; 140 return; 141 case ScalarTypeKind::Size_t: 142 BuiltinStr = "z"; 143 if (IsImmediate) 144 BuiltinStr = "I" + BuiltinStr; 145 if (IsPointer) 146 BuiltinStr += "*"; 147 return; 148 case ScalarTypeKind::Ptrdiff_t: 149 BuiltinStr = "Y"; 150 return; 151 case ScalarTypeKind::UnsignedLong: 152 BuiltinStr = "ULi"; 153 return; 154 case ScalarTypeKind::SignedLong: 155 BuiltinStr = "Li"; 156 return; 157 case ScalarTypeKind::Boolean: 158 assert(ElementBitwidth == 1); 159 BuiltinStr += "b"; 160 break; 161 case ScalarTypeKind::SignedInteger: 162 case ScalarTypeKind::UnsignedInteger: 163 switch (ElementBitwidth) { 164 case 8: 165 BuiltinStr += "c"; 166 break; 167 case 16: 168 BuiltinStr += "s"; 169 break; 170 case 32: 171 BuiltinStr += "i"; 172 break; 173 case 64: 174 BuiltinStr += "Wi"; 175 break; 176 default: 177 llvm_unreachable("Unhandled ElementBitwidth!"); 178 } 179 if (isSignedInteger()) 180 BuiltinStr = "S" + BuiltinStr; 181 else 182 BuiltinStr = "U" + BuiltinStr; 183 break; 184 case ScalarTypeKind::Float: 185 switch (ElementBitwidth) { 186 case 16: 187 BuiltinStr += "x"; 188 break; 189 case 32: 190 BuiltinStr += "f"; 191 break; 192 case 64: 193 BuiltinStr += "d"; 194 break; 195 default: 196 llvm_unreachable("Unhandled ElementBitwidth!"); 197 } 198 break; 199 default: 200 llvm_unreachable("ScalarType is invalid!"); 201 } 202 if (IsImmediate) 203 BuiltinStr = "I" + BuiltinStr; 204 if (isScalar()) { 205 if (IsConstant) 206 BuiltinStr += "C"; 207 if (IsPointer) 208 BuiltinStr += "*"; 209 return; 210 } 211 BuiltinStr = "q" + utostr(*Scale) + BuiltinStr; 212 // Pointer to vector types. Defined for segment load intrinsics. 213 // segment load intrinsics have pointer type arguments to store the loaded 214 // vector values. 215 if (IsPointer) 216 BuiltinStr += "*"; 217 } 218 219 void RVVType::initClangBuiltinStr() { 220 assert(isValid() && "RVVType is invalid"); 221 assert(isVector() && "Handle Vector type only"); 222 223 ClangBuiltinStr = "__rvv_"; 224 switch (ScalarType) { 225 case ScalarTypeKind::Boolean: 226 ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t"; 227 return; 228 case ScalarTypeKind::Float: 229 ClangBuiltinStr += "float"; 230 break; 231 case ScalarTypeKind::SignedInteger: 232 ClangBuiltinStr += "int"; 233 break; 234 case ScalarTypeKind::UnsignedInteger: 235 ClangBuiltinStr += "uint"; 236 break; 237 default: 238 llvm_unreachable("ScalarTypeKind is invalid"); 239 } 240 ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t"; 241 } 242 243 void RVVType::initTypeStr() { 244 assert(isValid() && "RVVType is invalid"); 245 246 if (IsConstant) 247 Str += "const "; 248 249 auto getTypeString = [&](StringRef TypeStr) { 250 if (isScalar()) 251 return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str(); 252 return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t") 253 .str(); 254 }; 255 256 switch (ScalarType) { 257 case ScalarTypeKind::Void: 258 Str = "void"; 259 return; 260 case ScalarTypeKind::Size_t: 261 Str = "size_t"; 262 if (IsPointer) 263 Str += " *"; 264 return; 265 case ScalarTypeKind::Ptrdiff_t: 266 Str = "ptrdiff_t"; 267 return; 268 case ScalarTypeKind::UnsignedLong: 269 Str = "unsigned long"; 270 return; 271 case ScalarTypeKind::SignedLong: 272 Str = "long"; 273 return; 274 case ScalarTypeKind::Boolean: 275 if (isScalar()) 276 Str += "bool"; 277 else 278 // Vector bool is special case, the formulate is 279 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1 280 Str += "vbool" + utostr(64 / *Scale) + "_t"; 281 break; 282 case ScalarTypeKind::Float: 283 if (isScalar()) { 284 if (ElementBitwidth == 64) 285 Str += "double"; 286 else if (ElementBitwidth == 32) 287 Str += "float"; 288 else if (ElementBitwidth == 16) 289 Str += "_Float16"; 290 else 291 llvm_unreachable("Unhandled floating type."); 292 } else 293 Str += getTypeString("float"); 294 break; 295 case ScalarTypeKind::SignedInteger: 296 Str += getTypeString("int"); 297 break; 298 case ScalarTypeKind::UnsignedInteger: 299 Str += getTypeString("uint"); 300 break; 301 default: 302 llvm_unreachable("ScalarType is invalid!"); 303 } 304 if (IsPointer) 305 Str += " *"; 306 } 307 308 void RVVType::initShortStr() { 309 switch (ScalarType) { 310 case ScalarTypeKind::Boolean: 311 assert(isVector()); 312 ShortStr = "b" + utostr(64 / *Scale); 313 return; 314 case ScalarTypeKind::Float: 315 ShortStr = "f" + utostr(ElementBitwidth); 316 break; 317 case ScalarTypeKind::SignedInteger: 318 ShortStr = "i" + utostr(ElementBitwidth); 319 break; 320 case ScalarTypeKind::UnsignedInteger: 321 ShortStr = "u" + utostr(ElementBitwidth); 322 break; 323 default: 324 llvm_unreachable("Unhandled case!"); 325 } 326 if (isVector()) 327 ShortStr += LMUL.str(); 328 } 329 330 void RVVType::applyBasicType() { 331 switch (BT) { 332 case BasicType::Int8: 333 ElementBitwidth = 8; 334 ScalarType = ScalarTypeKind::SignedInteger; 335 break; 336 case BasicType::Int16: 337 ElementBitwidth = 16; 338 ScalarType = ScalarTypeKind::SignedInteger; 339 break; 340 case BasicType::Int32: 341 ElementBitwidth = 32; 342 ScalarType = ScalarTypeKind::SignedInteger; 343 break; 344 case BasicType::Int64: 345 ElementBitwidth = 64; 346 ScalarType = ScalarTypeKind::SignedInteger; 347 break; 348 case BasicType::Float16: 349 ElementBitwidth = 16; 350 ScalarType = ScalarTypeKind::Float; 351 break; 352 case BasicType::Float32: 353 ElementBitwidth = 32; 354 ScalarType = ScalarTypeKind::Float; 355 break; 356 case BasicType::Float64: 357 ElementBitwidth = 64; 358 ScalarType = ScalarTypeKind::Float; 359 break; 360 default: 361 llvm_unreachable("Unhandled type code!"); 362 } 363 assert(ElementBitwidth != 0 && "Bad element bitwidth!"); 364 } 365 366 std::optional<PrototypeDescriptor> 367 PrototypeDescriptor::parsePrototypeDescriptor( 368 llvm::StringRef PrototypeDescriptorStr) { 369 PrototypeDescriptor PD; 370 BaseTypeModifier PT = BaseTypeModifier::Invalid; 371 VectorTypeModifier VTM = VectorTypeModifier::NoModifier; 372 373 if (PrototypeDescriptorStr.empty()) 374 return PD; 375 376 // Handle base type modifier 377 auto PType = PrototypeDescriptorStr.back(); 378 switch (PType) { 379 case 'e': 380 PT = BaseTypeModifier::Scalar; 381 break; 382 case 'v': 383 PT = BaseTypeModifier::Vector; 384 break; 385 case 'w': 386 PT = BaseTypeModifier::Vector; 387 VTM = VectorTypeModifier::Widening2XVector; 388 break; 389 case 'q': 390 PT = BaseTypeModifier::Vector; 391 VTM = VectorTypeModifier::Widening4XVector; 392 break; 393 case 'o': 394 PT = BaseTypeModifier::Vector; 395 VTM = VectorTypeModifier::Widening8XVector; 396 break; 397 case 'm': 398 PT = BaseTypeModifier::Vector; 399 VTM = VectorTypeModifier::MaskVector; 400 break; 401 case '0': 402 PT = BaseTypeModifier::Void; 403 break; 404 case 'z': 405 PT = BaseTypeModifier::SizeT; 406 break; 407 case 't': 408 PT = BaseTypeModifier::Ptrdiff; 409 break; 410 case 'u': 411 PT = BaseTypeModifier::UnsignedLong; 412 break; 413 case 'l': 414 PT = BaseTypeModifier::SignedLong; 415 break; 416 default: 417 llvm_unreachable("Illegal primitive type transformers!"); 418 } 419 PD.PT = static_cast<uint8_t>(PT); 420 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back(); 421 422 // Compute the vector type transformers, it can only appear one time. 423 if (PrototypeDescriptorStr.startswith("(")) { 424 assert(VTM == VectorTypeModifier::NoModifier && 425 "VectorTypeModifier should only have one modifier"); 426 size_t Idx = PrototypeDescriptorStr.find(')'); 427 assert(Idx != StringRef::npos); 428 StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx); 429 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1); 430 assert(!PrototypeDescriptorStr.contains('(') && 431 "Only allow one vector type modifier"); 432 433 auto ComplexTT = ComplexType.split(":"); 434 if (ComplexTT.first == "Log2EEW") { 435 uint32_t Log2EEW; 436 if (ComplexTT.second.getAsInteger(10, Log2EEW)) { 437 llvm_unreachable("Invalid Log2EEW value!"); 438 return std::nullopt; 439 } 440 switch (Log2EEW) { 441 case 3: 442 VTM = VectorTypeModifier::Log2EEW3; 443 break; 444 case 4: 445 VTM = VectorTypeModifier::Log2EEW4; 446 break; 447 case 5: 448 VTM = VectorTypeModifier::Log2EEW5; 449 break; 450 case 6: 451 VTM = VectorTypeModifier::Log2EEW6; 452 break; 453 default: 454 llvm_unreachable("Invalid Log2EEW value, should be [3-6]"); 455 return std::nullopt; 456 } 457 } else if (ComplexTT.first == "FixedSEW") { 458 uint32_t NewSEW; 459 if (ComplexTT.second.getAsInteger(10, NewSEW)) { 460 llvm_unreachable("Invalid FixedSEW value!"); 461 return std::nullopt; 462 } 463 switch (NewSEW) { 464 case 8: 465 VTM = VectorTypeModifier::FixedSEW8; 466 break; 467 case 16: 468 VTM = VectorTypeModifier::FixedSEW16; 469 break; 470 case 32: 471 VTM = VectorTypeModifier::FixedSEW32; 472 break; 473 case 64: 474 VTM = VectorTypeModifier::FixedSEW64; 475 break; 476 default: 477 llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64"); 478 return std::nullopt; 479 } 480 } else if (ComplexTT.first == "LFixedLog2LMUL") { 481 int32_t Log2LMUL; 482 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { 483 llvm_unreachable("Invalid LFixedLog2LMUL value!"); 484 return std::nullopt; 485 } 486 switch (Log2LMUL) { 487 case -3: 488 VTM = VectorTypeModifier::LFixedLog2LMULN3; 489 break; 490 case -2: 491 VTM = VectorTypeModifier::LFixedLog2LMULN2; 492 break; 493 case -1: 494 VTM = VectorTypeModifier::LFixedLog2LMULN1; 495 break; 496 case 0: 497 VTM = VectorTypeModifier::LFixedLog2LMUL0; 498 break; 499 case 1: 500 VTM = VectorTypeModifier::LFixedLog2LMUL1; 501 break; 502 case 2: 503 VTM = VectorTypeModifier::LFixedLog2LMUL2; 504 break; 505 case 3: 506 VTM = VectorTypeModifier::LFixedLog2LMUL3; 507 break; 508 default: 509 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); 510 return std::nullopt; 511 } 512 } else if (ComplexTT.first == "SFixedLog2LMUL") { 513 int32_t Log2LMUL; 514 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { 515 llvm_unreachable("Invalid SFixedLog2LMUL value!"); 516 return std::nullopt; 517 } 518 switch (Log2LMUL) { 519 case -3: 520 VTM = VectorTypeModifier::SFixedLog2LMULN3; 521 break; 522 case -2: 523 VTM = VectorTypeModifier::SFixedLog2LMULN2; 524 break; 525 case -1: 526 VTM = VectorTypeModifier::SFixedLog2LMULN1; 527 break; 528 case 0: 529 VTM = VectorTypeModifier::SFixedLog2LMUL0; 530 break; 531 case 1: 532 VTM = VectorTypeModifier::SFixedLog2LMUL1; 533 break; 534 case 2: 535 VTM = VectorTypeModifier::SFixedLog2LMUL2; 536 break; 537 case 3: 538 VTM = VectorTypeModifier::SFixedLog2LMUL3; 539 break; 540 default: 541 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); 542 return std::nullopt; 543 } 544 545 } else { 546 llvm_unreachable("Illegal complex type transformers!"); 547 } 548 } 549 PD.VTM = static_cast<uint8_t>(VTM); 550 551 // Compute the remain type transformers 552 TypeModifier TM = TypeModifier::NoModifier; 553 for (char I : PrototypeDescriptorStr) { 554 switch (I) { 555 case 'P': 556 if ((TM & TypeModifier::Const) == TypeModifier::Const) 557 llvm_unreachable("'P' transformer cannot be used after 'C'"); 558 if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer) 559 llvm_unreachable("'P' transformer cannot be used twice"); 560 TM |= TypeModifier::Pointer; 561 break; 562 case 'C': 563 TM |= TypeModifier::Const; 564 break; 565 case 'K': 566 TM |= TypeModifier::Immediate; 567 break; 568 case 'U': 569 TM |= TypeModifier::UnsignedInteger; 570 break; 571 case 'I': 572 TM |= TypeModifier::SignedInteger; 573 break; 574 case 'F': 575 TM |= TypeModifier::Float; 576 break; 577 case 'S': 578 TM |= TypeModifier::LMUL1; 579 break; 580 default: 581 llvm_unreachable("Illegal non-primitive type transformer!"); 582 } 583 } 584 PD.TM = static_cast<uint8_t>(TM); 585 586 return PD; 587 } 588 589 void RVVType::applyModifier(const PrototypeDescriptor &Transformer) { 590 // Handle primitive type transformer 591 switch (static_cast<BaseTypeModifier>(Transformer.PT)) { 592 case BaseTypeModifier::Scalar: 593 Scale = 0; 594 break; 595 case BaseTypeModifier::Vector: 596 Scale = LMUL.getScale(ElementBitwidth); 597 break; 598 case BaseTypeModifier::Void: 599 ScalarType = ScalarTypeKind::Void; 600 break; 601 case BaseTypeModifier::SizeT: 602 ScalarType = ScalarTypeKind::Size_t; 603 break; 604 case BaseTypeModifier::Ptrdiff: 605 ScalarType = ScalarTypeKind::Ptrdiff_t; 606 break; 607 case BaseTypeModifier::UnsignedLong: 608 ScalarType = ScalarTypeKind::UnsignedLong; 609 break; 610 case BaseTypeModifier::SignedLong: 611 ScalarType = ScalarTypeKind::SignedLong; 612 break; 613 case BaseTypeModifier::Invalid: 614 ScalarType = ScalarTypeKind::Invalid; 615 return; 616 } 617 618 switch (static_cast<VectorTypeModifier>(Transformer.VTM)) { 619 case VectorTypeModifier::Widening2XVector: 620 ElementBitwidth *= 2; 621 LMUL.MulLog2LMUL(1); 622 Scale = LMUL.getScale(ElementBitwidth); 623 break; 624 case VectorTypeModifier::Widening4XVector: 625 ElementBitwidth *= 4; 626 LMUL.MulLog2LMUL(2); 627 Scale = LMUL.getScale(ElementBitwidth); 628 break; 629 case VectorTypeModifier::Widening8XVector: 630 ElementBitwidth *= 8; 631 LMUL.MulLog2LMUL(3); 632 Scale = LMUL.getScale(ElementBitwidth); 633 break; 634 case VectorTypeModifier::MaskVector: 635 ScalarType = ScalarTypeKind::Boolean; 636 Scale = LMUL.getScale(ElementBitwidth); 637 ElementBitwidth = 1; 638 break; 639 case VectorTypeModifier::Log2EEW3: 640 applyLog2EEW(3); 641 break; 642 case VectorTypeModifier::Log2EEW4: 643 applyLog2EEW(4); 644 break; 645 case VectorTypeModifier::Log2EEW5: 646 applyLog2EEW(5); 647 break; 648 case VectorTypeModifier::Log2EEW6: 649 applyLog2EEW(6); 650 break; 651 case VectorTypeModifier::FixedSEW8: 652 applyFixedSEW(8); 653 break; 654 case VectorTypeModifier::FixedSEW16: 655 applyFixedSEW(16); 656 break; 657 case VectorTypeModifier::FixedSEW32: 658 applyFixedSEW(32); 659 break; 660 case VectorTypeModifier::FixedSEW64: 661 applyFixedSEW(64); 662 break; 663 case VectorTypeModifier::LFixedLog2LMULN3: 664 applyFixedLog2LMUL(-3, FixedLMULType::LargerThan); 665 break; 666 case VectorTypeModifier::LFixedLog2LMULN2: 667 applyFixedLog2LMUL(-2, FixedLMULType::LargerThan); 668 break; 669 case VectorTypeModifier::LFixedLog2LMULN1: 670 applyFixedLog2LMUL(-1, FixedLMULType::LargerThan); 671 break; 672 case VectorTypeModifier::LFixedLog2LMUL0: 673 applyFixedLog2LMUL(0, FixedLMULType::LargerThan); 674 break; 675 case VectorTypeModifier::LFixedLog2LMUL1: 676 applyFixedLog2LMUL(1, FixedLMULType::LargerThan); 677 break; 678 case VectorTypeModifier::LFixedLog2LMUL2: 679 applyFixedLog2LMUL(2, FixedLMULType::LargerThan); 680 break; 681 case VectorTypeModifier::LFixedLog2LMUL3: 682 applyFixedLog2LMUL(3, FixedLMULType::LargerThan); 683 break; 684 case VectorTypeModifier::SFixedLog2LMULN3: 685 applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan); 686 break; 687 case VectorTypeModifier::SFixedLog2LMULN2: 688 applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan); 689 break; 690 case VectorTypeModifier::SFixedLog2LMULN1: 691 applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan); 692 break; 693 case VectorTypeModifier::SFixedLog2LMUL0: 694 applyFixedLog2LMUL(0, FixedLMULType::SmallerThan); 695 break; 696 case VectorTypeModifier::SFixedLog2LMUL1: 697 applyFixedLog2LMUL(1, FixedLMULType::SmallerThan); 698 break; 699 case VectorTypeModifier::SFixedLog2LMUL2: 700 applyFixedLog2LMUL(2, FixedLMULType::SmallerThan); 701 break; 702 case VectorTypeModifier::SFixedLog2LMUL3: 703 applyFixedLog2LMUL(3, FixedLMULType::SmallerThan); 704 break; 705 case VectorTypeModifier::NoModifier: 706 break; 707 } 708 709 for (unsigned TypeModifierMaskShift = 0; 710 TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset); 711 ++TypeModifierMaskShift) { 712 unsigned TypeModifierMask = 1 << TypeModifierMaskShift; 713 if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) != 714 TypeModifierMask) 715 continue; 716 switch (static_cast<TypeModifier>(TypeModifierMask)) { 717 case TypeModifier::Pointer: 718 IsPointer = true; 719 break; 720 case TypeModifier::Const: 721 IsConstant = true; 722 break; 723 case TypeModifier::Immediate: 724 IsImmediate = true; 725 IsConstant = true; 726 break; 727 case TypeModifier::UnsignedInteger: 728 ScalarType = ScalarTypeKind::UnsignedInteger; 729 break; 730 case TypeModifier::SignedInteger: 731 ScalarType = ScalarTypeKind::SignedInteger; 732 break; 733 case TypeModifier::Float: 734 ScalarType = ScalarTypeKind::Float; 735 break; 736 case TypeModifier::LMUL1: 737 LMUL = LMULType(0); 738 // Update ElementBitwidth need to update Scale too. 739 Scale = LMUL.getScale(ElementBitwidth); 740 break; 741 default: 742 llvm_unreachable("Unknown type modifier mask!"); 743 } 744 } 745 } 746 747 void RVVType::applyLog2EEW(unsigned Log2EEW) { 748 // update new elmul = (eew/sew) * lmul 749 LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); 750 // update new eew 751 ElementBitwidth = 1 << Log2EEW; 752 ScalarType = ScalarTypeKind::SignedInteger; 753 Scale = LMUL.getScale(ElementBitwidth); 754 } 755 756 void RVVType::applyFixedSEW(unsigned NewSEW) { 757 // Set invalid type if src and dst SEW are same. 758 if (ElementBitwidth == NewSEW) { 759 ScalarType = ScalarTypeKind::Invalid; 760 return; 761 } 762 // Update new SEW 763 ElementBitwidth = NewSEW; 764 Scale = LMUL.getScale(ElementBitwidth); 765 } 766 767 void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) { 768 switch (Type) { 769 case FixedLMULType::LargerThan: 770 if (Log2LMUL < LMUL.Log2LMUL) { 771 ScalarType = ScalarTypeKind::Invalid; 772 return; 773 } 774 break; 775 case FixedLMULType::SmallerThan: 776 if (Log2LMUL > LMUL.Log2LMUL) { 777 ScalarType = ScalarTypeKind::Invalid; 778 return; 779 } 780 break; 781 } 782 783 // Update new LMUL 784 LMUL = LMULType(Log2LMUL); 785 Scale = LMUL.getScale(ElementBitwidth); 786 } 787 788 std::optional<RVVTypes> 789 RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, 790 ArrayRef<PrototypeDescriptor> Prototype) { 791 // LMUL x NF must be less than or equal to 8. 792 if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) 793 return std::nullopt; 794 795 RVVTypes Types; 796 for (const PrototypeDescriptor &Proto : Prototype) { 797 auto T = computeType(BT, Log2LMUL, Proto); 798 if (!T) 799 return std::nullopt; 800 // Record legal type index 801 Types.push_back(*T); 802 } 803 return Types; 804 } 805 806 // Compute the hash value of RVVType, used for cache the result of computeType. 807 static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL, 808 PrototypeDescriptor Proto) { 809 // Layout of hash value: 810 // 0 8 16 24 32 40 811 // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM | 812 assert(Log2LMUL >= -3 && Log2LMUL <= 3); 813 return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 | 814 ((uint64_t)(Proto.PT & 0xff) << 16) | 815 ((uint64_t)(Proto.TM & 0xff) << 24) | 816 ((uint64_t)(Proto.VTM & 0xff) << 32); 817 } 818 819 std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL, 820 PrototypeDescriptor Proto) { 821 uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto); 822 // Search first 823 auto It = LegalTypes.find(Idx); 824 if (It != LegalTypes.end()) 825 return &(It->second); 826 827 if (IllegalTypes.count(Idx)) 828 return std::nullopt; 829 830 // Compute type and record the result. 831 RVVType T(BT, Log2LMUL, Proto); 832 if (T.isValid()) { 833 // Record legal type index and value. 834 std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool> 835 InsertResult = LegalTypes.insert({Idx, T}); 836 return &(InsertResult.first->second); 837 } 838 // Record illegal type index. 839 IllegalTypes.insert(Idx); 840 return std::nullopt; 841 } 842 843 //===----------------------------------------------------------------------===// 844 // RVVIntrinsic implementation 845 //===----------------------------------------------------------------------===// 846 RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix, 847 StringRef NewOverloadedName, 848 StringRef OverloadedSuffix, StringRef IRName, 849 bool IsMasked, bool HasMaskedOffOperand, bool HasVL, 850 PolicyScheme Scheme, bool SupportOverloading, 851 bool HasBuiltinAlias, StringRef ManualCodegen, 852 const RVVTypes &OutInTypes, 853 const std::vector<int64_t> &NewIntrinsicTypes, 854 const std::vector<StringRef> &RequiredFeatures, 855 unsigned NF, Policy NewPolicyAttrs) 856 : IRName(IRName), IsMasked(IsMasked), 857 HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme), 858 SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias), 859 ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) { 860 861 // Init BuiltinName, Name and OverloadedName 862 BuiltinName = NewName.str(); 863 Name = BuiltinName; 864 if (NewOverloadedName.empty()) 865 OverloadedName = NewName.split("_").first.str(); 866 else 867 OverloadedName = NewOverloadedName.str(); 868 if (!Suffix.empty()) 869 Name += "_" + Suffix.str(); 870 if (!OverloadedSuffix.empty()) 871 OverloadedName += "_" + OverloadedSuffix.str(); 872 873 updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName, 874 PolicyAttrs); 875 876 // Init OutputType and InputTypes 877 OutputType = OutInTypes[0]; 878 InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end()); 879 880 // IntrinsicTypes is unmasked TA version index. Need to update it 881 // if there is merge operand (It is always in first operand). 882 IntrinsicTypes = NewIntrinsicTypes; 883 if ((IsMasked && hasMaskedOffOperand()) || 884 (!IsMasked && hasPassthruOperand())) { 885 for (auto &I : IntrinsicTypes) { 886 if (I >= 0) 887 I += NF; 888 } 889 } 890 } 891 892 std::string RVVIntrinsic::getBuiltinTypeStr() const { 893 std::string S; 894 S += OutputType->getBuiltinStr(); 895 for (const auto &T : InputTypes) { 896 S += T->getBuiltinStr(); 897 } 898 return S; 899 } 900 901 std::string RVVIntrinsic::getSuffixStr( 902 RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL, 903 llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) { 904 SmallVector<std::string> SuffixStrs; 905 for (auto PD : PrototypeDescriptors) { 906 auto T = TypeCache.computeType(Type, Log2LMUL, PD); 907 SuffixStrs.push_back((*T)->getShortStr()); 908 } 909 return join(SuffixStrs, "_"); 910 } 911 912 llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes( 913 llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked, 914 bool HasMaskedOffOperand, bool HasVL, unsigned NF, 915 PolicyScheme DefaultScheme, Policy PolicyAttrs) { 916 SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(), 917 Prototype.end()); 918 bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand; 919 if (IsMasked) { 920 // If HasMaskedOffOperand, insert result type as first input operand if 921 // need. 922 if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) { 923 if (NF == 1) { 924 NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]); 925 } else if (NF > 1) { 926 // Convert 927 // (void, op0 address, op1 address, ...) 928 // to 929 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) 930 PrototypeDescriptor MaskoffType = NewPrototype[1]; 931 MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer); 932 NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType); 933 } 934 } 935 if (HasMaskedOffOperand && NF > 1) { 936 // Convert 937 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) 938 // to 939 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1, 940 // ...) 941 NewPrototype.insert(NewPrototype.begin() + NF + 1, 942 PrototypeDescriptor::Mask); 943 } else { 944 // If IsMasked, insert PrototypeDescriptor:Mask as first input operand. 945 NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask); 946 } 947 } else { 948 if (NF == 1) { 949 if (PolicyAttrs.isTUPolicy() && HasPassthruOp) 950 NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]); 951 } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) { 952 // NF > 1 cases for segment load operations. 953 // Convert 954 // (void, op0 address, op1 address, ...) 955 // to 956 // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...) 957 PrototypeDescriptor MaskoffType = Prototype[1]; 958 MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer); 959 NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType); 960 } 961 } 962 963 // If HasVL, append PrototypeDescriptor:VL to last operand 964 if (HasVL) 965 NewPrototype.push_back(PrototypeDescriptor::VL); 966 return NewPrototype; 967 } 968 969 llvm::SmallVector<Policy> RVVIntrinsic::getSupportedUnMaskedPolicies() { 970 return {Policy(Policy::PolicyType::Undisturbed)}; // TU 971 } 972 973 llvm::SmallVector<Policy> 974 RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy, 975 bool HasMaskPolicy) { 976 if (HasTailPolicy && HasMaskPolicy) 977 return {Policy(Policy::PolicyType::Undisturbed, 978 Policy::PolicyType::Agnostic), // TUM 979 Policy(Policy::PolicyType::Undisturbed, 980 Policy::PolicyType::Undisturbed), // TUMU 981 Policy(Policy::PolicyType::Agnostic, 982 Policy::PolicyType::Undisturbed)}; // MU 983 if (HasTailPolicy && !HasMaskPolicy) 984 return {Policy(Policy::PolicyType::Undisturbed, 985 Policy::PolicyType::Agnostic)}; // TU 986 if (!HasTailPolicy && HasMaskPolicy) 987 return {Policy(Policy::PolicyType::Agnostic, 988 Policy::PolicyType::Undisturbed)}; // MU 989 llvm_unreachable("An RVV instruction should not be without both tail policy " 990 "and mask policy"); 991 } 992 993 void RVVIntrinsic::updateNamesAndPolicy(bool IsMasked, bool HasPolicy, 994 std::string &Name, 995 std::string &BuiltinName, 996 std::string &OverloadedName, 997 Policy &PolicyAttrs) { 998 999 auto appendPolicySuffix = [&](const std::string &suffix) { 1000 Name += suffix; 1001 BuiltinName += suffix; 1002 OverloadedName += suffix; 1003 }; 1004 1005 // This follows the naming guideline under riscv-c-api-doc to add the 1006 // `__riscv_` suffix for all RVV intrinsics. 1007 Name = "__riscv_" + Name; 1008 OverloadedName = "__riscv_" + OverloadedName; 1009 1010 if (IsMasked) { 1011 if (PolicyAttrs.isTUMUPolicy()) 1012 appendPolicySuffix("_tumu"); 1013 else if (PolicyAttrs.isTUMAPolicy()) 1014 appendPolicySuffix("_tum"); 1015 else if (PolicyAttrs.isTAMUPolicy()) 1016 appendPolicySuffix("_mu"); 1017 else if (PolicyAttrs.isTAMAPolicy()) { 1018 Name += "_m"; 1019 if (HasPolicy) 1020 BuiltinName += "_tama"; 1021 else 1022 BuiltinName += "_m"; 1023 } else 1024 llvm_unreachable("Unhandled policy condition"); 1025 } else { 1026 if (PolicyAttrs.isTUPolicy()) 1027 appendPolicySuffix("_tu"); 1028 else if (PolicyAttrs.isTAPolicy()) { 1029 if (HasPolicy) 1030 BuiltinName += "_ta"; 1031 } else 1032 llvm_unreachable("Unhandled policy condition"); 1033 } 1034 } 1035 1036 SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) { 1037 SmallVector<PrototypeDescriptor> PrototypeDescriptors; 1038 const StringRef Primaries("evwqom0ztul"); 1039 while (!Prototypes.empty()) { 1040 size_t Idx = 0; 1041 // Skip over complex prototype because it could contain primitive type 1042 // character. 1043 if (Prototypes[0] == '(') 1044 Idx = Prototypes.find_first_of(')'); 1045 Idx = Prototypes.find_first_of(Primaries, Idx); 1046 assert(Idx != StringRef::npos); 1047 auto PD = PrototypeDescriptor::parsePrototypeDescriptor( 1048 Prototypes.slice(0, Idx + 1)); 1049 if (!PD) 1050 llvm_unreachable("Error during parsing prototype."); 1051 PrototypeDescriptors.push_back(*PD); 1052 Prototypes = Prototypes.drop_front(Idx + 1); 1053 } 1054 return PrototypeDescriptors; 1055 } 1056 1057 raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) { 1058 OS << "{"; 1059 OS << "\"" << Record.Name << "\","; 1060 if (Record.OverloadedName == nullptr || 1061 StringRef(Record.OverloadedName).empty()) 1062 OS << "nullptr,"; 1063 else 1064 OS << "\"" << Record.OverloadedName << "\","; 1065 OS << Record.PrototypeIndex << ","; 1066 OS << Record.SuffixIndex << ","; 1067 OS << Record.OverloadedSuffixIndex << ","; 1068 OS << (int)Record.PrototypeLength << ","; 1069 OS << (int)Record.SuffixLength << ","; 1070 OS << (int)Record.OverloadedSuffixSize << ","; 1071 OS << (int)Record.RequiredExtensions << ","; 1072 OS << (int)Record.TypeRangeMask << ","; 1073 OS << (int)Record.Log2LMULMask << ","; 1074 OS << (int)Record.NF << ","; 1075 OS << (int)Record.HasMasked << ","; 1076 OS << (int)Record.HasVL << ","; 1077 OS << (int)Record.HasMaskedOffOperand << ","; 1078 OS << (int)Record.HasTailPolicy << ","; 1079 OS << (int)Record.HasMaskPolicy << ","; 1080 OS << (int)Record.UnMaskedPolicyScheme << ","; 1081 OS << (int)Record.MaskedPolicyScheme << ","; 1082 OS << "},\n"; 1083 return OS; 1084 } 1085 1086 } // end namespace RISCV 1087 } // end namespace clang 1088