xref: /freebsd/contrib/llvm-project/clang/lib/Support/RISCVVIntrinsicUtils.cpp (revision 19fae0f66023a97a9b464b3beeeabb2081f575b3)
1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/Support/RISCVVIntrinsicUtils.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/SmallSet.h"
12 #include "llvm/ADT/StringExtras.h"
13 #include "llvm/ADT/StringMap.h"
14 #include "llvm/ADT/StringSet.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Support/ErrorHandling.h"
17 #include "llvm/Support/raw_ostream.h"
18 #include <numeric>
19 #include <optional>
20 
21 using namespace llvm;
22 
23 namespace clang {
24 namespace RISCV {
25 
26 const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor(
27     BaseTypeModifier::Vector, VectorTypeModifier::MaskVector);
28 const PrototypeDescriptor PrototypeDescriptor::VL =
29     PrototypeDescriptor(BaseTypeModifier::SizeT);
30 const PrototypeDescriptor PrototypeDescriptor::Vector =
31     PrototypeDescriptor(BaseTypeModifier::Vector);
32 
33 //===----------------------------------------------------------------------===//
34 // Type implementation
35 //===----------------------------------------------------------------------===//
36 
37 LMULType::LMULType(int NewLog2LMUL) {
38   // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
39   assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
40   Log2LMUL = NewLog2LMUL;
41 }
42 
43 std::string LMULType::str() const {
44   if (Log2LMUL < 0)
45     return "mf" + utostr(1ULL << (-Log2LMUL));
46   return "m" + utostr(1ULL << Log2LMUL);
47 }
48 
49 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
50   int Log2ScaleResult = 0;
51   switch (ElementBitwidth) {
52   default:
53     break;
54   case 8:
55     Log2ScaleResult = Log2LMUL + 3;
56     break;
57   case 16:
58     Log2ScaleResult = Log2LMUL + 2;
59     break;
60   case 32:
61     Log2ScaleResult = Log2LMUL + 1;
62     break;
63   case 64:
64     Log2ScaleResult = Log2LMUL;
65     break;
66   }
67   // Illegal vscale result would be less than 1
68   if (Log2ScaleResult < 0)
69     return std::nullopt;
70   return 1 << Log2ScaleResult;
71 }
72 
73 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
74 
75 RVVType::RVVType(BasicType BT, int Log2LMUL,
76                  const PrototypeDescriptor &prototype)
77     : BT(BT), LMUL(LMULType(Log2LMUL)) {
78   applyBasicType();
79   applyModifier(prototype);
80   Valid = verifyType();
81   if (Valid) {
82     initBuiltinStr();
83     initTypeStr();
84     if (isVector()) {
85       initClangBuiltinStr();
86     }
87   }
88 }
89 
90 // clang-format off
91 // boolean type are encoded the ratio of n (SEW/LMUL)
92 // SEW/LMUL | 1         | 2         | 4         | 8        | 16        | 32        | 64
93 // c type   | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t  | vbool2_t  | vbool1_t
94 // IR type  | nxv1i1    | nxv2i1    | nxv4i1    | nxv8i1   | nxv16i1   | nxv32i1   | nxv64i1
95 
96 // type\lmul | 1/8    | 1/4      | 1/2     | 1       | 2        | 4        | 8
97 // --------  |------  | -------- | ------- | ------- | -------- | -------- | --------
98 // i64       | N/A    | N/A      | N/A     | nxv1i64 | nxv2i64  | nxv4i64  | nxv8i64
99 // i32       | N/A    | N/A      | nxv1i32 | nxv2i32 | nxv4i32  | nxv8i32  | nxv16i32
100 // i16       | N/A    | nxv1i16  | nxv2i16 | nxv4i16 | nxv8i16  | nxv16i16 | nxv32i16
101 // i8        | nxv1i8 | nxv2i8   | nxv4i8  | nxv8i8  | nxv16i8  | nxv32i8  | nxv64i8
102 // double    | N/A    | N/A      | N/A     | nxv1f64 | nxv2f64  | nxv4f64  | nxv8f64
103 // float     | N/A    | N/A      | nxv1f32 | nxv2f32 | nxv4f32  | nxv8f32  | nxv16f32
104 // half      | N/A    | nxv1f16  | nxv2f16 | nxv4f16 | nxv8f16  | nxv16f16 | nxv32f16
105 // clang-format on
106 
107 bool RVVType::verifyType() const {
108   if (ScalarType == Invalid)
109     return false;
110   if (isScalar())
111     return true;
112   if (!Scale)
113     return false;
114   if (isFloat() && ElementBitwidth == 8)
115     return false;
116   unsigned V = *Scale;
117   switch (ElementBitwidth) {
118   case 1:
119   case 8:
120     // Check Scale is 1,2,4,8,16,32,64
121     return (V <= 64 && isPowerOf2_32(V));
122   case 16:
123     // Check Scale is 1,2,4,8,16,32
124     return (V <= 32 && isPowerOf2_32(V));
125   case 32:
126     // Check Scale is 1,2,4,8,16
127     return (V <= 16 && isPowerOf2_32(V));
128   case 64:
129     // Check Scale is 1,2,4,8
130     return (V <= 8 && isPowerOf2_32(V));
131   }
132   return false;
133 }
134 
135 void RVVType::initBuiltinStr() {
136   assert(isValid() && "RVVType is invalid");
137   switch (ScalarType) {
138   case ScalarTypeKind::Void:
139     BuiltinStr = "v";
140     return;
141   case ScalarTypeKind::Size_t:
142     BuiltinStr = "z";
143     if (IsImmediate)
144       BuiltinStr = "I" + BuiltinStr;
145     if (IsPointer)
146       BuiltinStr += "*";
147     return;
148   case ScalarTypeKind::Ptrdiff_t:
149     BuiltinStr = "Y";
150     return;
151   case ScalarTypeKind::UnsignedLong:
152     BuiltinStr = "ULi";
153     return;
154   case ScalarTypeKind::SignedLong:
155     BuiltinStr = "Li";
156     return;
157   case ScalarTypeKind::Boolean:
158     assert(ElementBitwidth == 1);
159     BuiltinStr += "b";
160     break;
161   case ScalarTypeKind::SignedInteger:
162   case ScalarTypeKind::UnsignedInteger:
163     switch (ElementBitwidth) {
164     case 8:
165       BuiltinStr += "c";
166       break;
167     case 16:
168       BuiltinStr += "s";
169       break;
170     case 32:
171       BuiltinStr += "i";
172       break;
173     case 64:
174       BuiltinStr += "Wi";
175       break;
176     default:
177       llvm_unreachable("Unhandled ElementBitwidth!");
178     }
179     if (isSignedInteger())
180       BuiltinStr = "S" + BuiltinStr;
181     else
182       BuiltinStr = "U" + BuiltinStr;
183     break;
184   case ScalarTypeKind::Float:
185     switch (ElementBitwidth) {
186     case 16:
187       BuiltinStr += "x";
188       break;
189     case 32:
190       BuiltinStr += "f";
191       break;
192     case 64:
193       BuiltinStr += "d";
194       break;
195     default:
196       llvm_unreachable("Unhandled ElementBitwidth!");
197     }
198     break;
199   default:
200     llvm_unreachable("ScalarType is invalid!");
201   }
202   if (IsImmediate)
203     BuiltinStr = "I" + BuiltinStr;
204   if (isScalar()) {
205     if (IsConstant)
206       BuiltinStr += "C";
207     if (IsPointer)
208       BuiltinStr += "*";
209     return;
210   }
211   BuiltinStr = "q" + utostr(*Scale) + BuiltinStr;
212   // Pointer to vector types. Defined for segment load intrinsics.
213   // segment load intrinsics have pointer type arguments to store the loaded
214   // vector values.
215   if (IsPointer)
216     BuiltinStr += "*";
217 }
218 
219 void RVVType::initClangBuiltinStr() {
220   assert(isValid() && "RVVType is invalid");
221   assert(isVector() && "Handle Vector type only");
222 
223   ClangBuiltinStr = "__rvv_";
224   switch (ScalarType) {
225   case ScalarTypeKind::Boolean:
226     ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t";
227     return;
228   case ScalarTypeKind::Float:
229     ClangBuiltinStr += "float";
230     break;
231   case ScalarTypeKind::SignedInteger:
232     ClangBuiltinStr += "int";
233     break;
234   case ScalarTypeKind::UnsignedInteger:
235     ClangBuiltinStr += "uint";
236     break;
237   default:
238     llvm_unreachable("ScalarTypeKind is invalid");
239   }
240   ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
241 }
242 
243 void RVVType::initTypeStr() {
244   assert(isValid() && "RVVType is invalid");
245 
246   if (IsConstant)
247     Str += "const ";
248 
249   auto getTypeString = [&](StringRef TypeStr) {
250     if (isScalar())
251       return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
252     return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
253         .str();
254   };
255 
256   switch (ScalarType) {
257   case ScalarTypeKind::Void:
258     Str = "void";
259     return;
260   case ScalarTypeKind::Size_t:
261     Str = "size_t";
262     if (IsPointer)
263       Str += " *";
264     return;
265   case ScalarTypeKind::Ptrdiff_t:
266     Str = "ptrdiff_t";
267     return;
268   case ScalarTypeKind::UnsignedLong:
269     Str = "unsigned long";
270     return;
271   case ScalarTypeKind::SignedLong:
272     Str = "long";
273     return;
274   case ScalarTypeKind::Boolean:
275     if (isScalar())
276       Str += "bool";
277     else
278       // Vector bool is special case, the formulate is
279       // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
280       Str += "vbool" + utostr(64 / *Scale) + "_t";
281     break;
282   case ScalarTypeKind::Float:
283     if (isScalar()) {
284       if (ElementBitwidth == 64)
285         Str += "double";
286       else if (ElementBitwidth == 32)
287         Str += "float";
288       else if (ElementBitwidth == 16)
289         Str += "_Float16";
290       else
291         llvm_unreachable("Unhandled floating type.");
292     } else
293       Str += getTypeString("float");
294     break;
295   case ScalarTypeKind::SignedInteger:
296     Str += getTypeString("int");
297     break;
298   case ScalarTypeKind::UnsignedInteger:
299     Str += getTypeString("uint");
300     break;
301   default:
302     llvm_unreachable("ScalarType is invalid!");
303   }
304   if (IsPointer)
305     Str += " *";
306 }
307 
308 void RVVType::initShortStr() {
309   switch (ScalarType) {
310   case ScalarTypeKind::Boolean:
311     assert(isVector());
312     ShortStr = "b" + utostr(64 / *Scale);
313     return;
314   case ScalarTypeKind::Float:
315     ShortStr = "f" + utostr(ElementBitwidth);
316     break;
317   case ScalarTypeKind::SignedInteger:
318     ShortStr = "i" + utostr(ElementBitwidth);
319     break;
320   case ScalarTypeKind::UnsignedInteger:
321     ShortStr = "u" + utostr(ElementBitwidth);
322     break;
323   default:
324     llvm_unreachable("Unhandled case!");
325   }
326   if (isVector())
327     ShortStr += LMUL.str();
328 }
329 
330 void RVVType::applyBasicType() {
331   switch (BT) {
332   case BasicType::Int8:
333     ElementBitwidth = 8;
334     ScalarType = ScalarTypeKind::SignedInteger;
335     break;
336   case BasicType::Int16:
337     ElementBitwidth = 16;
338     ScalarType = ScalarTypeKind::SignedInteger;
339     break;
340   case BasicType::Int32:
341     ElementBitwidth = 32;
342     ScalarType = ScalarTypeKind::SignedInteger;
343     break;
344   case BasicType::Int64:
345     ElementBitwidth = 64;
346     ScalarType = ScalarTypeKind::SignedInteger;
347     break;
348   case BasicType::Float16:
349     ElementBitwidth = 16;
350     ScalarType = ScalarTypeKind::Float;
351     break;
352   case BasicType::Float32:
353     ElementBitwidth = 32;
354     ScalarType = ScalarTypeKind::Float;
355     break;
356   case BasicType::Float64:
357     ElementBitwidth = 64;
358     ScalarType = ScalarTypeKind::Float;
359     break;
360   default:
361     llvm_unreachable("Unhandled type code!");
362   }
363   assert(ElementBitwidth != 0 && "Bad element bitwidth!");
364 }
365 
366 std::optional<PrototypeDescriptor>
367 PrototypeDescriptor::parsePrototypeDescriptor(
368     llvm::StringRef PrototypeDescriptorStr) {
369   PrototypeDescriptor PD;
370   BaseTypeModifier PT = BaseTypeModifier::Invalid;
371   VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
372 
373   if (PrototypeDescriptorStr.empty())
374     return PD;
375 
376   // Handle base type modifier
377   auto PType = PrototypeDescriptorStr.back();
378   switch (PType) {
379   case 'e':
380     PT = BaseTypeModifier::Scalar;
381     break;
382   case 'v':
383     PT = BaseTypeModifier::Vector;
384     break;
385   case 'w':
386     PT = BaseTypeModifier::Vector;
387     VTM = VectorTypeModifier::Widening2XVector;
388     break;
389   case 'q':
390     PT = BaseTypeModifier::Vector;
391     VTM = VectorTypeModifier::Widening4XVector;
392     break;
393   case 'o':
394     PT = BaseTypeModifier::Vector;
395     VTM = VectorTypeModifier::Widening8XVector;
396     break;
397   case 'm':
398     PT = BaseTypeModifier::Vector;
399     VTM = VectorTypeModifier::MaskVector;
400     break;
401   case '0':
402     PT = BaseTypeModifier::Void;
403     break;
404   case 'z':
405     PT = BaseTypeModifier::SizeT;
406     break;
407   case 't':
408     PT = BaseTypeModifier::Ptrdiff;
409     break;
410   case 'u':
411     PT = BaseTypeModifier::UnsignedLong;
412     break;
413   case 'l':
414     PT = BaseTypeModifier::SignedLong;
415     break;
416   default:
417     llvm_unreachable("Illegal primitive type transformers!");
418   }
419   PD.PT = static_cast<uint8_t>(PT);
420   PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back();
421 
422   // Compute the vector type transformers, it can only appear one time.
423   if (PrototypeDescriptorStr.startswith("(")) {
424     assert(VTM == VectorTypeModifier::NoModifier &&
425            "VectorTypeModifier should only have one modifier");
426     size_t Idx = PrototypeDescriptorStr.find(')');
427     assert(Idx != StringRef::npos);
428     StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx);
429     PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1);
430     assert(!PrototypeDescriptorStr.contains('(') &&
431            "Only allow one vector type modifier");
432 
433     auto ComplexTT = ComplexType.split(":");
434     if (ComplexTT.first == "Log2EEW") {
435       uint32_t Log2EEW;
436       if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
437         llvm_unreachable("Invalid Log2EEW value!");
438         return std::nullopt;
439       }
440       switch (Log2EEW) {
441       case 3:
442         VTM = VectorTypeModifier::Log2EEW3;
443         break;
444       case 4:
445         VTM = VectorTypeModifier::Log2EEW4;
446         break;
447       case 5:
448         VTM = VectorTypeModifier::Log2EEW5;
449         break;
450       case 6:
451         VTM = VectorTypeModifier::Log2EEW6;
452         break;
453       default:
454         llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
455         return std::nullopt;
456       }
457     } else if (ComplexTT.first == "FixedSEW") {
458       uint32_t NewSEW;
459       if (ComplexTT.second.getAsInteger(10, NewSEW)) {
460         llvm_unreachable("Invalid FixedSEW value!");
461         return std::nullopt;
462       }
463       switch (NewSEW) {
464       case 8:
465         VTM = VectorTypeModifier::FixedSEW8;
466         break;
467       case 16:
468         VTM = VectorTypeModifier::FixedSEW16;
469         break;
470       case 32:
471         VTM = VectorTypeModifier::FixedSEW32;
472         break;
473       case 64:
474         VTM = VectorTypeModifier::FixedSEW64;
475         break;
476       default:
477         llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
478         return std::nullopt;
479       }
480     } else if (ComplexTT.first == "LFixedLog2LMUL") {
481       int32_t Log2LMUL;
482       if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
483         llvm_unreachable("Invalid LFixedLog2LMUL value!");
484         return std::nullopt;
485       }
486       switch (Log2LMUL) {
487       case -3:
488         VTM = VectorTypeModifier::LFixedLog2LMULN3;
489         break;
490       case -2:
491         VTM = VectorTypeModifier::LFixedLog2LMULN2;
492         break;
493       case -1:
494         VTM = VectorTypeModifier::LFixedLog2LMULN1;
495         break;
496       case 0:
497         VTM = VectorTypeModifier::LFixedLog2LMUL0;
498         break;
499       case 1:
500         VTM = VectorTypeModifier::LFixedLog2LMUL1;
501         break;
502       case 2:
503         VTM = VectorTypeModifier::LFixedLog2LMUL2;
504         break;
505       case 3:
506         VTM = VectorTypeModifier::LFixedLog2LMUL3;
507         break;
508       default:
509         llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
510         return std::nullopt;
511       }
512     } else if (ComplexTT.first == "SFixedLog2LMUL") {
513       int32_t Log2LMUL;
514       if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
515         llvm_unreachable("Invalid SFixedLog2LMUL value!");
516         return std::nullopt;
517       }
518       switch (Log2LMUL) {
519       case -3:
520         VTM = VectorTypeModifier::SFixedLog2LMULN3;
521         break;
522       case -2:
523         VTM = VectorTypeModifier::SFixedLog2LMULN2;
524         break;
525       case -1:
526         VTM = VectorTypeModifier::SFixedLog2LMULN1;
527         break;
528       case 0:
529         VTM = VectorTypeModifier::SFixedLog2LMUL0;
530         break;
531       case 1:
532         VTM = VectorTypeModifier::SFixedLog2LMUL1;
533         break;
534       case 2:
535         VTM = VectorTypeModifier::SFixedLog2LMUL2;
536         break;
537       case 3:
538         VTM = VectorTypeModifier::SFixedLog2LMUL3;
539         break;
540       default:
541         llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
542         return std::nullopt;
543       }
544 
545     } else {
546       llvm_unreachable("Illegal complex type transformers!");
547     }
548   }
549   PD.VTM = static_cast<uint8_t>(VTM);
550 
551   // Compute the remain type transformers
552   TypeModifier TM = TypeModifier::NoModifier;
553   for (char I : PrototypeDescriptorStr) {
554     switch (I) {
555     case 'P':
556       if ((TM & TypeModifier::Const) == TypeModifier::Const)
557         llvm_unreachable("'P' transformer cannot be used after 'C'");
558       if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
559         llvm_unreachable("'P' transformer cannot be used twice");
560       TM |= TypeModifier::Pointer;
561       break;
562     case 'C':
563       TM |= TypeModifier::Const;
564       break;
565     case 'K':
566       TM |= TypeModifier::Immediate;
567       break;
568     case 'U':
569       TM |= TypeModifier::UnsignedInteger;
570       break;
571     case 'I':
572       TM |= TypeModifier::SignedInteger;
573       break;
574     case 'F':
575       TM |= TypeModifier::Float;
576       break;
577     case 'S':
578       TM |= TypeModifier::LMUL1;
579       break;
580     default:
581       llvm_unreachable("Illegal non-primitive type transformer!");
582     }
583   }
584   PD.TM = static_cast<uint8_t>(TM);
585 
586   return PD;
587 }
588 
589 void RVVType::applyModifier(const PrototypeDescriptor &Transformer) {
590   // Handle primitive type transformer
591   switch (static_cast<BaseTypeModifier>(Transformer.PT)) {
592   case BaseTypeModifier::Scalar:
593     Scale = 0;
594     break;
595   case BaseTypeModifier::Vector:
596     Scale = LMUL.getScale(ElementBitwidth);
597     break;
598   case BaseTypeModifier::Void:
599     ScalarType = ScalarTypeKind::Void;
600     break;
601   case BaseTypeModifier::SizeT:
602     ScalarType = ScalarTypeKind::Size_t;
603     break;
604   case BaseTypeModifier::Ptrdiff:
605     ScalarType = ScalarTypeKind::Ptrdiff_t;
606     break;
607   case BaseTypeModifier::UnsignedLong:
608     ScalarType = ScalarTypeKind::UnsignedLong;
609     break;
610   case BaseTypeModifier::SignedLong:
611     ScalarType = ScalarTypeKind::SignedLong;
612     break;
613   case BaseTypeModifier::Invalid:
614     ScalarType = ScalarTypeKind::Invalid;
615     return;
616   }
617 
618   switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
619   case VectorTypeModifier::Widening2XVector:
620     ElementBitwidth *= 2;
621     LMUL.MulLog2LMUL(1);
622     Scale = LMUL.getScale(ElementBitwidth);
623     break;
624   case VectorTypeModifier::Widening4XVector:
625     ElementBitwidth *= 4;
626     LMUL.MulLog2LMUL(2);
627     Scale = LMUL.getScale(ElementBitwidth);
628     break;
629   case VectorTypeModifier::Widening8XVector:
630     ElementBitwidth *= 8;
631     LMUL.MulLog2LMUL(3);
632     Scale = LMUL.getScale(ElementBitwidth);
633     break;
634   case VectorTypeModifier::MaskVector:
635     ScalarType = ScalarTypeKind::Boolean;
636     Scale = LMUL.getScale(ElementBitwidth);
637     ElementBitwidth = 1;
638     break;
639   case VectorTypeModifier::Log2EEW3:
640     applyLog2EEW(3);
641     break;
642   case VectorTypeModifier::Log2EEW4:
643     applyLog2EEW(4);
644     break;
645   case VectorTypeModifier::Log2EEW5:
646     applyLog2EEW(5);
647     break;
648   case VectorTypeModifier::Log2EEW6:
649     applyLog2EEW(6);
650     break;
651   case VectorTypeModifier::FixedSEW8:
652     applyFixedSEW(8);
653     break;
654   case VectorTypeModifier::FixedSEW16:
655     applyFixedSEW(16);
656     break;
657   case VectorTypeModifier::FixedSEW32:
658     applyFixedSEW(32);
659     break;
660   case VectorTypeModifier::FixedSEW64:
661     applyFixedSEW(64);
662     break;
663   case VectorTypeModifier::LFixedLog2LMULN3:
664     applyFixedLog2LMUL(-3, FixedLMULType::LargerThan);
665     break;
666   case VectorTypeModifier::LFixedLog2LMULN2:
667     applyFixedLog2LMUL(-2, FixedLMULType::LargerThan);
668     break;
669   case VectorTypeModifier::LFixedLog2LMULN1:
670     applyFixedLog2LMUL(-1, FixedLMULType::LargerThan);
671     break;
672   case VectorTypeModifier::LFixedLog2LMUL0:
673     applyFixedLog2LMUL(0, FixedLMULType::LargerThan);
674     break;
675   case VectorTypeModifier::LFixedLog2LMUL1:
676     applyFixedLog2LMUL(1, FixedLMULType::LargerThan);
677     break;
678   case VectorTypeModifier::LFixedLog2LMUL2:
679     applyFixedLog2LMUL(2, FixedLMULType::LargerThan);
680     break;
681   case VectorTypeModifier::LFixedLog2LMUL3:
682     applyFixedLog2LMUL(3, FixedLMULType::LargerThan);
683     break;
684   case VectorTypeModifier::SFixedLog2LMULN3:
685     applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan);
686     break;
687   case VectorTypeModifier::SFixedLog2LMULN2:
688     applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan);
689     break;
690   case VectorTypeModifier::SFixedLog2LMULN1:
691     applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan);
692     break;
693   case VectorTypeModifier::SFixedLog2LMUL0:
694     applyFixedLog2LMUL(0, FixedLMULType::SmallerThan);
695     break;
696   case VectorTypeModifier::SFixedLog2LMUL1:
697     applyFixedLog2LMUL(1, FixedLMULType::SmallerThan);
698     break;
699   case VectorTypeModifier::SFixedLog2LMUL2:
700     applyFixedLog2LMUL(2, FixedLMULType::SmallerThan);
701     break;
702   case VectorTypeModifier::SFixedLog2LMUL3:
703     applyFixedLog2LMUL(3, FixedLMULType::SmallerThan);
704     break;
705   case VectorTypeModifier::NoModifier:
706     break;
707   }
708 
709   for (unsigned TypeModifierMaskShift = 0;
710        TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
711        ++TypeModifierMaskShift) {
712     unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
713     if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
714         TypeModifierMask)
715       continue;
716     switch (static_cast<TypeModifier>(TypeModifierMask)) {
717     case TypeModifier::Pointer:
718       IsPointer = true;
719       break;
720     case TypeModifier::Const:
721       IsConstant = true;
722       break;
723     case TypeModifier::Immediate:
724       IsImmediate = true;
725       IsConstant = true;
726       break;
727     case TypeModifier::UnsignedInteger:
728       ScalarType = ScalarTypeKind::UnsignedInteger;
729       break;
730     case TypeModifier::SignedInteger:
731       ScalarType = ScalarTypeKind::SignedInteger;
732       break;
733     case TypeModifier::Float:
734       ScalarType = ScalarTypeKind::Float;
735       break;
736     case TypeModifier::LMUL1:
737       LMUL = LMULType(0);
738       // Update ElementBitwidth need to update Scale too.
739       Scale = LMUL.getScale(ElementBitwidth);
740       break;
741     default:
742       llvm_unreachable("Unknown type modifier mask!");
743     }
744   }
745 }
746 
747 void RVVType::applyLog2EEW(unsigned Log2EEW) {
748   // update new elmul = (eew/sew) * lmul
749   LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
750   // update new eew
751   ElementBitwidth = 1 << Log2EEW;
752   ScalarType = ScalarTypeKind::SignedInteger;
753   Scale = LMUL.getScale(ElementBitwidth);
754 }
755 
756 void RVVType::applyFixedSEW(unsigned NewSEW) {
757   // Set invalid type if src and dst SEW are same.
758   if (ElementBitwidth == NewSEW) {
759     ScalarType = ScalarTypeKind::Invalid;
760     return;
761   }
762   // Update new SEW
763   ElementBitwidth = NewSEW;
764   Scale = LMUL.getScale(ElementBitwidth);
765 }
766 
767 void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
768   switch (Type) {
769   case FixedLMULType::LargerThan:
770     if (Log2LMUL < LMUL.Log2LMUL) {
771       ScalarType = ScalarTypeKind::Invalid;
772       return;
773     }
774     break;
775   case FixedLMULType::SmallerThan:
776     if (Log2LMUL > LMUL.Log2LMUL) {
777       ScalarType = ScalarTypeKind::Invalid;
778       return;
779     }
780     break;
781   }
782 
783   // Update new LMUL
784   LMUL = LMULType(Log2LMUL);
785   Scale = LMUL.getScale(ElementBitwidth);
786 }
787 
788 std::optional<RVVTypes>
789 RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
790                            ArrayRef<PrototypeDescriptor> Prototype) {
791   // LMUL x NF must be less than or equal to 8.
792   if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
793     return std::nullopt;
794 
795   RVVTypes Types;
796   for (const PrototypeDescriptor &Proto : Prototype) {
797     auto T = computeType(BT, Log2LMUL, Proto);
798     if (!T)
799       return std::nullopt;
800     // Record legal type index
801     Types.push_back(*T);
802   }
803   return Types;
804 }
805 
806 // Compute the hash value of RVVType, used for cache the result of computeType.
807 static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
808                                         PrototypeDescriptor Proto) {
809   // Layout of hash value:
810   // 0               8    16          24        32          40
811   // | Log2LMUL + 3  | BT  | Proto.PT | Proto.TM | Proto.VTM |
812   assert(Log2LMUL >= -3 && Log2LMUL <= 3);
813   return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 |
814          ((uint64_t)(Proto.PT & 0xff) << 16) |
815          ((uint64_t)(Proto.TM & 0xff) << 24) |
816          ((uint64_t)(Proto.VTM & 0xff) << 32);
817 }
818 
819 std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL,
820                                                     PrototypeDescriptor Proto) {
821   uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
822   // Search first
823   auto It = LegalTypes.find(Idx);
824   if (It != LegalTypes.end())
825     return &(It->second);
826 
827   if (IllegalTypes.count(Idx))
828     return std::nullopt;
829 
830   // Compute type and record the result.
831   RVVType T(BT, Log2LMUL, Proto);
832   if (T.isValid()) {
833     // Record legal type index and value.
834     std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool>
835         InsertResult = LegalTypes.insert({Idx, T});
836     return &(InsertResult.first->second);
837   }
838   // Record illegal type index.
839   IllegalTypes.insert(Idx);
840   return std::nullopt;
841 }
842 
843 //===----------------------------------------------------------------------===//
844 // RVVIntrinsic implementation
845 //===----------------------------------------------------------------------===//
846 RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix,
847                            StringRef NewOverloadedName,
848                            StringRef OverloadedSuffix, StringRef IRName,
849                            bool IsMasked, bool HasMaskedOffOperand, bool HasVL,
850                            PolicyScheme Scheme, bool SupportOverloading,
851                            bool HasBuiltinAlias, StringRef ManualCodegen,
852                            const RVVTypes &OutInTypes,
853                            const std::vector<int64_t> &NewIntrinsicTypes,
854                            const std::vector<StringRef> &RequiredFeatures,
855                            unsigned NF, Policy NewPolicyAttrs)
856     : IRName(IRName), IsMasked(IsMasked),
857       HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme),
858       SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias),
859       ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) {
860 
861   // Init BuiltinName, Name and OverloadedName
862   BuiltinName = NewName.str();
863   Name = BuiltinName;
864   if (NewOverloadedName.empty())
865     OverloadedName = NewName.split("_").first.str();
866   else
867     OverloadedName = NewOverloadedName.str();
868   if (!Suffix.empty())
869     Name += "_" + Suffix.str();
870   if (!OverloadedSuffix.empty())
871     OverloadedName += "_" + OverloadedSuffix.str();
872 
873   updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName,
874                        PolicyAttrs);
875 
876   // Init OutputType and InputTypes
877   OutputType = OutInTypes[0];
878   InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
879 
880   // IntrinsicTypes is unmasked TA version index. Need to update it
881   // if there is merge operand (It is always in first operand).
882   IntrinsicTypes = NewIntrinsicTypes;
883   if ((IsMasked && hasMaskedOffOperand()) ||
884       (!IsMasked && hasPassthruOperand())) {
885     for (auto &I : IntrinsicTypes) {
886       if (I >= 0)
887         I += NF;
888     }
889   }
890 }
891 
892 std::string RVVIntrinsic::getBuiltinTypeStr() const {
893   std::string S;
894   S += OutputType->getBuiltinStr();
895   for (const auto &T : InputTypes) {
896     S += T->getBuiltinStr();
897   }
898   return S;
899 }
900 
901 std::string RVVIntrinsic::getSuffixStr(
902     RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL,
903     llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) {
904   SmallVector<std::string> SuffixStrs;
905   for (auto PD : PrototypeDescriptors) {
906     auto T = TypeCache.computeType(Type, Log2LMUL, PD);
907     SuffixStrs.push_back((*T)->getShortStr());
908   }
909   return join(SuffixStrs, "_");
910 }
911 
912 llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes(
913     llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked,
914     bool HasMaskedOffOperand, bool HasVL, unsigned NF,
915     PolicyScheme DefaultScheme, Policy PolicyAttrs) {
916   SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(),
917                                                 Prototype.end());
918   bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand;
919   if (IsMasked) {
920     // If HasMaskedOffOperand, insert result type as first input operand if
921     // need.
922     if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) {
923       if (NF == 1) {
924         NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]);
925       } else if (NF > 1) {
926         // Convert
927         // (void, op0 address, op1 address, ...)
928         // to
929         // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
930         PrototypeDescriptor MaskoffType = NewPrototype[1];
931         MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
932         NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
933       }
934     }
935     if (HasMaskedOffOperand && NF > 1) {
936       // Convert
937       // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
938       // to
939       // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
940       // ...)
941       NewPrototype.insert(NewPrototype.begin() + NF + 1,
942                           PrototypeDescriptor::Mask);
943     } else {
944       // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
945       NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask);
946     }
947   } else {
948     if (NF == 1) {
949       if (PolicyAttrs.isTUPolicy() && HasPassthruOp)
950         NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]);
951     } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) {
952       // NF > 1 cases for segment load operations.
953       // Convert
954       // (void, op0 address, op1 address, ...)
955       // to
956       // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
957       PrototypeDescriptor MaskoffType = Prototype[1];
958       MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
959       NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
960     }
961  }
962 
963   // If HasVL, append PrototypeDescriptor:VL to last operand
964   if (HasVL)
965     NewPrototype.push_back(PrototypeDescriptor::VL);
966   return NewPrototype;
967 }
968 
969 llvm::SmallVector<Policy> RVVIntrinsic::getSupportedUnMaskedPolicies() {
970   return {Policy(Policy::PolicyType::Undisturbed)}; // TU
971 }
972 
973 llvm::SmallVector<Policy>
974 RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy,
975                                          bool HasMaskPolicy) {
976   if (HasTailPolicy && HasMaskPolicy)
977     return {Policy(Policy::PolicyType::Undisturbed,
978                    Policy::PolicyType::Agnostic), // TUM
979             Policy(Policy::PolicyType::Undisturbed,
980                    Policy::PolicyType::Undisturbed), // TUMU
981             Policy(Policy::PolicyType::Agnostic,
982                    Policy::PolicyType::Undisturbed)}; // MU
983   if (HasTailPolicy && !HasMaskPolicy)
984     return {Policy(Policy::PolicyType::Undisturbed,
985                    Policy::PolicyType::Agnostic)}; // TU
986   if (!HasTailPolicy && HasMaskPolicy)
987     return {Policy(Policy::PolicyType::Agnostic,
988                    Policy::PolicyType::Undisturbed)}; // MU
989   llvm_unreachable("An RVV instruction should not be without both tail policy "
990                    "and mask policy");
991 }
992 
993 void RVVIntrinsic::updateNamesAndPolicy(bool IsMasked, bool HasPolicy,
994                                         std::string &Name,
995                                         std::string &BuiltinName,
996                                         std::string &OverloadedName,
997                                         Policy &PolicyAttrs) {
998 
999   auto appendPolicySuffix = [&](const std::string &suffix) {
1000     Name += suffix;
1001     BuiltinName += suffix;
1002     OverloadedName += suffix;
1003   };
1004 
1005   // This follows the naming guideline under riscv-c-api-doc to add the
1006   // `__riscv_` suffix for all RVV intrinsics.
1007   Name = "__riscv_" + Name;
1008   OverloadedName = "__riscv_" + OverloadedName;
1009 
1010   if (IsMasked) {
1011     if (PolicyAttrs.isTUMUPolicy())
1012       appendPolicySuffix("_tumu");
1013     else if (PolicyAttrs.isTUMAPolicy())
1014       appendPolicySuffix("_tum");
1015     else if (PolicyAttrs.isTAMUPolicy())
1016       appendPolicySuffix("_mu");
1017     else if (PolicyAttrs.isTAMAPolicy()) {
1018       Name += "_m";
1019       if (HasPolicy)
1020         BuiltinName += "_tama";
1021       else
1022         BuiltinName += "_m";
1023     } else
1024       llvm_unreachable("Unhandled policy condition");
1025   } else {
1026     if (PolicyAttrs.isTUPolicy())
1027       appendPolicySuffix("_tu");
1028     else if (PolicyAttrs.isTAPolicy()) {
1029       if (HasPolicy)
1030         BuiltinName += "_ta";
1031     } else
1032       llvm_unreachable("Unhandled policy condition");
1033   }
1034 }
1035 
1036 SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
1037   SmallVector<PrototypeDescriptor> PrototypeDescriptors;
1038   const StringRef Primaries("evwqom0ztul");
1039   while (!Prototypes.empty()) {
1040     size_t Idx = 0;
1041     // Skip over complex prototype because it could contain primitive type
1042     // character.
1043     if (Prototypes[0] == '(')
1044       Idx = Prototypes.find_first_of(')');
1045     Idx = Prototypes.find_first_of(Primaries, Idx);
1046     assert(Idx != StringRef::npos);
1047     auto PD = PrototypeDescriptor::parsePrototypeDescriptor(
1048         Prototypes.slice(0, Idx + 1));
1049     if (!PD)
1050       llvm_unreachable("Error during parsing prototype.");
1051     PrototypeDescriptors.push_back(*PD);
1052     Prototypes = Prototypes.drop_front(Idx + 1);
1053   }
1054   return PrototypeDescriptors;
1055 }
1056 
1057 raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) {
1058   OS << "{";
1059   OS << "\"" << Record.Name << "\",";
1060   if (Record.OverloadedName == nullptr ||
1061       StringRef(Record.OverloadedName).empty())
1062     OS << "nullptr,";
1063   else
1064     OS << "\"" << Record.OverloadedName << "\",";
1065   OS << Record.PrototypeIndex << ",";
1066   OS << Record.SuffixIndex << ",";
1067   OS << Record.OverloadedSuffixIndex << ",";
1068   OS << (int)Record.PrototypeLength << ",";
1069   OS << (int)Record.SuffixLength << ",";
1070   OS << (int)Record.OverloadedSuffixSize << ",";
1071   OS << (int)Record.RequiredExtensions << ",";
1072   OS << (int)Record.TypeRangeMask << ",";
1073   OS << (int)Record.Log2LMULMask << ",";
1074   OS << (int)Record.NF << ",";
1075   OS << (int)Record.HasMasked << ",";
1076   OS << (int)Record.HasVL << ",";
1077   OS << (int)Record.HasMaskedOffOperand << ",";
1078   OS << (int)Record.HasTailPolicy << ",";
1079   OS << (int)Record.HasMaskPolicy << ",";
1080   OS << (int)Record.UnMaskedPolicyScheme << ",";
1081   OS << (int)Record.MaskedPolicyScheme << ",";
1082   OS << "},\n";
1083   return OS;
1084 }
1085 
1086 } // end namespace RISCV
1087 } // end namespace clang
1088