xref: /freebsd/contrib/llvm-project/clang/lib/CIR/Dialect/IR/CIRDialect.cpp (revision 9c77fb6aaa366cbabc80ee1b834bcfe4df135491)
1 //===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the CIR dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "clang/CIR/Dialect/IR/CIRDialect.h"
14 
15 #include "clang/CIR/Dialect/IR/CIROpsEnums.h"
16 #include "clang/CIR/Dialect/IR/CIRTypes.h"
17 
18 #include "mlir/Interfaces/ControlFlowInterfaces.h"
19 #include "mlir/Interfaces/FunctionImplementation.h"
20 
21 #include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
22 #include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
23 #include "clang/CIR/MissingFeatures.h"
24 #include "llvm/Support/LogicalResult.h"
25 
26 #include <numeric>
27 
28 using namespace mlir;
29 using namespace cir;
30 
31 //===----------------------------------------------------------------------===//
32 // CIR Dialect
33 //===----------------------------------------------------------------------===//
34 namespace {
35 struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
36   using OpAsmDialectInterface::OpAsmDialectInterface;
37 
38   AliasResult getAlias(Type type, raw_ostream &os) const final {
39     if (auto recordType = dyn_cast<cir::RecordType>(type)) {
40       StringAttr nameAttr = recordType.getName();
41       if (!nameAttr)
42         os << "rec_anon_" << recordType.getKindAsStr();
43       else
44         os << "rec_" << nameAttr.getValue();
45       return AliasResult::OverridableAlias;
46     }
47     if (auto intType = dyn_cast<cir::IntType>(type)) {
48       // We only provide alias for standard integer types (i.e. integer types
49       // whose width is a power of 2 and at least 8).
50       unsigned width = intType.getWidth();
51       if (width < 8 || !llvm::isPowerOf2_32(width))
52         return AliasResult::NoAlias;
53       os << intType.getAlias();
54       return AliasResult::OverridableAlias;
55     }
56     if (auto voidType = dyn_cast<cir::VoidType>(type)) {
57       os << voidType.getAlias();
58       return AliasResult::OverridableAlias;
59     }
60 
61     return AliasResult::NoAlias;
62   }
63 
64   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
65     if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) {
66       os << (boolAttr.getValue() ? "true" : "false");
67       return AliasResult::FinalAlias;
68     }
69     if (auto bitfield = mlir::dyn_cast<cir::BitfieldInfoAttr>(attr)) {
70       os << "bfi_" << bitfield.getName().str();
71       return AliasResult::FinalAlias;
72     }
73     return AliasResult::NoAlias;
74   }
75 };
76 } // namespace
77 
78 void cir::CIRDialect::initialize() {
79   registerTypes();
80   registerAttributes();
81   addOperations<
82 #define GET_OP_LIST
83 #include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
84       >();
85   addInterfaces<CIROpAsmDialectInterface>();
86 }
87 
88 Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
89                                                 mlir::Attribute value,
90                                                 mlir::Type type,
91                                                 mlir::Location loc) {
92   return builder.create<cir::ConstantOp>(loc, type,
93                                          mlir::cast<mlir::TypedAttr>(value));
94 }
95 
96 //===----------------------------------------------------------------------===//
97 // Helpers
98 //===----------------------------------------------------------------------===//
99 
100 // Parses one of the keywords provided in the list `keywords` and returns the
101 // position of the parsed keyword in the list. If none of the keywords from the
102 // list is parsed, returns -1.
103 static int parseOptionalKeywordAlternative(AsmParser &parser,
104                                            ArrayRef<llvm::StringRef> keywords) {
105   for (auto en : llvm::enumerate(keywords)) {
106     if (succeeded(parser.parseOptionalKeyword(en.value())))
107       return en.index();
108   }
109   return -1;
110 }
111 
112 namespace {
113 template <typename Ty> struct EnumTraits {};
114 
115 #define REGISTER_ENUM_TYPE(Ty)                                                 \
116   template <> struct EnumTraits<cir::Ty> {                                     \
117     static llvm::StringRef stringify(cir::Ty value) {                          \
118       return stringify##Ty(value);                                             \
119     }                                                                          \
120     static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); }    \
121   }
122 
123 REGISTER_ENUM_TYPE(GlobalLinkageKind);
124 REGISTER_ENUM_TYPE(VisibilityKind);
125 REGISTER_ENUM_TYPE(SideEffect);
126 } // namespace
127 
128 /// Parse an enum from the keyword, or default to the provided default value.
129 /// The return type is the enum type by default, unless overriden with the
130 /// second template argument.
131 template <typename EnumTy, typename RetTy = EnumTy>
132 static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
133   llvm::SmallVector<llvm::StringRef, 10> names;
134   for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
135     names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
136 
137   int index = parseOptionalKeywordAlternative(parser, names);
138   if (index == -1)
139     return static_cast<RetTy>(defaultValue);
140   return static_cast<RetTy>(index);
141 }
142 
143 /// Parse an enum from the keyword, return failure if the keyword is not found.
144 template <typename EnumTy, typename RetTy = EnumTy>
145 static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
146   llvm::SmallVector<llvm::StringRef, 10> names;
147   for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
148     names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
149 
150   int index = parseOptionalKeywordAlternative(parser, names);
151   if (index == -1)
152     return failure();
153   result = static_cast<RetTy>(index);
154   return success();
155 }
156 
157 // Check if a region's termination omission is valid and, if so, creates and
158 // inserts the omitted terminator into the region.
159 static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
160                                       SMLoc errLoc) {
161   Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
162   OpBuilder builder(parser.getBuilder().getContext());
163 
164   // Insert empty block in case the region is empty to ensure the terminator
165   // will be inserted
166   if (region.empty())
167     builder.createBlock(&region);
168 
169   Block &block = region.back();
170   // Region is properly terminated: nothing to do.
171   if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
172     return success();
173 
174   // Check for invalid terminator omissions.
175   if (!region.hasOneBlock())
176     return parser.emitError(errLoc,
177                             "multi-block region must not omit terminator");
178 
179   // Terminator was omitted correctly: recreate it.
180   builder.setInsertionPointToEnd(&block);
181   builder.create<cir::YieldOp>(eLoc);
182   return success();
183 }
184 
185 // True if the region's terminator should be omitted.
186 static bool omitRegionTerm(mlir::Region &r) {
187   const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty();
188   const auto yieldsNothing = [&r]() {
189     auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator());
190     return y && y.getArgs().empty();
191   };
192   return singleNonEmptyBlock && yieldsNothing();
193 }
194 
195 void printVisibilityAttr(OpAsmPrinter &printer,
196                          cir::VisibilityAttr &visibility) {
197   switch (visibility.getValue()) {
198   case cir::VisibilityKind::Hidden:
199     printer << "hidden";
200     break;
201   case cir::VisibilityKind::Protected:
202     printer << "protected";
203     break;
204   case cir::VisibilityKind::Default:
205     break;
206   }
207 }
208 
209 void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility) {
210   cir::VisibilityKind visibilityKind =
211       parseOptionalCIRKeyword(parser, cir::VisibilityKind::Default);
212   visibility = cir::VisibilityAttr::get(parser.getContext(), visibilityKind);
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // CIR Custom Parsers/Printers
217 //===----------------------------------------------------------------------===//
218 
219 static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser,
220                                                       mlir::Region &region) {
221   auto regionLoc = parser.getCurrentLocation();
222   if (parser.parseRegion(region))
223     return failure();
224   if (ensureRegionTerm(parser, region, regionLoc).failed())
225     return failure();
226   return success();
227 }
228 
229 static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer,
230                                          cir::ScopeOp &op,
231                                          mlir::Region &region) {
232   printer.printRegion(region,
233                       /*printEntryBlockArgs=*/false,
234                       /*printBlockTerminators=*/!omitRegionTerm(region));
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // AllocaOp
239 //===----------------------------------------------------------------------===//
240 
241 void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
242                           mlir::OperationState &odsState, mlir::Type addr,
243                           mlir::Type allocaType, llvm::StringRef name,
244                           mlir::IntegerAttr alignment) {
245   odsState.addAttribute(getAllocaTypeAttrName(odsState.name),
246                         mlir::TypeAttr::get(allocaType));
247   odsState.addAttribute(getNameAttrName(odsState.name),
248                         odsBuilder.getStringAttr(name));
249   if (alignment) {
250     odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment);
251   }
252   odsState.addTypes(addr);
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // BreakOp
257 //===----------------------------------------------------------------------===//
258 
259 LogicalResult cir::BreakOp::verify() {
260   assert(!cir::MissingFeatures::switchOp());
261   if (!getOperation()->getParentOfType<LoopOpInterface>() &&
262       !getOperation()->getParentOfType<SwitchOp>())
263     return emitOpError("must be within a loop");
264   return success();
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // ConditionOp
269 //===----------------------------------------------------------------------===//
270 
271 //===----------------------------------
272 // BranchOpTerminatorInterface Methods
273 //===----------------------------------
274 
275 void cir::ConditionOp::getSuccessorRegions(
276     ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
277   // TODO(cir): The condition value may be folded to a constant, narrowing
278   // down its list of possible successors.
279 
280   // Parent is a loop: condition may branch to the body or to the parent op.
281   if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
282     regions.emplace_back(&loopOp.getBody(), loopOp.getBody().getArguments());
283     regions.emplace_back(loopOp->getResults());
284   }
285 
286   assert(!cir::MissingFeatures::awaitOp());
287 }
288 
289 MutableOperandRange
290 cir::ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
291   // No values are yielded to the successor region.
292   return MutableOperandRange(getOperation(), 0, 0);
293 }
294 
295 LogicalResult cir::ConditionOp::verify() {
296   assert(!cir::MissingFeatures::awaitOp());
297   if (!isa<LoopOpInterface>(getOperation()->getParentOp()))
298     return emitOpError("condition must be within a conditional region");
299   return success();
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // ConstantOp
304 //===----------------------------------------------------------------------===//
305 
306 static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
307                                         mlir::Attribute attrType) {
308   if (isa<cir::ConstPtrAttr>(attrType)) {
309     if (!mlir::isa<cir::PointerType>(opType))
310       return op->emitOpError(
311           "pointer constant initializing a non-pointer type");
312     return success();
313   }
314 
315   if (isa<cir::ZeroAttr>(attrType)) {
316     if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
317             opType))
318       return success();
319     return op->emitOpError(
320         "zero expects struct, array, vector, or complex type");
321   }
322 
323   if (mlir::isa<cir::BoolAttr>(attrType)) {
324     if (!mlir::isa<cir::BoolType>(opType))
325       return op->emitOpError("result type (")
326              << opType << ") must be '!cir.bool' for '" << attrType << "'";
327     return success();
328   }
329 
330   if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
331     auto at = cast<TypedAttr>(attrType);
332     if (at.getType() != opType) {
333       return op->emitOpError("result type (")
334              << opType << ") does not match value type (" << at.getType()
335              << ")";
336     }
337     return success();
338   }
339 
340   if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
341                 cir::ConstComplexAttr>(attrType))
342     return success();
343 
344   assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
345   return op->emitOpError("global with type ")
346          << cast<TypedAttr>(attrType).getType() << " not yet supported";
347 }
348 
349 LogicalResult cir::ConstantOp::verify() {
350   // ODS already generates checks to make sure the result type is valid. We just
351   // need to additionally check that the value's attribute type is consistent
352   // with the result type.
353   return checkConstantTypes(getOperation(), getType(), getValue());
354 }
355 
356 OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
357   return getValue();
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // ContinueOp
362 //===----------------------------------------------------------------------===//
363 
364 LogicalResult cir::ContinueOp::verify() {
365   if (!getOperation()->getParentOfType<LoopOpInterface>())
366     return emitOpError("must be within a loop");
367   return success();
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // CastOp
372 //===----------------------------------------------------------------------===//
373 
374 LogicalResult cir::CastOp::verify() {
375   mlir::Type resType = getType();
376   mlir::Type srcType = getSrc().getType();
377 
378   if (mlir::isa<cir::VectorType>(srcType) &&
379       mlir::isa<cir::VectorType>(resType)) {
380     // Use the element type of the vector to verify the cast kind. (Except for
381     // bitcast, see below.)
382     srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType();
383     resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType();
384   }
385 
386   switch (getKind()) {
387   case cir::CastKind::int_to_bool: {
388     if (!mlir::isa<cir::BoolType>(resType))
389       return emitOpError() << "requires !cir.bool type for result";
390     if (!mlir::isa<cir::IntType>(srcType))
391       return emitOpError() << "requires !cir.int type for source";
392     return success();
393   }
394   case cir::CastKind::ptr_to_bool: {
395     if (!mlir::isa<cir::BoolType>(resType))
396       return emitOpError() << "requires !cir.bool type for result";
397     if (!mlir::isa<cir::PointerType>(srcType))
398       return emitOpError() << "requires !cir.ptr type for source";
399     return success();
400   }
401   case cir::CastKind::integral: {
402     if (!mlir::isa<cir::IntType>(resType))
403       return emitOpError() << "requires !cir.int type for result";
404     if (!mlir::isa<cir::IntType>(srcType))
405       return emitOpError() << "requires !cir.int type for source";
406     return success();
407   }
408   case cir::CastKind::array_to_ptrdecay: {
409     const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
410     const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
411     if (!arrayPtrTy || !flatPtrTy)
412       return emitOpError() << "requires !cir.ptr type for source and result";
413 
414     // TODO(CIR): Make sure the AddrSpace of both types are equals
415     return success();
416   }
417   case cir::CastKind::bitcast: {
418     // Handle the pointer types first.
419     auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
420     auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
421 
422     if (srcPtrTy && resPtrTy) {
423       return success();
424     }
425 
426     return success();
427   }
428   case cir::CastKind::floating: {
429     if (!mlir::isa<cir::FPTypeInterface>(srcType) ||
430         !mlir::isa<cir::FPTypeInterface>(resType))
431       return emitOpError() << "requires !cir.float type for source and result";
432     return success();
433   }
434   case cir::CastKind::float_to_int: {
435     if (!mlir::isa<cir::FPTypeInterface>(srcType))
436       return emitOpError() << "requires !cir.float type for source";
437     if (!mlir::dyn_cast<cir::IntType>(resType))
438       return emitOpError() << "requires !cir.int type for result";
439     return success();
440   }
441   case cir::CastKind::int_to_ptr: {
442     if (!mlir::dyn_cast<cir::IntType>(srcType))
443       return emitOpError() << "requires !cir.int type for source";
444     if (!mlir::dyn_cast<cir::PointerType>(resType))
445       return emitOpError() << "requires !cir.ptr type for result";
446     return success();
447   }
448   case cir::CastKind::ptr_to_int: {
449     if (!mlir::dyn_cast<cir::PointerType>(srcType))
450       return emitOpError() << "requires !cir.ptr type for source";
451     if (!mlir::dyn_cast<cir::IntType>(resType))
452       return emitOpError() << "requires !cir.int type for result";
453     return success();
454   }
455   case cir::CastKind::float_to_bool: {
456     if (!mlir::isa<cir::FPTypeInterface>(srcType))
457       return emitOpError() << "requires !cir.float type for source";
458     if (!mlir::isa<cir::BoolType>(resType))
459       return emitOpError() << "requires !cir.bool type for result";
460     return success();
461   }
462   case cir::CastKind::bool_to_int: {
463     if (!mlir::isa<cir::BoolType>(srcType))
464       return emitOpError() << "requires !cir.bool type for source";
465     if (!mlir::isa<cir::IntType>(resType))
466       return emitOpError() << "requires !cir.int type for result";
467     return success();
468   }
469   case cir::CastKind::int_to_float: {
470     if (!mlir::isa<cir::IntType>(srcType))
471       return emitOpError() << "requires !cir.int type for source";
472     if (!mlir::isa<cir::FPTypeInterface>(resType))
473       return emitOpError() << "requires !cir.float type for result";
474     return success();
475   }
476   case cir::CastKind::bool_to_float: {
477     if (!mlir::isa<cir::BoolType>(srcType))
478       return emitOpError() << "requires !cir.bool type for source";
479     if (!mlir::isa<cir::FPTypeInterface>(resType))
480       return emitOpError() << "requires !cir.float type for result";
481     return success();
482   }
483   case cir::CastKind::address_space: {
484     auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
485     auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
486     if (!srcPtrTy || !resPtrTy)
487       return emitOpError() << "requires !cir.ptr type for source and result";
488     if (srcPtrTy.getPointee() != resPtrTy.getPointee())
489       return emitOpError() << "requires two types differ in addrspace only";
490     return success();
491   }
492   default:
493     llvm_unreachable("Unknown CastOp kind?");
494   }
495 }
496 
497 static bool isIntOrBoolCast(cir::CastOp op) {
498   auto kind = op.getKind();
499   return kind == cir::CastKind::bool_to_int ||
500          kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral;
501 }
502 
503 static Value tryFoldCastChain(cir::CastOp op) {
504   cir::CastOp head = op, tail = op;
505 
506   while (op) {
507     if (!isIntOrBoolCast(op))
508       break;
509     head = op;
510     op = dyn_cast_or_null<cir::CastOp>(head.getSrc().getDefiningOp());
511   }
512 
513   if (head == tail)
514     return {};
515 
516   // if bool_to_int -> ...  -> int_to_bool: take the bool
517   // as we had it was before all casts
518   if (head.getKind() == cir::CastKind::bool_to_int &&
519       tail.getKind() == cir::CastKind::int_to_bool)
520     return head.getSrc();
521 
522   // if int_to_bool -> ...  -> int_to_bool: take the result
523   // of the first one, as no other casts (and ext casts as well)
524   // don't change the first result
525   if (head.getKind() == cir::CastKind::int_to_bool &&
526       tail.getKind() == cir::CastKind::int_to_bool)
527     return head.getResult();
528 
529   return {};
530 }
531 
532 OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
533   if (getSrc().getType() == getType()) {
534     switch (getKind()) {
535     case cir::CastKind::integral: {
536       // TODO: for sign differences, it's possible in certain conditions to
537       // create a new attribute that's capable of representing the source.
538       llvm::SmallVector<mlir::OpFoldResult, 1> foldResults;
539       auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
540       if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0]))
541         return mlir::cast<mlir::Attribute>(foldResults[0]);
542       return {};
543     }
544     case cir::CastKind::bitcast:
545     case cir::CastKind::address_space:
546     case cir::CastKind::float_complex:
547     case cir::CastKind::int_complex: {
548       return getSrc();
549     }
550     default:
551       return {};
552     }
553   }
554   return tryFoldCastChain(*this);
555 }
556 
557 //===----------------------------------------------------------------------===//
558 // CallOp
559 //===----------------------------------------------------------------------===//
560 
561 mlir::OperandRange cir::CallOp::getArgOperands() {
562   if (isIndirect())
563     return getArgs().drop_front(1);
564   return getArgs();
565 }
566 
567 mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
568   mlir::MutableOperandRange args = getArgsMutable();
569   if (isIndirect())
570     return args.slice(1, args.size() - 1);
571   return args;
572 }
573 
574 mlir::Value cir::CallOp::getIndirectCall() {
575   assert(isIndirect());
576   return getOperand(0);
577 }
578 
579 /// Return the operand at index 'i'.
580 Value cir::CallOp::getArgOperand(unsigned i) {
581   if (isIndirect())
582     ++i;
583   return getOperand(i);
584 }
585 
586 /// Return the number of operands.
587 unsigned cir::CallOp::getNumArgOperands() {
588   if (isIndirect())
589     return this->getOperation()->getNumOperands() - 1;
590   return this->getOperation()->getNumOperands();
591 }
592 
593 static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
594                                          mlir::OperationState &result) {
595   llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
596   llvm::SMLoc opsLoc;
597   mlir::FlatSymbolRefAttr calleeAttr;
598   llvm::ArrayRef<mlir::Type> allResultTypes;
599 
600   // If we cannot parse a string callee, it means this is an indirect call.
601   if (!parser
602            .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
603                                    result.attributes)
604            .has_value()) {
605     OpAsmParser::UnresolvedOperand indirectVal;
606     // Do not resolve right now, since we need to figure out the type
607     if (parser.parseOperand(indirectVal).failed())
608       return failure();
609     ops.push_back(indirectVal);
610   }
611 
612   if (parser.parseLParen())
613     return mlir::failure();
614 
615   opsLoc = parser.getCurrentLocation();
616   if (parser.parseOperandList(ops))
617     return mlir::failure();
618   if (parser.parseRParen())
619     return mlir::failure();
620 
621   if (parser.parseOptionalKeyword("nothrow").succeeded())
622     result.addAttribute(CIRDialect::getNoThrowAttrName(),
623                         mlir::UnitAttr::get(parser.getContext()));
624 
625   if (parser.parseOptionalKeyword("side_effect").succeeded()) {
626     if (parser.parseLParen().failed())
627       return failure();
628     cir::SideEffect sideEffect;
629     if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed())
630       return failure();
631     if (parser.parseRParen().failed())
632       return failure();
633     auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
634     result.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
635   }
636 
637   if (parser.parseOptionalAttrDict(result.attributes))
638     return ::mlir::failure();
639 
640   if (parser.parseColon())
641     return ::mlir::failure();
642 
643   mlir::FunctionType opsFnTy;
644   if (parser.parseType(opsFnTy))
645     return mlir::failure();
646 
647   allResultTypes = opsFnTy.getResults();
648   result.addTypes(allResultTypes);
649 
650   if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
651     return mlir::failure();
652 
653   return mlir::success();
654 }
655 
656 static void printCallCommon(mlir::Operation *op,
657                             mlir::FlatSymbolRefAttr calleeSym,
658                             mlir::Value indirectCallee,
659                             mlir::OpAsmPrinter &printer, bool isNothrow,
660                             cir::SideEffect sideEffect) {
661   printer << ' ';
662 
663   auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
664   auto ops = callLikeOp.getArgOperands();
665 
666   if (calleeSym) {
667     // Direct calls
668     printer.printAttributeWithoutType(calleeSym);
669   } else {
670     // Indirect calls
671     assert(indirectCallee);
672     printer << indirectCallee;
673   }
674   printer << "(" << ops << ")";
675 
676   if (isNothrow)
677     printer << " nothrow";
678 
679   if (sideEffect != cir::SideEffect::All) {
680     printer << " side_effect(";
681     printer << stringifySideEffect(sideEffect);
682     printer << ")";
683   }
684 
685   printer.printOptionalAttrDict(op->getAttrs(),
686                                 {CIRDialect::getCalleeAttrName(),
687                                  CIRDialect::getNoThrowAttrName(),
688                                  CIRDialect::getSideEffectAttrName()});
689 
690   printer << " : ";
691   printer.printFunctionalType(op->getOperands().getTypes(),
692                               op->getResultTypes());
693 }
694 
695 mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
696                                      mlir::OperationState &result) {
697   return parseCallCommon(parser, result);
698 }
699 
700 void cir::CallOp::print(mlir::OpAsmPrinter &p) {
701   mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
702   cir::SideEffect sideEffect = getSideEffect();
703   printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
704                   sideEffect);
705 }
706 
707 static LogicalResult
708 verifyCallCommInSymbolUses(mlir::Operation *op,
709                            SymbolTableCollection &symbolTable) {
710   auto fnAttr =
711       op->getAttrOfType<FlatSymbolRefAttr>(CIRDialect::getCalleeAttrName());
712   if (!fnAttr) {
713     // This is an indirect call, thus we don't have to check the symbol uses.
714     return mlir::success();
715   }
716 
717   auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
718   if (!fn)
719     return op->emitOpError() << "'" << fnAttr.getValue()
720                              << "' does not reference a valid function";
721 
722   auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
723   assert(callIf && "expected CIR call interface to be always available");
724 
725   // Verify that the operand and result types match the callee. Note that
726   // argument-checking is disabled for functions without a prototype.
727   auto fnType = fn.getFunctionType();
728   if (!fn.getNoProto()) {
729     unsigned numCallOperands = callIf.getNumArgOperands();
730     unsigned numFnOpOperands = fnType.getNumInputs();
731 
732     if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
733       return op->emitOpError("incorrect number of operands for callee");
734     if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
735       return op->emitOpError("too few operands for callee");
736 
737     for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
738       if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
739         return op->emitOpError("operand type mismatch: expected operand type ")
740                << fnType.getInput(i) << ", but provided "
741                << op->getOperand(i).getType() << " for operand number " << i;
742   }
743 
744   assert(!cir::MissingFeatures::opCallCallConv());
745 
746   // Void function must not return any results.
747   if (fnType.hasVoidReturn() && op->getNumResults() != 0)
748     return op->emitOpError("callee returns void but call has results");
749 
750   // Non-void function calls must return exactly one result.
751   if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
752     return op->emitOpError("incorrect number of results for callee");
753 
754   // Parent function and return value types must match.
755   if (!fnType.hasVoidReturn() &&
756       op->getResultTypes().front() != fnType.getReturnType()) {
757     return op->emitOpError("result type mismatch: expected ")
758            << fnType.getReturnType() << ", but provided "
759            << op->getResult(0).getType();
760   }
761 
762   return mlir::success();
763 }
764 
765 LogicalResult
766 cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
767   return verifyCallCommInSymbolUses(*this, symbolTable);
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // ReturnOp
772 //===----------------------------------------------------------------------===//
773 
774 static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op,
775                                                   cir::FuncOp function) {
776   // ReturnOps currently only have a single optional operand.
777   if (op.getNumOperands() > 1)
778     return op.emitOpError() << "expects at most 1 return operand";
779 
780   // Ensure returned type matches the function signature.
781   auto expectedTy = function.getFunctionType().getReturnType();
782   auto actualTy =
783       (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext())
784                                 : op.getOperand(0).getType());
785   if (actualTy != expectedTy)
786     return op.emitOpError() << "returns " << actualTy
787                             << " but enclosing function returns " << expectedTy;
788 
789   return mlir::success();
790 }
791 
792 mlir::LogicalResult cir::ReturnOp::verify() {
793   // Returns can be present in multiple different scopes, get the
794   // wrapping function and start from there.
795   auto *fnOp = getOperation()->getParentOp();
796   while (!isa<cir::FuncOp>(fnOp))
797     fnOp = fnOp->getParentOp();
798 
799   // Make sure return types match function return type.
800   if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed())
801     return failure();
802 
803   return success();
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // IfOp
808 //===----------------------------------------------------------------------===//
809 
810 ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
811   // create the regions for 'then'.
812   result.regions.reserve(2);
813   Region *thenRegion = result.addRegion();
814   Region *elseRegion = result.addRegion();
815 
816   mlir::Builder &builder = parser.getBuilder();
817   OpAsmParser::UnresolvedOperand cond;
818   Type boolType = cir::BoolType::get(builder.getContext());
819 
820   if (parser.parseOperand(cond) ||
821       parser.resolveOperand(cond, boolType, result.operands))
822     return failure();
823 
824   // Parse 'then' region.
825   mlir::SMLoc parseThenLoc = parser.getCurrentLocation();
826   if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
827     return failure();
828 
829   if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
830     return failure();
831 
832   // If we find an 'else' keyword, parse the 'else' region.
833   if (!parser.parseOptionalKeyword("else")) {
834     mlir::SMLoc parseElseLoc = parser.getCurrentLocation();
835     if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
836       return failure();
837     if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
838       return failure();
839   }
840 
841   // Parse the optional attribute list.
842   if (parser.parseOptionalAttrDict(result.attributes))
843     return failure();
844   return success();
845 }
846 
847 void cir::IfOp::print(OpAsmPrinter &p) {
848   p << " " << getCondition() << " ";
849   mlir::Region &thenRegion = this->getThenRegion();
850   p.printRegion(thenRegion,
851                 /*printEntryBlockArgs=*/false,
852                 /*printBlockTerminators=*/!omitRegionTerm(thenRegion));
853 
854   // Print the 'else' regions if it exists and has a block.
855   mlir::Region &elseRegion = this->getElseRegion();
856   if (!elseRegion.empty()) {
857     p << " else ";
858     p.printRegion(elseRegion,
859                   /*printEntryBlockArgs=*/false,
860                   /*printBlockTerminators=*/!omitRegionTerm(elseRegion));
861   }
862 
863   p.printOptionalAttrDict(getOperation()->getAttrs());
864 }
865 
866 /// Default callback for IfOp builders.
867 void cir::buildTerminatedBody(OpBuilder &builder, Location loc) {
868   // add cir.yield to end of the block
869   builder.create<cir::YieldOp>(loc);
870 }
871 
872 /// Given the region at `index`, or the parent operation if `index` is None,
873 /// return the successor regions. These are the regions that may be selected
874 /// during the flow of control. `operands` is a set of optional attributes that
875 /// correspond to a constant value for each operand, or null if that operand is
876 /// not a constant.
877 void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point,
878                                     SmallVectorImpl<RegionSuccessor> &regions) {
879   // The `then` and the `else` region branch back to the parent operation.
880   if (!point.isParent()) {
881     regions.push_back(RegionSuccessor());
882     return;
883   }
884 
885   // Don't consider the else region if it is empty.
886   Region *elseRegion = &this->getElseRegion();
887   if (elseRegion->empty())
888     elseRegion = nullptr;
889 
890   // If the condition isn't constant, both regions may be executed.
891   regions.push_back(RegionSuccessor(&getThenRegion()));
892   // If the else region does not exist, it is not a viable successor.
893   if (elseRegion)
894     regions.push_back(RegionSuccessor(elseRegion));
895 
896   return;
897 }
898 
899 void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
900                       bool withElseRegion, BuilderCallbackRef thenBuilder,
901                       BuilderCallbackRef elseBuilder) {
902   assert(thenBuilder && "the builder callback for 'then' must be present");
903   result.addOperands(cond);
904 
905   OpBuilder::InsertionGuard guard(builder);
906   Region *thenRegion = result.addRegion();
907   builder.createBlock(thenRegion);
908   thenBuilder(builder, result.location);
909 
910   Region *elseRegion = result.addRegion();
911   if (!withElseRegion)
912     return;
913 
914   builder.createBlock(elseRegion);
915   elseBuilder(builder, result.location);
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // ScopeOp
920 //===----------------------------------------------------------------------===//
921 
922 /// Given the region at `index`, or the parent operation if `index` is None,
923 /// return the successor regions. These are the regions that may be selected
924 /// during the flow of control. `operands` is a set of optional attributes
925 /// that correspond to a constant value for each operand, or null if that
926 /// operand is not a constant.
927 void cir::ScopeOp::getSuccessorRegions(
928     mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
929   // The only region always branch back to the parent operation.
930   if (!point.isParent()) {
931     regions.push_back(RegionSuccessor(getODSResults(0)));
932     return;
933   }
934 
935   // If the condition isn't constant, both regions may be executed.
936   regions.push_back(RegionSuccessor(&getScopeRegion()));
937 }
938 
939 void cir::ScopeOp::build(
940     OpBuilder &builder, OperationState &result,
941     function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) {
942   assert(scopeBuilder && "the builder callback for 'then' must be present");
943 
944   OpBuilder::InsertionGuard guard(builder);
945   Region *scopeRegion = result.addRegion();
946   builder.createBlock(scopeRegion);
947   assert(!cir::MissingFeatures::opScopeCleanupRegion());
948 
949   mlir::Type yieldTy;
950   scopeBuilder(builder, yieldTy, result.location);
951 
952   if (yieldTy)
953     result.addTypes(TypeRange{yieldTy});
954 }
955 
956 void cir::ScopeOp::build(
957     OpBuilder &builder, OperationState &result,
958     function_ref<void(OpBuilder &, Location)> scopeBuilder) {
959   assert(scopeBuilder && "the builder callback for 'then' must be present");
960   OpBuilder::InsertionGuard guard(builder);
961   Region *scopeRegion = result.addRegion();
962   builder.createBlock(scopeRegion);
963   assert(!cir::MissingFeatures::opScopeCleanupRegion());
964   scopeBuilder(builder, result.location);
965 }
966 
967 LogicalResult cir::ScopeOp::verify() {
968   if (getRegion().empty()) {
969     return emitOpError() << "cir.scope must not be empty since it should "
970                             "include at least an implicit cir.yield ";
971   }
972 
973   mlir::Block &lastBlock = getRegion().back();
974   if (lastBlock.empty() || !lastBlock.mightHaveTerminator() ||
975       !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>())
976     return emitOpError() << "last block of cir.scope must be terminated";
977   return success();
978 }
979 
980 //===----------------------------------------------------------------------===//
981 // BrOp
982 //===----------------------------------------------------------------------===//
983 
984 mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) {
985   assert(index == 0 && "invalid successor index");
986   return mlir::SuccessorOperands(getDestOperandsMutable());
987 }
988 
989 Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) {
990   return getDest();
991 }
992 
993 //===----------------------------------------------------------------------===//
994 // BrCondOp
995 //===----------------------------------------------------------------------===//
996 
997 mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) {
998   assert(index < getNumSuccessors() && "invalid successor index");
999   return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable()
1000                                       : getDestOperandsFalseMutable());
1001 }
1002 
1003 Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1004   if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front()))
1005     return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse();
1006   return nullptr;
1007 }
1008 
1009 //===----------------------------------------------------------------------===//
1010 // CaseOp
1011 //===----------------------------------------------------------------------===//
1012 
1013 void cir::CaseOp::getSuccessorRegions(
1014     mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1015   if (!point.isParent()) {
1016     regions.push_back(RegionSuccessor());
1017     return;
1018   }
1019   regions.push_back(RegionSuccessor(&getCaseRegion()));
1020 }
1021 
1022 void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
1023                         ArrayAttr value, CaseOpKind kind,
1024                         OpBuilder::InsertPoint &insertPoint) {
1025   OpBuilder::InsertionGuard guardSwitch(builder);
1026   result.addAttribute("value", value);
1027   result.getOrAddProperties<Properties>().kind =
1028       cir::CaseOpKindAttr::get(builder.getContext(), kind);
1029   Region *caseRegion = result.addRegion();
1030   builder.createBlock(caseRegion);
1031 
1032   insertPoint = builder.saveInsertionPoint();
1033 }
1034 
1035 //===----------------------------------------------------------------------===//
1036 // SwitchOp
1037 //===----------------------------------------------------------------------===//
1038 
1039 static ParseResult parseSwitchOp(OpAsmParser &parser, mlir::Region &regions,
1040                                  mlir::OpAsmParser::UnresolvedOperand &cond,
1041                                  mlir::Type &condType) {
1042   cir::IntType intCondType;
1043 
1044   if (parser.parseLParen())
1045     return mlir::failure();
1046 
1047   if (parser.parseOperand(cond))
1048     return mlir::failure();
1049   if (parser.parseColon())
1050     return mlir::failure();
1051   if (parser.parseCustomTypeWithFallback(intCondType))
1052     return mlir::failure();
1053   condType = intCondType;
1054 
1055   if (parser.parseRParen())
1056     return mlir::failure();
1057   if (parser.parseRegion(regions, /*arguments=*/{}, /*argTypes=*/{}))
1058     return failure();
1059 
1060   return mlir::success();
1061 }
1062 
1063 static void printSwitchOp(OpAsmPrinter &p, cir::SwitchOp op,
1064                           mlir::Region &bodyRegion, mlir::Value condition,
1065                           mlir::Type condType) {
1066   p << "(";
1067   p << condition;
1068   p << " : ";
1069   p.printStrippedAttrOrType(condType);
1070   p << ")";
1071 
1072   p << ' ';
1073   p.printRegion(bodyRegion, /*printEntryBlockArgs=*/false,
1074                 /*printBlockTerminators=*/true);
1075 }
1076 
1077 void cir::SwitchOp::getSuccessorRegions(
1078     mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
1079   if (!point.isParent()) {
1080     region.push_back(RegionSuccessor());
1081     return;
1082   }
1083 
1084   region.push_back(RegionSuccessor(&getBody()));
1085 }
1086 
1087 void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
1088                           Value cond, BuilderOpStateCallbackRef switchBuilder) {
1089   assert(switchBuilder && "the builder callback for regions must be present");
1090   OpBuilder::InsertionGuard guardSwitch(builder);
1091   Region *switchRegion = result.addRegion();
1092   builder.createBlock(switchRegion);
1093   result.addOperands({cond});
1094   switchBuilder(builder, result.location, result);
1095 }
1096 
1097 void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
1098   walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
1099     // Don't walk in nested switch op.
1100     if (isa<cir::SwitchOp>(op) && op != *this)
1101       return WalkResult::skip();
1102 
1103     if (auto caseOp = dyn_cast<cir::CaseOp>(op))
1104       cases.push_back(caseOp);
1105 
1106     return WalkResult::advance();
1107   });
1108 }
1109 
1110 bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
1111   collectCases(cases);
1112 
1113   if (getBody().empty())
1114     return false;
1115 
1116   if (!isa<YieldOp>(getBody().front().back()))
1117     return false;
1118 
1119   if (!llvm::all_of(getBody().front(),
1120                     [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
1121     return false;
1122 
1123   return llvm::all_of(cases, [this](CaseOp op) {
1124     return op->getParentOfType<SwitchOp>() == *this;
1125   });
1126 }
1127 
1128 //===----------------------------------------------------------------------===//
1129 // SwitchFlatOp
1130 //===----------------------------------------------------------------------===//
1131 
1132 void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
1133                               Value value, Block *defaultDestination,
1134                               ValueRange defaultOperands,
1135                               ArrayRef<APInt> caseValues,
1136                               BlockRange caseDestinations,
1137                               ArrayRef<ValueRange> caseOperands) {
1138 
1139   std::vector<mlir::Attribute> caseValuesAttrs;
1140   for (const APInt &val : caseValues)
1141     caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
1142   mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
1143 
1144   build(builder, result, value, defaultOperands, caseOperands, attrs,
1145         defaultDestination, caseDestinations);
1146 }
1147 
1148 /// <cases> ::= `[` (case (`,` case )* )? `]`
1149 /// <case>  ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
1150 static ParseResult parseSwitchFlatOpCases(
1151     OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
1152     SmallVectorImpl<Block *> &caseDestinations,
1153     SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>>
1154         &caseOperands,
1155     SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
1156   if (failed(parser.parseLSquare()))
1157     return failure();
1158   if (succeeded(parser.parseOptionalRSquare()))
1159     return success();
1160   llvm::SmallVector<mlir::Attribute> values;
1161 
1162   auto parseCase = [&]() {
1163     int64_t value = 0;
1164     if (failed(parser.parseInteger(value)))
1165       return failure();
1166 
1167     values.push_back(cir::IntAttr::get(flagType, value));
1168 
1169     Block *destination;
1170     llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands;
1171     llvm::SmallVector<Type> operandTypes;
1172     if (parser.parseColon() || parser.parseSuccessor(destination))
1173       return failure();
1174     if (!parser.parseOptionalLParen()) {
1175       if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
1176                                   /*allowResultNumber=*/false) ||
1177           parser.parseColonTypeList(operandTypes) || parser.parseRParen())
1178         return failure();
1179     }
1180     caseDestinations.push_back(destination);
1181     caseOperands.emplace_back(operands);
1182     caseOperandTypes.emplace_back(operandTypes);
1183     return success();
1184   };
1185   if (failed(parser.parseCommaSeparatedList(parseCase)))
1186     return failure();
1187 
1188   caseValues = ArrayAttr::get(flagType.getContext(), values);
1189 
1190   return parser.parseRSquare();
1191 }
1192 
1193 static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1194                                    Type flagType, mlir::ArrayAttr caseValues,
1195                                    SuccessorRange caseDestinations,
1196                                    OperandRangeRange caseOperands,
1197                                    const TypeRangeRange &caseOperandTypes) {
1198   p << '[';
1199   p.printNewline();
1200   if (!caseValues) {
1201     p << ']';
1202     return;
1203   }
1204 
1205   size_t index = 0;
1206   llvm::interleave(
1207       llvm::zip(caseValues, caseDestinations),
1208       [&](auto i) {
1209         p << "  ";
1210         mlir::Attribute a = std::get<0>(i);
1211         p << mlir::cast<cir::IntAttr>(a).getValue();
1212         p << ": ";
1213         p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
1214       },
1215       [&] {
1216         p << ',';
1217         p.printNewline();
1218       });
1219   p.printNewline();
1220   p << ']';
1221 }
1222 
1223 //===----------------------------------------------------------------------===//
1224 // GlobalOp
1225 //===----------------------------------------------------------------------===//
1226 
1227 static ParseResult parseConstantValue(OpAsmParser &parser,
1228                                       mlir::Attribute &valueAttr) {
1229   NamedAttrList attr;
1230   return parser.parseAttribute(valueAttr, "value", attr);
1231 }
1232 
1233 static void printConstant(OpAsmPrinter &p, Attribute value) {
1234   p.printAttribute(value);
1235 }
1236 
1237 mlir::LogicalResult cir::GlobalOp::verify() {
1238   // Verify that the initial value, if present, is either a unit attribute or
1239   // an attribute CIR supports.
1240   if (getInitialValue().has_value()) {
1241     if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
1242             .failed())
1243       return failure();
1244   }
1245 
1246   // TODO(CIR): Many other checks for properties that haven't been upstreamed
1247   // yet.
1248 
1249   return success();
1250 }
1251 
1252 void cir::GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1253                           llvm::StringRef sym_name, mlir::Type sym_type,
1254                           cir::GlobalLinkageKind linkage) {
1255   odsState.addAttribute(getSymNameAttrName(odsState.name),
1256                         odsBuilder.getStringAttr(sym_name));
1257   odsState.addAttribute(getSymTypeAttrName(odsState.name),
1258                         mlir::TypeAttr::get(sym_type));
1259 
1260   cir::GlobalLinkageKindAttr linkageAttr =
1261       cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage);
1262   odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr);
1263 
1264   odsState.addAttribute(getGlobalVisibilityAttrName(odsState.name),
1265                         cir::VisibilityAttr::get(odsBuilder.getContext()));
1266 }
1267 
1268 static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
1269                                              TypeAttr type,
1270                                              Attribute initAttr) {
1271   if (!op.isDeclaration()) {
1272     p << "= ";
1273     // This also prints the type...
1274     if (initAttr)
1275       printConstant(p, initAttr);
1276   } else {
1277     p << ": " << type;
1278   }
1279 }
1280 
1281 static ParseResult
1282 parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1283                                  Attribute &initialValueAttr) {
1284   mlir::Type opTy;
1285   if (parser.parseOptionalEqual().failed()) {
1286     // Absence of equal means a declaration, so we need to parse the type.
1287     //  cir.global @a : !cir.int<s, 32>
1288     if (parser.parseColonType(opTy))
1289       return failure();
1290   } else {
1291     // Parse constant with initializer, examples:
1292     //  cir.global @y = #cir.fp<1.250000e+00> : !cir.double
1293     //  cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
1294     if (parseConstantValue(parser, initialValueAttr).failed())
1295       return failure();
1296 
1297     assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
1298            "Non-typed attrs shouldn't appear here.");
1299     auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr);
1300     opTy = typedAttr.getType();
1301   }
1302 
1303   typeAttr = TypeAttr::get(opTy);
1304   return success();
1305 }
1306 
1307 //===----------------------------------------------------------------------===//
1308 // GetGlobalOp
1309 //===----------------------------------------------------------------------===//
1310 
1311 LogicalResult
1312 cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1313   // Verify that the result type underlying pointer type matches the type of
1314   // the referenced cir.global or cir.func op.
1315   mlir::Operation *op =
1316       symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
1317   if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op)))
1318     return emitOpError("'")
1319            << getName()
1320            << "' does not reference a valid cir.global or cir.func";
1321 
1322   mlir::Type symTy;
1323   if (auto g = dyn_cast<GlobalOp>(op)) {
1324     symTy = g.getSymType();
1325     assert(!cir::MissingFeatures::addressSpace());
1326     assert(!cir::MissingFeatures::opGlobalThreadLocal());
1327   } else if (auto f = dyn_cast<FuncOp>(op)) {
1328     symTy = f.getFunctionType();
1329   } else {
1330     llvm_unreachable("Unexpected operation for GetGlobalOp");
1331   }
1332 
1333   auto resultType = dyn_cast<PointerType>(getAddr().getType());
1334   if (!resultType || symTy != resultType.getPointee())
1335     return emitOpError("result type pointee type '")
1336            << resultType.getPointee() << "' does not match type " << symTy
1337            << " of the global @" << getName();
1338 
1339   return success();
1340 }
1341 
1342 //===----------------------------------------------------------------------===//
1343 // FuncOp
1344 //===----------------------------------------------------------------------===//
1345 
1346 /// Returns the name used for the linkage attribute. This *must* correspond to
1347 /// the name of the attribute in ODS.
1348 static llvm::StringRef getLinkageAttrNameString() { return "linkage"; }
1349 
1350 void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
1351                         StringRef name, FuncType type,
1352                         GlobalLinkageKind linkage) {
1353   result.addRegion();
1354   result.addAttribute(SymbolTable::getSymbolAttrName(),
1355                       builder.getStringAttr(name));
1356   result.addAttribute(getFunctionTypeAttrName(result.name),
1357                       TypeAttr::get(type));
1358   result.addAttribute(
1359       getLinkageAttrNameString(),
1360       GlobalLinkageKindAttr::get(builder.getContext(), linkage));
1361   result.addAttribute(getGlobalVisibilityAttrName(result.name),
1362                       cir::VisibilityAttr::get(builder.getContext()));
1363 }
1364 
1365 ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
1366   llvm::SMLoc loc = parser.getCurrentLocation();
1367   mlir::Builder &builder = parser.getBuilder();
1368 
1369   mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
1370   mlir::StringAttr visibilityNameAttr = getGlobalVisibilityAttrName(state.name);
1371   mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);
1372 
1373   // Default to external linkage if no keyword is provided.
1374   state.addAttribute(getLinkageAttrNameString(),
1375                      GlobalLinkageKindAttr::get(
1376                          parser.getContext(),
1377                          parseOptionalCIRKeyword<GlobalLinkageKind>(
1378                              parser, GlobalLinkageKind::ExternalLinkage)));
1379 
1380   ::llvm::StringRef visAttrStr;
1381   if (parser.parseOptionalKeyword(&visAttrStr, {"private", "public", "nested"})
1382           .succeeded()) {
1383     state.addAttribute(visNameAttr,
1384                        parser.getBuilder().getStringAttr(visAttrStr));
1385   }
1386 
1387   cir::VisibilityAttr cirVisibilityAttr;
1388   parseVisibilityAttr(parser, cirVisibilityAttr);
1389   state.addAttribute(visibilityNameAttr, cirVisibilityAttr);
1390 
1391   if (parser.parseOptionalKeyword(dsoLocalNameAttr).succeeded())
1392     state.addAttribute(dsoLocalNameAttr, parser.getBuilder().getUnitAttr());
1393 
1394   StringAttr nameAttr;
1395   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1396                              state.attributes))
1397     return failure();
1398   llvm::SmallVector<OpAsmParser::Argument, 8> arguments;
1399   llvm::SmallVector<mlir::Type> resultTypes;
1400   llvm::SmallVector<DictionaryAttr> resultAttrs;
1401   bool isVariadic = false;
1402   if (function_interface_impl::parseFunctionSignatureWithArguments(
1403           parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
1404           resultAttrs))
1405     return failure();
1406   llvm::SmallVector<mlir::Type> argTypes;
1407   for (OpAsmParser::Argument &arg : arguments)
1408     argTypes.push_back(arg.type);
1409 
1410   if (resultTypes.size() > 1) {
1411     return parser.emitError(
1412         loc, "functions with multiple return types are not supported");
1413   }
1414 
1415   mlir::Type returnType =
1416       (resultTypes.empty() ? cir::VoidType::get(builder.getContext())
1417                            : resultTypes.front());
1418 
1419   cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic);
1420   if (!fnType)
1421     return failure();
1422   state.addAttribute(getFunctionTypeAttrName(state.name),
1423                      TypeAttr::get(fnType));
1424 
1425   bool hasAlias = false;
1426   mlir::StringAttr aliaseeNameAttr = getAliaseeAttrName(state.name);
1427   if (parser.parseOptionalKeyword("alias").succeeded()) {
1428     if (parser.parseLParen().failed())
1429       return failure();
1430     mlir::StringAttr aliaseeAttr;
1431     if (parser.parseOptionalSymbolName(aliaseeAttr).failed())
1432       return failure();
1433     state.addAttribute(aliaseeNameAttr, FlatSymbolRefAttr::get(aliaseeAttr));
1434     if (parser.parseRParen().failed())
1435       return failure();
1436     hasAlias = true;
1437   }
1438 
1439   // Parse the optional function body.
1440   auto *body = state.addRegion();
1441   OptionalParseResult parseResult = parser.parseOptionalRegion(
1442       *body, arguments, /*enableNameShadowing=*/false);
1443   if (parseResult.has_value()) {
1444     if (hasAlias)
1445       return parser.emitError(loc, "function alias shall not have a body");
1446     if (failed(*parseResult))
1447       return failure();
1448     // Function body was parsed, make sure its not empty.
1449     if (body->empty())
1450       return parser.emitError(loc, "expected non-empty function body");
1451   }
1452 
1453   return success();
1454 }
1455 
1456 // This function corresponds to `llvm::GlobalValue::isDeclaration` and should
1457 // have a similar implementation. We don't currently ifuncs or materializable
1458 // functions, but those should be handled here as they are implemented.
1459 bool cir::FuncOp::isDeclaration() {
1460   assert(!cir::MissingFeatures::supportIFuncAttr());
1461 
1462   std::optional<StringRef> aliasee = getAliasee();
1463   if (!aliasee)
1464     return getFunctionBody().empty();
1465 
1466   // Aliases are always definitions.
1467   return false;
1468 }
1469 
1470 mlir::Region *cir::FuncOp::getCallableRegion() {
1471   // TODO(CIR): This function will have special handling for aliases and a
1472   // check for an external function, once those features have been upstreamed.
1473   return &getBody();
1474 }
1475 
1476 void cir::FuncOp::print(OpAsmPrinter &p) {
1477   if (getComdat())
1478     p << " comdat";
1479 
1480   if (getLinkage() != GlobalLinkageKind::ExternalLinkage)
1481     p << ' ' << stringifyGlobalLinkageKind(getLinkage());
1482 
1483   mlir::SymbolTable::Visibility vis = getVisibility();
1484   if (vis != mlir::SymbolTable::Visibility::Public)
1485     p << ' ' << vis;
1486 
1487   cir::VisibilityAttr cirVisibilityAttr = getGlobalVisibilityAttr();
1488   if (!cirVisibilityAttr.isDefault()) {
1489     p << ' ';
1490     printVisibilityAttr(p, cirVisibilityAttr);
1491   }
1492 
1493   if (getDsoLocal())
1494     p << " dso_local";
1495 
1496   p << ' ';
1497   p.printSymbolName(getSymName());
1498   cir::FuncType fnType = getFunctionType();
1499   function_interface_impl::printFunctionSignature(
1500       p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
1501 
1502   if (std::optional<StringRef> aliaseeName = getAliasee()) {
1503     p << " alias(";
1504     p.printSymbolName(*aliaseeName);
1505     p << ")";
1506   }
1507 
1508   // Print the body if this is not an external function.
1509   Region &body = getOperation()->getRegion(0);
1510   if (!body.empty()) {
1511     p << ' ';
1512     p.printRegion(body, /*printEntryBlockArgs=*/false,
1513                   /*printBlockTerminators=*/true);
1514   }
1515 }
1516 
1517 // TODO(CIR): The properties of functions that require verification haven't
1518 // been implemented yet.
1519 mlir::LogicalResult cir::FuncOp::verify() { return success(); }
1520 
1521 //===----------------------------------------------------------------------===//
1522 // BinOp
1523 //===----------------------------------------------------------------------===//
1524 LogicalResult cir::BinOp::verify() {
1525   bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
1526   bool saturated = getSaturated();
1527 
1528   if (!isa<cir::IntType>(getType()) && noWrap)
1529     return emitError()
1530            << "only operations on integer values may have nsw/nuw flags";
1531 
1532   bool noWrapOps = getKind() == cir::BinOpKind::Add ||
1533                    getKind() == cir::BinOpKind::Sub ||
1534                    getKind() == cir::BinOpKind::Mul;
1535 
1536   bool saturatedOps =
1537       getKind() == cir::BinOpKind::Add || getKind() == cir::BinOpKind::Sub;
1538 
1539   if (noWrap && !noWrapOps)
1540     return emitError() << "The nsw/nuw flags are applicable to opcodes: 'add', "
1541                           "'sub' and 'mul'";
1542   if (saturated && !saturatedOps)
1543     return emitError() << "The saturated flag is applicable to opcodes: 'add' "
1544                           "and 'sub'";
1545   if (noWrap && saturated)
1546     return emitError() << "The nsw/nuw flags and the saturated flag are "
1547                           "mutually exclusive";
1548 
1549   assert(!cir::MissingFeatures::complexType());
1550   // TODO(cir): verify for complex binops
1551 
1552   return mlir::success();
1553 }
1554 
1555 //===----------------------------------------------------------------------===//
1556 // TernaryOp
1557 //===----------------------------------------------------------------------===//
1558 
1559 /// Given the region at `point`, or the parent operation if `point` is None,
1560 /// return the successor regions. These are the regions that may be selected
1561 /// during the flow of control. `operands` is a set of optional attributes that
1562 /// correspond to a constant value for each operand, or null if that operand is
1563 /// not a constant.
1564 void cir::TernaryOp::getSuccessorRegions(
1565     mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1566   // The `true` and the `false` region branch back to the parent operation.
1567   if (!point.isParent()) {
1568     regions.push_back(RegionSuccessor(this->getODSResults(0)));
1569     return;
1570   }
1571 
1572   // When branching from the parent operation, both the true and false
1573   // regions are considered possible successors
1574   regions.push_back(RegionSuccessor(&getTrueRegion()));
1575   regions.push_back(RegionSuccessor(&getFalseRegion()));
1576 }
1577 
1578 void cir::TernaryOp::build(
1579     OpBuilder &builder, OperationState &result, Value cond,
1580     function_ref<void(OpBuilder &, Location)> trueBuilder,
1581     function_ref<void(OpBuilder &, Location)> falseBuilder) {
1582   result.addOperands(cond);
1583   OpBuilder::InsertionGuard guard(builder);
1584   Region *trueRegion = result.addRegion();
1585   Block *block = builder.createBlock(trueRegion);
1586   trueBuilder(builder, result.location);
1587   Region *falseRegion = result.addRegion();
1588   builder.createBlock(falseRegion);
1589   falseBuilder(builder, result.location);
1590 
1591   auto yield = dyn_cast<YieldOp>(block->getTerminator());
1592   assert((yield && yield.getNumOperands() <= 1) &&
1593          "expected zero or one result type");
1594   if (yield.getNumOperands() == 1)
1595     result.addTypes(TypeRange{yield.getOperandTypes().front()});
1596 }
1597 
1598 //===----------------------------------------------------------------------===//
1599 // SelectOp
1600 //===----------------------------------------------------------------------===//
1601 
1602 OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
1603   mlir::Attribute condition = adaptor.getCondition();
1604   if (condition) {
1605     bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
1606     return conditionValue ? getTrueValue() : getFalseValue();
1607   }
1608 
1609   // cir.select if %0 then x else x -> x
1610   mlir::Attribute trueValue = adaptor.getTrueValue();
1611   mlir::Attribute falseValue = adaptor.getFalseValue();
1612   if (trueValue == falseValue)
1613     return trueValue;
1614   if (getTrueValue() == getFalseValue())
1615     return getTrueValue();
1616 
1617   return {};
1618 }
1619 
1620 //===----------------------------------------------------------------------===//
1621 // ShiftOp
1622 //===----------------------------------------------------------------------===//
1623 LogicalResult cir::ShiftOp::verify() {
1624   mlir::Operation *op = getOperation();
1625   auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
1626   auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
1627   if (!op0VecTy ^ !op1VecTy)
1628     return emitOpError() << "input types cannot be one vector and one scalar";
1629 
1630   if (op0VecTy) {
1631     if (op0VecTy.getSize() != op1VecTy.getSize())
1632       return emitOpError() << "input vector types must have the same size";
1633 
1634     auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType());
1635     if (!opResultTy)
1636       return emitOpError() << "the type of the result must be a vector "
1637                            << "if it is vector shift";
1638 
1639     auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
1640     auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
1641     if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
1642       return emitOpError()
1643              << "vector operands do not have the same elements sizes";
1644 
1645     auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
1646     if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
1647       return emitOpError() << "vector operands and result type do not have the "
1648                               "same elements sizes";
1649   }
1650 
1651   return mlir::success();
1652 }
1653 
1654 //===----------------------------------------------------------------------===//
1655 // UnaryOp
1656 //===----------------------------------------------------------------------===//
1657 
1658 LogicalResult cir::UnaryOp::verify() {
1659   switch (getKind()) {
1660   case cir::UnaryOpKind::Inc:
1661   case cir::UnaryOpKind::Dec:
1662   case cir::UnaryOpKind::Plus:
1663   case cir::UnaryOpKind::Minus:
1664   case cir::UnaryOpKind::Not:
1665     // Nothing to verify.
1666     return success();
1667   }
1668 
1669   llvm_unreachable("Unknown UnaryOp kind?");
1670 }
1671 
1672 static bool isBoolNot(cir::UnaryOp op) {
1673   return isa<cir::BoolType>(op.getInput().getType()) &&
1674          op.getKind() == cir::UnaryOpKind::Not;
1675 }
1676 
1677 // This folder simplifies the sequential boolean not operations.
1678 // For instance, the next two unary operations will be eliminated:
1679 //
1680 // ```mlir
1681 // %1 = cir.unary(not, %0) : !cir.bool, !cir.bool
1682 // %2 = cir.unary(not, %1) : !cir.bool, !cir.bool
1683 // ```
1684 //
1685 // and the argument of the first one (%0) will be used instead.
1686 OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
1687   if (isBoolNot(*this))
1688     if (auto previous = dyn_cast_or_null<UnaryOp>(getInput().getDefiningOp()))
1689       if (isBoolNot(previous))
1690         return previous.getInput();
1691 
1692   return {};
1693 }
1694 
1695 //===----------------------------------------------------------------------===//
1696 // GetMemberOp Definitions
1697 //===----------------------------------------------------------------------===//
1698 
1699 LogicalResult cir::GetMemberOp::verify() {
1700   const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee());
1701   if (!recordTy)
1702     return emitError() << "expected pointer to a record type";
1703 
1704   if (recordTy.getMembers().size() <= getIndex())
1705     return emitError() << "member index out of bounds";
1706 
1707   if (recordTy.getMembers()[getIndex()] != getType().getPointee())
1708     return emitError() << "member type mismatch";
1709 
1710   return mlir::success();
1711 }
1712 
1713 //===----------------------------------------------------------------------===//
1714 // VecCreateOp
1715 //===----------------------------------------------------------------------===//
1716 
1717 OpFoldResult cir::VecCreateOp::fold(FoldAdaptor adaptor) {
1718   if (llvm::any_of(getElements(), [](mlir::Value value) {
1719         return !mlir::isa<cir::ConstantOp>(value.getDefiningOp());
1720       }))
1721     return {};
1722 
1723   return cir::ConstVectorAttr::get(
1724       getType(), mlir::ArrayAttr::get(getContext(), adaptor.getElements()));
1725 }
1726 
1727 LogicalResult cir::VecCreateOp::verify() {
1728   // Verify that the number of arguments matches the number of elements in the
1729   // vector, and that the type of all the arguments matches the type of the
1730   // elements in the vector.
1731   const cir::VectorType vecTy = getType();
1732   if (getElements().size() != vecTy.getSize()) {
1733     return emitOpError() << "operand count of " << getElements().size()
1734                          << " doesn't match vector type " << vecTy
1735                          << " element count of " << vecTy.getSize();
1736   }
1737 
1738   const mlir::Type elementType = vecTy.getElementType();
1739   for (const mlir::Value element : getElements()) {
1740     if (element.getType() != elementType) {
1741       return emitOpError() << "operand type " << element.getType()
1742                            << " doesn't match vector element type "
1743                            << elementType;
1744     }
1745   }
1746 
1747   return success();
1748 }
1749 
1750 //===----------------------------------------------------------------------===//
1751 // VecExtractOp
1752 //===----------------------------------------------------------------------===//
1753 
1754 OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
1755   const auto vectorAttr =
1756       llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
1757   if (!vectorAttr)
1758     return {};
1759 
1760   const auto indexAttr =
1761       llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
1762   if (!indexAttr)
1763     return {};
1764 
1765   const mlir::ArrayAttr elements = vectorAttr.getElts();
1766   const uint64_t index = indexAttr.getUInt();
1767   if (index >= elements.size())
1768     return {};
1769 
1770   return elements[index];
1771 }
1772 
1773 //===----------------------------------------------------------------------===//
1774 // VecCmpOp
1775 //===----------------------------------------------------------------------===//
1776 
1777 OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
1778   auto lhsVecAttr =
1779       mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
1780   auto rhsVecAttr =
1781       mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
1782   if (!lhsVecAttr || !rhsVecAttr)
1783     return {};
1784 
1785   mlir::Type inputElemTy =
1786       mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
1787   if (!isAnyIntegerOrFloatingPointType(inputElemTy))
1788     return {};
1789 
1790   cir::CmpOpKind opKind = adaptor.getKind();
1791   mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
1792   mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
1793   uint64_t vecSize = lhsVecElhs.size();
1794 
1795   SmallVector<mlir::Attribute, 16> elements(vecSize);
1796   bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
1797   for (uint64_t i = 0; i < vecSize; i++) {
1798     mlir::Attribute lhsAttr = lhsVecElhs[i];
1799     mlir::Attribute rhsAttr = rhsVecElhs[i];
1800     int cmpResult = 0;
1801     switch (opKind) {
1802     case cir::CmpOpKind::lt: {
1803       if (isIntAttr) {
1804         cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
1805                     mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
1806       } else {
1807         cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
1808                     mlir::cast<cir::FPAttr>(rhsAttr).getValue();
1809       }
1810       break;
1811     }
1812     case cir::CmpOpKind::le: {
1813       if (isIntAttr) {
1814         cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
1815                     mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
1816       } else {
1817         cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
1818                     mlir::cast<cir::FPAttr>(rhsAttr).getValue();
1819       }
1820       break;
1821     }
1822     case cir::CmpOpKind::gt: {
1823       if (isIntAttr) {
1824         cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
1825                     mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
1826       } else {
1827         cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
1828                     mlir::cast<cir::FPAttr>(rhsAttr).getValue();
1829       }
1830       break;
1831     }
1832     case cir::CmpOpKind::ge: {
1833       if (isIntAttr) {
1834         cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
1835                     mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
1836       } else {
1837         cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
1838                     mlir::cast<cir::FPAttr>(rhsAttr).getValue();
1839       }
1840       break;
1841     }
1842     case cir::CmpOpKind::eq: {
1843       if (isIntAttr) {
1844         cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
1845                     mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
1846       } else {
1847         cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
1848                     mlir::cast<cir::FPAttr>(rhsAttr).getValue();
1849       }
1850       break;
1851     }
1852     case cir::CmpOpKind::ne: {
1853       if (isIntAttr) {
1854         cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
1855                     mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
1856       } else {
1857         cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
1858                     mlir::cast<cir::FPAttr>(rhsAttr).getValue();
1859       }
1860       break;
1861     }
1862     }
1863 
1864     elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
1865   }
1866 
1867   return cir::ConstVectorAttr::get(
1868       getType(), mlir::ArrayAttr::get(getContext(), elements));
1869 }
1870 
1871 //===----------------------------------------------------------------------===//
1872 // VecShuffleOp
1873 //===----------------------------------------------------------------------===//
1874 
1875 OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
1876   auto vec1Attr =
1877       mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
1878   auto vec2Attr =
1879       mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
1880   if (!vec1Attr || !vec2Attr)
1881     return {};
1882 
1883   mlir::Type vec1ElemTy =
1884       mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();
1885 
1886   mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
1887   mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
1888   mlir::ArrayAttr indicesElts = adaptor.getIndices();
1889 
1890   SmallVector<mlir::Attribute, 16> elements;
1891   elements.reserve(indicesElts.size());
1892 
1893   uint64_t vec1Size = vec1Elts.size();
1894   for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
1895     if (idxAttr.getSInt() == -1) {
1896       elements.push_back(cir::UndefAttr::get(vec1ElemTy));
1897       continue;
1898     }
1899 
1900     uint64_t idxValue = idxAttr.getUInt();
1901     elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
1902                                            : vec2Elts[idxValue - vec1Size]);
1903   }
1904 
1905   return cir::ConstVectorAttr::get(
1906       getType(), mlir::ArrayAttr::get(getContext(), elements));
1907 }
1908 
1909 LogicalResult cir::VecShuffleOp::verify() {
1910   // The number of elements in the indices array must match the number of
1911   // elements in the result type.
1912   if (getIndices().size() != getResult().getType().getSize()) {
1913     return emitOpError() << ": the number of elements in " << getIndices()
1914                          << " and " << getResult().getType() << " don't match";
1915   }
1916 
1917   // The element types of the two input vectors and of the result type must
1918   // match.
1919   if (getVec1().getType().getElementType() !=
1920       getResult().getType().getElementType()) {
1921     return emitOpError() << ": element types of " << getVec1().getType()
1922                          << " and " << getResult().getType() << " don't match";
1923   }
1924 
1925   const uint64_t maxValidIndex =
1926       getVec1().getType().getSize() + getVec2().getType().getSize() - 1;
1927   if (llvm::any_of(
1928           getIndices().getAsRange<cir::IntAttr>(), [&](cir::IntAttr idxAttr) {
1929             return idxAttr.getSInt() != -1 && idxAttr.getUInt() > maxValidIndex;
1930           })) {
1931     return emitOpError() << ": index for __builtin_shufflevector must be "
1932                             "less than the total number of vector elements";
1933   }
1934   return success();
1935 }
1936 
1937 //===----------------------------------------------------------------------===//
1938 // VecShuffleDynamicOp
1939 //===----------------------------------------------------------------------===//
1940 
1941 OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
1942   mlir::Attribute vec = adaptor.getVec();
1943   mlir::Attribute indices = adaptor.getIndices();
1944   if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
1945       mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
1946     auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
1947     auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
1948 
1949     mlir::ArrayAttr vecElts = vecAttr.getElts();
1950     mlir::ArrayAttr indicesElts = indicesAttr.getElts();
1951 
1952     const uint64_t numElements = vecElts.size();
1953 
1954     SmallVector<mlir::Attribute, 16> elements;
1955     elements.reserve(numElements);
1956 
1957     const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
1958     for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
1959       uint64_t idxValue = idxAttr.getUInt();
1960       uint64_t newIdx = idxValue & maskBits;
1961       elements.push_back(vecElts[newIdx]);
1962     }
1963 
1964     return cir::ConstVectorAttr::get(
1965         getType(), mlir::ArrayAttr::get(getContext(), elements));
1966   }
1967 
1968   return {};
1969 }
1970 
1971 LogicalResult cir::VecShuffleDynamicOp::verify() {
1972   // The number of elements in the two input vectors must match.
1973   if (getVec().getType().getSize() !=
1974       mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
1975     return emitOpError() << ": the number of elements in " << getVec().getType()
1976                          << " and " << getIndices().getType() << " don't match";
1977   }
1978   return success();
1979 }
1980 
1981 //===----------------------------------------------------------------------===//
1982 // VecTernaryOp
1983 //===----------------------------------------------------------------------===//
1984 
1985 LogicalResult cir::VecTernaryOp::verify() {
1986   // Verify that the condition operand has the same number of elements as the
1987   // other operands.  (The automatic verification already checked that all
1988   // operands are vector types and that the second and third operands are the
1989   // same type.)
1990   if (getCond().getType().getSize() != getLhs().getType().getSize()) {
1991     return emitOpError() << ": the number of elements in "
1992                          << getCond().getType() << " and " << getLhs().getType()
1993                          << " don't match";
1994   }
1995   return success();
1996 }
1997 
1998 OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
1999   mlir::Attribute cond = adaptor.getCond();
2000   mlir::Attribute lhs = adaptor.getLhs();
2001   mlir::Attribute rhs = adaptor.getRhs();
2002 
2003   if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
2004       !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
2005       !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
2006     return {};
2007   auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
2008   auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
2009   auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
2010 
2011   mlir::ArrayAttr condElts = condVec.getElts();
2012 
2013   SmallVector<mlir::Attribute, 16> elements;
2014   elements.reserve(condElts.size());
2015 
2016   for (const auto &[idx, condAttr] :
2017        llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
2018     if (condAttr.getSInt()) {
2019       elements.push_back(lhsVec.getElts()[idx]);
2020     } else {
2021       elements.push_back(rhsVec.getElts()[idx]);
2022     }
2023   }
2024 
2025   cir::VectorType vecTy = getLhs().getType();
2026   return cir::ConstVectorAttr::get(
2027       vecTy, mlir::ArrayAttr::get(getContext(), elements));
2028 }
2029 
2030 //===----------------------------------------------------------------------===//
2031 // ComplexCreateOp
2032 //===----------------------------------------------------------------------===//
2033 
2034 LogicalResult cir::ComplexCreateOp::verify() {
2035   if (getType().getElementType() != getReal().getType()) {
2036     emitOpError()
2037         << "operand type of cir.complex.create does not match its result type";
2038     return failure();
2039   }
2040 
2041   return success();
2042 }
2043 
2044 OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
2045   mlir::Attribute real = adaptor.getReal();
2046   mlir::Attribute imag = adaptor.getImag();
2047   if (!real || !imag)
2048     return {};
2049 
2050   // When both of real and imag are constants, we can fold the operation into an
2051   // `#cir.const_complex` operation.
2052   auto realAttr = mlir::cast<mlir::TypedAttr>(real);
2053   auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
2054   return cir::ConstComplexAttr::get(realAttr, imagAttr);
2055 }
2056 
2057 //===----------------------------------------------------------------------===//
2058 // ComplexRealOp
2059 //===----------------------------------------------------------------------===//
2060 
2061 LogicalResult cir::ComplexRealOp::verify() {
2062   if (getType() != getOperand().getType().getElementType()) {
2063     emitOpError() << ": result type does not match operand type";
2064     return failure();
2065   }
2066   return success();
2067 }
2068 
2069 OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
2070   if (auto complexCreateOp =
2071           dyn_cast_or_null<cir::ComplexCreateOp>(getOperand().getDefiningOp()))
2072     return complexCreateOp.getOperand(0);
2073 
2074   auto complex =
2075       mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
2076   return complex ? complex.getReal() : nullptr;
2077 }
2078 
2079 //===----------------------------------------------------------------------===//
2080 // ComplexImagOp
2081 //===----------------------------------------------------------------------===//
2082 
2083 LogicalResult cir::ComplexImagOp::verify() {
2084   if (getType() != getOperand().getType().getElementType()) {
2085     emitOpError() << ": result type does not match operand type";
2086     return failure();
2087   }
2088   return success();
2089 }
2090 
2091 OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
2092   if (auto complexCreateOp =
2093           dyn_cast_or_null<cir::ComplexCreateOp>(getOperand().getDefiningOp()))
2094     return complexCreateOp.getOperand(1);
2095 
2096   auto complex =
2097       mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
2098   return complex ? complex.getImag() : nullptr;
2099 }
2100 
2101 //===----------------------------------------------------------------------===//
2102 // ComplexRealPtrOp
2103 //===----------------------------------------------------------------------===//
2104 
2105 LogicalResult cir::ComplexRealPtrOp::verify() {
2106   mlir::Type resultPointeeTy = getType().getPointee();
2107   cir::PointerType operandPtrTy = getOperand().getType();
2108   auto operandPointeeTy =
2109       mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
2110 
2111   if (resultPointeeTy != operandPointeeTy.getElementType()) {
2112     return emitOpError() << ": result type does not match operand type";
2113   }
2114 
2115   return success();
2116 }
2117 
2118 //===----------------------------------------------------------------------===//
2119 // ComplexImagPtrOp
2120 //===----------------------------------------------------------------------===//
2121 
2122 LogicalResult cir::ComplexImagPtrOp::verify() {
2123   mlir::Type resultPointeeTy = getType().getPointee();
2124   cir::PointerType operandPtrTy = getOperand().getType();
2125   auto operandPointeeTy =
2126       mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
2127 
2128   if (resultPointeeTy != operandPointeeTy.getElementType()) {
2129     return emitOpError()
2130            << "cir.complex.imag_ptr result type does not match operand type";
2131   }
2132   return success();
2133 }
2134 
2135 //===----------------------------------------------------------------------===//
2136 // TableGen'd op method definitions
2137 //===----------------------------------------------------------------------===//
2138 
2139 #define GET_OP_CLASSES
2140 #include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
2141