xref: /freebsd/contrib/llvm-project/clang/lib/CIR/Dialect/IR/CIRTypes.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- CIRTypes.cpp - MLIR CIR Types --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the CIR dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "clang/CIR/Dialect/IR/CIRTypes.h"
14 
15 #include "mlir/IR/DialectImplementation.h"
16 #include "clang/CIR/Dialect/IR/CIRDialect.h"
17 #include "clang/CIR/Dialect/IR/CIRTypesDetails.h"
18 #include "clang/CIR/MissingFeatures.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 //===----------------------------------------------------------------------===//
22 // CIR Helpers
23 //===----------------------------------------------------------------------===//
isSized(mlir::Type ty)24 bool cir::isSized(mlir::Type ty) {
25   if (auto sizedTy = mlir::dyn_cast<cir::SizedTypeInterface>(ty))
26     return sizedTy.isSized();
27   assert(!cir::MissingFeatures::unsizedTypes());
28   return false;
29 }
30 
31 //===----------------------------------------------------------------------===//
32 // CIR Custom Parser/Printer Signatures
33 //===----------------------------------------------------------------------===//
34 
35 static mlir::ParseResult
36 parseFuncTypeParams(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
37                     bool &isVarArg);
38 static void printFuncTypeParams(mlir::AsmPrinter &p,
39                                 mlir::ArrayRef<mlir::Type> params,
40                                 bool isVarArg);
41 
42 //===----------------------------------------------------------------------===//
43 // Get autogenerated stuff
44 //===----------------------------------------------------------------------===//
45 
46 namespace cir {
47 
48 #include "clang/CIR/Dialect/IR/CIRTypeConstraints.cpp.inc"
49 
50 } // namespace cir
51 
52 #define GET_TYPEDEF_CLASSES
53 #include "clang/CIR/Dialect/IR/CIROpsTypes.cpp.inc"
54 
55 using namespace mlir;
56 using namespace cir;
57 
58 //===----------------------------------------------------------------------===//
59 // General CIR parsing / printing
60 //===----------------------------------------------------------------------===//
61 
parseType(DialectAsmParser & parser) const62 Type CIRDialect::parseType(DialectAsmParser &parser) const {
63   llvm::SMLoc typeLoc = parser.getCurrentLocation();
64   llvm::StringRef mnemonic;
65   Type genType;
66 
67   // Try to parse as a tablegen'd type.
68   OptionalParseResult parseResult =
69       generatedTypeParser(parser, &mnemonic, genType);
70   if (parseResult.has_value())
71     return genType;
72 
73   // Type is not tablegen'd: try to parse as a raw C++ type.
74   return StringSwitch<function_ref<Type()>>(mnemonic)
75       .Case("record", [&] { return RecordType::parse(parser); })
76       .Default([&] {
77         parser.emitError(typeLoc) << "unknown CIR type: " << mnemonic;
78         return Type();
79       })();
80 }
81 
printType(Type type,DialectAsmPrinter & os) const82 void CIRDialect::printType(Type type, DialectAsmPrinter &os) const {
83   // Try to print as a tablegen'd type.
84   if (generatedTypePrinter(type, os).succeeded())
85     return;
86 
87   // TODO(CIR) Attempt to print as a raw C++ type.
88   llvm::report_fatal_error("printer is missing a handler for this type");
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // RecordType Definitions
93 //===----------------------------------------------------------------------===//
94 
parse(mlir::AsmParser & parser)95 Type RecordType::parse(mlir::AsmParser &parser) {
96   FailureOr<AsmParser::CyclicParseReset> cyclicParseGuard;
97   const llvm::SMLoc loc = parser.getCurrentLocation();
98   const mlir::Location eLoc = parser.getEncodedSourceLoc(loc);
99   bool packed = false;
100   bool padded = false;
101   RecordKind kind;
102   mlir::MLIRContext *context = parser.getContext();
103 
104   if (parser.parseLess())
105     return {};
106 
107   // TODO(cir): in the future we should probably separate types for different
108   // source language declarations such as cir.record and cir.union
109   if (parser.parseOptionalKeyword("struct").succeeded())
110     kind = RecordKind::Struct;
111   else if (parser.parseOptionalKeyword("union").succeeded())
112     kind = RecordKind::Union;
113   else if (parser.parseOptionalKeyword("class").succeeded())
114     kind = RecordKind::Class;
115   else {
116     parser.emitError(loc, "unknown record type");
117     return {};
118   }
119 
120   mlir::StringAttr name;
121   parser.parseOptionalAttribute(name);
122 
123   // Is a self reference: ensure referenced type was parsed.
124   if (name && parser.parseOptionalGreater().succeeded()) {
125     RecordType type = getChecked(eLoc, context, name, kind);
126     if (succeeded(parser.tryStartCyclicParse(type))) {
127       parser.emitError(loc, "invalid self-reference within record");
128       return {};
129     }
130     return type;
131   }
132 
133   // Is a named record definition: ensure name has not been parsed yet.
134   if (name) {
135     RecordType type = getChecked(eLoc, context, name, kind);
136     cyclicParseGuard = parser.tryStartCyclicParse(type);
137     if (failed(cyclicParseGuard)) {
138       parser.emitError(loc, "record already defined");
139       return {};
140     }
141   }
142 
143   if (parser.parseOptionalKeyword("packed").succeeded())
144     packed = true;
145 
146   if (parser.parseOptionalKeyword("padded").succeeded())
147     padded = true;
148 
149   // Parse record members or lack thereof.
150   bool incomplete = true;
151   llvm::SmallVector<mlir::Type> members;
152   if (parser.parseOptionalKeyword("incomplete").failed()) {
153     incomplete = false;
154     const auto delimiter = AsmParser::Delimiter::Braces;
155     const auto parseElementFn = [&parser, &members]() {
156       return parser.parseType(members.emplace_back());
157     };
158     if (parser.parseCommaSeparatedList(delimiter, parseElementFn).failed())
159       return {};
160   }
161 
162   if (parser.parseGreater())
163     return {};
164 
165   // Try to create the proper record type.
166   ArrayRef<mlir::Type> membersRef(members); // Needed for template deduction.
167   mlir::Type type = {};
168   if (name && incomplete) { // Identified & incomplete
169     type = getChecked(eLoc, context, name, kind);
170   } else if (!name && !incomplete) { // Anonymous & complete
171     type = getChecked(eLoc, context, membersRef, packed, padded, kind);
172   } else if (!incomplete) { // Identified & complete
173     type = getChecked(eLoc, context, membersRef, name, packed, padded, kind);
174     // If the record has a self-reference, its type already exists in a
175     // incomplete state. In this case, we must complete it.
176     if (mlir::cast<RecordType>(type).isIncomplete())
177       mlir::cast<RecordType>(type).complete(membersRef, packed, padded);
178     assert(!cir::MissingFeatures::astRecordDeclAttr());
179   } else { // anonymous & incomplete
180     parser.emitError(loc, "anonymous records must be complete");
181     return {};
182   }
183 
184   return type;
185 }
186 
print(mlir::AsmPrinter & printer) const187 void RecordType::print(mlir::AsmPrinter &printer) const {
188   FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrintGuard;
189   printer << '<';
190 
191   switch (getKind()) {
192   case RecordKind::Struct:
193     printer << "struct ";
194     break;
195   case RecordKind::Union:
196     printer << "union ";
197     break;
198   case RecordKind::Class:
199     printer << "class ";
200     break;
201   }
202 
203   if (getName())
204     printer << getName();
205 
206   // Current type has already been printed: print as self reference.
207   cyclicPrintGuard = printer.tryStartCyclicPrint(*this);
208   if (failed(cyclicPrintGuard)) {
209     printer << '>';
210     return;
211   }
212 
213   // Type not yet printed: continue printing the entire record.
214   printer << ' ';
215 
216   if (getPacked())
217     printer << "packed ";
218 
219   if (getPadded())
220     printer << "padded ";
221 
222   if (isIncomplete()) {
223     printer << "incomplete";
224   } else {
225     printer << "{";
226     llvm::interleaveComma(getMembers(), printer);
227     printer << "}";
228   }
229 
230   printer << '>';
231 }
232 
233 mlir::LogicalResult
verify(function_ref<mlir::InFlightDiagnostic ()> emitError,llvm::ArrayRef<mlir::Type> members,mlir::StringAttr name,bool incomplete,bool packed,bool padded,RecordType::RecordKind kind)234 RecordType::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
235                    llvm::ArrayRef<mlir::Type> members, mlir::StringAttr name,
236                    bool incomplete, bool packed, bool padded,
237                    RecordType::RecordKind kind) {
238   if (name && name.getValue().empty())
239     return emitError() << "identified records cannot have an empty name";
240   return mlir::success();
241 }
242 
getMembers() const243 ::llvm::ArrayRef<mlir::Type> RecordType::getMembers() const {
244   return getImpl()->members;
245 }
246 
isIncomplete() const247 bool RecordType::isIncomplete() const { return getImpl()->incomplete; }
248 
getName() const249 mlir::StringAttr RecordType::getName() const { return getImpl()->name; }
250 
getIncomplete() const251 bool RecordType::getIncomplete() const { return getImpl()->incomplete; }
252 
getPacked() const253 bool RecordType::getPacked() const { return getImpl()->packed; }
254 
getPadded() const255 bool RecordType::getPadded() const { return getImpl()->padded; }
256 
getKind() const257 cir::RecordType::RecordKind RecordType::getKind() const {
258   return getImpl()->kind;
259 }
260 
complete(ArrayRef<Type> members,bool packed,bool padded)261 void RecordType::complete(ArrayRef<Type> members, bool packed, bool padded) {
262   assert(!cir::MissingFeatures::astRecordDeclAttr());
263   if (mutate(members, packed, padded).failed())
264     llvm_unreachable("failed to complete record");
265 }
266 
267 /// Return the largest member of in the type.
268 ///
269 /// Recurses into union members never returning a union as the largest member.
getLargestMember(const::mlir::DataLayout & dataLayout) const270 Type RecordType::getLargestMember(const ::mlir::DataLayout &dataLayout) const {
271   assert(isUnion() && "Only call getLargestMember on unions");
272   llvm::ArrayRef<Type> members = getMembers();
273   // If the union is padded, we need to ignore the last member,
274   // which is the padding.
275   return *std::max_element(
276       members.begin(), getPadded() ? members.end() - 1 : members.end(),
277       [&](Type lhs, Type rhs) {
278         return dataLayout.getTypeABIAlignment(lhs) <
279                    dataLayout.getTypeABIAlignment(rhs) ||
280                (dataLayout.getTypeABIAlignment(lhs) ==
281                     dataLayout.getTypeABIAlignment(rhs) &&
282                 dataLayout.getTypeSize(lhs) < dataLayout.getTypeSize(rhs));
283       });
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // Data Layout information for types
288 //===----------------------------------------------------------------------===//
289 
290 llvm::TypeSize
getTypeSizeInBits(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const291 RecordType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
292                               mlir::DataLayoutEntryListRef params) const {
293   if (isUnion())
294     return dataLayout.getTypeSize(getLargestMember(dataLayout));
295 
296   unsigned recordSize = computeStructSize(dataLayout);
297   return llvm::TypeSize::getFixed(recordSize * 8);
298 }
299 
300 uint64_t
getABIAlignment(const::mlir::DataLayout & dataLayout,::mlir::DataLayoutEntryListRef params) const301 RecordType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
302                             ::mlir::DataLayoutEntryListRef params) const {
303   if (isUnion())
304     return dataLayout.getTypeABIAlignment(getLargestMember(dataLayout));
305 
306   // Packed structures always have an ABI alignment of 1.
307   if (getPacked())
308     return 1;
309   return computeStructAlignment(dataLayout);
310 }
311 
312 unsigned
computeStructSize(const mlir::DataLayout & dataLayout) const313 RecordType::computeStructSize(const mlir::DataLayout &dataLayout) const {
314   assert(isComplete() && "Cannot get layout of incomplete records");
315 
316   // This is a similar algorithm to LLVM's StructLayout.
317   unsigned recordSize = 0;
318   uint64_t recordAlignment = 1;
319 
320   for (mlir::Type ty : getMembers()) {
321     // This assumes that we're calculating size based on the ABI alignment, not
322     // the preferred alignment for each type.
323     const uint64_t tyAlign =
324         (getPacked() ? 1 : dataLayout.getTypeABIAlignment(ty));
325 
326     // Add padding to the struct size to align it to the abi alignment of the
327     // element type before than adding the size of the element.
328     recordSize = llvm::alignTo(recordSize, tyAlign);
329     recordSize += dataLayout.getTypeSize(ty);
330 
331     // The alignment requirement of a struct is equal to the strictest alignment
332     // requirement of its elements.
333     recordAlignment = std::max(tyAlign, recordAlignment);
334   }
335 
336   // At the end, add padding to the struct to satisfy its own alignment
337   // requirement. Otherwise structs inside of arrays would be misaligned.
338   recordSize = llvm::alignTo(recordSize, recordAlignment);
339   return recordSize;
340 }
341 
342 // We also compute the alignment as part of computeStructSize, but this is more
343 // efficient. Ideally, we'd like to compute both at once and cache the result,
344 // but that's implemented yet.
345 // TODO(CIR): Implement a way to cache the result.
346 uint64_t
computeStructAlignment(const mlir::DataLayout & dataLayout) const347 RecordType::computeStructAlignment(const mlir::DataLayout &dataLayout) const {
348   assert(isComplete() && "Cannot get layout of incomplete records");
349 
350   // This is a similar algorithm to LLVM's StructLayout.
351   uint64_t recordAlignment = 1;
352   for (mlir::Type ty : getMembers())
353     recordAlignment =
354         std::max(dataLayout.getTypeABIAlignment(ty), recordAlignment);
355 
356   return recordAlignment;
357 }
358 
getElementOffset(const::mlir::DataLayout & dataLayout,unsigned idx) const359 uint64_t RecordType::getElementOffset(const ::mlir::DataLayout &dataLayout,
360                                       unsigned idx) const {
361   assert(idx < getMembers().size() && "access not valid");
362 
363   // All union elements are at offset zero.
364   if (isUnion() || idx == 0)
365     return 0;
366 
367   assert(isComplete() && "Cannot get layout of incomplete records");
368   assert(idx < getNumElements());
369   llvm::ArrayRef<mlir::Type> members = getMembers();
370 
371   unsigned offset = 0;
372 
373   for (mlir::Type ty :
374        llvm::make_range(members.begin(), std::next(members.begin(), idx))) {
375     // This matches LLVM since it uses the ABI instead of preferred alignment.
376     const llvm::Align tyAlign =
377         llvm::Align(getPacked() ? 1 : dataLayout.getTypeABIAlignment(ty));
378 
379     // Add padding if necessary to align the data element properly.
380     offset = llvm::alignTo(offset, tyAlign);
381 
382     // Consume space for this data item
383     offset += dataLayout.getTypeSize(ty);
384   }
385 
386   // Account for padding, if necessary, for the alignment of the field whose
387   // offset we are calculating.
388   const llvm::Align tyAlign = llvm::Align(
389       getPacked() ? 1 : dataLayout.getTypeABIAlignment(members[idx]));
390   offset = llvm::alignTo(offset, tyAlign);
391 
392   return offset;
393 }
394 
395 //===----------------------------------------------------------------------===//
396 // IntType Definitions
397 //===----------------------------------------------------------------------===//
398 
parse(mlir::AsmParser & parser)399 Type IntType::parse(mlir::AsmParser &parser) {
400   mlir::MLIRContext *context = parser.getBuilder().getContext();
401   llvm::SMLoc loc = parser.getCurrentLocation();
402   bool isSigned;
403   unsigned width;
404 
405   if (parser.parseLess())
406     return {};
407 
408   // Fetch integer sign.
409   llvm::StringRef sign;
410   if (parser.parseKeyword(&sign))
411     return {};
412   if (sign == "s")
413     isSigned = true;
414   else if (sign == "u")
415     isSigned = false;
416   else {
417     parser.emitError(loc, "expected 's' or 'u'");
418     return {};
419   }
420 
421   if (parser.parseComma())
422     return {};
423 
424   // Fetch integer size.
425   if (parser.parseInteger(width))
426     return {};
427   if (width < IntType::minBitwidth() || width > IntType::maxBitwidth()) {
428     parser.emitError(loc, "expected integer width to be from ")
429         << IntType::minBitwidth() << " up to " << IntType::maxBitwidth();
430     return {};
431   }
432 
433   if (parser.parseGreater())
434     return {};
435 
436   return IntType::get(context, width, isSigned);
437 }
438 
print(mlir::AsmPrinter & printer) const439 void IntType::print(mlir::AsmPrinter &printer) const {
440   char sign = isSigned() ? 's' : 'u';
441   printer << '<' << sign << ", " << getWidth() << '>';
442 }
443 
444 llvm::TypeSize
getTypeSizeInBits(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const445 IntType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
446                            mlir::DataLayoutEntryListRef params) const {
447   return llvm::TypeSize::getFixed(getWidth());
448 }
449 
getABIAlignment(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const450 uint64_t IntType::getABIAlignment(const mlir::DataLayout &dataLayout,
451                                   mlir::DataLayoutEntryListRef params) const {
452   return (uint64_t)(getWidth() / 8);
453 }
454 
455 mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,unsigned width,bool isSigned)456 IntType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
457                 unsigned width, bool isSigned) {
458   if (width < IntType::minBitwidth() || width > IntType::maxBitwidth())
459     return emitError() << "IntType only supports widths from "
460                        << IntType::minBitwidth() << " up to "
461                        << IntType::maxBitwidth();
462   return mlir::success();
463 }
464 
isValidFundamentalIntWidth(unsigned width)465 bool cir::isValidFundamentalIntWidth(unsigned width) {
466   return width == 8 || width == 16 || width == 32 || width == 64;
467 }
468 
469 //===----------------------------------------------------------------------===//
470 // Floating-point type definitions
471 //===----------------------------------------------------------------------===//
472 
getFloatSemantics() const473 const llvm::fltSemantics &SingleType::getFloatSemantics() const {
474   return llvm::APFloat::IEEEsingle();
475 }
476 
477 llvm::TypeSize
getTypeSizeInBits(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const478 SingleType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
479                               mlir::DataLayoutEntryListRef params) const {
480   return llvm::TypeSize::getFixed(getWidth());
481 }
482 
483 uint64_t
getABIAlignment(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const484 SingleType::getABIAlignment(const mlir::DataLayout &dataLayout,
485                             mlir::DataLayoutEntryListRef params) const {
486   return (uint64_t)(getWidth() / 8);
487 }
488 
getFloatSemantics() const489 const llvm::fltSemantics &DoubleType::getFloatSemantics() const {
490   return llvm::APFloat::IEEEdouble();
491 }
492 
493 llvm::TypeSize
getTypeSizeInBits(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const494 DoubleType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
495                               mlir::DataLayoutEntryListRef params) const {
496   return llvm::TypeSize::getFixed(getWidth());
497 }
498 
499 uint64_t
getABIAlignment(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const500 DoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
501                             mlir::DataLayoutEntryListRef params) const {
502   return (uint64_t)(getWidth() / 8);
503 }
504 
getFloatSemantics() const505 const llvm::fltSemantics &FP16Type::getFloatSemantics() const {
506   return llvm::APFloat::IEEEhalf();
507 }
508 
509 llvm::TypeSize
getTypeSizeInBits(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const510 FP16Type::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
511                             mlir::DataLayoutEntryListRef params) const {
512   return llvm::TypeSize::getFixed(getWidth());
513 }
514 
getABIAlignment(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const515 uint64_t FP16Type::getABIAlignment(const mlir::DataLayout &dataLayout,
516                                    mlir::DataLayoutEntryListRef params) const {
517   return (uint64_t)(getWidth() / 8);
518 }
519 
getFloatSemantics() const520 const llvm::fltSemantics &BF16Type::getFloatSemantics() const {
521   return llvm::APFloat::BFloat();
522 }
523 
524 llvm::TypeSize
getTypeSizeInBits(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const525 BF16Type::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
526                             mlir::DataLayoutEntryListRef params) const {
527   return llvm::TypeSize::getFixed(getWidth());
528 }
529 
getABIAlignment(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const530 uint64_t BF16Type::getABIAlignment(const mlir::DataLayout &dataLayout,
531                                    mlir::DataLayoutEntryListRef params) const {
532   return (uint64_t)(getWidth() / 8);
533 }
534 
getFloatSemantics() const535 const llvm::fltSemantics &FP80Type::getFloatSemantics() const {
536   return llvm::APFloat::x87DoubleExtended();
537 }
538 
539 llvm::TypeSize
getTypeSizeInBits(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const540 FP80Type::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
541                             mlir::DataLayoutEntryListRef params) const {
542   // Though only 80 bits are used for the value, the type is 128 bits in size.
543   return llvm::TypeSize::getFixed(128);
544 }
545 
getABIAlignment(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const546 uint64_t FP80Type::getABIAlignment(const mlir::DataLayout &dataLayout,
547                                    mlir::DataLayoutEntryListRef params) const {
548   return 16;
549 }
550 
getFloatSemantics() const551 const llvm::fltSemantics &FP128Type::getFloatSemantics() const {
552   return llvm::APFloat::IEEEquad();
553 }
554 
555 llvm::TypeSize
getTypeSizeInBits(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const556 FP128Type::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
557                              mlir::DataLayoutEntryListRef params) const {
558   return llvm::TypeSize::getFixed(getWidth());
559 }
560 
getABIAlignment(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const561 uint64_t FP128Type::getABIAlignment(const mlir::DataLayout &dataLayout,
562                                     mlir::DataLayoutEntryListRef params) const {
563   return 16;
564 }
565 
getFloatSemantics() const566 const llvm::fltSemantics &LongDoubleType::getFloatSemantics() const {
567   return mlir::cast<cir::FPTypeInterface>(getUnderlying()).getFloatSemantics();
568 }
569 
570 llvm::TypeSize
getTypeSizeInBits(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const571 LongDoubleType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
572                                   mlir::DataLayoutEntryListRef params) const {
573   return mlir::cast<mlir::DataLayoutTypeInterface>(getUnderlying())
574       .getTypeSizeInBits(dataLayout, params);
575 }
576 
577 uint64_t
getABIAlignment(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const578 LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
579                                 mlir::DataLayoutEntryListRef params) const {
580   return mlir::cast<mlir::DataLayoutTypeInterface>(getUnderlying())
581       .getABIAlignment(dataLayout, params);
582 }
583 
584 //===----------------------------------------------------------------------===//
585 // ComplexType Definitions
586 //===----------------------------------------------------------------------===//
587 
588 llvm::TypeSize
getTypeSizeInBits(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const589 cir::ComplexType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
590                                     mlir::DataLayoutEntryListRef params) const {
591   // C17 6.2.5p13:
592   //   Each complex type has the same representation and alignment requirements
593   //   as an array type containing exactly two elements of the corresponding
594   //   real type.
595 
596   return dataLayout.getTypeSizeInBits(getElementType()) * 2;
597 }
598 
599 uint64_t
getABIAlignment(const mlir::DataLayout & dataLayout,mlir::DataLayoutEntryListRef params) const600 cir::ComplexType::getABIAlignment(const mlir::DataLayout &dataLayout,
601                                   mlir::DataLayoutEntryListRef params) const {
602   // C17 6.2.5p13:
603   //   Each complex type has the same representation and alignment requirements
604   //   as an array type containing exactly two elements of the corresponding
605   //   real type.
606 
607   return dataLayout.getTypeABIAlignment(getElementType());
608 }
609 
clone(TypeRange inputs,TypeRange results) const610 FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
611   assert(results.size() == 1 && "expected exactly one result type");
612   return get(llvm::to_vector(inputs), results[0], isVarArg());
613 }
614 
615 // Custom parser that parses function parameters of form `(<type>*, ...)`.
616 static mlir::ParseResult
parseFuncTypeParams(mlir::AsmParser & p,llvm::SmallVector<mlir::Type> & params,bool & isVarArg)617 parseFuncTypeParams(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
618                     bool &isVarArg) {
619   isVarArg = false;
620   return p.parseCommaSeparatedList(
621       AsmParser::Delimiter::Paren, [&]() -> mlir::ParseResult {
622         if (isVarArg)
623           return p.emitError(p.getCurrentLocation(),
624                              "variadic `...` must be the last parameter");
625         if (succeeded(p.parseOptionalEllipsis())) {
626           isVarArg = true;
627           return success();
628         }
629         mlir::Type type;
630         if (failed(p.parseType(type)))
631           return failure();
632         params.push_back(type);
633         return success();
634       });
635 }
636 
printFuncTypeParams(mlir::AsmPrinter & p,mlir::ArrayRef<mlir::Type> params,bool isVarArg)637 static void printFuncTypeParams(mlir::AsmPrinter &p,
638                                 mlir::ArrayRef<mlir::Type> params,
639                                 bool isVarArg) {
640   p << '(';
641   llvm::interleaveComma(params, p,
642                         [&p](mlir::Type type) { p.printType(type); });
643   if (isVarArg) {
644     if (!params.empty())
645       p << ", ";
646     p << "...";
647   }
648   p << ')';
649 }
650 
651 /// Get the C-style return type of the function, which is !cir.void if the
652 /// function returns nothing and the actual return type otherwise.
getReturnType() const653 mlir::Type FuncType::getReturnType() const {
654   if (hasVoidReturn())
655     return cir::VoidType::get(getContext());
656   return getOptionalReturnType();
657 }
658 
659 /// Get the MLIR-style return type of the function, which is an empty
660 /// ArrayRef if the function returns nothing and a single-element ArrayRef
661 /// with the actual return type otherwise.
getReturnTypes() const662 llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
663   if (hasVoidReturn())
664     return {};
665   // Can't use getOptionalReturnType() here because llvm::ArrayRef hold a
666   // pointer to its elements and doesn't do lifetime extension.  That would
667   // result in returning a pointer to a temporary that has gone out of scope.
668   return getImpl()->optionalReturnType;
669 }
670 
671 // Does the fuction type return nothing?
hasVoidReturn() const672 bool FuncType::hasVoidReturn() const { return !getOptionalReturnType(); }
673 
674 mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,llvm::ArrayRef<mlir::Type> argTypes,mlir::Type returnType,bool isVarArg)675 FuncType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
676                  llvm::ArrayRef<mlir::Type> argTypes, mlir::Type returnType,
677                  bool isVarArg) {
678   if (mlir::isa_and_nonnull<cir::VoidType>(returnType))
679     return emitError()
680            << "!cir.func cannot have an explicit 'void' return type";
681   return mlir::success();
682 }
683 
684 //===----------------------------------------------------------------------===//
685 // BoolType
686 //===----------------------------------------------------------------------===//
687 
688 llvm::TypeSize
getTypeSizeInBits(const::mlir::DataLayout & dataLayout,::mlir::DataLayoutEntryListRef params) const689 BoolType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
690                             ::mlir::DataLayoutEntryListRef params) const {
691   return llvm::TypeSize::getFixed(8);
692 }
693 
694 uint64_t
getABIAlignment(const::mlir::DataLayout & dataLayout,::mlir::DataLayoutEntryListRef params) const695 BoolType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
696                           ::mlir::DataLayoutEntryListRef params) const {
697   return 1;
698 }
699 
700 //===----------------------------------------------------------------------===//
701 //  ArrayType Definitions
702 //===----------------------------------------------------------------------===//
703 
704 llvm::TypeSize
getTypeSizeInBits(const::mlir::DataLayout & dataLayout,::mlir::DataLayoutEntryListRef params) const705 ArrayType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
706                              ::mlir::DataLayoutEntryListRef params) const {
707   return getSize() * dataLayout.getTypeSizeInBits(getElementType());
708 }
709 
710 uint64_t
getABIAlignment(const::mlir::DataLayout & dataLayout,::mlir::DataLayoutEntryListRef params) const711 ArrayType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
712                            ::mlir::DataLayoutEntryListRef params) const {
713   return dataLayout.getTypeABIAlignment(getElementType());
714 }
715 
716 //===----------------------------------------------------------------------===//
717 // VectorType Definitions
718 //===----------------------------------------------------------------------===//
719 
getTypeSizeInBits(const::mlir::DataLayout & dataLayout,::mlir::DataLayoutEntryListRef params) const720 llvm::TypeSize cir::VectorType::getTypeSizeInBits(
721     const ::mlir::DataLayout &dataLayout,
722     ::mlir::DataLayoutEntryListRef params) const {
723   return llvm::TypeSize::getFixed(
724       getSize() * dataLayout.getTypeSizeInBits(getElementType()));
725 }
726 
727 uint64_t
getABIAlignment(const::mlir::DataLayout & dataLayout,::mlir::DataLayoutEntryListRef params) const728 cir::VectorType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
729                                  ::mlir::DataLayoutEntryListRef params) const {
730   return llvm::NextPowerOf2(dataLayout.getTypeSizeInBits(*this));
731 }
732 
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type elementType,uint64_t size)733 mlir::LogicalResult cir::VectorType::verify(
734     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
735     mlir::Type elementType, uint64_t size) {
736   if (size == 0)
737     return emitError() << "the number of vector elements must be non-zero";
738   return success();
739 }
740 
741 //===----------------------------------------------------------------------===//
742 // PointerType Definitions
743 //===----------------------------------------------------------------------===//
744 
745 llvm::TypeSize
getTypeSizeInBits(const::mlir::DataLayout & dataLayout,::mlir::DataLayoutEntryListRef params) const746 PointerType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
747                                ::mlir::DataLayoutEntryListRef params) const {
748   // FIXME: improve this in face of address spaces
749   return llvm::TypeSize::getFixed(64);
750 }
751 
752 uint64_t
getABIAlignment(const::mlir::DataLayout & dataLayout,::mlir::DataLayoutEntryListRef params) const753 PointerType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
754                              ::mlir::DataLayoutEntryListRef params) const {
755   // FIXME: improve this in face of address spaces
756   return 8;
757 }
758 
759 mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type pointee)760 PointerType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
761                     mlir::Type pointee) {
762   // TODO(CIR): Verification of the address space goes here.
763   return mlir::success();
764 }
765 
766 //===----------------------------------------------------------------------===//
767 // CIR Dialect
768 //===----------------------------------------------------------------------===//
769 
registerTypes()770 void CIRDialect::registerTypes() {
771   // Register tablegen'd types.
772   addTypes<
773 #define GET_TYPEDEF_LIST
774 #include "clang/CIR/Dialect/IR/CIROpsTypes.cpp.inc"
775       >();
776 
777   // Register raw C++ types.
778   // TODO(CIR) addTypes<RecordType>();
779 }
780