xref: /freebsd/contrib/llvm-project/clang/lib/Support/RISCVVIntrinsicUtils.cpp (revision c9539b89010900499a200cdd6c0265ea5d950875)
1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/Support/RISCVVIntrinsicUtils.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/Optional.h"
12 #include "llvm/ADT/SmallSet.h"
13 #include "llvm/ADT/StringExtras.h"
14 #include "llvm/ADT/StringMap.h"
15 #include "llvm/ADT/StringSet.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/raw_ostream.h"
18 #include <numeric>
19 #include <set>
20 #include <unordered_map>
21 
22 using namespace llvm;
23 
24 namespace clang {
25 namespace RISCV {
26 
27 const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor(
28     BaseTypeModifier::Vector, VectorTypeModifier::MaskVector);
29 const PrototypeDescriptor PrototypeDescriptor::VL =
30     PrototypeDescriptor(BaseTypeModifier::SizeT);
31 const PrototypeDescriptor PrototypeDescriptor::Vector =
32     PrototypeDescriptor(BaseTypeModifier::Vector);
33 
34 //===----------------------------------------------------------------------===//
35 // Type implementation
36 //===----------------------------------------------------------------------===//
37 
38 LMULType::LMULType(int NewLog2LMUL) {
39   // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
40   assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
41   Log2LMUL = NewLog2LMUL;
42 }
43 
44 std::string LMULType::str() const {
45   if (Log2LMUL < 0)
46     return "mf" + utostr(1ULL << (-Log2LMUL));
47   return "m" + utostr(1ULL << Log2LMUL);
48 }
49 
50 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
51   int Log2ScaleResult = 0;
52   switch (ElementBitwidth) {
53   default:
54     break;
55   case 8:
56     Log2ScaleResult = Log2LMUL + 3;
57     break;
58   case 16:
59     Log2ScaleResult = Log2LMUL + 2;
60     break;
61   case 32:
62     Log2ScaleResult = Log2LMUL + 1;
63     break;
64   case 64:
65     Log2ScaleResult = Log2LMUL;
66     break;
67   }
68   // Illegal vscale result would be less than 1
69   if (Log2ScaleResult < 0)
70     return llvm::None;
71   return 1 << Log2ScaleResult;
72 }
73 
74 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
75 
76 RVVType::RVVType(BasicType BT, int Log2LMUL,
77                  const PrototypeDescriptor &prototype)
78     : BT(BT), LMUL(LMULType(Log2LMUL)) {
79   applyBasicType();
80   applyModifier(prototype);
81   Valid = verifyType();
82   if (Valid) {
83     initBuiltinStr();
84     initTypeStr();
85     if (isVector()) {
86       initClangBuiltinStr();
87     }
88   }
89 }
90 
91 // clang-format off
92 // boolean type are encoded the ratio of n (SEW/LMUL)
93 // SEW/LMUL | 1         | 2         | 4         | 8        | 16        | 32        | 64
94 // c type   | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t  | vbool2_t  | vbool1_t
95 // IR type  | nxv1i1    | nxv2i1    | nxv4i1    | nxv8i1   | nxv16i1   | nxv32i1   | nxv64i1
96 
97 // type\lmul | 1/8    | 1/4      | 1/2     | 1       | 2        | 4        | 8
98 // --------  |------  | -------- | ------- | ------- | -------- | -------- | --------
99 // i64       | N/A    | N/A      | N/A     | nxv1i64 | nxv2i64  | nxv4i64  | nxv8i64
100 // i32       | N/A    | N/A      | nxv1i32 | nxv2i32 | nxv4i32  | nxv8i32  | nxv16i32
101 // i16       | N/A    | nxv1i16  | nxv2i16 | nxv4i16 | nxv8i16  | nxv16i16 | nxv32i16
102 // i8        | nxv1i8 | nxv2i8   | nxv4i8  | nxv8i8  | nxv16i8  | nxv32i8  | nxv64i8
103 // double    | N/A    | N/A      | N/A     | nxv1f64 | nxv2f64  | nxv4f64  | nxv8f64
104 // float     | N/A    | N/A      | nxv1f32 | nxv2f32 | nxv4f32  | nxv8f32  | nxv16f32
105 // half      | N/A    | nxv1f16  | nxv2f16 | nxv4f16 | nxv8f16  | nxv16f16 | nxv32f16
106 // clang-format on
107 
108 bool RVVType::verifyType() const {
109   if (ScalarType == Invalid)
110     return false;
111   if (isScalar())
112     return true;
113   if (!Scale)
114     return false;
115   if (isFloat() && ElementBitwidth == 8)
116     return false;
117   unsigned V = Scale.value();
118   switch (ElementBitwidth) {
119   case 1:
120   case 8:
121     // Check Scale is 1,2,4,8,16,32,64
122     return (V <= 64 && isPowerOf2_32(V));
123   case 16:
124     // Check Scale is 1,2,4,8,16,32
125     return (V <= 32 && isPowerOf2_32(V));
126   case 32:
127     // Check Scale is 1,2,4,8,16
128     return (V <= 16 && isPowerOf2_32(V));
129   case 64:
130     // Check Scale is 1,2,4,8
131     return (V <= 8 && isPowerOf2_32(V));
132   }
133   return false;
134 }
135 
136 void RVVType::initBuiltinStr() {
137   assert(isValid() && "RVVType is invalid");
138   switch (ScalarType) {
139   case ScalarTypeKind::Void:
140     BuiltinStr = "v";
141     return;
142   case ScalarTypeKind::Size_t:
143     BuiltinStr = "z";
144     if (IsImmediate)
145       BuiltinStr = "I" + BuiltinStr;
146     if (IsPointer)
147       BuiltinStr += "*";
148     return;
149   case ScalarTypeKind::Ptrdiff_t:
150     BuiltinStr = "Y";
151     return;
152   case ScalarTypeKind::UnsignedLong:
153     BuiltinStr = "ULi";
154     return;
155   case ScalarTypeKind::SignedLong:
156     BuiltinStr = "Li";
157     return;
158   case ScalarTypeKind::Boolean:
159     assert(ElementBitwidth == 1);
160     BuiltinStr += "b";
161     break;
162   case ScalarTypeKind::SignedInteger:
163   case ScalarTypeKind::UnsignedInteger:
164     switch (ElementBitwidth) {
165     case 8:
166       BuiltinStr += "c";
167       break;
168     case 16:
169       BuiltinStr += "s";
170       break;
171     case 32:
172       BuiltinStr += "i";
173       break;
174     case 64:
175       BuiltinStr += "Wi";
176       break;
177     default:
178       llvm_unreachable("Unhandled ElementBitwidth!");
179     }
180     if (isSignedInteger())
181       BuiltinStr = "S" + BuiltinStr;
182     else
183       BuiltinStr = "U" + BuiltinStr;
184     break;
185   case ScalarTypeKind::Float:
186     switch (ElementBitwidth) {
187     case 16:
188       BuiltinStr += "x";
189       break;
190     case 32:
191       BuiltinStr += "f";
192       break;
193     case 64:
194       BuiltinStr += "d";
195       break;
196     default:
197       llvm_unreachable("Unhandled ElementBitwidth!");
198     }
199     break;
200   default:
201     llvm_unreachable("ScalarType is invalid!");
202   }
203   if (IsImmediate)
204     BuiltinStr = "I" + BuiltinStr;
205   if (isScalar()) {
206     if (IsConstant)
207       BuiltinStr += "C";
208     if (IsPointer)
209       BuiltinStr += "*";
210     return;
211   }
212   BuiltinStr = "q" + utostr(*Scale) + BuiltinStr;
213   // Pointer to vector types. Defined for segment load intrinsics.
214   // segment load intrinsics have pointer type arguments to store the loaded
215   // vector values.
216   if (IsPointer)
217     BuiltinStr += "*";
218 }
219 
220 void RVVType::initClangBuiltinStr() {
221   assert(isValid() && "RVVType is invalid");
222   assert(isVector() && "Handle Vector type only");
223 
224   ClangBuiltinStr = "__rvv_";
225   switch (ScalarType) {
226   case ScalarTypeKind::Boolean:
227     ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t";
228     return;
229   case ScalarTypeKind::Float:
230     ClangBuiltinStr += "float";
231     break;
232   case ScalarTypeKind::SignedInteger:
233     ClangBuiltinStr += "int";
234     break;
235   case ScalarTypeKind::UnsignedInteger:
236     ClangBuiltinStr += "uint";
237     break;
238   default:
239     llvm_unreachable("ScalarTypeKind is invalid");
240   }
241   ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
242 }
243 
244 void RVVType::initTypeStr() {
245   assert(isValid() && "RVVType is invalid");
246 
247   if (IsConstant)
248     Str += "const ";
249 
250   auto getTypeString = [&](StringRef TypeStr) {
251     if (isScalar())
252       return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
253     return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
254         .str();
255   };
256 
257   switch (ScalarType) {
258   case ScalarTypeKind::Void:
259     Str = "void";
260     return;
261   case ScalarTypeKind::Size_t:
262     Str = "size_t";
263     if (IsPointer)
264       Str += " *";
265     return;
266   case ScalarTypeKind::Ptrdiff_t:
267     Str = "ptrdiff_t";
268     return;
269   case ScalarTypeKind::UnsignedLong:
270     Str = "unsigned long";
271     return;
272   case ScalarTypeKind::SignedLong:
273     Str = "long";
274     return;
275   case ScalarTypeKind::Boolean:
276     if (isScalar())
277       Str += "bool";
278     else
279       // Vector bool is special case, the formulate is
280       // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
281       Str += "vbool" + utostr(64 / *Scale) + "_t";
282     break;
283   case ScalarTypeKind::Float:
284     if (isScalar()) {
285       if (ElementBitwidth == 64)
286         Str += "double";
287       else if (ElementBitwidth == 32)
288         Str += "float";
289       else if (ElementBitwidth == 16)
290         Str += "_Float16";
291       else
292         llvm_unreachable("Unhandled floating type.");
293     } else
294       Str += getTypeString("float");
295     break;
296   case ScalarTypeKind::SignedInteger:
297     Str += getTypeString("int");
298     break;
299   case ScalarTypeKind::UnsignedInteger:
300     Str += getTypeString("uint");
301     break;
302   default:
303     llvm_unreachable("ScalarType is invalid!");
304   }
305   if (IsPointer)
306     Str += " *";
307 }
308 
309 void RVVType::initShortStr() {
310   switch (ScalarType) {
311   case ScalarTypeKind::Boolean:
312     assert(isVector());
313     ShortStr = "b" + utostr(64 / *Scale);
314     return;
315   case ScalarTypeKind::Float:
316     ShortStr = "f" + utostr(ElementBitwidth);
317     break;
318   case ScalarTypeKind::SignedInteger:
319     ShortStr = "i" + utostr(ElementBitwidth);
320     break;
321   case ScalarTypeKind::UnsignedInteger:
322     ShortStr = "u" + utostr(ElementBitwidth);
323     break;
324   default:
325     llvm_unreachable("Unhandled case!");
326   }
327   if (isVector())
328     ShortStr += LMUL.str();
329 }
330 
331 void RVVType::applyBasicType() {
332   switch (BT) {
333   case BasicType::Int8:
334     ElementBitwidth = 8;
335     ScalarType = ScalarTypeKind::SignedInteger;
336     break;
337   case BasicType::Int16:
338     ElementBitwidth = 16;
339     ScalarType = ScalarTypeKind::SignedInteger;
340     break;
341   case BasicType::Int32:
342     ElementBitwidth = 32;
343     ScalarType = ScalarTypeKind::SignedInteger;
344     break;
345   case BasicType::Int64:
346     ElementBitwidth = 64;
347     ScalarType = ScalarTypeKind::SignedInteger;
348     break;
349   case BasicType::Float16:
350     ElementBitwidth = 16;
351     ScalarType = ScalarTypeKind::Float;
352     break;
353   case BasicType::Float32:
354     ElementBitwidth = 32;
355     ScalarType = ScalarTypeKind::Float;
356     break;
357   case BasicType::Float64:
358     ElementBitwidth = 64;
359     ScalarType = ScalarTypeKind::Float;
360     break;
361   default:
362     llvm_unreachable("Unhandled type code!");
363   }
364   assert(ElementBitwidth != 0 && "Bad element bitwidth!");
365 }
366 
367 Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor(
368     llvm::StringRef PrototypeDescriptorStr) {
369   PrototypeDescriptor PD;
370   BaseTypeModifier PT = BaseTypeModifier::Invalid;
371   VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
372 
373   if (PrototypeDescriptorStr.empty())
374     return PD;
375 
376   // Handle base type modifier
377   auto PType = PrototypeDescriptorStr.back();
378   switch (PType) {
379   case 'e':
380     PT = BaseTypeModifier::Scalar;
381     break;
382   case 'v':
383     PT = BaseTypeModifier::Vector;
384     break;
385   case 'w':
386     PT = BaseTypeModifier::Vector;
387     VTM = VectorTypeModifier::Widening2XVector;
388     break;
389   case 'q':
390     PT = BaseTypeModifier::Vector;
391     VTM = VectorTypeModifier::Widening4XVector;
392     break;
393   case 'o':
394     PT = BaseTypeModifier::Vector;
395     VTM = VectorTypeModifier::Widening8XVector;
396     break;
397   case 'm':
398     PT = BaseTypeModifier::Vector;
399     VTM = VectorTypeModifier::MaskVector;
400     break;
401   case '0':
402     PT = BaseTypeModifier::Void;
403     break;
404   case 'z':
405     PT = BaseTypeModifier::SizeT;
406     break;
407   case 't':
408     PT = BaseTypeModifier::Ptrdiff;
409     break;
410   case 'u':
411     PT = BaseTypeModifier::UnsignedLong;
412     break;
413   case 'l':
414     PT = BaseTypeModifier::SignedLong;
415     break;
416   default:
417     llvm_unreachable("Illegal primitive type transformers!");
418   }
419   PD.PT = static_cast<uint8_t>(PT);
420   PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back();
421 
422   // Compute the vector type transformers, it can only appear one time.
423   if (PrototypeDescriptorStr.startswith("(")) {
424     assert(VTM == VectorTypeModifier::NoModifier &&
425            "VectorTypeModifier should only have one modifier");
426     size_t Idx = PrototypeDescriptorStr.find(')');
427     assert(Idx != StringRef::npos);
428     StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx);
429     PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1);
430     assert(!PrototypeDescriptorStr.contains('(') &&
431            "Only allow one vector type modifier");
432 
433     auto ComplexTT = ComplexType.split(":");
434     if (ComplexTT.first == "Log2EEW") {
435       uint32_t Log2EEW;
436       if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
437         llvm_unreachable("Invalid Log2EEW value!");
438         return None;
439       }
440       switch (Log2EEW) {
441       case 3:
442         VTM = VectorTypeModifier::Log2EEW3;
443         break;
444       case 4:
445         VTM = VectorTypeModifier::Log2EEW4;
446         break;
447       case 5:
448         VTM = VectorTypeModifier::Log2EEW5;
449         break;
450       case 6:
451         VTM = VectorTypeModifier::Log2EEW6;
452         break;
453       default:
454         llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
455         return None;
456       }
457     } else if (ComplexTT.first == "FixedSEW") {
458       uint32_t NewSEW;
459       if (ComplexTT.second.getAsInteger(10, NewSEW)) {
460         llvm_unreachable("Invalid FixedSEW value!");
461         return None;
462       }
463       switch (NewSEW) {
464       case 8:
465         VTM = VectorTypeModifier::FixedSEW8;
466         break;
467       case 16:
468         VTM = VectorTypeModifier::FixedSEW16;
469         break;
470       case 32:
471         VTM = VectorTypeModifier::FixedSEW32;
472         break;
473       case 64:
474         VTM = VectorTypeModifier::FixedSEW64;
475         break;
476       default:
477         llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
478         return None;
479       }
480     } else if (ComplexTT.first == "LFixedLog2LMUL") {
481       int32_t Log2LMUL;
482       if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
483         llvm_unreachable("Invalid LFixedLog2LMUL value!");
484         return None;
485       }
486       switch (Log2LMUL) {
487       case -3:
488         VTM = VectorTypeModifier::LFixedLog2LMULN3;
489         break;
490       case -2:
491         VTM = VectorTypeModifier::LFixedLog2LMULN2;
492         break;
493       case -1:
494         VTM = VectorTypeModifier::LFixedLog2LMULN1;
495         break;
496       case 0:
497         VTM = VectorTypeModifier::LFixedLog2LMUL0;
498         break;
499       case 1:
500         VTM = VectorTypeModifier::LFixedLog2LMUL1;
501         break;
502       case 2:
503         VTM = VectorTypeModifier::LFixedLog2LMUL2;
504         break;
505       case 3:
506         VTM = VectorTypeModifier::LFixedLog2LMUL3;
507         break;
508       default:
509         llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
510         return None;
511       }
512     } else if (ComplexTT.first == "SFixedLog2LMUL") {
513       int32_t Log2LMUL;
514       if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
515         llvm_unreachable("Invalid SFixedLog2LMUL value!");
516         return None;
517       }
518       switch (Log2LMUL) {
519       case -3:
520         VTM = VectorTypeModifier::SFixedLog2LMULN3;
521         break;
522       case -2:
523         VTM = VectorTypeModifier::SFixedLog2LMULN2;
524         break;
525       case -1:
526         VTM = VectorTypeModifier::SFixedLog2LMULN1;
527         break;
528       case 0:
529         VTM = VectorTypeModifier::SFixedLog2LMUL0;
530         break;
531       case 1:
532         VTM = VectorTypeModifier::SFixedLog2LMUL1;
533         break;
534       case 2:
535         VTM = VectorTypeModifier::SFixedLog2LMUL2;
536         break;
537       case 3:
538         VTM = VectorTypeModifier::SFixedLog2LMUL3;
539         break;
540       default:
541         llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
542         return None;
543       }
544 
545     } else {
546       llvm_unreachable("Illegal complex type transformers!");
547     }
548   }
549   PD.VTM = static_cast<uint8_t>(VTM);
550 
551   // Compute the remain type transformers
552   TypeModifier TM = TypeModifier::NoModifier;
553   for (char I : PrototypeDescriptorStr) {
554     switch (I) {
555     case 'P':
556       if ((TM & TypeModifier::Const) == TypeModifier::Const)
557         llvm_unreachable("'P' transformer cannot be used after 'C'");
558       if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
559         llvm_unreachable("'P' transformer cannot be used twice");
560       TM |= TypeModifier::Pointer;
561       break;
562     case 'C':
563       TM |= TypeModifier::Const;
564       break;
565     case 'K':
566       TM |= TypeModifier::Immediate;
567       break;
568     case 'U':
569       TM |= TypeModifier::UnsignedInteger;
570       break;
571     case 'I':
572       TM |= TypeModifier::SignedInteger;
573       break;
574     case 'F':
575       TM |= TypeModifier::Float;
576       break;
577     case 'S':
578       TM |= TypeModifier::LMUL1;
579       break;
580     default:
581       llvm_unreachable("Illegal non-primitive type transformer!");
582     }
583   }
584   PD.TM = static_cast<uint8_t>(TM);
585 
586   return PD;
587 }
588 
589 void RVVType::applyModifier(const PrototypeDescriptor &Transformer) {
590   // Handle primitive type transformer
591   switch (static_cast<BaseTypeModifier>(Transformer.PT)) {
592   case BaseTypeModifier::Scalar:
593     Scale = 0;
594     break;
595   case BaseTypeModifier::Vector:
596     Scale = LMUL.getScale(ElementBitwidth);
597     break;
598   case BaseTypeModifier::Void:
599     ScalarType = ScalarTypeKind::Void;
600     break;
601   case BaseTypeModifier::SizeT:
602     ScalarType = ScalarTypeKind::Size_t;
603     break;
604   case BaseTypeModifier::Ptrdiff:
605     ScalarType = ScalarTypeKind::Ptrdiff_t;
606     break;
607   case BaseTypeModifier::UnsignedLong:
608     ScalarType = ScalarTypeKind::UnsignedLong;
609     break;
610   case BaseTypeModifier::SignedLong:
611     ScalarType = ScalarTypeKind::SignedLong;
612     break;
613   case BaseTypeModifier::Invalid:
614     ScalarType = ScalarTypeKind::Invalid;
615     return;
616   }
617 
618   switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
619   case VectorTypeModifier::Widening2XVector:
620     ElementBitwidth *= 2;
621     LMUL.MulLog2LMUL(1);
622     Scale = LMUL.getScale(ElementBitwidth);
623     break;
624   case VectorTypeModifier::Widening4XVector:
625     ElementBitwidth *= 4;
626     LMUL.MulLog2LMUL(2);
627     Scale = LMUL.getScale(ElementBitwidth);
628     break;
629   case VectorTypeModifier::Widening8XVector:
630     ElementBitwidth *= 8;
631     LMUL.MulLog2LMUL(3);
632     Scale = LMUL.getScale(ElementBitwidth);
633     break;
634   case VectorTypeModifier::MaskVector:
635     ScalarType = ScalarTypeKind::Boolean;
636     Scale = LMUL.getScale(ElementBitwidth);
637     ElementBitwidth = 1;
638     break;
639   case VectorTypeModifier::Log2EEW3:
640     applyLog2EEW(3);
641     break;
642   case VectorTypeModifier::Log2EEW4:
643     applyLog2EEW(4);
644     break;
645   case VectorTypeModifier::Log2EEW5:
646     applyLog2EEW(5);
647     break;
648   case VectorTypeModifier::Log2EEW6:
649     applyLog2EEW(6);
650     break;
651   case VectorTypeModifier::FixedSEW8:
652     applyFixedSEW(8);
653     break;
654   case VectorTypeModifier::FixedSEW16:
655     applyFixedSEW(16);
656     break;
657   case VectorTypeModifier::FixedSEW32:
658     applyFixedSEW(32);
659     break;
660   case VectorTypeModifier::FixedSEW64:
661     applyFixedSEW(64);
662     break;
663   case VectorTypeModifier::LFixedLog2LMULN3:
664     applyFixedLog2LMUL(-3, FixedLMULType::LargerThan);
665     break;
666   case VectorTypeModifier::LFixedLog2LMULN2:
667     applyFixedLog2LMUL(-2, FixedLMULType::LargerThan);
668     break;
669   case VectorTypeModifier::LFixedLog2LMULN1:
670     applyFixedLog2LMUL(-1, FixedLMULType::LargerThan);
671     break;
672   case VectorTypeModifier::LFixedLog2LMUL0:
673     applyFixedLog2LMUL(0, FixedLMULType::LargerThan);
674     break;
675   case VectorTypeModifier::LFixedLog2LMUL1:
676     applyFixedLog2LMUL(1, FixedLMULType::LargerThan);
677     break;
678   case VectorTypeModifier::LFixedLog2LMUL2:
679     applyFixedLog2LMUL(2, FixedLMULType::LargerThan);
680     break;
681   case VectorTypeModifier::LFixedLog2LMUL3:
682     applyFixedLog2LMUL(3, FixedLMULType::LargerThan);
683     break;
684   case VectorTypeModifier::SFixedLog2LMULN3:
685     applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan);
686     break;
687   case VectorTypeModifier::SFixedLog2LMULN2:
688     applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan);
689     break;
690   case VectorTypeModifier::SFixedLog2LMULN1:
691     applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan);
692     break;
693   case VectorTypeModifier::SFixedLog2LMUL0:
694     applyFixedLog2LMUL(0, FixedLMULType::SmallerThan);
695     break;
696   case VectorTypeModifier::SFixedLog2LMUL1:
697     applyFixedLog2LMUL(1, FixedLMULType::SmallerThan);
698     break;
699   case VectorTypeModifier::SFixedLog2LMUL2:
700     applyFixedLog2LMUL(2, FixedLMULType::SmallerThan);
701     break;
702   case VectorTypeModifier::SFixedLog2LMUL3:
703     applyFixedLog2LMUL(3, FixedLMULType::SmallerThan);
704     break;
705   case VectorTypeModifier::NoModifier:
706     break;
707   }
708 
709   for (unsigned TypeModifierMaskShift = 0;
710        TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
711        ++TypeModifierMaskShift) {
712     unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
713     if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
714         TypeModifierMask)
715       continue;
716     switch (static_cast<TypeModifier>(TypeModifierMask)) {
717     case TypeModifier::Pointer:
718       IsPointer = true;
719       break;
720     case TypeModifier::Const:
721       IsConstant = true;
722       break;
723     case TypeModifier::Immediate:
724       IsImmediate = true;
725       IsConstant = true;
726       break;
727     case TypeModifier::UnsignedInteger:
728       ScalarType = ScalarTypeKind::UnsignedInteger;
729       break;
730     case TypeModifier::SignedInteger:
731       ScalarType = ScalarTypeKind::SignedInteger;
732       break;
733     case TypeModifier::Float:
734       ScalarType = ScalarTypeKind::Float;
735       break;
736     case TypeModifier::LMUL1:
737       LMUL = LMULType(0);
738       // Update ElementBitwidth need to update Scale too.
739       Scale = LMUL.getScale(ElementBitwidth);
740       break;
741     default:
742       llvm_unreachable("Unknown type modifier mask!");
743     }
744   }
745 }
746 
747 void RVVType::applyLog2EEW(unsigned Log2EEW) {
748   // update new elmul = (eew/sew) * lmul
749   LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
750   // update new eew
751   ElementBitwidth = 1 << Log2EEW;
752   ScalarType = ScalarTypeKind::SignedInteger;
753   Scale = LMUL.getScale(ElementBitwidth);
754 }
755 
756 void RVVType::applyFixedSEW(unsigned NewSEW) {
757   // Set invalid type if src and dst SEW are same.
758   if (ElementBitwidth == NewSEW) {
759     ScalarType = ScalarTypeKind::Invalid;
760     return;
761   }
762   // Update new SEW
763   ElementBitwidth = NewSEW;
764   Scale = LMUL.getScale(ElementBitwidth);
765 }
766 
767 void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
768   switch (Type) {
769   case FixedLMULType::LargerThan:
770     if (Log2LMUL < LMUL.Log2LMUL) {
771       ScalarType = ScalarTypeKind::Invalid;
772       return;
773     }
774     break;
775   case FixedLMULType::SmallerThan:
776     if (Log2LMUL > LMUL.Log2LMUL) {
777       ScalarType = ScalarTypeKind::Invalid;
778       return;
779     }
780     break;
781   }
782 
783   // Update new LMUL
784   LMUL = LMULType(Log2LMUL);
785   Scale = LMUL.getScale(ElementBitwidth);
786 }
787 
788 Optional<RVVTypes>
789 RVVType::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
790                       ArrayRef<PrototypeDescriptor> Prototype) {
791   // LMUL x NF must be less than or equal to 8.
792   if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
793     return llvm::None;
794 
795   RVVTypes Types;
796   for (const PrototypeDescriptor &Proto : Prototype) {
797     auto T = computeType(BT, Log2LMUL, Proto);
798     if (!T)
799       return llvm::None;
800     // Record legal type index
801     Types.push_back(T.value());
802   }
803   return Types;
804 }
805 
806 // Compute the hash value of RVVType, used for cache the result of computeType.
807 static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
808                                         PrototypeDescriptor Proto) {
809   // Layout of hash value:
810   // 0               8    16          24        32          40
811   // | Log2LMUL + 3  | BT  | Proto.PT | Proto.TM | Proto.VTM |
812   assert(Log2LMUL >= -3 && Log2LMUL <= 3);
813   return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 |
814          ((uint64_t)(Proto.PT & 0xff) << 16) |
815          ((uint64_t)(Proto.TM & 0xff) << 24) |
816          ((uint64_t)(Proto.VTM & 0xff) << 32);
817 }
818 
819 Optional<RVVTypePtr> RVVType::computeType(BasicType BT, int Log2LMUL,
820                                           PrototypeDescriptor Proto) {
821   // Concat BasicType, LMUL and Proto as key
822   static std::unordered_map<uint64_t, RVVType> LegalTypes;
823   static std::set<uint64_t> IllegalTypes;
824   uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
825   // Search first
826   auto It = LegalTypes.find(Idx);
827   if (It != LegalTypes.end())
828     return &(It->second);
829 
830   if (IllegalTypes.count(Idx))
831     return llvm::None;
832 
833   // Compute type and record the result.
834   RVVType T(BT, Log2LMUL, Proto);
835   if (T.isValid()) {
836     // Record legal type index and value.
837     LegalTypes.insert({Idx, T});
838     return &(LegalTypes[Idx]);
839   }
840   // Record illegal type index.
841   IllegalTypes.insert(Idx);
842   return llvm::None;
843 }
844 
845 //===----------------------------------------------------------------------===//
846 // RVVIntrinsic implementation
847 //===----------------------------------------------------------------------===//
848 RVVIntrinsic::RVVIntrinsic(
849     StringRef NewName, StringRef Suffix, StringRef NewOverloadedName,
850     StringRef OverloadedSuffix, StringRef IRName, bool IsMasked,
851     bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
852     bool HasUnMaskedOverloaded, bool HasBuiltinAlias, StringRef ManualCodegen,
853     const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes,
854     const std::vector<StringRef> &RequiredFeatures, unsigned NF)
855     : IRName(IRName), IsMasked(IsMasked), HasVL(HasVL), Scheme(Scheme),
856       HasUnMaskedOverloaded(HasUnMaskedOverloaded),
857       HasBuiltinAlias(HasBuiltinAlias), ManualCodegen(ManualCodegen.str()),
858       NF(NF) {
859 
860   // Init BuiltinName, Name and OverloadedName
861   BuiltinName = NewName.str();
862   Name = BuiltinName;
863   if (NewOverloadedName.empty())
864     OverloadedName = NewName.split("_").first.str();
865   else
866     OverloadedName = NewOverloadedName.str();
867   if (!Suffix.empty())
868     Name += "_" + Suffix.str();
869   if (!OverloadedSuffix.empty())
870     OverloadedName += "_" + OverloadedSuffix.str();
871   if (IsMasked) {
872     BuiltinName += "_m";
873     Name += "_m";
874   }
875 
876   // Init OutputType and InputTypes
877   OutputType = OutInTypes[0];
878   InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
879 
880   // IntrinsicTypes is unmasked TA version index. Need to update it
881   // if there is merge operand (It is always in first operand).
882   IntrinsicTypes = NewIntrinsicTypes;
883   if ((IsMasked && HasMaskedOffOperand) ||
884       (!IsMasked && hasPassthruOperand())) {
885     for (auto &I : IntrinsicTypes) {
886       if (I >= 0)
887         I += NF;
888     }
889   }
890 }
891 
892 std::string RVVIntrinsic::getBuiltinTypeStr() const {
893   std::string S;
894   S += OutputType->getBuiltinStr();
895   for (const auto &T : InputTypes) {
896     S += T->getBuiltinStr();
897   }
898   return S;
899 }
900 
901 std::string RVVIntrinsic::getSuffixStr(
902     BasicType Type, int Log2LMUL,
903     llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) {
904   SmallVector<std::string> SuffixStrs;
905   for (auto PD : PrototypeDescriptors) {
906     auto T = RVVType::computeType(Type, Log2LMUL, PD);
907     SuffixStrs.push_back((*T)->getShortStr());
908   }
909   return join(SuffixStrs, "_");
910 }
911 
912 llvm::SmallVector<PrototypeDescriptor>
913 RVVIntrinsic::computeBuiltinTypes(llvm::ArrayRef<PrototypeDescriptor> Prototype,
914                                   bool IsMasked, bool HasMaskedOffOperand,
915                                   bool HasVL, unsigned NF) {
916   SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(),
917                                                 Prototype.end());
918   if (IsMasked) {
919     // If HasMaskedOffOperand, insert result type as first input operand.
920     if (HasMaskedOffOperand) {
921       if (NF == 1) {
922         NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]);
923       } else {
924         // Convert
925         // (void, op0 address, op1 address, ...)
926         // to
927         // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
928         PrototypeDescriptor MaskoffType = NewPrototype[1];
929         MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
930         for (unsigned I = 0; I < NF; ++I)
931           NewPrototype.insert(NewPrototype.begin() + NF + 1, MaskoffType);
932       }
933     }
934     if (HasMaskedOffOperand && NF > 1) {
935       // Convert
936       // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
937       // to
938       // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
939       // ...)
940       NewPrototype.insert(NewPrototype.begin() + NF + 1,
941                           PrototypeDescriptor::Mask);
942     } else {
943       // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
944       NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask);
945     }
946   }
947 
948   // If HasVL, append PrototypeDescriptor:VL to last operand
949   if (HasVL)
950     NewPrototype.push_back(PrototypeDescriptor::VL);
951   return NewPrototype;
952 }
953 
954 SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
955   SmallVector<PrototypeDescriptor> PrototypeDescriptors;
956   const StringRef Primaries("evwqom0ztul");
957   while (!Prototypes.empty()) {
958     size_t Idx = 0;
959     // Skip over complex prototype because it could contain primitive type
960     // character.
961     if (Prototypes[0] == '(')
962       Idx = Prototypes.find_first_of(')');
963     Idx = Prototypes.find_first_of(Primaries, Idx);
964     assert(Idx != StringRef::npos);
965     auto PD = PrototypeDescriptor::parsePrototypeDescriptor(
966         Prototypes.slice(0, Idx + 1));
967     if (!PD)
968       llvm_unreachable("Error during parsing prototype.");
969     PrototypeDescriptors.push_back(*PD);
970     Prototypes = Prototypes.drop_front(Idx + 1);
971   }
972   return PrototypeDescriptors;
973 }
974 
975 raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) {
976   OS << "{";
977   OS << "\"" << Record.Name << "\",";
978   if (Record.OverloadedName == nullptr ||
979       StringRef(Record.OverloadedName).empty())
980     OS << "nullptr,";
981   else
982     OS << "\"" << Record.OverloadedName << "\",";
983   OS << Record.PrototypeIndex << ",";
984   OS << Record.SuffixIndex << ",";
985   OS << Record.OverloadedSuffixIndex << ",";
986   OS << (int)Record.PrototypeLength << ",";
987   OS << (int)Record.SuffixLength << ",";
988   OS << (int)Record.OverloadedSuffixSize << ",";
989   OS << (int)Record.RequiredExtensions << ",";
990   OS << (int)Record.TypeRangeMask << ",";
991   OS << (int)Record.Log2LMULMask << ",";
992   OS << (int)Record.NF << ",";
993   OS << (int)Record.HasMasked << ",";
994   OS << (int)Record.HasVL << ",";
995   OS << (int)Record.HasMaskedOffOperand << ",";
996   OS << "},\n";
997   return OS;
998 }
999 
1000 } // end namespace RISCV
1001 } // end namespace clang
1002