1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "clang/Support/RISCVVIntrinsicUtils.h" 10 #include "llvm/ADT/ArrayRef.h" 11 #include "llvm/ADT/Optional.h" 12 #include "llvm/ADT/SmallSet.h" 13 #include "llvm/ADT/StringExtras.h" 14 #include "llvm/ADT/StringMap.h" 15 #include "llvm/ADT/StringSet.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Support/raw_ostream.h" 18 #include <numeric> 19 #include <set> 20 #include <unordered_map> 21 22 using namespace llvm; 23 24 namespace clang { 25 namespace RISCV { 26 27 const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor( 28 BaseTypeModifier::Vector, VectorTypeModifier::MaskVector); 29 const PrototypeDescriptor PrototypeDescriptor::VL = 30 PrototypeDescriptor(BaseTypeModifier::SizeT); 31 const PrototypeDescriptor PrototypeDescriptor::Vector = 32 PrototypeDescriptor(BaseTypeModifier::Vector); 33 34 //===----------------------------------------------------------------------===// 35 // Type implementation 36 //===----------------------------------------------------------------------===// 37 38 LMULType::LMULType(int NewLog2LMUL) { 39 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3 40 assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!"); 41 Log2LMUL = NewLog2LMUL; 42 } 43 44 std::string LMULType::str() const { 45 if (Log2LMUL < 0) 46 return "mf" + utostr(1ULL << (-Log2LMUL)); 47 return "m" + utostr(1ULL << Log2LMUL); 48 } 49 50 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { 51 int Log2ScaleResult = 0; 52 switch (ElementBitwidth) { 53 default: 54 break; 55 case 8: 56 Log2ScaleResult = Log2LMUL + 3; 57 break; 58 case 16: 59 Log2ScaleResult = Log2LMUL + 2; 60 break; 61 case 32: 62 Log2ScaleResult = Log2LMUL + 1; 63 break; 64 case 64: 65 Log2ScaleResult = Log2LMUL; 66 break; 67 } 68 // Illegal vscale result would be less than 1 69 if (Log2ScaleResult < 0) 70 return llvm::None; 71 return 1 << Log2ScaleResult; 72 } 73 74 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; } 75 76 RVVType::RVVType(BasicType BT, int Log2LMUL, 77 const PrototypeDescriptor &prototype) 78 : BT(BT), LMUL(LMULType(Log2LMUL)) { 79 applyBasicType(); 80 applyModifier(prototype); 81 Valid = verifyType(); 82 if (Valid) { 83 initBuiltinStr(); 84 initTypeStr(); 85 if (isVector()) { 86 initClangBuiltinStr(); 87 } 88 } 89 } 90 91 // clang-format off 92 // boolean type are encoded the ratio of n (SEW/LMUL) 93 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64 94 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t 95 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1 96 97 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 98 // -------- |------ | -------- | ------- | ------- | -------- | -------- | -------- 99 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64 100 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32 101 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16 102 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8 103 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 104 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 105 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 106 // clang-format on 107 108 bool RVVType::verifyType() const { 109 if (ScalarType == Invalid) 110 return false; 111 if (isScalar()) 112 return true; 113 if (!Scale) 114 return false; 115 if (isFloat() && ElementBitwidth == 8) 116 return false; 117 unsigned V = Scale.value(); 118 switch (ElementBitwidth) { 119 case 1: 120 case 8: 121 // Check Scale is 1,2,4,8,16,32,64 122 return (V <= 64 && isPowerOf2_32(V)); 123 case 16: 124 // Check Scale is 1,2,4,8,16,32 125 return (V <= 32 && isPowerOf2_32(V)); 126 case 32: 127 // Check Scale is 1,2,4,8,16 128 return (V <= 16 && isPowerOf2_32(V)); 129 case 64: 130 // Check Scale is 1,2,4,8 131 return (V <= 8 && isPowerOf2_32(V)); 132 } 133 return false; 134 } 135 136 void RVVType::initBuiltinStr() { 137 assert(isValid() && "RVVType is invalid"); 138 switch (ScalarType) { 139 case ScalarTypeKind::Void: 140 BuiltinStr = "v"; 141 return; 142 case ScalarTypeKind::Size_t: 143 BuiltinStr = "z"; 144 if (IsImmediate) 145 BuiltinStr = "I" + BuiltinStr; 146 if (IsPointer) 147 BuiltinStr += "*"; 148 return; 149 case ScalarTypeKind::Ptrdiff_t: 150 BuiltinStr = "Y"; 151 return; 152 case ScalarTypeKind::UnsignedLong: 153 BuiltinStr = "ULi"; 154 return; 155 case ScalarTypeKind::SignedLong: 156 BuiltinStr = "Li"; 157 return; 158 case ScalarTypeKind::Boolean: 159 assert(ElementBitwidth == 1); 160 BuiltinStr += "b"; 161 break; 162 case ScalarTypeKind::SignedInteger: 163 case ScalarTypeKind::UnsignedInteger: 164 switch (ElementBitwidth) { 165 case 8: 166 BuiltinStr += "c"; 167 break; 168 case 16: 169 BuiltinStr += "s"; 170 break; 171 case 32: 172 BuiltinStr += "i"; 173 break; 174 case 64: 175 BuiltinStr += "Wi"; 176 break; 177 default: 178 llvm_unreachable("Unhandled ElementBitwidth!"); 179 } 180 if (isSignedInteger()) 181 BuiltinStr = "S" + BuiltinStr; 182 else 183 BuiltinStr = "U" + BuiltinStr; 184 break; 185 case ScalarTypeKind::Float: 186 switch (ElementBitwidth) { 187 case 16: 188 BuiltinStr += "x"; 189 break; 190 case 32: 191 BuiltinStr += "f"; 192 break; 193 case 64: 194 BuiltinStr += "d"; 195 break; 196 default: 197 llvm_unreachable("Unhandled ElementBitwidth!"); 198 } 199 break; 200 default: 201 llvm_unreachable("ScalarType is invalid!"); 202 } 203 if (IsImmediate) 204 BuiltinStr = "I" + BuiltinStr; 205 if (isScalar()) { 206 if (IsConstant) 207 BuiltinStr += "C"; 208 if (IsPointer) 209 BuiltinStr += "*"; 210 return; 211 } 212 BuiltinStr = "q" + utostr(*Scale) + BuiltinStr; 213 // Pointer to vector types. Defined for segment load intrinsics. 214 // segment load intrinsics have pointer type arguments to store the loaded 215 // vector values. 216 if (IsPointer) 217 BuiltinStr += "*"; 218 } 219 220 void RVVType::initClangBuiltinStr() { 221 assert(isValid() && "RVVType is invalid"); 222 assert(isVector() && "Handle Vector type only"); 223 224 ClangBuiltinStr = "__rvv_"; 225 switch (ScalarType) { 226 case ScalarTypeKind::Boolean: 227 ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t"; 228 return; 229 case ScalarTypeKind::Float: 230 ClangBuiltinStr += "float"; 231 break; 232 case ScalarTypeKind::SignedInteger: 233 ClangBuiltinStr += "int"; 234 break; 235 case ScalarTypeKind::UnsignedInteger: 236 ClangBuiltinStr += "uint"; 237 break; 238 default: 239 llvm_unreachable("ScalarTypeKind is invalid"); 240 } 241 ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t"; 242 } 243 244 void RVVType::initTypeStr() { 245 assert(isValid() && "RVVType is invalid"); 246 247 if (IsConstant) 248 Str += "const "; 249 250 auto getTypeString = [&](StringRef TypeStr) { 251 if (isScalar()) 252 return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str(); 253 return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t") 254 .str(); 255 }; 256 257 switch (ScalarType) { 258 case ScalarTypeKind::Void: 259 Str = "void"; 260 return; 261 case ScalarTypeKind::Size_t: 262 Str = "size_t"; 263 if (IsPointer) 264 Str += " *"; 265 return; 266 case ScalarTypeKind::Ptrdiff_t: 267 Str = "ptrdiff_t"; 268 return; 269 case ScalarTypeKind::UnsignedLong: 270 Str = "unsigned long"; 271 return; 272 case ScalarTypeKind::SignedLong: 273 Str = "long"; 274 return; 275 case ScalarTypeKind::Boolean: 276 if (isScalar()) 277 Str += "bool"; 278 else 279 // Vector bool is special case, the formulate is 280 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1 281 Str += "vbool" + utostr(64 / *Scale) + "_t"; 282 break; 283 case ScalarTypeKind::Float: 284 if (isScalar()) { 285 if (ElementBitwidth == 64) 286 Str += "double"; 287 else if (ElementBitwidth == 32) 288 Str += "float"; 289 else if (ElementBitwidth == 16) 290 Str += "_Float16"; 291 else 292 llvm_unreachable("Unhandled floating type."); 293 } else 294 Str += getTypeString("float"); 295 break; 296 case ScalarTypeKind::SignedInteger: 297 Str += getTypeString("int"); 298 break; 299 case ScalarTypeKind::UnsignedInteger: 300 Str += getTypeString("uint"); 301 break; 302 default: 303 llvm_unreachable("ScalarType is invalid!"); 304 } 305 if (IsPointer) 306 Str += " *"; 307 } 308 309 void RVVType::initShortStr() { 310 switch (ScalarType) { 311 case ScalarTypeKind::Boolean: 312 assert(isVector()); 313 ShortStr = "b" + utostr(64 / *Scale); 314 return; 315 case ScalarTypeKind::Float: 316 ShortStr = "f" + utostr(ElementBitwidth); 317 break; 318 case ScalarTypeKind::SignedInteger: 319 ShortStr = "i" + utostr(ElementBitwidth); 320 break; 321 case ScalarTypeKind::UnsignedInteger: 322 ShortStr = "u" + utostr(ElementBitwidth); 323 break; 324 default: 325 llvm_unreachable("Unhandled case!"); 326 } 327 if (isVector()) 328 ShortStr += LMUL.str(); 329 } 330 331 void RVVType::applyBasicType() { 332 switch (BT) { 333 case BasicType::Int8: 334 ElementBitwidth = 8; 335 ScalarType = ScalarTypeKind::SignedInteger; 336 break; 337 case BasicType::Int16: 338 ElementBitwidth = 16; 339 ScalarType = ScalarTypeKind::SignedInteger; 340 break; 341 case BasicType::Int32: 342 ElementBitwidth = 32; 343 ScalarType = ScalarTypeKind::SignedInteger; 344 break; 345 case BasicType::Int64: 346 ElementBitwidth = 64; 347 ScalarType = ScalarTypeKind::SignedInteger; 348 break; 349 case BasicType::Float16: 350 ElementBitwidth = 16; 351 ScalarType = ScalarTypeKind::Float; 352 break; 353 case BasicType::Float32: 354 ElementBitwidth = 32; 355 ScalarType = ScalarTypeKind::Float; 356 break; 357 case BasicType::Float64: 358 ElementBitwidth = 64; 359 ScalarType = ScalarTypeKind::Float; 360 break; 361 default: 362 llvm_unreachable("Unhandled type code!"); 363 } 364 assert(ElementBitwidth != 0 && "Bad element bitwidth!"); 365 } 366 367 Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor( 368 llvm::StringRef PrototypeDescriptorStr) { 369 PrototypeDescriptor PD; 370 BaseTypeModifier PT = BaseTypeModifier::Invalid; 371 VectorTypeModifier VTM = VectorTypeModifier::NoModifier; 372 373 if (PrototypeDescriptorStr.empty()) 374 return PD; 375 376 // Handle base type modifier 377 auto PType = PrototypeDescriptorStr.back(); 378 switch (PType) { 379 case 'e': 380 PT = BaseTypeModifier::Scalar; 381 break; 382 case 'v': 383 PT = BaseTypeModifier::Vector; 384 break; 385 case 'w': 386 PT = BaseTypeModifier::Vector; 387 VTM = VectorTypeModifier::Widening2XVector; 388 break; 389 case 'q': 390 PT = BaseTypeModifier::Vector; 391 VTM = VectorTypeModifier::Widening4XVector; 392 break; 393 case 'o': 394 PT = BaseTypeModifier::Vector; 395 VTM = VectorTypeModifier::Widening8XVector; 396 break; 397 case 'm': 398 PT = BaseTypeModifier::Vector; 399 VTM = VectorTypeModifier::MaskVector; 400 break; 401 case '0': 402 PT = BaseTypeModifier::Void; 403 break; 404 case 'z': 405 PT = BaseTypeModifier::SizeT; 406 break; 407 case 't': 408 PT = BaseTypeModifier::Ptrdiff; 409 break; 410 case 'u': 411 PT = BaseTypeModifier::UnsignedLong; 412 break; 413 case 'l': 414 PT = BaseTypeModifier::SignedLong; 415 break; 416 default: 417 llvm_unreachable("Illegal primitive type transformers!"); 418 } 419 PD.PT = static_cast<uint8_t>(PT); 420 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back(); 421 422 // Compute the vector type transformers, it can only appear one time. 423 if (PrototypeDescriptorStr.startswith("(")) { 424 assert(VTM == VectorTypeModifier::NoModifier && 425 "VectorTypeModifier should only have one modifier"); 426 size_t Idx = PrototypeDescriptorStr.find(')'); 427 assert(Idx != StringRef::npos); 428 StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx); 429 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1); 430 assert(!PrototypeDescriptorStr.contains('(') && 431 "Only allow one vector type modifier"); 432 433 auto ComplexTT = ComplexType.split(":"); 434 if (ComplexTT.first == "Log2EEW") { 435 uint32_t Log2EEW; 436 if (ComplexTT.second.getAsInteger(10, Log2EEW)) { 437 llvm_unreachable("Invalid Log2EEW value!"); 438 return None; 439 } 440 switch (Log2EEW) { 441 case 3: 442 VTM = VectorTypeModifier::Log2EEW3; 443 break; 444 case 4: 445 VTM = VectorTypeModifier::Log2EEW4; 446 break; 447 case 5: 448 VTM = VectorTypeModifier::Log2EEW5; 449 break; 450 case 6: 451 VTM = VectorTypeModifier::Log2EEW6; 452 break; 453 default: 454 llvm_unreachable("Invalid Log2EEW value, should be [3-6]"); 455 return None; 456 } 457 } else if (ComplexTT.first == "FixedSEW") { 458 uint32_t NewSEW; 459 if (ComplexTT.second.getAsInteger(10, NewSEW)) { 460 llvm_unreachable("Invalid FixedSEW value!"); 461 return None; 462 } 463 switch (NewSEW) { 464 case 8: 465 VTM = VectorTypeModifier::FixedSEW8; 466 break; 467 case 16: 468 VTM = VectorTypeModifier::FixedSEW16; 469 break; 470 case 32: 471 VTM = VectorTypeModifier::FixedSEW32; 472 break; 473 case 64: 474 VTM = VectorTypeModifier::FixedSEW64; 475 break; 476 default: 477 llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64"); 478 return None; 479 } 480 } else if (ComplexTT.first == "LFixedLog2LMUL") { 481 int32_t Log2LMUL; 482 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { 483 llvm_unreachable("Invalid LFixedLog2LMUL value!"); 484 return None; 485 } 486 switch (Log2LMUL) { 487 case -3: 488 VTM = VectorTypeModifier::LFixedLog2LMULN3; 489 break; 490 case -2: 491 VTM = VectorTypeModifier::LFixedLog2LMULN2; 492 break; 493 case -1: 494 VTM = VectorTypeModifier::LFixedLog2LMULN1; 495 break; 496 case 0: 497 VTM = VectorTypeModifier::LFixedLog2LMUL0; 498 break; 499 case 1: 500 VTM = VectorTypeModifier::LFixedLog2LMUL1; 501 break; 502 case 2: 503 VTM = VectorTypeModifier::LFixedLog2LMUL2; 504 break; 505 case 3: 506 VTM = VectorTypeModifier::LFixedLog2LMUL3; 507 break; 508 default: 509 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); 510 return None; 511 } 512 } else if (ComplexTT.first == "SFixedLog2LMUL") { 513 int32_t Log2LMUL; 514 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { 515 llvm_unreachable("Invalid SFixedLog2LMUL value!"); 516 return None; 517 } 518 switch (Log2LMUL) { 519 case -3: 520 VTM = VectorTypeModifier::SFixedLog2LMULN3; 521 break; 522 case -2: 523 VTM = VectorTypeModifier::SFixedLog2LMULN2; 524 break; 525 case -1: 526 VTM = VectorTypeModifier::SFixedLog2LMULN1; 527 break; 528 case 0: 529 VTM = VectorTypeModifier::SFixedLog2LMUL0; 530 break; 531 case 1: 532 VTM = VectorTypeModifier::SFixedLog2LMUL1; 533 break; 534 case 2: 535 VTM = VectorTypeModifier::SFixedLog2LMUL2; 536 break; 537 case 3: 538 VTM = VectorTypeModifier::SFixedLog2LMUL3; 539 break; 540 default: 541 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); 542 return None; 543 } 544 545 } else { 546 llvm_unreachable("Illegal complex type transformers!"); 547 } 548 } 549 PD.VTM = static_cast<uint8_t>(VTM); 550 551 // Compute the remain type transformers 552 TypeModifier TM = TypeModifier::NoModifier; 553 for (char I : PrototypeDescriptorStr) { 554 switch (I) { 555 case 'P': 556 if ((TM & TypeModifier::Const) == TypeModifier::Const) 557 llvm_unreachable("'P' transformer cannot be used after 'C'"); 558 if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer) 559 llvm_unreachable("'P' transformer cannot be used twice"); 560 TM |= TypeModifier::Pointer; 561 break; 562 case 'C': 563 TM |= TypeModifier::Const; 564 break; 565 case 'K': 566 TM |= TypeModifier::Immediate; 567 break; 568 case 'U': 569 TM |= TypeModifier::UnsignedInteger; 570 break; 571 case 'I': 572 TM |= TypeModifier::SignedInteger; 573 break; 574 case 'F': 575 TM |= TypeModifier::Float; 576 break; 577 case 'S': 578 TM |= TypeModifier::LMUL1; 579 break; 580 default: 581 llvm_unreachable("Illegal non-primitive type transformer!"); 582 } 583 } 584 PD.TM = static_cast<uint8_t>(TM); 585 586 return PD; 587 } 588 589 void RVVType::applyModifier(const PrototypeDescriptor &Transformer) { 590 // Handle primitive type transformer 591 switch (static_cast<BaseTypeModifier>(Transformer.PT)) { 592 case BaseTypeModifier::Scalar: 593 Scale = 0; 594 break; 595 case BaseTypeModifier::Vector: 596 Scale = LMUL.getScale(ElementBitwidth); 597 break; 598 case BaseTypeModifier::Void: 599 ScalarType = ScalarTypeKind::Void; 600 break; 601 case BaseTypeModifier::SizeT: 602 ScalarType = ScalarTypeKind::Size_t; 603 break; 604 case BaseTypeModifier::Ptrdiff: 605 ScalarType = ScalarTypeKind::Ptrdiff_t; 606 break; 607 case BaseTypeModifier::UnsignedLong: 608 ScalarType = ScalarTypeKind::UnsignedLong; 609 break; 610 case BaseTypeModifier::SignedLong: 611 ScalarType = ScalarTypeKind::SignedLong; 612 break; 613 case BaseTypeModifier::Invalid: 614 ScalarType = ScalarTypeKind::Invalid; 615 return; 616 } 617 618 switch (static_cast<VectorTypeModifier>(Transformer.VTM)) { 619 case VectorTypeModifier::Widening2XVector: 620 ElementBitwidth *= 2; 621 LMUL.MulLog2LMUL(1); 622 Scale = LMUL.getScale(ElementBitwidth); 623 break; 624 case VectorTypeModifier::Widening4XVector: 625 ElementBitwidth *= 4; 626 LMUL.MulLog2LMUL(2); 627 Scale = LMUL.getScale(ElementBitwidth); 628 break; 629 case VectorTypeModifier::Widening8XVector: 630 ElementBitwidth *= 8; 631 LMUL.MulLog2LMUL(3); 632 Scale = LMUL.getScale(ElementBitwidth); 633 break; 634 case VectorTypeModifier::MaskVector: 635 ScalarType = ScalarTypeKind::Boolean; 636 Scale = LMUL.getScale(ElementBitwidth); 637 ElementBitwidth = 1; 638 break; 639 case VectorTypeModifier::Log2EEW3: 640 applyLog2EEW(3); 641 break; 642 case VectorTypeModifier::Log2EEW4: 643 applyLog2EEW(4); 644 break; 645 case VectorTypeModifier::Log2EEW5: 646 applyLog2EEW(5); 647 break; 648 case VectorTypeModifier::Log2EEW6: 649 applyLog2EEW(6); 650 break; 651 case VectorTypeModifier::FixedSEW8: 652 applyFixedSEW(8); 653 break; 654 case VectorTypeModifier::FixedSEW16: 655 applyFixedSEW(16); 656 break; 657 case VectorTypeModifier::FixedSEW32: 658 applyFixedSEW(32); 659 break; 660 case VectorTypeModifier::FixedSEW64: 661 applyFixedSEW(64); 662 break; 663 case VectorTypeModifier::LFixedLog2LMULN3: 664 applyFixedLog2LMUL(-3, FixedLMULType::LargerThan); 665 break; 666 case VectorTypeModifier::LFixedLog2LMULN2: 667 applyFixedLog2LMUL(-2, FixedLMULType::LargerThan); 668 break; 669 case VectorTypeModifier::LFixedLog2LMULN1: 670 applyFixedLog2LMUL(-1, FixedLMULType::LargerThan); 671 break; 672 case VectorTypeModifier::LFixedLog2LMUL0: 673 applyFixedLog2LMUL(0, FixedLMULType::LargerThan); 674 break; 675 case VectorTypeModifier::LFixedLog2LMUL1: 676 applyFixedLog2LMUL(1, FixedLMULType::LargerThan); 677 break; 678 case VectorTypeModifier::LFixedLog2LMUL2: 679 applyFixedLog2LMUL(2, FixedLMULType::LargerThan); 680 break; 681 case VectorTypeModifier::LFixedLog2LMUL3: 682 applyFixedLog2LMUL(3, FixedLMULType::LargerThan); 683 break; 684 case VectorTypeModifier::SFixedLog2LMULN3: 685 applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan); 686 break; 687 case VectorTypeModifier::SFixedLog2LMULN2: 688 applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan); 689 break; 690 case VectorTypeModifier::SFixedLog2LMULN1: 691 applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan); 692 break; 693 case VectorTypeModifier::SFixedLog2LMUL0: 694 applyFixedLog2LMUL(0, FixedLMULType::SmallerThan); 695 break; 696 case VectorTypeModifier::SFixedLog2LMUL1: 697 applyFixedLog2LMUL(1, FixedLMULType::SmallerThan); 698 break; 699 case VectorTypeModifier::SFixedLog2LMUL2: 700 applyFixedLog2LMUL(2, FixedLMULType::SmallerThan); 701 break; 702 case VectorTypeModifier::SFixedLog2LMUL3: 703 applyFixedLog2LMUL(3, FixedLMULType::SmallerThan); 704 break; 705 case VectorTypeModifier::NoModifier: 706 break; 707 } 708 709 for (unsigned TypeModifierMaskShift = 0; 710 TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset); 711 ++TypeModifierMaskShift) { 712 unsigned TypeModifierMask = 1 << TypeModifierMaskShift; 713 if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) != 714 TypeModifierMask) 715 continue; 716 switch (static_cast<TypeModifier>(TypeModifierMask)) { 717 case TypeModifier::Pointer: 718 IsPointer = true; 719 break; 720 case TypeModifier::Const: 721 IsConstant = true; 722 break; 723 case TypeModifier::Immediate: 724 IsImmediate = true; 725 IsConstant = true; 726 break; 727 case TypeModifier::UnsignedInteger: 728 ScalarType = ScalarTypeKind::UnsignedInteger; 729 break; 730 case TypeModifier::SignedInteger: 731 ScalarType = ScalarTypeKind::SignedInteger; 732 break; 733 case TypeModifier::Float: 734 ScalarType = ScalarTypeKind::Float; 735 break; 736 case TypeModifier::LMUL1: 737 LMUL = LMULType(0); 738 // Update ElementBitwidth need to update Scale too. 739 Scale = LMUL.getScale(ElementBitwidth); 740 break; 741 default: 742 llvm_unreachable("Unknown type modifier mask!"); 743 } 744 } 745 } 746 747 void RVVType::applyLog2EEW(unsigned Log2EEW) { 748 // update new elmul = (eew/sew) * lmul 749 LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); 750 // update new eew 751 ElementBitwidth = 1 << Log2EEW; 752 ScalarType = ScalarTypeKind::SignedInteger; 753 Scale = LMUL.getScale(ElementBitwidth); 754 } 755 756 void RVVType::applyFixedSEW(unsigned NewSEW) { 757 // Set invalid type if src and dst SEW are same. 758 if (ElementBitwidth == NewSEW) { 759 ScalarType = ScalarTypeKind::Invalid; 760 return; 761 } 762 // Update new SEW 763 ElementBitwidth = NewSEW; 764 Scale = LMUL.getScale(ElementBitwidth); 765 } 766 767 void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) { 768 switch (Type) { 769 case FixedLMULType::LargerThan: 770 if (Log2LMUL < LMUL.Log2LMUL) { 771 ScalarType = ScalarTypeKind::Invalid; 772 return; 773 } 774 break; 775 case FixedLMULType::SmallerThan: 776 if (Log2LMUL > LMUL.Log2LMUL) { 777 ScalarType = ScalarTypeKind::Invalid; 778 return; 779 } 780 break; 781 } 782 783 // Update new LMUL 784 LMUL = LMULType(Log2LMUL); 785 Scale = LMUL.getScale(ElementBitwidth); 786 } 787 788 Optional<RVVTypes> 789 RVVType::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, 790 ArrayRef<PrototypeDescriptor> Prototype) { 791 // LMUL x NF must be less than or equal to 8. 792 if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) 793 return llvm::None; 794 795 RVVTypes Types; 796 for (const PrototypeDescriptor &Proto : Prototype) { 797 auto T = computeType(BT, Log2LMUL, Proto); 798 if (!T) 799 return llvm::None; 800 // Record legal type index 801 Types.push_back(T.value()); 802 } 803 return Types; 804 } 805 806 // Compute the hash value of RVVType, used for cache the result of computeType. 807 static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL, 808 PrototypeDescriptor Proto) { 809 // Layout of hash value: 810 // 0 8 16 24 32 40 811 // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM | 812 assert(Log2LMUL >= -3 && Log2LMUL <= 3); 813 return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 | 814 ((uint64_t)(Proto.PT & 0xff) << 16) | 815 ((uint64_t)(Proto.TM & 0xff) << 24) | 816 ((uint64_t)(Proto.VTM & 0xff) << 32); 817 } 818 819 Optional<RVVTypePtr> RVVType::computeType(BasicType BT, int Log2LMUL, 820 PrototypeDescriptor Proto) { 821 // Concat BasicType, LMUL and Proto as key 822 static std::unordered_map<uint64_t, RVVType> LegalTypes; 823 static std::set<uint64_t> IllegalTypes; 824 uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto); 825 // Search first 826 auto It = LegalTypes.find(Idx); 827 if (It != LegalTypes.end()) 828 return &(It->second); 829 830 if (IllegalTypes.count(Idx)) 831 return llvm::None; 832 833 // Compute type and record the result. 834 RVVType T(BT, Log2LMUL, Proto); 835 if (T.isValid()) { 836 // Record legal type index and value. 837 LegalTypes.insert({Idx, T}); 838 return &(LegalTypes[Idx]); 839 } 840 // Record illegal type index. 841 IllegalTypes.insert(Idx); 842 return llvm::None; 843 } 844 845 //===----------------------------------------------------------------------===// 846 // RVVIntrinsic implementation 847 //===----------------------------------------------------------------------===// 848 RVVIntrinsic::RVVIntrinsic( 849 StringRef NewName, StringRef Suffix, StringRef NewOverloadedName, 850 StringRef OverloadedSuffix, StringRef IRName, bool IsMasked, 851 bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme, 852 bool HasUnMaskedOverloaded, bool HasBuiltinAlias, StringRef ManualCodegen, 853 const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes, 854 const std::vector<StringRef> &RequiredFeatures, unsigned NF) 855 : IRName(IRName), IsMasked(IsMasked), HasVL(HasVL), Scheme(Scheme), 856 HasUnMaskedOverloaded(HasUnMaskedOverloaded), 857 HasBuiltinAlias(HasBuiltinAlias), ManualCodegen(ManualCodegen.str()), 858 NF(NF) { 859 860 // Init BuiltinName, Name and OverloadedName 861 BuiltinName = NewName.str(); 862 Name = BuiltinName; 863 if (NewOverloadedName.empty()) 864 OverloadedName = NewName.split("_").first.str(); 865 else 866 OverloadedName = NewOverloadedName.str(); 867 if (!Suffix.empty()) 868 Name += "_" + Suffix.str(); 869 if (!OverloadedSuffix.empty()) 870 OverloadedName += "_" + OverloadedSuffix.str(); 871 if (IsMasked) { 872 BuiltinName += "_m"; 873 Name += "_m"; 874 } 875 876 // Init OutputType and InputTypes 877 OutputType = OutInTypes[0]; 878 InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end()); 879 880 // IntrinsicTypes is unmasked TA version index. Need to update it 881 // if there is merge operand (It is always in first operand). 882 IntrinsicTypes = NewIntrinsicTypes; 883 if ((IsMasked && HasMaskedOffOperand) || 884 (!IsMasked && hasPassthruOperand())) { 885 for (auto &I : IntrinsicTypes) { 886 if (I >= 0) 887 I += NF; 888 } 889 } 890 } 891 892 std::string RVVIntrinsic::getBuiltinTypeStr() const { 893 std::string S; 894 S += OutputType->getBuiltinStr(); 895 for (const auto &T : InputTypes) { 896 S += T->getBuiltinStr(); 897 } 898 return S; 899 } 900 901 std::string RVVIntrinsic::getSuffixStr( 902 BasicType Type, int Log2LMUL, 903 llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) { 904 SmallVector<std::string> SuffixStrs; 905 for (auto PD : PrototypeDescriptors) { 906 auto T = RVVType::computeType(Type, Log2LMUL, PD); 907 SuffixStrs.push_back((*T)->getShortStr()); 908 } 909 return join(SuffixStrs, "_"); 910 } 911 912 llvm::SmallVector<PrototypeDescriptor> 913 RVVIntrinsic::computeBuiltinTypes(llvm::ArrayRef<PrototypeDescriptor> Prototype, 914 bool IsMasked, bool HasMaskedOffOperand, 915 bool HasVL, unsigned NF) { 916 SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(), 917 Prototype.end()); 918 if (IsMasked) { 919 // If HasMaskedOffOperand, insert result type as first input operand. 920 if (HasMaskedOffOperand) { 921 if (NF == 1) { 922 NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]); 923 } else { 924 // Convert 925 // (void, op0 address, op1 address, ...) 926 // to 927 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) 928 PrototypeDescriptor MaskoffType = NewPrototype[1]; 929 MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer); 930 for (unsigned I = 0; I < NF; ++I) 931 NewPrototype.insert(NewPrototype.begin() + NF + 1, MaskoffType); 932 } 933 } 934 if (HasMaskedOffOperand && NF > 1) { 935 // Convert 936 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) 937 // to 938 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1, 939 // ...) 940 NewPrototype.insert(NewPrototype.begin() + NF + 1, 941 PrototypeDescriptor::Mask); 942 } else { 943 // If IsMasked, insert PrototypeDescriptor:Mask as first input operand. 944 NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask); 945 } 946 } 947 948 // If HasVL, append PrototypeDescriptor:VL to last operand 949 if (HasVL) 950 NewPrototype.push_back(PrototypeDescriptor::VL); 951 return NewPrototype; 952 } 953 954 SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) { 955 SmallVector<PrototypeDescriptor> PrototypeDescriptors; 956 const StringRef Primaries("evwqom0ztul"); 957 while (!Prototypes.empty()) { 958 size_t Idx = 0; 959 // Skip over complex prototype because it could contain primitive type 960 // character. 961 if (Prototypes[0] == '(') 962 Idx = Prototypes.find_first_of(')'); 963 Idx = Prototypes.find_first_of(Primaries, Idx); 964 assert(Idx != StringRef::npos); 965 auto PD = PrototypeDescriptor::parsePrototypeDescriptor( 966 Prototypes.slice(0, Idx + 1)); 967 if (!PD) 968 llvm_unreachable("Error during parsing prototype."); 969 PrototypeDescriptors.push_back(*PD); 970 Prototypes = Prototypes.drop_front(Idx + 1); 971 } 972 return PrototypeDescriptors; 973 } 974 975 raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) { 976 OS << "{"; 977 OS << "\"" << Record.Name << "\","; 978 if (Record.OverloadedName == nullptr || 979 StringRef(Record.OverloadedName).empty()) 980 OS << "nullptr,"; 981 else 982 OS << "\"" << Record.OverloadedName << "\","; 983 OS << Record.PrototypeIndex << ","; 984 OS << Record.SuffixIndex << ","; 985 OS << Record.OverloadedSuffixIndex << ","; 986 OS << (int)Record.PrototypeLength << ","; 987 OS << (int)Record.SuffixLength << ","; 988 OS << (int)Record.OverloadedSuffixSize << ","; 989 OS << (int)Record.RequiredExtensions << ","; 990 OS << (int)Record.TypeRangeMask << ","; 991 OS << (int)Record.Log2LMULMask << ","; 992 OS << (int)Record.NF << ","; 993 OS << (int)Record.HasMasked << ","; 994 OS << (int)Record.HasVL << ","; 995 OS << (int)Record.HasMaskedOffOperand << ","; 996 OS << "},\n"; 997 return OS; 998 } 999 1000 } // end namespace RISCV 1001 } // end namespace clang 1002