xref: /freebsd/contrib/llvm-project/clang/lib/Support/RISCVVIntrinsicUtils.cpp (revision 78cd75393ec79565c63927bf200f06f839a1dc05)
1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/Support/RISCVVIntrinsicUtils.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/SmallSet.h"
12 #include "llvm/ADT/StringExtras.h"
13 #include "llvm/ADT/StringMap.h"
14 #include "llvm/ADT/StringSet.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Support/ErrorHandling.h"
17 #include "llvm/Support/raw_ostream.h"
18 #include <numeric>
19 #include <optional>
20 
21 using namespace llvm;
22 
23 namespace clang {
24 namespace RISCV {
25 
26 const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor(
27     BaseTypeModifier::Vector, VectorTypeModifier::MaskVector);
28 const PrototypeDescriptor PrototypeDescriptor::VL =
29     PrototypeDescriptor(BaseTypeModifier::SizeT);
30 const PrototypeDescriptor PrototypeDescriptor::Vector =
31     PrototypeDescriptor(BaseTypeModifier::Vector);
32 
33 //===----------------------------------------------------------------------===//
34 // Type implementation
35 //===----------------------------------------------------------------------===//
36 
37 LMULType::LMULType(int NewLog2LMUL) {
38   // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
39   assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
40   Log2LMUL = NewLog2LMUL;
41 }
42 
43 std::string LMULType::str() const {
44   if (Log2LMUL < 0)
45     return "mf" + utostr(1ULL << (-Log2LMUL));
46   return "m" + utostr(1ULL << Log2LMUL);
47 }
48 
49 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
50   int Log2ScaleResult = 0;
51   switch (ElementBitwidth) {
52   default:
53     break;
54   case 8:
55     Log2ScaleResult = Log2LMUL + 3;
56     break;
57   case 16:
58     Log2ScaleResult = Log2LMUL + 2;
59     break;
60   case 32:
61     Log2ScaleResult = Log2LMUL + 1;
62     break;
63   case 64:
64     Log2ScaleResult = Log2LMUL;
65     break;
66   }
67   // Illegal vscale result would be less than 1
68   if (Log2ScaleResult < 0)
69     return std::nullopt;
70   return 1 << Log2ScaleResult;
71 }
72 
73 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
74 
75 RVVType::RVVType(BasicType BT, int Log2LMUL,
76                  const PrototypeDescriptor &prototype)
77     : BT(BT), LMUL(LMULType(Log2LMUL)) {
78   applyBasicType();
79   applyModifier(prototype);
80   Valid = verifyType();
81   if (Valid) {
82     initBuiltinStr();
83     initTypeStr();
84     if (isVector()) {
85       initClangBuiltinStr();
86     }
87   }
88 }
89 
90 // clang-format off
91 // boolean type are encoded the ratio of n (SEW/LMUL)
92 // SEW/LMUL | 1         | 2         | 4         | 8        | 16        | 32        | 64
93 // c type   | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t  | vbool2_t  | vbool1_t
94 // IR type  | nxv1i1    | nxv2i1    | nxv4i1    | nxv8i1   | nxv16i1   | nxv32i1   | nxv64i1
95 
96 // type\lmul | 1/8    | 1/4      | 1/2     | 1       | 2        | 4        | 8
97 // --------  |------  | -------- | ------- | ------- | -------- | -------- | --------
98 // i64       | N/A    | N/A      | N/A     | nxv1i64 | nxv2i64  | nxv4i64  | nxv8i64
99 // i32       | N/A    | N/A      | nxv1i32 | nxv2i32 | nxv4i32  | nxv8i32  | nxv16i32
100 // i16       | N/A    | nxv1i16  | nxv2i16 | nxv4i16 | nxv8i16  | nxv16i16 | nxv32i16
101 // i8        | nxv1i8 | nxv2i8   | nxv4i8  | nxv8i8  | nxv16i8  | nxv32i8  | nxv64i8
102 // double    | N/A    | N/A      | N/A     | nxv1f64 | nxv2f64  | nxv4f64  | nxv8f64
103 // float     | N/A    | N/A      | nxv1f32 | nxv2f32 | nxv4f32  | nxv8f32  | nxv16f32
104 // half      | N/A    | nxv1f16  | nxv2f16 | nxv4f16 | nxv8f16  | nxv16f16 | nxv32f16
105 // clang-format on
106 
107 bool RVVType::verifyType() const {
108   if (ScalarType == Invalid)
109     return false;
110   if (isScalar())
111     return true;
112   if (!Scale)
113     return false;
114   if (isFloat() && ElementBitwidth == 8)
115     return false;
116   if (IsTuple && (NF == 1 || NF > 8))
117     return false;
118   if (IsTuple && (1 << std::max(0, LMUL.Log2LMUL)) * NF > 8)
119     return false;
120   unsigned V = *Scale;
121   switch (ElementBitwidth) {
122   case 1:
123   case 8:
124     // Check Scale is 1,2,4,8,16,32,64
125     return (V <= 64 && isPowerOf2_32(V));
126   case 16:
127     // Check Scale is 1,2,4,8,16,32
128     return (V <= 32 && isPowerOf2_32(V));
129   case 32:
130     // Check Scale is 1,2,4,8,16
131     return (V <= 16 && isPowerOf2_32(V));
132   case 64:
133     // Check Scale is 1,2,4,8
134     return (V <= 8 && isPowerOf2_32(V));
135   }
136   return false;
137 }
138 
139 void RVVType::initBuiltinStr() {
140   assert(isValid() && "RVVType is invalid");
141   switch (ScalarType) {
142   case ScalarTypeKind::Void:
143     BuiltinStr = "v";
144     return;
145   case ScalarTypeKind::Size_t:
146     BuiltinStr = "z";
147     if (IsImmediate)
148       BuiltinStr = "I" + BuiltinStr;
149     if (IsPointer)
150       BuiltinStr += "*";
151     return;
152   case ScalarTypeKind::Ptrdiff_t:
153     BuiltinStr = "Y";
154     return;
155   case ScalarTypeKind::UnsignedLong:
156     BuiltinStr = "ULi";
157     return;
158   case ScalarTypeKind::SignedLong:
159     BuiltinStr = "Li";
160     return;
161   case ScalarTypeKind::Boolean:
162     assert(ElementBitwidth == 1);
163     BuiltinStr += "b";
164     break;
165   case ScalarTypeKind::SignedInteger:
166   case ScalarTypeKind::UnsignedInteger:
167     switch (ElementBitwidth) {
168     case 8:
169       BuiltinStr += "c";
170       break;
171     case 16:
172       BuiltinStr += "s";
173       break;
174     case 32:
175       BuiltinStr += "i";
176       break;
177     case 64:
178       BuiltinStr += "Wi";
179       break;
180     default:
181       llvm_unreachable("Unhandled ElementBitwidth!");
182     }
183     if (isSignedInteger())
184       BuiltinStr = "S" + BuiltinStr;
185     else
186       BuiltinStr = "U" + BuiltinStr;
187     break;
188   case ScalarTypeKind::Float:
189     switch (ElementBitwidth) {
190     case 16:
191       BuiltinStr += "x";
192       break;
193     case 32:
194       BuiltinStr += "f";
195       break;
196     case 64:
197       BuiltinStr += "d";
198       break;
199     default:
200       llvm_unreachable("Unhandled ElementBitwidth!");
201     }
202     break;
203   default:
204     llvm_unreachable("ScalarType is invalid!");
205   }
206   if (IsImmediate)
207     BuiltinStr = "I" + BuiltinStr;
208   if (isScalar()) {
209     if (IsConstant)
210       BuiltinStr += "C";
211     if (IsPointer)
212       BuiltinStr += "*";
213     return;
214   }
215   BuiltinStr = "q" + utostr(*Scale) + BuiltinStr;
216   // Pointer to vector types. Defined for segment load intrinsics.
217   // segment load intrinsics have pointer type arguments to store the loaded
218   // vector values.
219   if (IsPointer)
220     BuiltinStr += "*";
221 
222   if (IsTuple)
223     BuiltinStr = "T" + utostr(NF) + BuiltinStr;
224 }
225 
226 void RVVType::initClangBuiltinStr() {
227   assert(isValid() && "RVVType is invalid");
228   assert(isVector() && "Handle Vector type only");
229 
230   ClangBuiltinStr = "__rvv_";
231   switch (ScalarType) {
232   case ScalarTypeKind::Boolean:
233     ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t";
234     return;
235   case ScalarTypeKind::Float:
236     ClangBuiltinStr += "float";
237     break;
238   case ScalarTypeKind::SignedInteger:
239     ClangBuiltinStr += "int";
240     break;
241   case ScalarTypeKind::UnsignedInteger:
242     ClangBuiltinStr += "uint";
243     break;
244   default:
245     llvm_unreachable("ScalarTypeKind is invalid");
246   }
247   ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() +
248                      (IsTuple ? "x" + utostr(NF) : "") + "_t";
249 }
250 
251 void RVVType::initTypeStr() {
252   assert(isValid() && "RVVType is invalid");
253 
254   if (IsConstant)
255     Str += "const ";
256 
257   auto getTypeString = [&](StringRef TypeStr) {
258     if (isScalar())
259       return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
260     return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() +
261                  (IsTuple ? "x" + utostr(NF) : "") + "_t")
262         .str();
263   };
264 
265   switch (ScalarType) {
266   case ScalarTypeKind::Void:
267     Str = "void";
268     return;
269   case ScalarTypeKind::Size_t:
270     Str = "size_t";
271     if (IsPointer)
272       Str += " *";
273     return;
274   case ScalarTypeKind::Ptrdiff_t:
275     Str = "ptrdiff_t";
276     return;
277   case ScalarTypeKind::UnsignedLong:
278     Str = "unsigned long";
279     return;
280   case ScalarTypeKind::SignedLong:
281     Str = "long";
282     return;
283   case ScalarTypeKind::Boolean:
284     if (isScalar())
285       Str += "bool";
286     else
287       // Vector bool is special case, the formulate is
288       // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
289       Str += "vbool" + utostr(64 / *Scale) + "_t";
290     break;
291   case ScalarTypeKind::Float:
292     if (isScalar()) {
293       if (ElementBitwidth == 64)
294         Str += "double";
295       else if (ElementBitwidth == 32)
296         Str += "float";
297       else if (ElementBitwidth == 16)
298         Str += "_Float16";
299       else
300         llvm_unreachable("Unhandled floating type.");
301     } else
302       Str += getTypeString("float");
303     break;
304   case ScalarTypeKind::SignedInteger:
305     Str += getTypeString("int");
306     break;
307   case ScalarTypeKind::UnsignedInteger:
308     Str += getTypeString("uint");
309     break;
310   default:
311     llvm_unreachable("ScalarType is invalid!");
312   }
313   if (IsPointer)
314     Str += " *";
315 }
316 
317 void RVVType::initShortStr() {
318   switch (ScalarType) {
319   case ScalarTypeKind::Boolean:
320     assert(isVector());
321     ShortStr = "b" + utostr(64 / *Scale);
322     return;
323   case ScalarTypeKind::Float:
324     ShortStr = "f" + utostr(ElementBitwidth);
325     break;
326   case ScalarTypeKind::SignedInteger:
327     ShortStr = "i" + utostr(ElementBitwidth);
328     break;
329   case ScalarTypeKind::UnsignedInteger:
330     ShortStr = "u" + utostr(ElementBitwidth);
331     break;
332   default:
333     llvm_unreachable("Unhandled case!");
334   }
335   if (isVector())
336     ShortStr += LMUL.str();
337   if (isTuple())
338     ShortStr += "x" + utostr(NF);
339 }
340 
341 static VectorTypeModifier getTupleVTM(unsigned NF) {
342   assert(2 <= NF && NF <= 8 && "2 <= NF <= 8");
343   return static_cast<VectorTypeModifier>(
344       static_cast<uint8_t>(VectorTypeModifier::Tuple2) + (NF - 2));
345 }
346 
347 void RVVType::applyBasicType() {
348   switch (BT) {
349   case BasicType::Int8:
350     ElementBitwidth = 8;
351     ScalarType = ScalarTypeKind::SignedInteger;
352     break;
353   case BasicType::Int16:
354     ElementBitwidth = 16;
355     ScalarType = ScalarTypeKind::SignedInteger;
356     break;
357   case BasicType::Int32:
358     ElementBitwidth = 32;
359     ScalarType = ScalarTypeKind::SignedInteger;
360     break;
361   case BasicType::Int64:
362     ElementBitwidth = 64;
363     ScalarType = ScalarTypeKind::SignedInteger;
364     break;
365   case BasicType::Float16:
366     ElementBitwidth = 16;
367     ScalarType = ScalarTypeKind::Float;
368     break;
369   case BasicType::Float32:
370     ElementBitwidth = 32;
371     ScalarType = ScalarTypeKind::Float;
372     break;
373   case BasicType::Float64:
374     ElementBitwidth = 64;
375     ScalarType = ScalarTypeKind::Float;
376     break;
377   default:
378     llvm_unreachable("Unhandled type code!");
379   }
380   assert(ElementBitwidth != 0 && "Bad element bitwidth!");
381 }
382 
383 std::optional<PrototypeDescriptor>
384 PrototypeDescriptor::parsePrototypeDescriptor(
385     llvm::StringRef PrototypeDescriptorStr) {
386   PrototypeDescriptor PD;
387   BaseTypeModifier PT = BaseTypeModifier::Invalid;
388   VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
389 
390   if (PrototypeDescriptorStr.empty())
391     return PD;
392 
393   // Handle base type modifier
394   auto PType = PrototypeDescriptorStr.back();
395   switch (PType) {
396   case 'e':
397     PT = BaseTypeModifier::Scalar;
398     break;
399   case 'v':
400     PT = BaseTypeModifier::Vector;
401     break;
402   case 'w':
403     PT = BaseTypeModifier::Vector;
404     VTM = VectorTypeModifier::Widening2XVector;
405     break;
406   case 'q':
407     PT = BaseTypeModifier::Vector;
408     VTM = VectorTypeModifier::Widening4XVector;
409     break;
410   case 'o':
411     PT = BaseTypeModifier::Vector;
412     VTM = VectorTypeModifier::Widening8XVector;
413     break;
414   case 'm':
415     PT = BaseTypeModifier::Vector;
416     VTM = VectorTypeModifier::MaskVector;
417     break;
418   case '0':
419     PT = BaseTypeModifier::Void;
420     break;
421   case 'z':
422     PT = BaseTypeModifier::SizeT;
423     break;
424   case 't':
425     PT = BaseTypeModifier::Ptrdiff;
426     break;
427   case 'u':
428     PT = BaseTypeModifier::UnsignedLong;
429     break;
430   case 'l':
431     PT = BaseTypeModifier::SignedLong;
432     break;
433   default:
434     llvm_unreachable("Illegal primitive type transformers!");
435   }
436   PD.PT = static_cast<uint8_t>(PT);
437   PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back();
438 
439   // Compute the vector type transformers, it can only appear one time.
440   if (PrototypeDescriptorStr.startswith("(")) {
441     assert(VTM == VectorTypeModifier::NoModifier &&
442            "VectorTypeModifier should only have one modifier");
443     size_t Idx = PrototypeDescriptorStr.find(')');
444     assert(Idx != StringRef::npos);
445     StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx);
446     PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1);
447     assert(!PrototypeDescriptorStr.contains('(') &&
448            "Only allow one vector type modifier");
449 
450     auto ComplexTT = ComplexType.split(":");
451     if (ComplexTT.first == "Log2EEW") {
452       uint32_t Log2EEW;
453       if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
454         llvm_unreachable("Invalid Log2EEW value!");
455         return std::nullopt;
456       }
457       switch (Log2EEW) {
458       case 3:
459         VTM = VectorTypeModifier::Log2EEW3;
460         break;
461       case 4:
462         VTM = VectorTypeModifier::Log2EEW4;
463         break;
464       case 5:
465         VTM = VectorTypeModifier::Log2EEW5;
466         break;
467       case 6:
468         VTM = VectorTypeModifier::Log2EEW6;
469         break;
470       default:
471         llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
472         return std::nullopt;
473       }
474     } else if (ComplexTT.first == "FixedSEW") {
475       uint32_t NewSEW;
476       if (ComplexTT.second.getAsInteger(10, NewSEW)) {
477         llvm_unreachable("Invalid FixedSEW value!");
478         return std::nullopt;
479       }
480       switch (NewSEW) {
481       case 8:
482         VTM = VectorTypeModifier::FixedSEW8;
483         break;
484       case 16:
485         VTM = VectorTypeModifier::FixedSEW16;
486         break;
487       case 32:
488         VTM = VectorTypeModifier::FixedSEW32;
489         break;
490       case 64:
491         VTM = VectorTypeModifier::FixedSEW64;
492         break;
493       default:
494         llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
495         return std::nullopt;
496       }
497     } else if (ComplexTT.first == "LFixedLog2LMUL") {
498       int32_t Log2LMUL;
499       if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
500         llvm_unreachable("Invalid LFixedLog2LMUL value!");
501         return std::nullopt;
502       }
503       switch (Log2LMUL) {
504       case -3:
505         VTM = VectorTypeModifier::LFixedLog2LMULN3;
506         break;
507       case -2:
508         VTM = VectorTypeModifier::LFixedLog2LMULN2;
509         break;
510       case -1:
511         VTM = VectorTypeModifier::LFixedLog2LMULN1;
512         break;
513       case 0:
514         VTM = VectorTypeModifier::LFixedLog2LMUL0;
515         break;
516       case 1:
517         VTM = VectorTypeModifier::LFixedLog2LMUL1;
518         break;
519       case 2:
520         VTM = VectorTypeModifier::LFixedLog2LMUL2;
521         break;
522       case 3:
523         VTM = VectorTypeModifier::LFixedLog2LMUL3;
524         break;
525       default:
526         llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
527         return std::nullopt;
528       }
529     } else if (ComplexTT.first == "SFixedLog2LMUL") {
530       int32_t Log2LMUL;
531       if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
532         llvm_unreachable("Invalid SFixedLog2LMUL value!");
533         return std::nullopt;
534       }
535       switch (Log2LMUL) {
536       case -3:
537         VTM = VectorTypeModifier::SFixedLog2LMULN3;
538         break;
539       case -2:
540         VTM = VectorTypeModifier::SFixedLog2LMULN2;
541         break;
542       case -1:
543         VTM = VectorTypeModifier::SFixedLog2LMULN1;
544         break;
545       case 0:
546         VTM = VectorTypeModifier::SFixedLog2LMUL0;
547         break;
548       case 1:
549         VTM = VectorTypeModifier::SFixedLog2LMUL1;
550         break;
551       case 2:
552         VTM = VectorTypeModifier::SFixedLog2LMUL2;
553         break;
554       case 3:
555         VTM = VectorTypeModifier::SFixedLog2LMUL3;
556         break;
557       default:
558         llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
559         return std::nullopt;
560       }
561 
562     } else if (ComplexTT.first == "Tuple") {
563       unsigned NF = 0;
564       if (ComplexTT.second.getAsInteger(10, NF)) {
565         llvm_unreachable("Invalid NF value!");
566         return std::nullopt;
567       }
568       VTM = getTupleVTM(NF);
569     } else {
570       llvm_unreachable("Illegal complex type transformers!");
571     }
572   }
573   PD.VTM = static_cast<uint8_t>(VTM);
574 
575   // Compute the remain type transformers
576   TypeModifier TM = TypeModifier::NoModifier;
577   for (char I : PrototypeDescriptorStr) {
578     switch (I) {
579     case 'P':
580       if ((TM & TypeModifier::Const) == TypeModifier::Const)
581         llvm_unreachable("'P' transformer cannot be used after 'C'");
582       if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
583         llvm_unreachable("'P' transformer cannot be used twice");
584       TM |= TypeModifier::Pointer;
585       break;
586     case 'C':
587       TM |= TypeModifier::Const;
588       break;
589     case 'K':
590       TM |= TypeModifier::Immediate;
591       break;
592     case 'U':
593       TM |= TypeModifier::UnsignedInteger;
594       break;
595     case 'I':
596       TM |= TypeModifier::SignedInteger;
597       break;
598     case 'F':
599       TM |= TypeModifier::Float;
600       break;
601     case 'S':
602       TM |= TypeModifier::LMUL1;
603       break;
604     default:
605       llvm_unreachable("Illegal non-primitive type transformer!");
606     }
607   }
608   PD.TM = static_cast<uint8_t>(TM);
609 
610   return PD;
611 }
612 
613 void RVVType::applyModifier(const PrototypeDescriptor &Transformer) {
614   // Handle primitive type transformer
615   switch (static_cast<BaseTypeModifier>(Transformer.PT)) {
616   case BaseTypeModifier::Scalar:
617     Scale = 0;
618     break;
619   case BaseTypeModifier::Vector:
620     Scale = LMUL.getScale(ElementBitwidth);
621     break;
622   case BaseTypeModifier::Void:
623     ScalarType = ScalarTypeKind::Void;
624     break;
625   case BaseTypeModifier::SizeT:
626     ScalarType = ScalarTypeKind::Size_t;
627     break;
628   case BaseTypeModifier::Ptrdiff:
629     ScalarType = ScalarTypeKind::Ptrdiff_t;
630     break;
631   case BaseTypeModifier::UnsignedLong:
632     ScalarType = ScalarTypeKind::UnsignedLong;
633     break;
634   case BaseTypeModifier::SignedLong:
635     ScalarType = ScalarTypeKind::SignedLong;
636     break;
637   case BaseTypeModifier::Invalid:
638     ScalarType = ScalarTypeKind::Invalid;
639     return;
640   }
641 
642   switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
643   case VectorTypeModifier::Widening2XVector:
644     ElementBitwidth *= 2;
645     LMUL.MulLog2LMUL(1);
646     Scale = LMUL.getScale(ElementBitwidth);
647     break;
648   case VectorTypeModifier::Widening4XVector:
649     ElementBitwidth *= 4;
650     LMUL.MulLog2LMUL(2);
651     Scale = LMUL.getScale(ElementBitwidth);
652     break;
653   case VectorTypeModifier::Widening8XVector:
654     ElementBitwidth *= 8;
655     LMUL.MulLog2LMUL(3);
656     Scale = LMUL.getScale(ElementBitwidth);
657     break;
658   case VectorTypeModifier::MaskVector:
659     ScalarType = ScalarTypeKind::Boolean;
660     Scale = LMUL.getScale(ElementBitwidth);
661     ElementBitwidth = 1;
662     break;
663   case VectorTypeModifier::Log2EEW3:
664     applyLog2EEW(3);
665     break;
666   case VectorTypeModifier::Log2EEW4:
667     applyLog2EEW(4);
668     break;
669   case VectorTypeModifier::Log2EEW5:
670     applyLog2EEW(5);
671     break;
672   case VectorTypeModifier::Log2EEW6:
673     applyLog2EEW(6);
674     break;
675   case VectorTypeModifier::FixedSEW8:
676     applyFixedSEW(8);
677     break;
678   case VectorTypeModifier::FixedSEW16:
679     applyFixedSEW(16);
680     break;
681   case VectorTypeModifier::FixedSEW32:
682     applyFixedSEW(32);
683     break;
684   case VectorTypeModifier::FixedSEW64:
685     applyFixedSEW(64);
686     break;
687   case VectorTypeModifier::LFixedLog2LMULN3:
688     applyFixedLog2LMUL(-3, FixedLMULType::LargerThan);
689     break;
690   case VectorTypeModifier::LFixedLog2LMULN2:
691     applyFixedLog2LMUL(-2, FixedLMULType::LargerThan);
692     break;
693   case VectorTypeModifier::LFixedLog2LMULN1:
694     applyFixedLog2LMUL(-1, FixedLMULType::LargerThan);
695     break;
696   case VectorTypeModifier::LFixedLog2LMUL0:
697     applyFixedLog2LMUL(0, FixedLMULType::LargerThan);
698     break;
699   case VectorTypeModifier::LFixedLog2LMUL1:
700     applyFixedLog2LMUL(1, FixedLMULType::LargerThan);
701     break;
702   case VectorTypeModifier::LFixedLog2LMUL2:
703     applyFixedLog2LMUL(2, FixedLMULType::LargerThan);
704     break;
705   case VectorTypeModifier::LFixedLog2LMUL3:
706     applyFixedLog2LMUL(3, FixedLMULType::LargerThan);
707     break;
708   case VectorTypeModifier::SFixedLog2LMULN3:
709     applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan);
710     break;
711   case VectorTypeModifier::SFixedLog2LMULN2:
712     applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan);
713     break;
714   case VectorTypeModifier::SFixedLog2LMULN1:
715     applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan);
716     break;
717   case VectorTypeModifier::SFixedLog2LMUL0:
718     applyFixedLog2LMUL(0, FixedLMULType::SmallerThan);
719     break;
720   case VectorTypeModifier::SFixedLog2LMUL1:
721     applyFixedLog2LMUL(1, FixedLMULType::SmallerThan);
722     break;
723   case VectorTypeModifier::SFixedLog2LMUL2:
724     applyFixedLog2LMUL(2, FixedLMULType::SmallerThan);
725     break;
726   case VectorTypeModifier::SFixedLog2LMUL3:
727     applyFixedLog2LMUL(3, FixedLMULType::SmallerThan);
728     break;
729   case VectorTypeModifier::Tuple2:
730   case VectorTypeModifier::Tuple3:
731   case VectorTypeModifier::Tuple4:
732   case VectorTypeModifier::Tuple5:
733   case VectorTypeModifier::Tuple6:
734   case VectorTypeModifier::Tuple7:
735   case VectorTypeModifier::Tuple8: {
736     IsTuple = true;
737     NF = 2 + static_cast<uint8_t>(Transformer.VTM) -
738          static_cast<uint8_t>(VectorTypeModifier::Tuple2);
739     break;
740   }
741   case VectorTypeModifier::NoModifier:
742     break;
743   }
744 
745   for (unsigned TypeModifierMaskShift = 0;
746        TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
747        ++TypeModifierMaskShift) {
748     unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
749     if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
750         TypeModifierMask)
751       continue;
752     switch (static_cast<TypeModifier>(TypeModifierMask)) {
753     case TypeModifier::Pointer:
754       IsPointer = true;
755       break;
756     case TypeModifier::Const:
757       IsConstant = true;
758       break;
759     case TypeModifier::Immediate:
760       IsImmediate = true;
761       IsConstant = true;
762       break;
763     case TypeModifier::UnsignedInteger:
764       ScalarType = ScalarTypeKind::UnsignedInteger;
765       break;
766     case TypeModifier::SignedInteger:
767       ScalarType = ScalarTypeKind::SignedInteger;
768       break;
769     case TypeModifier::Float:
770       ScalarType = ScalarTypeKind::Float;
771       break;
772     case TypeModifier::LMUL1:
773       LMUL = LMULType(0);
774       // Update ElementBitwidth need to update Scale too.
775       Scale = LMUL.getScale(ElementBitwidth);
776       break;
777     default:
778       llvm_unreachable("Unknown type modifier mask!");
779     }
780   }
781 }
782 
783 void RVVType::applyLog2EEW(unsigned Log2EEW) {
784   // update new elmul = (eew/sew) * lmul
785   LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
786   // update new eew
787   ElementBitwidth = 1 << Log2EEW;
788   ScalarType = ScalarTypeKind::SignedInteger;
789   Scale = LMUL.getScale(ElementBitwidth);
790 }
791 
792 void RVVType::applyFixedSEW(unsigned NewSEW) {
793   // Set invalid type if src and dst SEW are same.
794   if (ElementBitwidth == NewSEW) {
795     ScalarType = ScalarTypeKind::Invalid;
796     return;
797   }
798   // Update new SEW
799   ElementBitwidth = NewSEW;
800   Scale = LMUL.getScale(ElementBitwidth);
801 }
802 
803 void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
804   switch (Type) {
805   case FixedLMULType::LargerThan:
806     if (Log2LMUL < LMUL.Log2LMUL) {
807       ScalarType = ScalarTypeKind::Invalid;
808       return;
809     }
810     break;
811   case FixedLMULType::SmallerThan:
812     if (Log2LMUL > LMUL.Log2LMUL) {
813       ScalarType = ScalarTypeKind::Invalid;
814       return;
815     }
816     break;
817   }
818 
819   // Update new LMUL
820   LMUL = LMULType(Log2LMUL);
821   Scale = LMUL.getScale(ElementBitwidth);
822 }
823 
824 std::optional<RVVTypes>
825 RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
826                            ArrayRef<PrototypeDescriptor> Prototype) {
827   RVVTypes Types;
828   for (const PrototypeDescriptor &Proto : Prototype) {
829     auto T = computeType(BT, Log2LMUL, Proto);
830     if (!T)
831       return std::nullopt;
832     // Record legal type index
833     Types.push_back(*T);
834   }
835   return Types;
836 }
837 
838 // Compute the hash value of RVVType, used for cache the result of computeType.
839 static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
840                                         PrototypeDescriptor Proto) {
841   // Layout of hash value:
842   // 0               8    16          24        32          40
843   // | Log2LMUL + 3  | BT  | Proto.PT | Proto.TM | Proto.VTM |
844   assert(Log2LMUL >= -3 && Log2LMUL <= 3);
845   return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 |
846          ((uint64_t)(Proto.PT & 0xff) << 16) |
847          ((uint64_t)(Proto.TM & 0xff) << 24) |
848          ((uint64_t)(Proto.VTM & 0xff) << 32);
849 }
850 
851 std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL,
852                                                     PrototypeDescriptor Proto) {
853   uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
854   // Search first
855   auto It = LegalTypes.find(Idx);
856   if (It != LegalTypes.end())
857     return &(It->second);
858 
859   if (IllegalTypes.count(Idx))
860     return std::nullopt;
861 
862   // Compute type and record the result.
863   RVVType T(BT, Log2LMUL, Proto);
864   if (T.isValid()) {
865     // Record legal type index and value.
866     std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool>
867         InsertResult = LegalTypes.insert({Idx, T});
868     return &(InsertResult.first->second);
869   }
870   // Record illegal type index.
871   IllegalTypes.insert(Idx);
872   return std::nullopt;
873 }
874 
875 //===----------------------------------------------------------------------===//
876 // RVVIntrinsic implementation
877 //===----------------------------------------------------------------------===//
878 RVVIntrinsic::RVVIntrinsic(
879     StringRef NewName, StringRef Suffix, StringRef NewOverloadedName,
880     StringRef OverloadedSuffix, StringRef IRName, bool IsMasked,
881     bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
882     bool SupportOverloading, bool HasBuiltinAlias, StringRef ManualCodegen,
883     const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes,
884     const std::vector<StringRef> &RequiredFeatures, unsigned NF,
885     Policy NewPolicyAttrs, bool HasFRMRoundModeOp)
886     : IRName(IRName), IsMasked(IsMasked),
887       HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme),
888       SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias),
889       ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) {
890 
891   // Init BuiltinName, Name and OverloadedName
892   BuiltinName = NewName.str();
893   Name = BuiltinName;
894   if (NewOverloadedName.empty())
895     OverloadedName = NewName.split("_").first.str();
896   else
897     OverloadedName = NewOverloadedName.str();
898   if (!Suffix.empty())
899     Name += "_" + Suffix.str();
900   if (!OverloadedSuffix.empty())
901     OverloadedName += "_" + OverloadedSuffix.str();
902 
903   updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName,
904                        PolicyAttrs, HasFRMRoundModeOp);
905 
906   // Init OutputType and InputTypes
907   OutputType = OutInTypes[0];
908   InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
909 
910   // IntrinsicTypes is unmasked TA version index. Need to update it
911   // if there is merge operand (It is always in first operand).
912   IntrinsicTypes = NewIntrinsicTypes;
913   if ((IsMasked && hasMaskedOffOperand()) ||
914       (!IsMasked && hasPassthruOperand())) {
915     for (auto &I : IntrinsicTypes) {
916       if (I >= 0)
917         I += NF;
918     }
919   }
920 }
921 
922 std::string RVVIntrinsic::getBuiltinTypeStr() const {
923   std::string S;
924   S += OutputType->getBuiltinStr();
925   for (const auto &T : InputTypes) {
926     S += T->getBuiltinStr();
927   }
928   return S;
929 }
930 
931 std::string RVVIntrinsic::getSuffixStr(
932     RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL,
933     llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) {
934   SmallVector<std::string> SuffixStrs;
935   for (auto PD : PrototypeDescriptors) {
936     auto T = TypeCache.computeType(Type, Log2LMUL, PD);
937     SuffixStrs.push_back((*T)->getShortStr());
938   }
939   return join(SuffixStrs, "_");
940 }
941 
942 llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes(
943     llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked,
944     bool HasMaskedOffOperand, bool HasVL, unsigned NF,
945     PolicyScheme DefaultScheme, Policy PolicyAttrs, bool IsTuple) {
946   SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(),
947                                                 Prototype.end());
948   bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand;
949   if (IsMasked) {
950     // If HasMaskedOffOperand, insert result type as first input operand if
951     // need.
952     if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) {
953       if (NF == 1) {
954         NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]);
955       } else if (NF > 1) {
956         if (IsTuple) {
957           PrototypeDescriptor BasePtrOperand = Prototype[1];
958           PrototypeDescriptor MaskoffType = PrototypeDescriptor(
959               static_cast<uint8_t>(BaseTypeModifier::Vector),
960               static_cast<uint8_t>(getTupleVTM(NF)),
961               BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer));
962           NewPrototype.insert(NewPrototype.begin() + 1, MaskoffType);
963         } else {
964           // Convert
965           // (void, op0 address, op1 address, ...)
966           // to
967           // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
968           PrototypeDescriptor MaskoffType = NewPrototype[1];
969           MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
970           NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
971         }
972       }
973     }
974     if (HasMaskedOffOperand && NF > 1) {
975       // Convert
976       // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
977       // to
978       // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
979       // ...)
980       if (IsTuple)
981         NewPrototype.insert(NewPrototype.begin() + 1,
982                             PrototypeDescriptor::Mask);
983       else
984         NewPrototype.insert(NewPrototype.begin() + NF + 1,
985                             PrototypeDescriptor::Mask);
986     } else {
987       // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
988       NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask);
989     }
990   } else {
991     if (NF == 1) {
992       if (PolicyAttrs.isTUPolicy() && HasPassthruOp)
993         NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]);
994     } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) {
995       if (IsTuple) {
996         PrototypeDescriptor BasePtrOperand = Prototype[0];
997         PrototypeDescriptor MaskoffType = PrototypeDescriptor(
998             static_cast<uint8_t>(BaseTypeModifier::Vector),
999             static_cast<uint8_t>(getTupleVTM(NF)),
1000             BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer));
1001         NewPrototype.insert(NewPrototype.begin(), MaskoffType);
1002       } else {
1003         // NF > 1 cases for segment load operations.
1004         // Convert
1005         // (void, op0 address, op1 address, ...)
1006         // to
1007         // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
1008         PrototypeDescriptor MaskoffType = Prototype[1];
1009         MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
1010         NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
1011       }
1012     }
1013  }
1014 
1015   // If HasVL, append PrototypeDescriptor:VL to last operand
1016   if (HasVL)
1017     NewPrototype.push_back(PrototypeDescriptor::VL);
1018 
1019   return NewPrototype;
1020 }
1021 
1022 llvm::SmallVector<Policy> RVVIntrinsic::getSupportedUnMaskedPolicies() {
1023   return {Policy(Policy::PolicyType::Undisturbed)}; // TU
1024 }
1025 
1026 llvm::SmallVector<Policy>
1027 RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy,
1028                                          bool HasMaskPolicy) {
1029   if (HasTailPolicy && HasMaskPolicy)
1030     return {Policy(Policy::PolicyType::Undisturbed,
1031                    Policy::PolicyType::Agnostic), // TUM
1032             Policy(Policy::PolicyType::Undisturbed,
1033                    Policy::PolicyType::Undisturbed), // TUMU
1034             Policy(Policy::PolicyType::Agnostic,
1035                    Policy::PolicyType::Undisturbed)}; // MU
1036   if (HasTailPolicy && !HasMaskPolicy)
1037     return {Policy(Policy::PolicyType::Undisturbed,
1038                    Policy::PolicyType::Agnostic)}; // TU
1039   if (!HasTailPolicy && HasMaskPolicy)
1040     return {Policy(Policy::PolicyType::Agnostic,
1041                    Policy::PolicyType::Undisturbed)}; // MU
1042   llvm_unreachable("An RVV instruction should not be without both tail policy "
1043                    "and mask policy");
1044 }
1045 
1046 void RVVIntrinsic::updateNamesAndPolicy(
1047     bool IsMasked, bool HasPolicy, std::string &Name, std::string &BuiltinName,
1048     std::string &OverloadedName, Policy &PolicyAttrs, bool HasFRMRoundModeOp) {
1049 
1050   auto appendPolicySuffix = [&](const std::string &suffix) {
1051     Name += suffix;
1052     BuiltinName += suffix;
1053     OverloadedName += suffix;
1054   };
1055 
1056   // This follows the naming guideline under riscv-c-api-doc to add the
1057   // `__riscv_` suffix for all RVV intrinsics.
1058   Name = "__riscv_" + Name;
1059   OverloadedName = "__riscv_" + OverloadedName;
1060 
1061   if (HasFRMRoundModeOp) {
1062     Name += "_rm";
1063     BuiltinName += "_rm";
1064   }
1065 
1066   if (IsMasked) {
1067     if (PolicyAttrs.isTUMUPolicy())
1068       appendPolicySuffix("_tumu");
1069     else if (PolicyAttrs.isTUMAPolicy())
1070       appendPolicySuffix("_tum");
1071     else if (PolicyAttrs.isTAMUPolicy())
1072       appendPolicySuffix("_mu");
1073     else if (PolicyAttrs.isTAMAPolicy()) {
1074       Name += "_m";
1075       BuiltinName += "_m";
1076     } else
1077       llvm_unreachable("Unhandled policy condition");
1078   } else {
1079     if (PolicyAttrs.isTUPolicy())
1080       appendPolicySuffix("_tu");
1081     else if (PolicyAttrs.isTAPolicy()) // no suffix needed
1082       return;
1083     else
1084       llvm_unreachable("Unhandled policy condition");
1085   }
1086 }
1087 
1088 SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
1089   SmallVector<PrototypeDescriptor> PrototypeDescriptors;
1090   const StringRef Primaries("evwqom0ztul");
1091   while (!Prototypes.empty()) {
1092     size_t Idx = 0;
1093     // Skip over complex prototype because it could contain primitive type
1094     // character.
1095     if (Prototypes[0] == '(')
1096       Idx = Prototypes.find_first_of(')');
1097     Idx = Prototypes.find_first_of(Primaries, Idx);
1098     assert(Idx != StringRef::npos);
1099     auto PD = PrototypeDescriptor::parsePrototypeDescriptor(
1100         Prototypes.slice(0, Idx + 1));
1101     if (!PD)
1102       llvm_unreachable("Error during parsing prototype.");
1103     PrototypeDescriptors.push_back(*PD);
1104     Prototypes = Prototypes.drop_front(Idx + 1);
1105   }
1106   return PrototypeDescriptors;
1107 }
1108 
1109 raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) {
1110   OS << "{";
1111   OS << "\"" << Record.Name << "\",";
1112   if (Record.OverloadedName == nullptr ||
1113       StringRef(Record.OverloadedName).empty())
1114     OS << "nullptr,";
1115   else
1116     OS << "\"" << Record.OverloadedName << "\",";
1117   OS << Record.PrototypeIndex << ",";
1118   OS << Record.SuffixIndex << ",";
1119   OS << Record.OverloadedSuffixIndex << ",";
1120   OS << (int)Record.PrototypeLength << ",";
1121   OS << (int)Record.SuffixLength << ",";
1122   OS << (int)Record.OverloadedSuffixSize << ",";
1123   OS << (int)Record.RequiredExtensions << ",";
1124   OS << (int)Record.TypeRangeMask << ",";
1125   OS << (int)Record.Log2LMULMask << ",";
1126   OS << (int)Record.NF << ",";
1127   OS << (int)Record.HasMasked << ",";
1128   OS << (int)Record.HasVL << ",";
1129   OS << (int)Record.HasMaskedOffOperand << ",";
1130   OS << (int)Record.HasTailPolicy << ",";
1131   OS << (int)Record.HasMaskPolicy << ",";
1132   OS << (int)Record.HasFRMRoundModeOp << ",";
1133   OS << (int)Record.IsTuple << ",";
1134   OS << (int)Record.UnMaskedPolicyScheme << ",";
1135   OS << (int)Record.MaskedPolicyScheme << ",";
1136   OS << "},\n";
1137   return OS;
1138 }
1139 
1140 } // end namespace RISCV
1141 } // end namespace clang
1142