xref: /freebsd/contrib/llvm-project/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- CIRAttrs.cpp - MLIR CIR Attributes ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the attributes in the CIR dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "clang/CIR/Dialect/IR/CIRDialect.h"
14 
15 #include "mlir/IR/DialectImplementation.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 
18 //===-----------------------------------------------------------------===//
19 // IntLiteral
20 //===-----------------------------------------------------------------===//
21 
22 static void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
23                             cir::IntTypeInterface ty);
24 static mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser,
25                                          llvm::APInt &value,
26                                          cir::IntTypeInterface ty);
27 //===-----------------------------------------------------------------===//
28 // FloatLiteral
29 //===-----------------------------------------------------------------===//
30 
31 static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
32                               mlir::Type ty);
33 static mlir::ParseResult
34 parseFloatLiteral(mlir::AsmParser &parser,
35                   mlir::FailureOr<llvm::APFloat> &value,
36                   cir::FPTypeInterface fpType);
37 
38 static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser,
39                                        mlir::IntegerAttr &value);
40 
41 static void printConstPtr(mlir::AsmPrinter &p, mlir::IntegerAttr value);
42 
43 #define GET_ATTRDEF_CLASSES
44 #include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
45 
46 using namespace mlir;
47 using namespace cir;
48 
49 //===----------------------------------------------------------------------===//
50 // General CIR parsing / printing
51 //===----------------------------------------------------------------------===//
52 
parseAttribute(DialectAsmParser & parser,Type type) const53 Attribute CIRDialect::parseAttribute(DialectAsmParser &parser,
54                                      Type type) const {
55   llvm::SMLoc typeLoc = parser.getCurrentLocation();
56   llvm::StringRef mnemonic;
57   Attribute genAttr;
58   OptionalParseResult parseResult =
59       generatedAttributeParser(parser, &mnemonic, type, genAttr);
60   if (parseResult.has_value())
61     return genAttr;
62   parser.emitError(typeLoc, "unknown attribute in CIR dialect");
63   return Attribute();
64 }
65 
printAttribute(Attribute attr,DialectAsmPrinter & os) const66 void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
67   if (failed(generatedAttributePrinter(attr, os)))
68     llvm_unreachable("unexpected CIR type kind");
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // OptInfoAttr definitions
73 //===----------------------------------------------------------------------===//
74 
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned level,unsigned size)75 LogicalResult OptInfoAttr::verify(function_ref<InFlightDiagnostic()> emitError,
76                                   unsigned level, unsigned size) {
77   if (level > 3)
78     return emitError()
79            << "optimization level must be between 0 and 3 inclusive";
80   if (size > 2)
81     return emitError()
82            << "size optimization level must be between 0 and 2 inclusive";
83   return success();
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // ConstPtrAttr definitions
88 //===----------------------------------------------------------------------===//
89 
90 // TODO(CIR): Consider encoding the null value differently and use conditional
91 // assembly format instead of custom parsing/printing.
parseConstPtr(AsmParser & parser,mlir::IntegerAttr & value)92 static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) {
93 
94   if (parser.parseOptionalKeyword("null").succeeded()) {
95     value = parser.getBuilder().getI64IntegerAttr(0);
96     return success();
97   }
98 
99   return parser.parseAttribute(value);
100 }
101 
printConstPtr(AsmPrinter & p,mlir::IntegerAttr value)102 static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) {
103   if (!value.getInt())
104     p << "null";
105   else
106     p << value;
107 }
108 
109 //===----------------------------------------------------------------------===//
110 // IntAttr definitions
111 //===----------------------------------------------------------------------===//
112 
113 template <typename IntT>
isTooLargeForType(const mlir::APInt & value,IntT expectedValue)114 static bool isTooLargeForType(const mlir::APInt &value, IntT expectedValue) {
115   if constexpr (std::is_signed_v<IntT>) {
116     return value.getSExtValue() != expectedValue;
117   } else {
118     return value.getZExtValue() != expectedValue;
119   }
120 }
121 
122 template <typename IntT>
parseIntLiteralImpl(mlir::AsmParser & p,llvm::APInt & value,cir::IntTypeInterface ty)123 static mlir::ParseResult parseIntLiteralImpl(mlir::AsmParser &p,
124                                              llvm::APInt &value,
125                                              cir::IntTypeInterface ty) {
126   IntT ivalue;
127   const bool isSigned = ty.isSigned();
128   if (p.parseInteger(ivalue))
129     return p.emitError(p.getCurrentLocation(), "expected integer value");
130 
131   value = mlir::APInt(ty.getWidth(), ivalue, isSigned, /*implicitTrunc=*/true);
132   if (isTooLargeForType(value, ivalue))
133     return p.emitError(p.getCurrentLocation(),
134                        "integer value too large for the given type");
135 
136   return success();
137 }
138 
parseIntLiteral(mlir::AsmParser & parser,llvm::APInt & value,cir::IntTypeInterface ty)139 mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, llvm::APInt &value,
140                                   cir::IntTypeInterface ty) {
141   if (ty.isSigned())
142     return parseIntLiteralImpl<int64_t>(parser, value, ty);
143   return parseIntLiteralImpl<uint64_t>(parser, value, ty);
144 }
145 
printIntLiteral(mlir::AsmPrinter & p,llvm::APInt value,cir::IntTypeInterface ty)146 void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
147                      cir::IntTypeInterface ty) {
148   if (ty.isSigned())
149     p << value.getSExtValue();
150   else
151     p << value.getZExtValue();
152 }
153 
verify(function_ref<InFlightDiagnostic ()> emitError,cir::IntTypeInterface type,llvm::APInt value)154 LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
155                               cir::IntTypeInterface type, llvm::APInt value) {
156   if (value.getBitWidth() != type.getWidth())
157     return emitError() << "type and value bitwidth mismatch: "
158                        << type.getWidth() << " != " << value.getBitWidth();
159   return success();
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // FPAttr definitions
164 //===----------------------------------------------------------------------===//
165 
printFloatLiteral(AsmPrinter & p,APFloat value,Type ty)166 static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) {
167   p << value;
168 }
169 
parseFloatLiteral(AsmParser & parser,FailureOr<APFloat> & value,cir::FPTypeInterface fpType)170 static ParseResult parseFloatLiteral(AsmParser &parser,
171                                      FailureOr<APFloat> &value,
172                                      cir::FPTypeInterface fpType) {
173 
174   APFloat parsedValue(0.0);
175   if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue))
176     return failure();
177 
178   value.emplace(parsedValue);
179   return success();
180 }
181 
getZero(Type type)182 FPAttr FPAttr::getZero(Type type) {
183   return get(type,
184              APFloat::getZero(
185                  mlir::cast<cir::FPTypeInterface>(type).getFloatSemantics()));
186 }
187 
verify(function_ref<InFlightDiagnostic ()> emitError,cir::FPTypeInterface fpType,APFloat value)188 LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
189                              cir::FPTypeInterface fpType, APFloat value) {
190   if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) !=
191       APFloat::SemanticsToEnum(value.getSemantics()))
192     return emitError() << "floating-point semantics mismatch";
193 
194   return success();
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // ConstComplexAttr definitions
199 //===----------------------------------------------------------------------===//
200 
201 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,cir::ComplexType type,mlir::TypedAttr real,mlir::TypedAttr imag)202 ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,
203                          cir::ComplexType type, mlir::TypedAttr real,
204                          mlir::TypedAttr imag) {
205   mlir::Type elemType = type.getElementType();
206   if (real.getType() != elemType)
207     return emitError()
208            << "type of the real part does not match the complex type";
209 
210   if (imag.getType() != elemType)
211     return emitError()
212            << "type of the imaginary part does not match the complex type";
213 
214   return success();
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // CIR ConstArrayAttr
219 //===----------------------------------------------------------------------===//
220 
221 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type type,Attribute elts,int trailingZerosNum)222 ConstArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError, Type type,
223                        Attribute elts, int trailingZerosNum) {
224 
225   if (!(mlir::isa<ArrayAttr, StringAttr>(elts)))
226     return emitError() << "constant array expects ArrayAttr or StringAttr";
227 
228   if (auto strAttr = mlir::dyn_cast<StringAttr>(elts)) {
229     const auto arrayTy = mlir::cast<ArrayType>(type);
230     const auto intTy = mlir::dyn_cast<IntType>(arrayTy.getElementType());
231 
232     // TODO: add CIR type for char.
233     if (!intTy || intTy.getWidth() != 8)
234       return emitError()
235              << "constant array element for string literals expects "
236                 "!cir.int<u, 8> element type";
237     return success();
238   }
239 
240   assert(mlir::isa<ArrayAttr>(elts));
241   const auto arrayAttr = mlir::cast<mlir::ArrayAttr>(elts);
242   const auto arrayTy = mlir::cast<ArrayType>(type);
243 
244   // Make sure both number of elements and subelement types match type.
245   if (arrayTy.getSize() != arrayAttr.size() + trailingZerosNum)
246     return emitError() << "constant array size should match type size";
247   return success();
248 }
249 
parse(AsmParser & parser,Type type)250 Attribute ConstArrayAttr::parse(AsmParser &parser, Type type) {
251   mlir::FailureOr<Type> resultTy;
252   mlir::FailureOr<Attribute> resultVal;
253 
254   // Parse literal '<'
255   if (parser.parseLess())
256     return {};
257 
258   // Parse variable 'value'
259   resultVal = FieldParser<Attribute>::parse(parser);
260   if (failed(resultVal)) {
261     parser.emitError(
262         parser.getCurrentLocation(),
263         "failed to parse ConstArrayAttr parameter 'value' which is "
264         "to be a `Attribute`");
265     return {};
266   }
267 
268   // ArrayAttrrs have per-element type, not the type of the array...
269   if (mlir::isa<ArrayAttr>(*resultVal)) {
270     // Array has implicit type: infer from const array type.
271     if (parser.parseOptionalColon().failed()) {
272       resultTy = type;
273     } else { // Array has explicit type: parse it.
274       resultTy = FieldParser<Type>::parse(parser);
275       if (failed(resultTy)) {
276         parser.emitError(
277             parser.getCurrentLocation(),
278             "failed to parse ConstArrayAttr parameter 'type' which is "
279             "to be a `::mlir::Type`");
280         return {};
281       }
282     }
283   } else {
284     auto ta = mlir::cast<TypedAttr>(*resultVal);
285     resultTy = ta.getType();
286     if (mlir::isa<mlir::NoneType>(*resultTy)) {
287       parser.emitError(parser.getCurrentLocation(),
288                        "expected type declaration for string literal");
289       return {};
290     }
291   }
292 
293   unsigned zeros = 0;
294   if (parser.parseOptionalComma().succeeded()) {
295     if (parser.parseOptionalKeyword("trailing_zeros").succeeded()) {
296       unsigned typeSize =
297           mlir::cast<cir::ArrayType>(resultTy.value()).getSize();
298       mlir::Attribute elts = resultVal.value();
299       if (auto str = mlir::dyn_cast<mlir::StringAttr>(elts))
300         zeros = typeSize - str.size();
301       else
302         zeros = typeSize - mlir::cast<mlir::ArrayAttr>(elts).size();
303     } else {
304       return {};
305     }
306   }
307 
308   // Parse literal '>'
309   if (parser.parseGreater())
310     return {};
311 
312   return parser.getChecked<ConstArrayAttr>(
313       parser.getCurrentLocation(), parser.getContext(), resultTy.value(),
314       resultVal.value(), zeros);
315 }
316 
print(AsmPrinter & printer) const317 void ConstArrayAttr::print(AsmPrinter &printer) const {
318   printer << "<";
319   printer.printStrippedAttrOrType(getElts());
320   if (getTrailingZerosNum())
321     printer << ", trailing_zeros";
322   printer << ">";
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // CIR ConstVectorAttr
327 //===----------------------------------------------------------------------===//
328 
329 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type type,ArrayAttr elts)330 cir::ConstVectorAttr::verify(function_ref<InFlightDiagnostic()> emitError,
331                              Type type, ArrayAttr elts) {
332 
333   if (!mlir::isa<cir::VectorType>(type))
334     return emitError() << "type of cir::ConstVectorAttr is not a "
335                           "cir::VectorType: "
336                        << type;
337 
338   const auto vecType = mlir::cast<cir::VectorType>(type);
339 
340   if (vecType.getSize() != elts.size())
341     return emitError()
342            << "number of constant elements should match vector size";
343 
344   // Check if the types of the elements match
345   LogicalResult elementTypeCheck = success();
346   elts.walkImmediateSubElements(
347       [&](Attribute element) {
348         if (elementTypeCheck.failed()) {
349           // An earlier element didn't match
350           return;
351         }
352         auto typedElement = mlir::dyn_cast<TypedAttr>(element);
353         if (!typedElement ||
354             typedElement.getType() != vecType.getElementType()) {
355           elementTypeCheck = failure();
356           emitError() << "constant type should match vector element type";
357         }
358       },
359       [&](Type) {});
360 
361   return elementTypeCheck;
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // CIR Dialect
366 //===----------------------------------------------------------------------===//
367 
registerAttributes()368 void CIRDialect::registerAttributes() {
369   addAttributes<
370 #define GET_ATTRDEF_LIST
371 #include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
372       >();
373 }
374