1 //===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the CIR dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "clang/CIR/Dialect/IR/CIRDialect.h" 14 15 #include "clang/CIR/Dialect/IR/CIROpsEnums.h" 16 #include "clang/CIR/Dialect/IR/CIRTypes.h" 17 18 #include "mlir/Interfaces/ControlFlowInterfaces.h" 19 #include "mlir/Interfaces/FunctionImplementation.h" 20 21 #include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc" 22 #include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc" 23 #include "clang/CIR/MissingFeatures.h" 24 #include "llvm/Support/LogicalResult.h" 25 26 #include <numeric> 27 28 using namespace mlir; 29 using namespace cir; 30 31 //===----------------------------------------------------------------------===// 32 // CIR Dialect 33 //===----------------------------------------------------------------------===// 34 namespace { 35 struct CIROpAsmDialectInterface : public OpAsmDialectInterface { 36 using OpAsmDialectInterface::OpAsmDialectInterface; 37 38 AliasResult getAlias(Type type, raw_ostream &os) const final { 39 if (auto recordType = dyn_cast<cir::RecordType>(type)) { 40 StringAttr nameAttr = recordType.getName(); 41 if (!nameAttr) 42 os << "rec_anon_" << recordType.getKindAsStr(); 43 else 44 os << "rec_" << nameAttr.getValue(); 45 return AliasResult::OverridableAlias; 46 } 47 if (auto intType = dyn_cast<cir::IntType>(type)) { 48 // We only provide alias for standard integer types (i.e. integer types 49 // whose width is a power of 2 and at least 8). 50 unsigned width = intType.getWidth(); 51 if (width < 8 || !llvm::isPowerOf2_32(width)) 52 return AliasResult::NoAlias; 53 os << intType.getAlias(); 54 return AliasResult::OverridableAlias; 55 } 56 if (auto voidType = dyn_cast<cir::VoidType>(type)) { 57 os << voidType.getAlias(); 58 return AliasResult::OverridableAlias; 59 } 60 61 return AliasResult::NoAlias; 62 } 63 64 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 65 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) { 66 os << (boolAttr.getValue() ? "true" : "false"); 67 return AliasResult::FinalAlias; 68 } 69 if (auto bitfield = mlir::dyn_cast<cir::BitfieldInfoAttr>(attr)) { 70 os << "bfi_" << bitfield.getName().str(); 71 return AliasResult::FinalAlias; 72 } 73 return AliasResult::NoAlias; 74 } 75 }; 76 } // namespace 77 78 void cir::CIRDialect::initialize() { 79 registerTypes(); 80 registerAttributes(); 81 addOperations< 82 #define GET_OP_LIST 83 #include "clang/CIR/Dialect/IR/CIROps.cpp.inc" 84 >(); 85 addInterfaces<CIROpAsmDialectInterface>(); 86 } 87 88 Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder, 89 mlir::Attribute value, 90 mlir::Type type, 91 mlir::Location loc) { 92 return builder.create<cir::ConstantOp>(loc, type, 93 mlir::cast<mlir::TypedAttr>(value)); 94 } 95 96 //===----------------------------------------------------------------------===// 97 // Helpers 98 //===----------------------------------------------------------------------===// 99 100 // Parses one of the keywords provided in the list `keywords` and returns the 101 // position of the parsed keyword in the list. If none of the keywords from the 102 // list is parsed, returns -1. 103 static int parseOptionalKeywordAlternative(AsmParser &parser, 104 ArrayRef<llvm::StringRef> keywords) { 105 for (auto en : llvm::enumerate(keywords)) { 106 if (succeeded(parser.parseOptionalKeyword(en.value()))) 107 return en.index(); 108 } 109 return -1; 110 } 111 112 namespace { 113 template <typename Ty> struct EnumTraits {}; 114 115 #define REGISTER_ENUM_TYPE(Ty) \ 116 template <> struct EnumTraits<cir::Ty> { \ 117 static llvm::StringRef stringify(cir::Ty value) { \ 118 return stringify##Ty(value); \ 119 } \ 120 static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); } \ 121 } 122 123 REGISTER_ENUM_TYPE(GlobalLinkageKind); 124 REGISTER_ENUM_TYPE(VisibilityKind); 125 REGISTER_ENUM_TYPE(SideEffect); 126 } // namespace 127 128 /// Parse an enum from the keyword, or default to the provided default value. 129 /// The return type is the enum type by default, unless overriden with the 130 /// second template argument. 131 template <typename EnumTy, typename RetTy = EnumTy> 132 static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) { 133 llvm::SmallVector<llvm::StringRef, 10> names; 134 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 135 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 136 137 int index = parseOptionalKeywordAlternative(parser, names); 138 if (index == -1) 139 return static_cast<RetTy>(defaultValue); 140 return static_cast<RetTy>(index); 141 } 142 143 /// Parse an enum from the keyword, return failure if the keyword is not found. 144 template <typename EnumTy, typename RetTy = EnumTy> 145 static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) { 146 llvm::SmallVector<llvm::StringRef, 10> names; 147 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 148 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 149 150 int index = parseOptionalKeywordAlternative(parser, names); 151 if (index == -1) 152 return failure(); 153 result = static_cast<RetTy>(index); 154 return success(); 155 } 156 157 // Check if a region's termination omission is valid and, if so, creates and 158 // inserts the omitted terminator into the region. 159 static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region ®ion, 160 SMLoc errLoc) { 161 Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); 162 OpBuilder builder(parser.getBuilder().getContext()); 163 164 // Insert empty block in case the region is empty to ensure the terminator 165 // will be inserted 166 if (region.empty()) 167 builder.createBlock(®ion); 168 169 Block &block = region.back(); 170 // Region is properly terminated: nothing to do. 171 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) 172 return success(); 173 174 // Check for invalid terminator omissions. 175 if (!region.hasOneBlock()) 176 return parser.emitError(errLoc, 177 "multi-block region must not omit terminator"); 178 179 // Terminator was omitted correctly: recreate it. 180 builder.setInsertionPointToEnd(&block); 181 builder.create<cir::YieldOp>(eLoc); 182 return success(); 183 } 184 185 // True if the region's terminator should be omitted. 186 static bool omitRegionTerm(mlir::Region &r) { 187 const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty(); 188 const auto yieldsNothing = [&r]() { 189 auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator()); 190 return y && y.getArgs().empty(); 191 }; 192 return singleNonEmptyBlock && yieldsNothing(); 193 } 194 195 void printVisibilityAttr(OpAsmPrinter &printer, 196 cir::VisibilityAttr &visibility) { 197 switch (visibility.getValue()) { 198 case cir::VisibilityKind::Hidden: 199 printer << "hidden"; 200 break; 201 case cir::VisibilityKind::Protected: 202 printer << "protected"; 203 break; 204 case cir::VisibilityKind::Default: 205 break; 206 } 207 } 208 209 void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility) { 210 cir::VisibilityKind visibilityKind = 211 parseOptionalCIRKeyword(parser, cir::VisibilityKind::Default); 212 visibility = cir::VisibilityAttr::get(parser.getContext(), visibilityKind); 213 } 214 215 //===----------------------------------------------------------------------===// 216 // CIR Custom Parsers/Printers 217 //===----------------------------------------------------------------------===// 218 219 static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser, 220 mlir::Region ®ion) { 221 auto regionLoc = parser.getCurrentLocation(); 222 if (parser.parseRegion(region)) 223 return failure(); 224 if (ensureRegionTerm(parser, region, regionLoc).failed()) 225 return failure(); 226 return success(); 227 } 228 229 static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer, 230 cir::ScopeOp &op, 231 mlir::Region ®ion) { 232 printer.printRegion(region, 233 /*printEntryBlockArgs=*/false, 234 /*printBlockTerminators=*/!omitRegionTerm(region)); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // AllocaOp 239 //===----------------------------------------------------------------------===// 240 241 void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder, 242 mlir::OperationState &odsState, mlir::Type addr, 243 mlir::Type allocaType, llvm::StringRef name, 244 mlir::IntegerAttr alignment) { 245 odsState.addAttribute(getAllocaTypeAttrName(odsState.name), 246 mlir::TypeAttr::get(allocaType)); 247 odsState.addAttribute(getNameAttrName(odsState.name), 248 odsBuilder.getStringAttr(name)); 249 if (alignment) { 250 odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment); 251 } 252 odsState.addTypes(addr); 253 } 254 255 //===----------------------------------------------------------------------===// 256 // BreakOp 257 //===----------------------------------------------------------------------===// 258 259 LogicalResult cir::BreakOp::verify() { 260 assert(!cir::MissingFeatures::switchOp()); 261 if (!getOperation()->getParentOfType<LoopOpInterface>() && 262 !getOperation()->getParentOfType<SwitchOp>()) 263 return emitOpError("must be within a loop"); 264 return success(); 265 } 266 267 //===----------------------------------------------------------------------===// 268 // ConditionOp 269 //===----------------------------------------------------------------------===// 270 271 //===---------------------------------- 272 // BranchOpTerminatorInterface Methods 273 //===---------------------------------- 274 275 void cir::ConditionOp::getSuccessorRegions( 276 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) { 277 // TODO(cir): The condition value may be folded to a constant, narrowing 278 // down its list of possible successors. 279 280 // Parent is a loop: condition may branch to the body or to the parent op. 281 if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) { 282 regions.emplace_back(&loopOp.getBody(), loopOp.getBody().getArguments()); 283 regions.emplace_back(loopOp->getResults()); 284 } 285 286 assert(!cir::MissingFeatures::awaitOp()); 287 } 288 289 MutableOperandRange 290 cir::ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { 291 // No values are yielded to the successor region. 292 return MutableOperandRange(getOperation(), 0, 0); 293 } 294 295 LogicalResult cir::ConditionOp::verify() { 296 assert(!cir::MissingFeatures::awaitOp()); 297 if (!isa<LoopOpInterface>(getOperation()->getParentOp())) 298 return emitOpError("condition must be within a conditional region"); 299 return success(); 300 } 301 302 //===----------------------------------------------------------------------===// 303 // ConstantOp 304 //===----------------------------------------------------------------------===// 305 306 static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, 307 mlir::Attribute attrType) { 308 if (isa<cir::ConstPtrAttr>(attrType)) { 309 if (!mlir::isa<cir::PointerType>(opType)) 310 return op->emitOpError( 311 "pointer constant initializing a non-pointer type"); 312 return success(); 313 } 314 315 if (isa<cir::ZeroAttr>(attrType)) { 316 if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>( 317 opType)) 318 return success(); 319 return op->emitOpError( 320 "zero expects struct, array, vector, or complex type"); 321 } 322 323 if (mlir::isa<cir::BoolAttr>(attrType)) { 324 if (!mlir::isa<cir::BoolType>(opType)) 325 return op->emitOpError("result type (") 326 << opType << ") must be '!cir.bool' for '" << attrType << "'"; 327 return success(); 328 } 329 330 if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) { 331 auto at = cast<TypedAttr>(attrType); 332 if (at.getType() != opType) { 333 return op->emitOpError("result type (") 334 << opType << ") does not match value type (" << at.getType() 335 << ")"; 336 } 337 return success(); 338 } 339 340 if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr, 341 cir::ConstComplexAttr>(attrType)) 342 return success(); 343 344 assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?"); 345 return op->emitOpError("global with type ") 346 << cast<TypedAttr>(attrType).getType() << " not yet supported"; 347 } 348 349 LogicalResult cir::ConstantOp::verify() { 350 // ODS already generates checks to make sure the result type is valid. We just 351 // need to additionally check that the value's attribute type is consistent 352 // with the result type. 353 return checkConstantTypes(getOperation(), getType(), getValue()); 354 } 355 356 OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) { 357 return getValue(); 358 } 359 360 //===----------------------------------------------------------------------===// 361 // ContinueOp 362 //===----------------------------------------------------------------------===// 363 364 LogicalResult cir::ContinueOp::verify() { 365 if (!getOperation()->getParentOfType<LoopOpInterface>()) 366 return emitOpError("must be within a loop"); 367 return success(); 368 } 369 370 //===----------------------------------------------------------------------===// 371 // CastOp 372 //===----------------------------------------------------------------------===// 373 374 LogicalResult cir::CastOp::verify() { 375 mlir::Type resType = getType(); 376 mlir::Type srcType = getSrc().getType(); 377 378 if (mlir::isa<cir::VectorType>(srcType) && 379 mlir::isa<cir::VectorType>(resType)) { 380 // Use the element type of the vector to verify the cast kind. (Except for 381 // bitcast, see below.) 382 srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType(); 383 resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType(); 384 } 385 386 switch (getKind()) { 387 case cir::CastKind::int_to_bool: { 388 if (!mlir::isa<cir::BoolType>(resType)) 389 return emitOpError() << "requires !cir.bool type for result"; 390 if (!mlir::isa<cir::IntType>(srcType)) 391 return emitOpError() << "requires !cir.int type for source"; 392 return success(); 393 } 394 case cir::CastKind::ptr_to_bool: { 395 if (!mlir::isa<cir::BoolType>(resType)) 396 return emitOpError() << "requires !cir.bool type for result"; 397 if (!mlir::isa<cir::PointerType>(srcType)) 398 return emitOpError() << "requires !cir.ptr type for source"; 399 return success(); 400 } 401 case cir::CastKind::integral: { 402 if (!mlir::isa<cir::IntType>(resType)) 403 return emitOpError() << "requires !cir.int type for result"; 404 if (!mlir::isa<cir::IntType>(srcType)) 405 return emitOpError() << "requires !cir.int type for source"; 406 return success(); 407 } 408 case cir::CastKind::array_to_ptrdecay: { 409 const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType); 410 const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType); 411 if (!arrayPtrTy || !flatPtrTy) 412 return emitOpError() << "requires !cir.ptr type for source and result"; 413 414 // TODO(CIR): Make sure the AddrSpace of both types are equals 415 return success(); 416 } 417 case cir::CastKind::bitcast: { 418 // Handle the pointer types first. 419 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType); 420 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType); 421 422 if (srcPtrTy && resPtrTy) { 423 return success(); 424 } 425 426 return success(); 427 } 428 case cir::CastKind::floating: { 429 if (!mlir::isa<cir::FPTypeInterface>(srcType) || 430 !mlir::isa<cir::FPTypeInterface>(resType)) 431 return emitOpError() << "requires !cir.float type for source and result"; 432 return success(); 433 } 434 case cir::CastKind::float_to_int: { 435 if (!mlir::isa<cir::FPTypeInterface>(srcType)) 436 return emitOpError() << "requires !cir.float type for source"; 437 if (!mlir::dyn_cast<cir::IntType>(resType)) 438 return emitOpError() << "requires !cir.int type for result"; 439 return success(); 440 } 441 case cir::CastKind::int_to_ptr: { 442 if (!mlir::dyn_cast<cir::IntType>(srcType)) 443 return emitOpError() << "requires !cir.int type for source"; 444 if (!mlir::dyn_cast<cir::PointerType>(resType)) 445 return emitOpError() << "requires !cir.ptr type for result"; 446 return success(); 447 } 448 case cir::CastKind::ptr_to_int: { 449 if (!mlir::dyn_cast<cir::PointerType>(srcType)) 450 return emitOpError() << "requires !cir.ptr type for source"; 451 if (!mlir::dyn_cast<cir::IntType>(resType)) 452 return emitOpError() << "requires !cir.int type for result"; 453 return success(); 454 } 455 case cir::CastKind::float_to_bool: { 456 if (!mlir::isa<cir::FPTypeInterface>(srcType)) 457 return emitOpError() << "requires !cir.float type for source"; 458 if (!mlir::isa<cir::BoolType>(resType)) 459 return emitOpError() << "requires !cir.bool type for result"; 460 return success(); 461 } 462 case cir::CastKind::bool_to_int: { 463 if (!mlir::isa<cir::BoolType>(srcType)) 464 return emitOpError() << "requires !cir.bool type for source"; 465 if (!mlir::isa<cir::IntType>(resType)) 466 return emitOpError() << "requires !cir.int type for result"; 467 return success(); 468 } 469 case cir::CastKind::int_to_float: { 470 if (!mlir::isa<cir::IntType>(srcType)) 471 return emitOpError() << "requires !cir.int type for source"; 472 if (!mlir::isa<cir::FPTypeInterface>(resType)) 473 return emitOpError() << "requires !cir.float type for result"; 474 return success(); 475 } 476 case cir::CastKind::bool_to_float: { 477 if (!mlir::isa<cir::BoolType>(srcType)) 478 return emitOpError() << "requires !cir.bool type for source"; 479 if (!mlir::isa<cir::FPTypeInterface>(resType)) 480 return emitOpError() << "requires !cir.float type for result"; 481 return success(); 482 } 483 case cir::CastKind::address_space: { 484 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType); 485 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType); 486 if (!srcPtrTy || !resPtrTy) 487 return emitOpError() << "requires !cir.ptr type for source and result"; 488 if (srcPtrTy.getPointee() != resPtrTy.getPointee()) 489 return emitOpError() << "requires two types differ in addrspace only"; 490 return success(); 491 } 492 default: 493 llvm_unreachable("Unknown CastOp kind?"); 494 } 495 } 496 497 static bool isIntOrBoolCast(cir::CastOp op) { 498 auto kind = op.getKind(); 499 return kind == cir::CastKind::bool_to_int || 500 kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral; 501 } 502 503 static Value tryFoldCastChain(cir::CastOp op) { 504 cir::CastOp head = op, tail = op; 505 506 while (op) { 507 if (!isIntOrBoolCast(op)) 508 break; 509 head = op; 510 op = dyn_cast_or_null<cir::CastOp>(head.getSrc().getDefiningOp()); 511 } 512 513 if (head == tail) 514 return {}; 515 516 // if bool_to_int -> ... -> int_to_bool: take the bool 517 // as we had it was before all casts 518 if (head.getKind() == cir::CastKind::bool_to_int && 519 tail.getKind() == cir::CastKind::int_to_bool) 520 return head.getSrc(); 521 522 // if int_to_bool -> ... -> int_to_bool: take the result 523 // of the first one, as no other casts (and ext casts as well) 524 // don't change the first result 525 if (head.getKind() == cir::CastKind::int_to_bool && 526 tail.getKind() == cir::CastKind::int_to_bool) 527 return head.getResult(); 528 529 return {}; 530 } 531 532 OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) { 533 if (getSrc().getType() == getType()) { 534 switch (getKind()) { 535 case cir::CastKind::integral: { 536 // TODO: for sign differences, it's possible in certain conditions to 537 // create a new attribute that's capable of representing the source. 538 llvm::SmallVector<mlir::OpFoldResult, 1> foldResults; 539 auto foldOrder = getSrc().getDefiningOp()->fold(foldResults); 540 if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0])) 541 return mlir::cast<mlir::Attribute>(foldResults[0]); 542 return {}; 543 } 544 case cir::CastKind::bitcast: 545 case cir::CastKind::address_space: 546 case cir::CastKind::float_complex: 547 case cir::CastKind::int_complex: { 548 return getSrc(); 549 } 550 default: 551 return {}; 552 } 553 } 554 return tryFoldCastChain(*this); 555 } 556 557 //===----------------------------------------------------------------------===// 558 // CallOp 559 //===----------------------------------------------------------------------===// 560 561 mlir::OperandRange cir::CallOp::getArgOperands() { 562 if (isIndirect()) 563 return getArgs().drop_front(1); 564 return getArgs(); 565 } 566 567 mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() { 568 mlir::MutableOperandRange args = getArgsMutable(); 569 if (isIndirect()) 570 return args.slice(1, args.size() - 1); 571 return args; 572 } 573 574 mlir::Value cir::CallOp::getIndirectCall() { 575 assert(isIndirect()); 576 return getOperand(0); 577 } 578 579 /// Return the operand at index 'i'. 580 Value cir::CallOp::getArgOperand(unsigned i) { 581 if (isIndirect()) 582 ++i; 583 return getOperand(i); 584 } 585 586 /// Return the number of operands. 587 unsigned cir::CallOp::getNumArgOperands() { 588 if (isIndirect()) 589 return this->getOperation()->getNumOperands() - 1; 590 return this->getOperation()->getNumOperands(); 591 } 592 593 static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, 594 mlir::OperationState &result) { 595 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops; 596 llvm::SMLoc opsLoc; 597 mlir::FlatSymbolRefAttr calleeAttr; 598 llvm::ArrayRef<mlir::Type> allResultTypes; 599 600 // If we cannot parse a string callee, it means this is an indirect call. 601 if (!parser 602 .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(), 603 result.attributes) 604 .has_value()) { 605 OpAsmParser::UnresolvedOperand indirectVal; 606 // Do not resolve right now, since we need to figure out the type 607 if (parser.parseOperand(indirectVal).failed()) 608 return failure(); 609 ops.push_back(indirectVal); 610 } 611 612 if (parser.parseLParen()) 613 return mlir::failure(); 614 615 opsLoc = parser.getCurrentLocation(); 616 if (parser.parseOperandList(ops)) 617 return mlir::failure(); 618 if (parser.parseRParen()) 619 return mlir::failure(); 620 621 if (parser.parseOptionalKeyword("nothrow").succeeded()) 622 result.addAttribute(CIRDialect::getNoThrowAttrName(), 623 mlir::UnitAttr::get(parser.getContext())); 624 625 if (parser.parseOptionalKeyword("side_effect").succeeded()) { 626 if (parser.parseLParen().failed()) 627 return failure(); 628 cir::SideEffect sideEffect; 629 if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed()) 630 return failure(); 631 if (parser.parseRParen().failed()) 632 return failure(); 633 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect); 634 result.addAttribute(CIRDialect::getSideEffectAttrName(), attr); 635 } 636 637 if (parser.parseOptionalAttrDict(result.attributes)) 638 return ::mlir::failure(); 639 640 if (parser.parseColon()) 641 return ::mlir::failure(); 642 643 mlir::FunctionType opsFnTy; 644 if (parser.parseType(opsFnTy)) 645 return mlir::failure(); 646 647 allResultTypes = opsFnTy.getResults(); 648 result.addTypes(allResultTypes); 649 650 if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands)) 651 return mlir::failure(); 652 653 return mlir::success(); 654 } 655 656 static void printCallCommon(mlir::Operation *op, 657 mlir::FlatSymbolRefAttr calleeSym, 658 mlir::Value indirectCallee, 659 mlir::OpAsmPrinter &printer, bool isNothrow, 660 cir::SideEffect sideEffect) { 661 printer << ' '; 662 663 auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op); 664 auto ops = callLikeOp.getArgOperands(); 665 666 if (calleeSym) { 667 // Direct calls 668 printer.printAttributeWithoutType(calleeSym); 669 } else { 670 // Indirect calls 671 assert(indirectCallee); 672 printer << indirectCallee; 673 } 674 printer << "(" << ops << ")"; 675 676 if (isNothrow) 677 printer << " nothrow"; 678 679 if (sideEffect != cir::SideEffect::All) { 680 printer << " side_effect("; 681 printer << stringifySideEffect(sideEffect); 682 printer << ")"; 683 } 684 685 printer.printOptionalAttrDict(op->getAttrs(), 686 {CIRDialect::getCalleeAttrName(), 687 CIRDialect::getNoThrowAttrName(), 688 CIRDialect::getSideEffectAttrName()}); 689 690 printer << " : "; 691 printer.printFunctionalType(op->getOperands().getTypes(), 692 op->getResultTypes()); 693 } 694 695 mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser, 696 mlir::OperationState &result) { 697 return parseCallCommon(parser, result); 698 } 699 700 void cir::CallOp::print(mlir::OpAsmPrinter &p) { 701 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr; 702 cir::SideEffect sideEffect = getSideEffect(); 703 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(), 704 sideEffect); 705 } 706 707 static LogicalResult 708 verifyCallCommInSymbolUses(mlir::Operation *op, 709 SymbolTableCollection &symbolTable) { 710 auto fnAttr = 711 op->getAttrOfType<FlatSymbolRefAttr>(CIRDialect::getCalleeAttrName()); 712 if (!fnAttr) { 713 // This is an indirect call, thus we don't have to check the symbol uses. 714 return mlir::success(); 715 } 716 717 auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr); 718 if (!fn) 719 return op->emitOpError() << "'" << fnAttr.getValue() 720 << "' does not reference a valid function"; 721 722 auto callIf = dyn_cast<cir::CIRCallOpInterface>(op); 723 assert(callIf && "expected CIR call interface to be always available"); 724 725 // Verify that the operand and result types match the callee. Note that 726 // argument-checking is disabled for functions without a prototype. 727 auto fnType = fn.getFunctionType(); 728 if (!fn.getNoProto()) { 729 unsigned numCallOperands = callIf.getNumArgOperands(); 730 unsigned numFnOpOperands = fnType.getNumInputs(); 731 732 if (!fnType.isVarArg() && numCallOperands != numFnOpOperands) 733 return op->emitOpError("incorrect number of operands for callee"); 734 if (fnType.isVarArg() && numCallOperands < numFnOpOperands) 735 return op->emitOpError("too few operands for callee"); 736 737 for (unsigned i = 0, e = numFnOpOperands; i != e; ++i) 738 if (callIf.getArgOperand(i).getType() != fnType.getInput(i)) 739 return op->emitOpError("operand type mismatch: expected operand type ") 740 << fnType.getInput(i) << ", but provided " 741 << op->getOperand(i).getType() << " for operand number " << i; 742 } 743 744 assert(!cir::MissingFeatures::opCallCallConv()); 745 746 // Void function must not return any results. 747 if (fnType.hasVoidReturn() && op->getNumResults() != 0) 748 return op->emitOpError("callee returns void but call has results"); 749 750 // Non-void function calls must return exactly one result. 751 if (!fnType.hasVoidReturn() && op->getNumResults() != 1) 752 return op->emitOpError("incorrect number of results for callee"); 753 754 // Parent function and return value types must match. 755 if (!fnType.hasVoidReturn() && 756 op->getResultTypes().front() != fnType.getReturnType()) { 757 return op->emitOpError("result type mismatch: expected ") 758 << fnType.getReturnType() << ", but provided " 759 << op->getResult(0).getType(); 760 } 761 762 return mlir::success(); 763 } 764 765 LogicalResult 766 cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 767 return verifyCallCommInSymbolUses(*this, symbolTable); 768 } 769 770 //===----------------------------------------------------------------------===// 771 // ReturnOp 772 //===----------------------------------------------------------------------===// 773 774 static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op, 775 cir::FuncOp function) { 776 // ReturnOps currently only have a single optional operand. 777 if (op.getNumOperands() > 1) 778 return op.emitOpError() << "expects at most 1 return operand"; 779 780 // Ensure returned type matches the function signature. 781 auto expectedTy = function.getFunctionType().getReturnType(); 782 auto actualTy = 783 (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext()) 784 : op.getOperand(0).getType()); 785 if (actualTy != expectedTy) 786 return op.emitOpError() << "returns " << actualTy 787 << " but enclosing function returns " << expectedTy; 788 789 return mlir::success(); 790 } 791 792 mlir::LogicalResult cir::ReturnOp::verify() { 793 // Returns can be present in multiple different scopes, get the 794 // wrapping function and start from there. 795 auto *fnOp = getOperation()->getParentOp(); 796 while (!isa<cir::FuncOp>(fnOp)) 797 fnOp = fnOp->getParentOp(); 798 799 // Make sure return types match function return type. 800 if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed()) 801 return failure(); 802 803 return success(); 804 } 805 806 //===----------------------------------------------------------------------===// 807 // IfOp 808 //===----------------------------------------------------------------------===// 809 810 ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) { 811 // create the regions for 'then'. 812 result.regions.reserve(2); 813 Region *thenRegion = result.addRegion(); 814 Region *elseRegion = result.addRegion(); 815 816 mlir::Builder &builder = parser.getBuilder(); 817 OpAsmParser::UnresolvedOperand cond; 818 Type boolType = cir::BoolType::get(builder.getContext()); 819 820 if (parser.parseOperand(cond) || 821 parser.resolveOperand(cond, boolType, result.operands)) 822 return failure(); 823 824 // Parse 'then' region. 825 mlir::SMLoc parseThenLoc = parser.getCurrentLocation(); 826 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) 827 return failure(); 828 829 if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed()) 830 return failure(); 831 832 // If we find an 'else' keyword, parse the 'else' region. 833 if (!parser.parseOptionalKeyword("else")) { 834 mlir::SMLoc parseElseLoc = parser.getCurrentLocation(); 835 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) 836 return failure(); 837 if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed()) 838 return failure(); 839 } 840 841 // Parse the optional attribute list. 842 if (parser.parseOptionalAttrDict(result.attributes)) 843 return failure(); 844 return success(); 845 } 846 847 void cir::IfOp::print(OpAsmPrinter &p) { 848 p << " " << getCondition() << " "; 849 mlir::Region &thenRegion = this->getThenRegion(); 850 p.printRegion(thenRegion, 851 /*printEntryBlockArgs=*/false, 852 /*printBlockTerminators=*/!omitRegionTerm(thenRegion)); 853 854 // Print the 'else' regions if it exists and has a block. 855 mlir::Region &elseRegion = this->getElseRegion(); 856 if (!elseRegion.empty()) { 857 p << " else "; 858 p.printRegion(elseRegion, 859 /*printEntryBlockArgs=*/false, 860 /*printBlockTerminators=*/!omitRegionTerm(elseRegion)); 861 } 862 863 p.printOptionalAttrDict(getOperation()->getAttrs()); 864 } 865 866 /// Default callback for IfOp builders. 867 void cir::buildTerminatedBody(OpBuilder &builder, Location loc) { 868 // add cir.yield to end of the block 869 builder.create<cir::YieldOp>(loc); 870 } 871 872 /// Given the region at `index`, or the parent operation if `index` is None, 873 /// return the successor regions. These are the regions that may be selected 874 /// during the flow of control. `operands` is a set of optional attributes that 875 /// correspond to a constant value for each operand, or null if that operand is 876 /// not a constant. 877 void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point, 878 SmallVectorImpl<RegionSuccessor> ®ions) { 879 // The `then` and the `else` region branch back to the parent operation. 880 if (!point.isParent()) { 881 regions.push_back(RegionSuccessor()); 882 return; 883 } 884 885 // Don't consider the else region if it is empty. 886 Region *elseRegion = &this->getElseRegion(); 887 if (elseRegion->empty()) 888 elseRegion = nullptr; 889 890 // If the condition isn't constant, both regions may be executed. 891 regions.push_back(RegionSuccessor(&getThenRegion())); 892 // If the else region does not exist, it is not a viable successor. 893 if (elseRegion) 894 regions.push_back(RegionSuccessor(elseRegion)); 895 896 return; 897 } 898 899 void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond, 900 bool withElseRegion, BuilderCallbackRef thenBuilder, 901 BuilderCallbackRef elseBuilder) { 902 assert(thenBuilder && "the builder callback for 'then' must be present"); 903 result.addOperands(cond); 904 905 OpBuilder::InsertionGuard guard(builder); 906 Region *thenRegion = result.addRegion(); 907 builder.createBlock(thenRegion); 908 thenBuilder(builder, result.location); 909 910 Region *elseRegion = result.addRegion(); 911 if (!withElseRegion) 912 return; 913 914 builder.createBlock(elseRegion); 915 elseBuilder(builder, result.location); 916 } 917 918 //===----------------------------------------------------------------------===// 919 // ScopeOp 920 //===----------------------------------------------------------------------===// 921 922 /// Given the region at `index`, or the parent operation if `index` is None, 923 /// return the successor regions. These are the regions that may be selected 924 /// during the flow of control. `operands` is a set of optional attributes 925 /// that correspond to a constant value for each operand, or null if that 926 /// operand is not a constant. 927 void cir::ScopeOp::getSuccessorRegions( 928 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 929 // The only region always branch back to the parent operation. 930 if (!point.isParent()) { 931 regions.push_back(RegionSuccessor(getODSResults(0))); 932 return; 933 } 934 935 // If the condition isn't constant, both regions may be executed. 936 regions.push_back(RegionSuccessor(&getScopeRegion())); 937 } 938 939 void cir::ScopeOp::build( 940 OpBuilder &builder, OperationState &result, 941 function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) { 942 assert(scopeBuilder && "the builder callback for 'then' must be present"); 943 944 OpBuilder::InsertionGuard guard(builder); 945 Region *scopeRegion = result.addRegion(); 946 builder.createBlock(scopeRegion); 947 assert(!cir::MissingFeatures::opScopeCleanupRegion()); 948 949 mlir::Type yieldTy; 950 scopeBuilder(builder, yieldTy, result.location); 951 952 if (yieldTy) 953 result.addTypes(TypeRange{yieldTy}); 954 } 955 956 void cir::ScopeOp::build( 957 OpBuilder &builder, OperationState &result, 958 function_ref<void(OpBuilder &, Location)> scopeBuilder) { 959 assert(scopeBuilder && "the builder callback for 'then' must be present"); 960 OpBuilder::InsertionGuard guard(builder); 961 Region *scopeRegion = result.addRegion(); 962 builder.createBlock(scopeRegion); 963 assert(!cir::MissingFeatures::opScopeCleanupRegion()); 964 scopeBuilder(builder, result.location); 965 } 966 967 LogicalResult cir::ScopeOp::verify() { 968 if (getRegion().empty()) { 969 return emitOpError() << "cir.scope must not be empty since it should " 970 "include at least an implicit cir.yield "; 971 } 972 973 mlir::Block &lastBlock = getRegion().back(); 974 if (lastBlock.empty() || !lastBlock.mightHaveTerminator() || 975 !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>()) 976 return emitOpError() << "last block of cir.scope must be terminated"; 977 return success(); 978 } 979 980 //===----------------------------------------------------------------------===// 981 // BrOp 982 //===----------------------------------------------------------------------===// 983 984 mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) { 985 assert(index == 0 && "invalid successor index"); 986 return mlir::SuccessorOperands(getDestOperandsMutable()); 987 } 988 989 Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) { 990 return getDest(); 991 } 992 993 //===----------------------------------------------------------------------===// 994 // BrCondOp 995 //===----------------------------------------------------------------------===// 996 997 mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) { 998 assert(index < getNumSuccessors() && "invalid successor index"); 999 return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable() 1000 : getDestOperandsFalseMutable()); 1001 } 1002 1003 Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { 1004 if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front())) 1005 return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse(); 1006 return nullptr; 1007 } 1008 1009 //===----------------------------------------------------------------------===// 1010 // CaseOp 1011 //===----------------------------------------------------------------------===// 1012 1013 void cir::CaseOp::getSuccessorRegions( 1014 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 1015 if (!point.isParent()) { 1016 regions.push_back(RegionSuccessor()); 1017 return; 1018 } 1019 regions.push_back(RegionSuccessor(&getCaseRegion())); 1020 } 1021 1022 void cir::CaseOp::build(OpBuilder &builder, OperationState &result, 1023 ArrayAttr value, CaseOpKind kind, 1024 OpBuilder::InsertPoint &insertPoint) { 1025 OpBuilder::InsertionGuard guardSwitch(builder); 1026 result.addAttribute("value", value); 1027 result.getOrAddProperties<Properties>().kind = 1028 cir::CaseOpKindAttr::get(builder.getContext(), kind); 1029 Region *caseRegion = result.addRegion(); 1030 builder.createBlock(caseRegion); 1031 1032 insertPoint = builder.saveInsertionPoint(); 1033 } 1034 1035 //===----------------------------------------------------------------------===// 1036 // SwitchOp 1037 //===----------------------------------------------------------------------===// 1038 1039 static ParseResult parseSwitchOp(OpAsmParser &parser, mlir::Region ®ions, 1040 mlir::OpAsmParser::UnresolvedOperand &cond, 1041 mlir::Type &condType) { 1042 cir::IntType intCondType; 1043 1044 if (parser.parseLParen()) 1045 return mlir::failure(); 1046 1047 if (parser.parseOperand(cond)) 1048 return mlir::failure(); 1049 if (parser.parseColon()) 1050 return mlir::failure(); 1051 if (parser.parseCustomTypeWithFallback(intCondType)) 1052 return mlir::failure(); 1053 condType = intCondType; 1054 1055 if (parser.parseRParen()) 1056 return mlir::failure(); 1057 if (parser.parseRegion(regions, /*arguments=*/{}, /*argTypes=*/{})) 1058 return failure(); 1059 1060 return mlir::success(); 1061 } 1062 1063 static void printSwitchOp(OpAsmPrinter &p, cir::SwitchOp op, 1064 mlir::Region &bodyRegion, mlir::Value condition, 1065 mlir::Type condType) { 1066 p << "("; 1067 p << condition; 1068 p << " : "; 1069 p.printStrippedAttrOrType(condType); 1070 p << ")"; 1071 1072 p << ' '; 1073 p.printRegion(bodyRegion, /*printEntryBlockArgs=*/false, 1074 /*printBlockTerminators=*/true); 1075 } 1076 1077 void cir::SwitchOp::getSuccessorRegions( 1078 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ion) { 1079 if (!point.isParent()) { 1080 region.push_back(RegionSuccessor()); 1081 return; 1082 } 1083 1084 region.push_back(RegionSuccessor(&getBody())); 1085 } 1086 1087 void cir::SwitchOp::build(OpBuilder &builder, OperationState &result, 1088 Value cond, BuilderOpStateCallbackRef switchBuilder) { 1089 assert(switchBuilder && "the builder callback for regions must be present"); 1090 OpBuilder::InsertionGuard guardSwitch(builder); 1091 Region *switchRegion = result.addRegion(); 1092 builder.createBlock(switchRegion); 1093 result.addOperands({cond}); 1094 switchBuilder(builder, result.location, result); 1095 } 1096 1097 void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) { 1098 walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) { 1099 // Don't walk in nested switch op. 1100 if (isa<cir::SwitchOp>(op) && op != *this) 1101 return WalkResult::skip(); 1102 1103 if (auto caseOp = dyn_cast<cir::CaseOp>(op)) 1104 cases.push_back(caseOp); 1105 1106 return WalkResult::advance(); 1107 }); 1108 } 1109 1110 bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) { 1111 collectCases(cases); 1112 1113 if (getBody().empty()) 1114 return false; 1115 1116 if (!isa<YieldOp>(getBody().front().back())) 1117 return false; 1118 1119 if (!llvm::all_of(getBody().front(), 1120 [](Operation &op) { return isa<CaseOp, YieldOp>(op); })) 1121 return false; 1122 1123 return llvm::all_of(cases, [this](CaseOp op) { 1124 return op->getParentOfType<SwitchOp>() == *this; 1125 }); 1126 } 1127 1128 //===----------------------------------------------------------------------===// 1129 // SwitchFlatOp 1130 //===----------------------------------------------------------------------===// 1131 1132 void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result, 1133 Value value, Block *defaultDestination, 1134 ValueRange defaultOperands, 1135 ArrayRef<APInt> caseValues, 1136 BlockRange caseDestinations, 1137 ArrayRef<ValueRange> caseOperands) { 1138 1139 std::vector<mlir::Attribute> caseValuesAttrs; 1140 for (const APInt &val : caseValues) 1141 caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val)); 1142 mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs); 1143 1144 build(builder, result, value, defaultOperands, caseOperands, attrs, 1145 defaultDestination, caseDestinations); 1146 } 1147 1148 /// <cases> ::= `[` (case (`,` case )* )? `]` 1149 /// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 1150 static ParseResult parseSwitchFlatOpCases( 1151 OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues, 1152 SmallVectorImpl<Block *> &caseDestinations, 1153 SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>> 1154 &caseOperands, 1155 SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) { 1156 if (failed(parser.parseLSquare())) 1157 return failure(); 1158 if (succeeded(parser.parseOptionalRSquare())) 1159 return success(); 1160 llvm::SmallVector<mlir::Attribute> values; 1161 1162 auto parseCase = [&]() { 1163 int64_t value = 0; 1164 if (failed(parser.parseInteger(value))) 1165 return failure(); 1166 1167 values.push_back(cir::IntAttr::get(flagType, value)); 1168 1169 Block *destination; 1170 llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands; 1171 llvm::SmallVector<Type> operandTypes; 1172 if (parser.parseColon() || parser.parseSuccessor(destination)) 1173 return failure(); 1174 if (!parser.parseOptionalLParen()) { 1175 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None, 1176 /*allowResultNumber=*/false) || 1177 parser.parseColonTypeList(operandTypes) || parser.parseRParen()) 1178 return failure(); 1179 } 1180 caseDestinations.push_back(destination); 1181 caseOperands.emplace_back(operands); 1182 caseOperandTypes.emplace_back(operandTypes); 1183 return success(); 1184 }; 1185 if (failed(parser.parseCommaSeparatedList(parseCase))) 1186 return failure(); 1187 1188 caseValues = ArrayAttr::get(flagType.getContext(), values); 1189 1190 return parser.parseRSquare(); 1191 } 1192 1193 static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op, 1194 Type flagType, mlir::ArrayAttr caseValues, 1195 SuccessorRange caseDestinations, 1196 OperandRangeRange caseOperands, 1197 const TypeRangeRange &caseOperandTypes) { 1198 p << '['; 1199 p.printNewline(); 1200 if (!caseValues) { 1201 p << ']'; 1202 return; 1203 } 1204 1205 size_t index = 0; 1206 llvm::interleave( 1207 llvm::zip(caseValues, caseDestinations), 1208 [&](auto i) { 1209 p << " "; 1210 mlir::Attribute a = std::get<0>(i); 1211 p << mlir::cast<cir::IntAttr>(a).getValue(); 1212 p << ": "; 1213 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); 1214 }, 1215 [&] { 1216 p << ','; 1217 p.printNewline(); 1218 }); 1219 p.printNewline(); 1220 p << ']'; 1221 } 1222 1223 //===----------------------------------------------------------------------===// 1224 // GlobalOp 1225 //===----------------------------------------------------------------------===// 1226 1227 static ParseResult parseConstantValue(OpAsmParser &parser, 1228 mlir::Attribute &valueAttr) { 1229 NamedAttrList attr; 1230 return parser.parseAttribute(valueAttr, "value", attr); 1231 } 1232 1233 static void printConstant(OpAsmPrinter &p, Attribute value) { 1234 p.printAttribute(value); 1235 } 1236 1237 mlir::LogicalResult cir::GlobalOp::verify() { 1238 // Verify that the initial value, if present, is either a unit attribute or 1239 // an attribute CIR supports. 1240 if (getInitialValue().has_value()) { 1241 if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue()) 1242 .failed()) 1243 return failure(); 1244 } 1245 1246 // TODO(CIR): Many other checks for properties that haven't been upstreamed 1247 // yet. 1248 1249 return success(); 1250 } 1251 1252 void cir::GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState, 1253 llvm::StringRef sym_name, mlir::Type sym_type, 1254 cir::GlobalLinkageKind linkage) { 1255 odsState.addAttribute(getSymNameAttrName(odsState.name), 1256 odsBuilder.getStringAttr(sym_name)); 1257 odsState.addAttribute(getSymTypeAttrName(odsState.name), 1258 mlir::TypeAttr::get(sym_type)); 1259 1260 cir::GlobalLinkageKindAttr linkageAttr = 1261 cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage); 1262 odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr); 1263 1264 odsState.addAttribute(getGlobalVisibilityAttrName(odsState.name), 1265 cir::VisibilityAttr::get(odsBuilder.getContext())); 1266 } 1267 1268 static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op, 1269 TypeAttr type, 1270 Attribute initAttr) { 1271 if (!op.isDeclaration()) { 1272 p << "= "; 1273 // This also prints the type... 1274 if (initAttr) 1275 printConstant(p, initAttr); 1276 } else { 1277 p << ": " << type; 1278 } 1279 } 1280 1281 static ParseResult 1282 parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1283 Attribute &initialValueAttr) { 1284 mlir::Type opTy; 1285 if (parser.parseOptionalEqual().failed()) { 1286 // Absence of equal means a declaration, so we need to parse the type. 1287 // cir.global @a : !cir.int<s, 32> 1288 if (parser.parseColonType(opTy)) 1289 return failure(); 1290 } else { 1291 // Parse constant with initializer, examples: 1292 // cir.global @y = #cir.fp<1.250000e+00> : !cir.double 1293 // cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>> 1294 if (parseConstantValue(parser, initialValueAttr).failed()) 1295 return failure(); 1296 1297 assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) && 1298 "Non-typed attrs shouldn't appear here."); 1299 auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr); 1300 opTy = typedAttr.getType(); 1301 } 1302 1303 typeAttr = TypeAttr::get(opTy); 1304 return success(); 1305 } 1306 1307 //===----------------------------------------------------------------------===// 1308 // GetGlobalOp 1309 //===----------------------------------------------------------------------===// 1310 1311 LogicalResult 1312 cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1313 // Verify that the result type underlying pointer type matches the type of 1314 // the referenced cir.global or cir.func op. 1315 mlir::Operation *op = 1316 symbolTable.lookupNearestSymbolFrom(*this, getNameAttr()); 1317 if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op))) 1318 return emitOpError("'") 1319 << getName() 1320 << "' does not reference a valid cir.global or cir.func"; 1321 1322 mlir::Type symTy; 1323 if (auto g = dyn_cast<GlobalOp>(op)) { 1324 symTy = g.getSymType(); 1325 assert(!cir::MissingFeatures::addressSpace()); 1326 assert(!cir::MissingFeatures::opGlobalThreadLocal()); 1327 } else if (auto f = dyn_cast<FuncOp>(op)) { 1328 symTy = f.getFunctionType(); 1329 } else { 1330 llvm_unreachable("Unexpected operation for GetGlobalOp"); 1331 } 1332 1333 auto resultType = dyn_cast<PointerType>(getAddr().getType()); 1334 if (!resultType || symTy != resultType.getPointee()) 1335 return emitOpError("result type pointee type '") 1336 << resultType.getPointee() << "' does not match type " << symTy 1337 << " of the global @" << getName(); 1338 1339 return success(); 1340 } 1341 1342 //===----------------------------------------------------------------------===// 1343 // FuncOp 1344 //===----------------------------------------------------------------------===// 1345 1346 /// Returns the name used for the linkage attribute. This *must* correspond to 1347 /// the name of the attribute in ODS. 1348 static llvm::StringRef getLinkageAttrNameString() { return "linkage"; } 1349 1350 void cir::FuncOp::build(OpBuilder &builder, OperationState &result, 1351 StringRef name, FuncType type, 1352 GlobalLinkageKind linkage) { 1353 result.addRegion(); 1354 result.addAttribute(SymbolTable::getSymbolAttrName(), 1355 builder.getStringAttr(name)); 1356 result.addAttribute(getFunctionTypeAttrName(result.name), 1357 TypeAttr::get(type)); 1358 result.addAttribute( 1359 getLinkageAttrNameString(), 1360 GlobalLinkageKindAttr::get(builder.getContext(), linkage)); 1361 result.addAttribute(getGlobalVisibilityAttrName(result.name), 1362 cir::VisibilityAttr::get(builder.getContext())); 1363 } 1364 1365 ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) { 1366 llvm::SMLoc loc = parser.getCurrentLocation(); 1367 mlir::Builder &builder = parser.getBuilder(); 1368 1369 mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name); 1370 mlir::StringAttr visibilityNameAttr = getGlobalVisibilityAttrName(state.name); 1371 mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name); 1372 1373 // Default to external linkage if no keyword is provided. 1374 state.addAttribute(getLinkageAttrNameString(), 1375 GlobalLinkageKindAttr::get( 1376 parser.getContext(), 1377 parseOptionalCIRKeyword<GlobalLinkageKind>( 1378 parser, GlobalLinkageKind::ExternalLinkage))); 1379 1380 ::llvm::StringRef visAttrStr; 1381 if (parser.parseOptionalKeyword(&visAttrStr, {"private", "public", "nested"}) 1382 .succeeded()) { 1383 state.addAttribute(visNameAttr, 1384 parser.getBuilder().getStringAttr(visAttrStr)); 1385 } 1386 1387 cir::VisibilityAttr cirVisibilityAttr; 1388 parseVisibilityAttr(parser, cirVisibilityAttr); 1389 state.addAttribute(visibilityNameAttr, cirVisibilityAttr); 1390 1391 if (parser.parseOptionalKeyword(dsoLocalNameAttr).succeeded()) 1392 state.addAttribute(dsoLocalNameAttr, parser.getBuilder().getUnitAttr()); 1393 1394 StringAttr nameAttr; 1395 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 1396 state.attributes)) 1397 return failure(); 1398 llvm::SmallVector<OpAsmParser::Argument, 8> arguments; 1399 llvm::SmallVector<mlir::Type> resultTypes; 1400 llvm::SmallVector<DictionaryAttr> resultAttrs; 1401 bool isVariadic = false; 1402 if (function_interface_impl::parseFunctionSignatureWithArguments( 1403 parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes, 1404 resultAttrs)) 1405 return failure(); 1406 llvm::SmallVector<mlir::Type> argTypes; 1407 for (OpAsmParser::Argument &arg : arguments) 1408 argTypes.push_back(arg.type); 1409 1410 if (resultTypes.size() > 1) { 1411 return parser.emitError( 1412 loc, "functions with multiple return types are not supported"); 1413 } 1414 1415 mlir::Type returnType = 1416 (resultTypes.empty() ? cir::VoidType::get(builder.getContext()) 1417 : resultTypes.front()); 1418 1419 cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic); 1420 if (!fnType) 1421 return failure(); 1422 state.addAttribute(getFunctionTypeAttrName(state.name), 1423 TypeAttr::get(fnType)); 1424 1425 bool hasAlias = false; 1426 mlir::StringAttr aliaseeNameAttr = getAliaseeAttrName(state.name); 1427 if (parser.parseOptionalKeyword("alias").succeeded()) { 1428 if (parser.parseLParen().failed()) 1429 return failure(); 1430 mlir::StringAttr aliaseeAttr; 1431 if (parser.parseOptionalSymbolName(aliaseeAttr).failed()) 1432 return failure(); 1433 state.addAttribute(aliaseeNameAttr, FlatSymbolRefAttr::get(aliaseeAttr)); 1434 if (parser.parseRParen().failed()) 1435 return failure(); 1436 hasAlias = true; 1437 } 1438 1439 // Parse the optional function body. 1440 auto *body = state.addRegion(); 1441 OptionalParseResult parseResult = parser.parseOptionalRegion( 1442 *body, arguments, /*enableNameShadowing=*/false); 1443 if (parseResult.has_value()) { 1444 if (hasAlias) 1445 return parser.emitError(loc, "function alias shall not have a body"); 1446 if (failed(*parseResult)) 1447 return failure(); 1448 // Function body was parsed, make sure its not empty. 1449 if (body->empty()) 1450 return parser.emitError(loc, "expected non-empty function body"); 1451 } 1452 1453 return success(); 1454 } 1455 1456 // This function corresponds to `llvm::GlobalValue::isDeclaration` and should 1457 // have a similar implementation. We don't currently ifuncs or materializable 1458 // functions, but those should be handled here as they are implemented. 1459 bool cir::FuncOp::isDeclaration() { 1460 assert(!cir::MissingFeatures::supportIFuncAttr()); 1461 1462 std::optional<StringRef> aliasee = getAliasee(); 1463 if (!aliasee) 1464 return getFunctionBody().empty(); 1465 1466 // Aliases are always definitions. 1467 return false; 1468 } 1469 1470 mlir::Region *cir::FuncOp::getCallableRegion() { 1471 // TODO(CIR): This function will have special handling for aliases and a 1472 // check for an external function, once those features have been upstreamed. 1473 return &getBody(); 1474 } 1475 1476 void cir::FuncOp::print(OpAsmPrinter &p) { 1477 if (getComdat()) 1478 p << " comdat"; 1479 1480 if (getLinkage() != GlobalLinkageKind::ExternalLinkage) 1481 p << ' ' << stringifyGlobalLinkageKind(getLinkage()); 1482 1483 mlir::SymbolTable::Visibility vis = getVisibility(); 1484 if (vis != mlir::SymbolTable::Visibility::Public) 1485 p << ' ' << vis; 1486 1487 cir::VisibilityAttr cirVisibilityAttr = getGlobalVisibilityAttr(); 1488 if (!cirVisibilityAttr.isDefault()) { 1489 p << ' '; 1490 printVisibilityAttr(p, cirVisibilityAttr); 1491 } 1492 1493 if (getDsoLocal()) 1494 p << " dso_local"; 1495 1496 p << ' '; 1497 p.printSymbolName(getSymName()); 1498 cir::FuncType fnType = getFunctionType(); 1499 function_interface_impl::printFunctionSignature( 1500 p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes()); 1501 1502 if (std::optional<StringRef> aliaseeName = getAliasee()) { 1503 p << " alias("; 1504 p.printSymbolName(*aliaseeName); 1505 p << ")"; 1506 } 1507 1508 // Print the body if this is not an external function. 1509 Region &body = getOperation()->getRegion(0); 1510 if (!body.empty()) { 1511 p << ' '; 1512 p.printRegion(body, /*printEntryBlockArgs=*/false, 1513 /*printBlockTerminators=*/true); 1514 } 1515 } 1516 1517 // TODO(CIR): The properties of functions that require verification haven't 1518 // been implemented yet. 1519 mlir::LogicalResult cir::FuncOp::verify() { return success(); } 1520 1521 //===----------------------------------------------------------------------===// 1522 // BinOp 1523 //===----------------------------------------------------------------------===// 1524 LogicalResult cir::BinOp::verify() { 1525 bool noWrap = getNoUnsignedWrap() || getNoSignedWrap(); 1526 bool saturated = getSaturated(); 1527 1528 if (!isa<cir::IntType>(getType()) && noWrap) 1529 return emitError() 1530 << "only operations on integer values may have nsw/nuw flags"; 1531 1532 bool noWrapOps = getKind() == cir::BinOpKind::Add || 1533 getKind() == cir::BinOpKind::Sub || 1534 getKind() == cir::BinOpKind::Mul; 1535 1536 bool saturatedOps = 1537 getKind() == cir::BinOpKind::Add || getKind() == cir::BinOpKind::Sub; 1538 1539 if (noWrap && !noWrapOps) 1540 return emitError() << "The nsw/nuw flags are applicable to opcodes: 'add', " 1541 "'sub' and 'mul'"; 1542 if (saturated && !saturatedOps) 1543 return emitError() << "The saturated flag is applicable to opcodes: 'add' " 1544 "and 'sub'"; 1545 if (noWrap && saturated) 1546 return emitError() << "The nsw/nuw flags and the saturated flag are " 1547 "mutually exclusive"; 1548 1549 assert(!cir::MissingFeatures::complexType()); 1550 // TODO(cir): verify for complex binops 1551 1552 return mlir::success(); 1553 } 1554 1555 //===----------------------------------------------------------------------===// 1556 // TernaryOp 1557 //===----------------------------------------------------------------------===// 1558 1559 /// Given the region at `point`, or the parent operation if `point` is None, 1560 /// return the successor regions. These are the regions that may be selected 1561 /// during the flow of control. `operands` is a set of optional attributes that 1562 /// correspond to a constant value for each operand, or null if that operand is 1563 /// not a constant. 1564 void cir::TernaryOp::getSuccessorRegions( 1565 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 1566 // The `true` and the `false` region branch back to the parent operation. 1567 if (!point.isParent()) { 1568 regions.push_back(RegionSuccessor(this->getODSResults(0))); 1569 return; 1570 } 1571 1572 // When branching from the parent operation, both the true and false 1573 // regions are considered possible successors 1574 regions.push_back(RegionSuccessor(&getTrueRegion())); 1575 regions.push_back(RegionSuccessor(&getFalseRegion())); 1576 } 1577 1578 void cir::TernaryOp::build( 1579 OpBuilder &builder, OperationState &result, Value cond, 1580 function_ref<void(OpBuilder &, Location)> trueBuilder, 1581 function_ref<void(OpBuilder &, Location)> falseBuilder) { 1582 result.addOperands(cond); 1583 OpBuilder::InsertionGuard guard(builder); 1584 Region *trueRegion = result.addRegion(); 1585 Block *block = builder.createBlock(trueRegion); 1586 trueBuilder(builder, result.location); 1587 Region *falseRegion = result.addRegion(); 1588 builder.createBlock(falseRegion); 1589 falseBuilder(builder, result.location); 1590 1591 auto yield = dyn_cast<YieldOp>(block->getTerminator()); 1592 assert((yield && yield.getNumOperands() <= 1) && 1593 "expected zero or one result type"); 1594 if (yield.getNumOperands() == 1) 1595 result.addTypes(TypeRange{yield.getOperandTypes().front()}); 1596 } 1597 1598 //===----------------------------------------------------------------------===// 1599 // SelectOp 1600 //===----------------------------------------------------------------------===// 1601 1602 OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) { 1603 mlir::Attribute condition = adaptor.getCondition(); 1604 if (condition) { 1605 bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue(); 1606 return conditionValue ? getTrueValue() : getFalseValue(); 1607 } 1608 1609 // cir.select if %0 then x else x -> x 1610 mlir::Attribute trueValue = adaptor.getTrueValue(); 1611 mlir::Attribute falseValue = adaptor.getFalseValue(); 1612 if (trueValue == falseValue) 1613 return trueValue; 1614 if (getTrueValue() == getFalseValue()) 1615 return getTrueValue(); 1616 1617 return {}; 1618 } 1619 1620 //===----------------------------------------------------------------------===// 1621 // ShiftOp 1622 //===----------------------------------------------------------------------===// 1623 LogicalResult cir::ShiftOp::verify() { 1624 mlir::Operation *op = getOperation(); 1625 auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType()); 1626 auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType()); 1627 if (!op0VecTy ^ !op1VecTy) 1628 return emitOpError() << "input types cannot be one vector and one scalar"; 1629 1630 if (op0VecTy) { 1631 if (op0VecTy.getSize() != op1VecTy.getSize()) 1632 return emitOpError() << "input vector types must have the same size"; 1633 1634 auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType()); 1635 if (!opResultTy) 1636 return emitOpError() << "the type of the result must be a vector " 1637 << "if it is vector shift"; 1638 1639 auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType()); 1640 auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType()); 1641 if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth()) 1642 return emitOpError() 1643 << "vector operands do not have the same elements sizes"; 1644 1645 auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType()); 1646 if (op0VecEleTy.getWidth() != resVecEleTy.getWidth()) 1647 return emitOpError() << "vector operands and result type do not have the " 1648 "same elements sizes"; 1649 } 1650 1651 return mlir::success(); 1652 } 1653 1654 //===----------------------------------------------------------------------===// 1655 // UnaryOp 1656 //===----------------------------------------------------------------------===// 1657 1658 LogicalResult cir::UnaryOp::verify() { 1659 switch (getKind()) { 1660 case cir::UnaryOpKind::Inc: 1661 case cir::UnaryOpKind::Dec: 1662 case cir::UnaryOpKind::Plus: 1663 case cir::UnaryOpKind::Minus: 1664 case cir::UnaryOpKind::Not: 1665 // Nothing to verify. 1666 return success(); 1667 } 1668 1669 llvm_unreachable("Unknown UnaryOp kind?"); 1670 } 1671 1672 static bool isBoolNot(cir::UnaryOp op) { 1673 return isa<cir::BoolType>(op.getInput().getType()) && 1674 op.getKind() == cir::UnaryOpKind::Not; 1675 } 1676 1677 // This folder simplifies the sequential boolean not operations. 1678 // For instance, the next two unary operations will be eliminated: 1679 // 1680 // ```mlir 1681 // %1 = cir.unary(not, %0) : !cir.bool, !cir.bool 1682 // %2 = cir.unary(not, %1) : !cir.bool, !cir.bool 1683 // ``` 1684 // 1685 // and the argument of the first one (%0) will be used instead. 1686 OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) { 1687 if (isBoolNot(*this)) 1688 if (auto previous = dyn_cast_or_null<UnaryOp>(getInput().getDefiningOp())) 1689 if (isBoolNot(previous)) 1690 return previous.getInput(); 1691 1692 return {}; 1693 } 1694 1695 //===----------------------------------------------------------------------===// 1696 // GetMemberOp Definitions 1697 //===----------------------------------------------------------------------===// 1698 1699 LogicalResult cir::GetMemberOp::verify() { 1700 const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee()); 1701 if (!recordTy) 1702 return emitError() << "expected pointer to a record type"; 1703 1704 if (recordTy.getMembers().size() <= getIndex()) 1705 return emitError() << "member index out of bounds"; 1706 1707 if (recordTy.getMembers()[getIndex()] != getType().getPointee()) 1708 return emitError() << "member type mismatch"; 1709 1710 return mlir::success(); 1711 } 1712 1713 //===----------------------------------------------------------------------===// 1714 // VecCreateOp 1715 //===----------------------------------------------------------------------===// 1716 1717 OpFoldResult cir::VecCreateOp::fold(FoldAdaptor adaptor) { 1718 if (llvm::any_of(getElements(), [](mlir::Value value) { 1719 return !mlir::isa<cir::ConstantOp>(value.getDefiningOp()); 1720 })) 1721 return {}; 1722 1723 return cir::ConstVectorAttr::get( 1724 getType(), mlir::ArrayAttr::get(getContext(), adaptor.getElements())); 1725 } 1726 1727 LogicalResult cir::VecCreateOp::verify() { 1728 // Verify that the number of arguments matches the number of elements in the 1729 // vector, and that the type of all the arguments matches the type of the 1730 // elements in the vector. 1731 const cir::VectorType vecTy = getType(); 1732 if (getElements().size() != vecTy.getSize()) { 1733 return emitOpError() << "operand count of " << getElements().size() 1734 << " doesn't match vector type " << vecTy 1735 << " element count of " << vecTy.getSize(); 1736 } 1737 1738 const mlir::Type elementType = vecTy.getElementType(); 1739 for (const mlir::Value element : getElements()) { 1740 if (element.getType() != elementType) { 1741 return emitOpError() << "operand type " << element.getType() 1742 << " doesn't match vector element type " 1743 << elementType; 1744 } 1745 } 1746 1747 return success(); 1748 } 1749 1750 //===----------------------------------------------------------------------===// 1751 // VecExtractOp 1752 //===----------------------------------------------------------------------===// 1753 1754 OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) { 1755 const auto vectorAttr = 1756 llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec()); 1757 if (!vectorAttr) 1758 return {}; 1759 1760 const auto indexAttr = 1761 llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex()); 1762 if (!indexAttr) 1763 return {}; 1764 1765 const mlir::ArrayAttr elements = vectorAttr.getElts(); 1766 const uint64_t index = indexAttr.getUInt(); 1767 if (index >= elements.size()) 1768 return {}; 1769 1770 return elements[index]; 1771 } 1772 1773 //===----------------------------------------------------------------------===// 1774 // VecCmpOp 1775 //===----------------------------------------------------------------------===// 1776 1777 OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) { 1778 auto lhsVecAttr = 1779 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs()); 1780 auto rhsVecAttr = 1781 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs()); 1782 if (!lhsVecAttr || !rhsVecAttr) 1783 return {}; 1784 1785 mlir::Type inputElemTy = 1786 mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType(); 1787 if (!isAnyIntegerOrFloatingPointType(inputElemTy)) 1788 return {}; 1789 1790 cir::CmpOpKind opKind = adaptor.getKind(); 1791 mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts(); 1792 mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts(); 1793 uint64_t vecSize = lhsVecElhs.size(); 1794 1795 SmallVector<mlir::Attribute, 16> elements(vecSize); 1796 bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]); 1797 for (uint64_t i = 0; i < vecSize; i++) { 1798 mlir::Attribute lhsAttr = lhsVecElhs[i]; 1799 mlir::Attribute rhsAttr = rhsVecElhs[i]; 1800 int cmpResult = 0; 1801 switch (opKind) { 1802 case cir::CmpOpKind::lt: { 1803 if (isIntAttr) { 1804 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() < 1805 mlir::cast<cir::IntAttr>(rhsAttr).getSInt(); 1806 } else { 1807 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() < 1808 mlir::cast<cir::FPAttr>(rhsAttr).getValue(); 1809 } 1810 break; 1811 } 1812 case cir::CmpOpKind::le: { 1813 if (isIntAttr) { 1814 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <= 1815 mlir::cast<cir::IntAttr>(rhsAttr).getSInt(); 1816 } else { 1817 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <= 1818 mlir::cast<cir::FPAttr>(rhsAttr).getValue(); 1819 } 1820 break; 1821 } 1822 case cir::CmpOpKind::gt: { 1823 if (isIntAttr) { 1824 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() > 1825 mlir::cast<cir::IntAttr>(rhsAttr).getSInt(); 1826 } else { 1827 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() > 1828 mlir::cast<cir::FPAttr>(rhsAttr).getValue(); 1829 } 1830 break; 1831 } 1832 case cir::CmpOpKind::ge: { 1833 if (isIntAttr) { 1834 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >= 1835 mlir::cast<cir::IntAttr>(rhsAttr).getSInt(); 1836 } else { 1837 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >= 1838 mlir::cast<cir::FPAttr>(rhsAttr).getValue(); 1839 } 1840 break; 1841 } 1842 case cir::CmpOpKind::eq: { 1843 if (isIntAttr) { 1844 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() == 1845 mlir::cast<cir::IntAttr>(rhsAttr).getSInt(); 1846 } else { 1847 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() == 1848 mlir::cast<cir::FPAttr>(rhsAttr).getValue(); 1849 } 1850 break; 1851 } 1852 case cir::CmpOpKind::ne: { 1853 if (isIntAttr) { 1854 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() != 1855 mlir::cast<cir::IntAttr>(rhsAttr).getSInt(); 1856 } else { 1857 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() != 1858 mlir::cast<cir::FPAttr>(rhsAttr).getValue(); 1859 } 1860 break; 1861 } 1862 } 1863 1864 elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult); 1865 } 1866 1867 return cir::ConstVectorAttr::get( 1868 getType(), mlir::ArrayAttr::get(getContext(), elements)); 1869 } 1870 1871 //===----------------------------------------------------------------------===// 1872 // VecShuffleOp 1873 //===----------------------------------------------------------------------===// 1874 1875 OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) { 1876 auto vec1Attr = 1877 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1()); 1878 auto vec2Attr = 1879 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2()); 1880 if (!vec1Attr || !vec2Attr) 1881 return {}; 1882 1883 mlir::Type vec1ElemTy = 1884 mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType(); 1885 1886 mlir::ArrayAttr vec1Elts = vec1Attr.getElts(); 1887 mlir::ArrayAttr vec2Elts = vec2Attr.getElts(); 1888 mlir::ArrayAttr indicesElts = adaptor.getIndices(); 1889 1890 SmallVector<mlir::Attribute, 16> elements; 1891 elements.reserve(indicesElts.size()); 1892 1893 uint64_t vec1Size = vec1Elts.size(); 1894 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) { 1895 if (idxAttr.getSInt() == -1) { 1896 elements.push_back(cir::UndefAttr::get(vec1ElemTy)); 1897 continue; 1898 } 1899 1900 uint64_t idxValue = idxAttr.getUInt(); 1901 elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue] 1902 : vec2Elts[idxValue - vec1Size]); 1903 } 1904 1905 return cir::ConstVectorAttr::get( 1906 getType(), mlir::ArrayAttr::get(getContext(), elements)); 1907 } 1908 1909 LogicalResult cir::VecShuffleOp::verify() { 1910 // The number of elements in the indices array must match the number of 1911 // elements in the result type. 1912 if (getIndices().size() != getResult().getType().getSize()) { 1913 return emitOpError() << ": the number of elements in " << getIndices() 1914 << " and " << getResult().getType() << " don't match"; 1915 } 1916 1917 // The element types of the two input vectors and of the result type must 1918 // match. 1919 if (getVec1().getType().getElementType() != 1920 getResult().getType().getElementType()) { 1921 return emitOpError() << ": element types of " << getVec1().getType() 1922 << " and " << getResult().getType() << " don't match"; 1923 } 1924 1925 const uint64_t maxValidIndex = 1926 getVec1().getType().getSize() + getVec2().getType().getSize() - 1; 1927 if (llvm::any_of( 1928 getIndices().getAsRange<cir::IntAttr>(), [&](cir::IntAttr idxAttr) { 1929 return idxAttr.getSInt() != -1 && idxAttr.getUInt() > maxValidIndex; 1930 })) { 1931 return emitOpError() << ": index for __builtin_shufflevector must be " 1932 "less than the total number of vector elements"; 1933 } 1934 return success(); 1935 } 1936 1937 //===----------------------------------------------------------------------===// 1938 // VecShuffleDynamicOp 1939 //===----------------------------------------------------------------------===// 1940 1941 OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) { 1942 mlir::Attribute vec = adaptor.getVec(); 1943 mlir::Attribute indices = adaptor.getIndices(); 1944 if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) && 1945 mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) { 1946 auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec); 1947 auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices); 1948 1949 mlir::ArrayAttr vecElts = vecAttr.getElts(); 1950 mlir::ArrayAttr indicesElts = indicesAttr.getElts(); 1951 1952 const uint64_t numElements = vecElts.size(); 1953 1954 SmallVector<mlir::Attribute, 16> elements; 1955 elements.reserve(numElements); 1956 1957 const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1; 1958 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) { 1959 uint64_t idxValue = idxAttr.getUInt(); 1960 uint64_t newIdx = idxValue & maskBits; 1961 elements.push_back(vecElts[newIdx]); 1962 } 1963 1964 return cir::ConstVectorAttr::get( 1965 getType(), mlir::ArrayAttr::get(getContext(), elements)); 1966 } 1967 1968 return {}; 1969 } 1970 1971 LogicalResult cir::VecShuffleDynamicOp::verify() { 1972 // The number of elements in the two input vectors must match. 1973 if (getVec().getType().getSize() != 1974 mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) { 1975 return emitOpError() << ": the number of elements in " << getVec().getType() 1976 << " and " << getIndices().getType() << " don't match"; 1977 } 1978 return success(); 1979 } 1980 1981 //===----------------------------------------------------------------------===// 1982 // VecTernaryOp 1983 //===----------------------------------------------------------------------===// 1984 1985 LogicalResult cir::VecTernaryOp::verify() { 1986 // Verify that the condition operand has the same number of elements as the 1987 // other operands. (The automatic verification already checked that all 1988 // operands are vector types and that the second and third operands are the 1989 // same type.) 1990 if (getCond().getType().getSize() != getLhs().getType().getSize()) { 1991 return emitOpError() << ": the number of elements in " 1992 << getCond().getType() << " and " << getLhs().getType() 1993 << " don't match"; 1994 } 1995 return success(); 1996 } 1997 1998 OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) { 1999 mlir::Attribute cond = adaptor.getCond(); 2000 mlir::Attribute lhs = adaptor.getLhs(); 2001 mlir::Attribute rhs = adaptor.getRhs(); 2002 2003 if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) || 2004 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) || 2005 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs)) 2006 return {}; 2007 auto condVec = mlir::cast<cir::ConstVectorAttr>(cond); 2008 auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs); 2009 auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs); 2010 2011 mlir::ArrayAttr condElts = condVec.getElts(); 2012 2013 SmallVector<mlir::Attribute, 16> elements; 2014 elements.reserve(condElts.size()); 2015 2016 for (const auto &[idx, condAttr] : 2017 llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) { 2018 if (condAttr.getSInt()) { 2019 elements.push_back(lhsVec.getElts()[idx]); 2020 } else { 2021 elements.push_back(rhsVec.getElts()[idx]); 2022 } 2023 } 2024 2025 cir::VectorType vecTy = getLhs().getType(); 2026 return cir::ConstVectorAttr::get( 2027 vecTy, mlir::ArrayAttr::get(getContext(), elements)); 2028 } 2029 2030 //===----------------------------------------------------------------------===// 2031 // ComplexCreateOp 2032 //===----------------------------------------------------------------------===// 2033 2034 LogicalResult cir::ComplexCreateOp::verify() { 2035 if (getType().getElementType() != getReal().getType()) { 2036 emitOpError() 2037 << "operand type of cir.complex.create does not match its result type"; 2038 return failure(); 2039 } 2040 2041 return success(); 2042 } 2043 2044 OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) { 2045 mlir::Attribute real = adaptor.getReal(); 2046 mlir::Attribute imag = adaptor.getImag(); 2047 if (!real || !imag) 2048 return {}; 2049 2050 // When both of real and imag are constants, we can fold the operation into an 2051 // `#cir.const_complex` operation. 2052 auto realAttr = mlir::cast<mlir::TypedAttr>(real); 2053 auto imagAttr = mlir::cast<mlir::TypedAttr>(imag); 2054 return cir::ConstComplexAttr::get(realAttr, imagAttr); 2055 } 2056 2057 //===----------------------------------------------------------------------===// 2058 // ComplexRealOp 2059 //===----------------------------------------------------------------------===// 2060 2061 LogicalResult cir::ComplexRealOp::verify() { 2062 if (getType() != getOperand().getType().getElementType()) { 2063 emitOpError() << ": result type does not match operand type"; 2064 return failure(); 2065 } 2066 return success(); 2067 } 2068 2069 OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) { 2070 if (auto complexCreateOp = 2071 dyn_cast_or_null<cir::ComplexCreateOp>(getOperand().getDefiningOp())) 2072 return complexCreateOp.getOperand(0); 2073 2074 auto complex = 2075 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand()); 2076 return complex ? complex.getReal() : nullptr; 2077 } 2078 2079 //===----------------------------------------------------------------------===// 2080 // ComplexImagOp 2081 //===----------------------------------------------------------------------===// 2082 2083 LogicalResult cir::ComplexImagOp::verify() { 2084 if (getType() != getOperand().getType().getElementType()) { 2085 emitOpError() << ": result type does not match operand type"; 2086 return failure(); 2087 } 2088 return success(); 2089 } 2090 2091 OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) { 2092 if (auto complexCreateOp = 2093 dyn_cast_or_null<cir::ComplexCreateOp>(getOperand().getDefiningOp())) 2094 return complexCreateOp.getOperand(1); 2095 2096 auto complex = 2097 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand()); 2098 return complex ? complex.getImag() : nullptr; 2099 } 2100 2101 //===----------------------------------------------------------------------===// 2102 // ComplexRealPtrOp 2103 //===----------------------------------------------------------------------===// 2104 2105 LogicalResult cir::ComplexRealPtrOp::verify() { 2106 mlir::Type resultPointeeTy = getType().getPointee(); 2107 cir::PointerType operandPtrTy = getOperand().getType(); 2108 auto operandPointeeTy = 2109 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee()); 2110 2111 if (resultPointeeTy != operandPointeeTy.getElementType()) { 2112 return emitOpError() << ": result type does not match operand type"; 2113 } 2114 2115 return success(); 2116 } 2117 2118 //===----------------------------------------------------------------------===// 2119 // ComplexImagPtrOp 2120 //===----------------------------------------------------------------------===// 2121 2122 LogicalResult cir::ComplexImagPtrOp::verify() { 2123 mlir::Type resultPointeeTy = getType().getPointee(); 2124 cir::PointerType operandPtrTy = getOperand().getType(); 2125 auto operandPointeeTy = 2126 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee()); 2127 2128 if (resultPointeeTy != operandPointeeTy.getElementType()) { 2129 return emitOpError() 2130 << "cir.complex.imag_ptr result type does not match operand type"; 2131 } 2132 return success(); 2133 } 2134 2135 //===----------------------------------------------------------------------===// 2136 // TableGen'd op method definitions 2137 //===----------------------------------------------------------------------===// 2138 2139 #define GET_OP_CLASSES 2140 #include "clang/CIR/Dialect/IR/CIROps.cpp.inc" 2141