1 //===- CIRAttrs.cpp - MLIR CIR Attributes ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the attributes in the CIR dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "clang/CIR/Dialect/IR/CIRDialect.h"
14
15 #include "mlir/IR/DialectImplementation.h"
16 #include "llvm/ADT/TypeSwitch.h"
17
18 //===-----------------------------------------------------------------===//
19 // IntLiteral
20 //===-----------------------------------------------------------------===//
21
22 static void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
23 cir::IntTypeInterface ty);
24 static mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser,
25 llvm::APInt &value,
26 cir::IntTypeInterface ty);
27 //===-----------------------------------------------------------------===//
28 // FloatLiteral
29 //===-----------------------------------------------------------------===//
30
31 static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
32 mlir::Type ty);
33 static mlir::ParseResult
34 parseFloatLiteral(mlir::AsmParser &parser,
35 mlir::FailureOr<llvm::APFloat> &value,
36 cir::FPTypeInterface fpType);
37
38 static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser,
39 mlir::IntegerAttr &value);
40
41 static void printConstPtr(mlir::AsmPrinter &p, mlir::IntegerAttr value);
42
43 #define GET_ATTRDEF_CLASSES
44 #include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
45
46 using namespace mlir;
47 using namespace cir;
48
49 //===----------------------------------------------------------------------===//
50 // General CIR parsing / printing
51 //===----------------------------------------------------------------------===//
52
parseAttribute(DialectAsmParser & parser,Type type) const53 Attribute CIRDialect::parseAttribute(DialectAsmParser &parser,
54 Type type) const {
55 llvm::SMLoc typeLoc = parser.getCurrentLocation();
56 llvm::StringRef mnemonic;
57 Attribute genAttr;
58 OptionalParseResult parseResult =
59 generatedAttributeParser(parser, &mnemonic, type, genAttr);
60 if (parseResult.has_value())
61 return genAttr;
62 parser.emitError(typeLoc, "unknown attribute in CIR dialect");
63 return Attribute();
64 }
65
printAttribute(Attribute attr,DialectAsmPrinter & os) const66 void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
67 if (failed(generatedAttributePrinter(attr, os)))
68 llvm_unreachable("unexpected CIR type kind");
69 }
70
71 //===----------------------------------------------------------------------===//
72 // OptInfoAttr definitions
73 //===----------------------------------------------------------------------===//
74
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned level,unsigned size)75 LogicalResult OptInfoAttr::verify(function_ref<InFlightDiagnostic()> emitError,
76 unsigned level, unsigned size) {
77 if (level > 3)
78 return emitError()
79 << "optimization level must be between 0 and 3 inclusive";
80 if (size > 2)
81 return emitError()
82 << "size optimization level must be between 0 and 2 inclusive";
83 return success();
84 }
85
86 //===----------------------------------------------------------------------===//
87 // ConstPtrAttr definitions
88 //===----------------------------------------------------------------------===//
89
90 // TODO(CIR): Consider encoding the null value differently and use conditional
91 // assembly format instead of custom parsing/printing.
parseConstPtr(AsmParser & parser,mlir::IntegerAttr & value)92 static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) {
93
94 if (parser.parseOptionalKeyword("null").succeeded()) {
95 value = parser.getBuilder().getI64IntegerAttr(0);
96 return success();
97 }
98
99 return parser.parseAttribute(value);
100 }
101
printConstPtr(AsmPrinter & p,mlir::IntegerAttr value)102 static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) {
103 if (!value.getInt())
104 p << "null";
105 else
106 p << value;
107 }
108
109 //===----------------------------------------------------------------------===//
110 // IntAttr definitions
111 //===----------------------------------------------------------------------===//
112
113 template <typename IntT>
isTooLargeForType(const mlir::APInt & value,IntT expectedValue)114 static bool isTooLargeForType(const mlir::APInt &value, IntT expectedValue) {
115 if constexpr (std::is_signed_v<IntT>) {
116 return value.getSExtValue() != expectedValue;
117 } else {
118 return value.getZExtValue() != expectedValue;
119 }
120 }
121
122 template <typename IntT>
parseIntLiteralImpl(mlir::AsmParser & p,llvm::APInt & value,cir::IntTypeInterface ty)123 static mlir::ParseResult parseIntLiteralImpl(mlir::AsmParser &p,
124 llvm::APInt &value,
125 cir::IntTypeInterface ty) {
126 IntT ivalue;
127 const bool isSigned = ty.isSigned();
128 if (p.parseInteger(ivalue))
129 return p.emitError(p.getCurrentLocation(), "expected integer value");
130
131 value = mlir::APInt(ty.getWidth(), ivalue, isSigned, /*implicitTrunc=*/true);
132 if (isTooLargeForType(value, ivalue))
133 return p.emitError(p.getCurrentLocation(),
134 "integer value too large for the given type");
135
136 return success();
137 }
138
parseIntLiteral(mlir::AsmParser & parser,llvm::APInt & value,cir::IntTypeInterface ty)139 mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, llvm::APInt &value,
140 cir::IntTypeInterface ty) {
141 if (ty.isSigned())
142 return parseIntLiteralImpl<int64_t>(parser, value, ty);
143 return parseIntLiteralImpl<uint64_t>(parser, value, ty);
144 }
145
printIntLiteral(mlir::AsmPrinter & p,llvm::APInt value,cir::IntTypeInterface ty)146 void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
147 cir::IntTypeInterface ty) {
148 if (ty.isSigned())
149 p << value.getSExtValue();
150 else
151 p << value.getZExtValue();
152 }
153
verify(function_ref<InFlightDiagnostic ()> emitError,cir::IntTypeInterface type,llvm::APInt value)154 LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
155 cir::IntTypeInterface type, llvm::APInt value) {
156 if (value.getBitWidth() != type.getWidth())
157 return emitError() << "type and value bitwidth mismatch: "
158 << type.getWidth() << " != " << value.getBitWidth();
159 return success();
160 }
161
162 //===----------------------------------------------------------------------===//
163 // FPAttr definitions
164 //===----------------------------------------------------------------------===//
165
printFloatLiteral(AsmPrinter & p,APFloat value,Type ty)166 static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) {
167 p << value;
168 }
169
parseFloatLiteral(AsmParser & parser,FailureOr<APFloat> & value,cir::FPTypeInterface fpType)170 static ParseResult parseFloatLiteral(AsmParser &parser,
171 FailureOr<APFloat> &value,
172 cir::FPTypeInterface fpType) {
173
174 APFloat parsedValue(0.0);
175 if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue))
176 return failure();
177
178 value.emplace(parsedValue);
179 return success();
180 }
181
getZero(Type type)182 FPAttr FPAttr::getZero(Type type) {
183 return get(type,
184 APFloat::getZero(
185 mlir::cast<cir::FPTypeInterface>(type).getFloatSemantics()));
186 }
187
verify(function_ref<InFlightDiagnostic ()> emitError,cir::FPTypeInterface fpType,APFloat value)188 LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
189 cir::FPTypeInterface fpType, APFloat value) {
190 if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) !=
191 APFloat::SemanticsToEnum(value.getSemantics()))
192 return emitError() << "floating-point semantics mismatch";
193
194 return success();
195 }
196
197 //===----------------------------------------------------------------------===//
198 // ConstComplexAttr definitions
199 //===----------------------------------------------------------------------===//
200
201 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,cir::ComplexType type,mlir::TypedAttr real,mlir::TypedAttr imag)202 ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,
203 cir::ComplexType type, mlir::TypedAttr real,
204 mlir::TypedAttr imag) {
205 mlir::Type elemType = type.getElementType();
206 if (real.getType() != elemType)
207 return emitError()
208 << "type of the real part does not match the complex type";
209
210 if (imag.getType() != elemType)
211 return emitError()
212 << "type of the imaginary part does not match the complex type";
213
214 return success();
215 }
216
217 //===----------------------------------------------------------------------===//
218 // CIR ConstArrayAttr
219 //===----------------------------------------------------------------------===//
220
221 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type type,Attribute elts,int trailingZerosNum)222 ConstArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError, Type type,
223 Attribute elts, int trailingZerosNum) {
224
225 if (!(mlir::isa<ArrayAttr, StringAttr>(elts)))
226 return emitError() << "constant array expects ArrayAttr or StringAttr";
227
228 if (auto strAttr = mlir::dyn_cast<StringAttr>(elts)) {
229 const auto arrayTy = mlir::cast<ArrayType>(type);
230 const auto intTy = mlir::dyn_cast<IntType>(arrayTy.getElementType());
231
232 // TODO: add CIR type for char.
233 if (!intTy || intTy.getWidth() != 8)
234 return emitError()
235 << "constant array element for string literals expects "
236 "!cir.int<u, 8> element type";
237 return success();
238 }
239
240 assert(mlir::isa<ArrayAttr>(elts));
241 const auto arrayAttr = mlir::cast<mlir::ArrayAttr>(elts);
242 const auto arrayTy = mlir::cast<ArrayType>(type);
243
244 // Make sure both number of elements and subelement types match type.
245 if (arrayTy.getSize() != arrayAttr.size() + trailingZerosNum)
246 return emitError() << "constant array size should match type size";
247 return success();
248 }
249
parse(AsmParser & parser,Type type)250 Attribute ConstArrayAttr::parse(AsmParser &parser, Type type) {
251 mlir::FailureOr<Type> resultTy;
252 mlir::FailureOr<Attribute> resultVal;
253
254 // Parse literal '<'
255 if (parser.parseLess())
256 return {};
257
258 // Parse variable 'value'
259 resultVal = FieldParser<Attribute>::parse(parser);
260 if (failed(resultVal)) {
261 parser.emitError(
262 parser.getCurrentLocation(),
263 "failed to parse ConstArrayAttr parameter 'value' which is "
264 "to be a `Attribute`");
265 return {};
266 }
267
268 // ArrayAttrrs have per-element type, not the type of the array...
269 if (mlir::isa<ArrayAttr>(*resultVal)) {
270 // Array has implicit type: infer from const array type.
271 if (parser.parseOptionalColon().failed()) {
272 resultTy = type;
273 } else { // Array has explicit type: parse it.
274 resultTy = FieldParser<Type>::parse(parser);
275 if (failed(resultTy)) {
276 parser.emitError(
277 parser.getCurrentLocation(),
278 "failed to parse ConstArrayAttr parameter 'type' which is "
279 "to be a `::mlir::Type`");
280 return {};
281 }
282 }
283 } else {
284 auto ta = mlir::cast<TypedAttr>(*resultVal);
285 resultTy = ta.getType();
286 if (mlir::isa<mlir::NoneType>(*resultTy)) {
287 parser.emitError(parser.getCurrentLocation(),
288 "expected type declaration for string literal");
289 return {};
290 }
291 }
292
293 unsigned zeros = 0;
294 if (parser.parseOptionalComma().succeeded()) {
295 if (parser.parseOptionalKeyword("trailing_zeros").succeeded()) {
296 unsigned typeSize =
297 mlir::cast<cir::ArrayType>(resultTy.value()).getSize();
298 mlir::Attribute elts = resultVal.value();
299 if (auto str = mlir::dyn_cast<mlir::StringAttr>(elts))
300 zeros = typeSize - str.size();
301 else
302 zeros = typeSize - mlir::cast<mlir::ArrayAttr>(elts).size();
303 } else {
304 return {};
305 }
306 }
307
308 // Parse literal '>'
309 if (parser.parseGreater())
310 return {};
311
312 return parser.getChecked<ConstArrayAttr>(
313 parser.getCurrentLocation(), parser.getContext(), resultTy.value(),
314 resultVal.value(), zeros);
315 }
316
print(AsmPrinter & printer) const317 void ConstArrayAttr::print(AsmPrinter &printer) const {
318 printer << "<";
319 printer.printStrippedAttrOrType(getElts());
320 if (getTrailingZerosNum())
321 printer << ", trailing_zeros";
322 printer << ">";
323 }
324
325 //===----------------------------------------------------------------------===//
326 // CIR ConstVectorAttr
327 //===----------------------------------------------------------------------===//
328
329 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type type,ArrayAttr elts)330 cir::ConstVectorAttr::verify(function_ref<InFlightDiagnostic()> emitError,
331 Type type, ArrayAttr elts) {
332
333 if (!mlir::isa<cir::VectorType>(type))
334 return emitError() << "type of cir::ConstVectorAttr is not a "
335 "cir::VectorType: "
336 << type;
337
338 const auto vecType = mlir::cast<cir::VectorType>(type);
339
340 if (vecType.getSize() != elts.size())
341 return emitError()
342 << "number of constant elements should match vector size";
343
344 // Check if the types of the elements match
345 LogicalResult elementTypeCheck = success();
346 elts.walkImmediateSubElements(
347 [&](Attribute element) {
348 if (elementTypeCheck.failed()) {
349 // An earlier element didn't match
350 return;
351 }
352 auto typedElement = mlir::dyn_cast<TypedAttr>(element);
353 if (!typedElement ||
354 typedElement.getType() != vecType.getElementType()) {
355 elementTypeCheck = failure();
356 emitError() << "constant type should match vector element type";
357 }
358 },
359 [&](Type) {});
360
361 return elementTypeCheck;
362 }
363
364 //===----------------------------------------------------------------------===//
365 // CIR Dialect
366 //===----------------------------------------------------------------------===//
367
registerAttributes()368 void CIRDialect::registerAttributes() {
369 addAttributes<
370 #define GET_ATTRDEF_LIST
371 #include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
372 >();
373 }
374