1 //====- LowerToLLVM.cpp - Lowering from CIR to LLVMIR ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of CIR operations to LLVMIR. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "LowerToLLVM.h" 14 15 #include <deque> 16 #include <optional> 17 18 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 19 #include "mlir/Dialect/DLTI/DLTI.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 23 #include "mlir/IR/BuiltinAttributes.h" 24 #include "mlir/IR/BuiltinDialect.h" 25 #include "mlir/IR/BuiltinOps.h" 26 #include "mlir/IR/Types.h" 27 #include "mlir/Pass/Pass.h" 28 #include "mlir/Pass/PassManager.h" 29 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" 30 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" 31 #include "mlir/Target/LLVMIR/Export.h" 32 #include "mlir/Transforms/DialectConversion.h" 33 #include "clang/CIR/Dialect/IR/CIRAttrs.h" 34 #include "clang/CIR/Dialect/IR/CIRDialect.h" 35 #include "clang/CIR/Dialect/Passes.h" 36 #include "clang/CIR/LoweringHelpers.h" 37 #include "clang/CIR/MissingFeatures.h" 38 #include "clang/CIR/Passes.h" 39 #include "llvm/ADT/TypeSwitch.h" 40 #include "llvm/IR/Module.h" 41 #include "llvm/Support/ErrorHandling.h" 42 #include "llvm/Support/TimeProfiler.h" 43 44 using namespace cir; 45 using namespace llvm; 46 47 namespace cir { 48 namespace direct { 49 50 //===----------------------------------------------------------------------===// 51 // Helper Methods 52 //===----------------------------------------------------------------------===// 53 54 namespace { 55 /// If the given type is a vector type, return the vector's element type. 56 /// Otherwise return the given type unchanged. 57 mlir::Type elementTypeIfVector(mlir::Type type) { 58 return llvm::TypeSwitch<mlir::Type, mlir::Type>(type) 59 .Case<cir::VectorType, mlir::VectorType>( 60 [](auto p) { return p.getElementType(); }) 61 .Default([](mlir::Type p) { return p; }); 62 } 63 } // namespace 64 65 /// Given a type convertor and a data layout, convert the given type to a type 66 /// that is suitable for memory operations. For example, this can be used to 67 /// lower cir.bool accesses to i8. 68 static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter, 69 mlir::DataLayout const &dataLayout, 70 mlir::Type type) { 71 // TODO(cir): Handle other types similarly to clang's codegen 72 // convertTypeForMemory 73 if (isa<cir::BoolType>(type)) { 74 return mlir::IntegerType::get(type.getContext(), 75 dataLayout.getTypeSizeInBits(type)); 76 } 77 78 return converter.convertType(type); 79 } 80 81 static mlir::Value createIntCast(mlir::OpBuilder &bld, mlir::Value src, 82 mlir::IntegerType dstTy, 83 bool isSigned = false) { 84 mlir::Type srcTy = src.getType(); 85 assert(mlir::isa<mlir::IntegerType>(srcTy)); 86 87 unsigned srcWidth = mlir::cast<mlir::IntegerType>(srcTy).getWidth(); 88 unsigned dstWidth = mlir::cast<mlir::IntegerType>(dstTy).getWidth(); 89 mlir::Location loc = src.getLoc(); 90 91 if (dstWidth > srcWidth && isSigned) 92 return bld.create<mlir::LLVM::SExtOp>(loc, dstTy, src); 93 if (dstWidth > srcWidth) 94 return bld.create<mlir::LLVM::ZExtOp>(loc, dstTy, src); 95 if (dstWidth < srcWidth) 96 return bld.create<mlir::LLVM::TruncOp>(loc, dstTy, src); 97 return bld.create<mlir::LLVM::BitcastOp>(loc, dstTy, src); 98 } 99 100 static mlir::LLVM::Visibility 101 lowerCIRVisibilityToLLVMVisibility(cir::VisibilityKind visibilityKind) { 102 switch (visibilityKind) { 103 case cir::VisibilityKind::Default: 104 return ::mlir::LLVM::Visibility::Default; 105 case cir::VisibilityKind::Hidden: 106 return ::mlir::LLVM::Visibility::Hidden; 107 case cir::VisibilityKind::Protected: 108 return ::mlir::LLVM::Visibility::Protected; 109 } 110 } 111 112 /// Emits the value from memory as expected by its users. Should be called when 113 /// the memory represetnation of a CIR type is not equal to its scalar 114 /// representation. 115 static mlir::Value emitFromMemory(mlir::ConversionPatternRewriter &rewriter, 116 mlir::DataLayout const &dataLayout, 117 cir::LoadOp op, mlir::Value value) { 118 119 // TODO(cir): Handle other types similarly to clang's codegen EmitFromMemory 120 if (auto boolTy = mlir::dyn_cast<cir::BoolType>(op.getType())) { 121 // Create a cast value from specified size in datalayout to i1 122 assert(value.getType().isInteger(dataLayout.getTypeSizeInBits(boolTy))); 123 return createIntCast(rewriter, value, rewriter.getI1Type()); 124 } 125 126 return value; 127 } 128 129 /// Emits a value to memory with the expected scalar type. Should be called when 130 /// the memory represetnation of a CIR type is not equal to its scalar 131 /// representation. 132 static mlir::Value emitToMemory(mlir::ConversionPatternRewriter &rewriter, 133 mlir::DataLayout const &dataLayout, 134 mlir::Type origType, mlir::Value value) { 135 136 // TODO(cir): Handle other types similarly to clang's codegen EmitToMemory 137 if (auto boolTy = mlir::dyn_cast<cir::BoolType>(origType)) { 138 // Create zext of value from i1 to i8 139 mlir::IntegerType memType = 140 rewriter.getIntegerType(dataLayout.getTypeSizeInBits(boolTy)); 141 return createIntCast(rewriter, value, memType); 142 } 143 144 return value; 145 } 146 147 mlir::LLVM::Linkage convertLinkage(cir::GlobalLinkageKind linkage) { 148 using CIR = cir::GlobalLinkageKind; 149 using LLVM = mlir::LLVM::Linkage; 150 151 switch (linkage) { 152 case CIR::AvailableExternallyLinkage: 153 return LLVM::AvailableExternally; 154 case CIR::CommonLinkage: 155 return LLVM::Common; 156 case CIR::ExternalLinkage: 157 return LLVM::External; 158 case CIR::ExternalWeakLinkage: 159 return LLVM::ExternWeak; 160 case CIR::InternalLinkage: 161 return LLVM::Internal; 162 case CIR::LinkOnceAnyLinkage: 163 return LLVM::Linkonce; 164 case CIR::LinkOnceODRLinkage: 165 return LLVM::LinkonceODR; 166 case CIR::PrivateLinkage: 167 return LLVM::Private; 168 case CIR::WeakAnyLinkage: 169 return LLVM::Weak; 170 case CIR::WeakODRLinkage: 171 return LLVM::WeakODR; 172 }; 173 llvm_unreachable("Unknown CIR linkage type"); 174 } 175 176 static mlir::Value getLLVMIntCast(mlir::ConversionPatternRewriter &rewriter, 177 mlir::Value llvmSrc, mlir::Type llvmDstIntTy, 178 bool isUnsigned, uint64_t cirSrcWidth, 179 uint64_t cirDstIntWidth) { 180 if (cirSrcWidth == cirDstIntWidth) 181 return llvmSrc; 182 183 auto loc = llvmSrc.getLoc(); 184 if (cirSrcWidth < cirDstIntWidth) { 185 if (isUnsigned) 186 return rewriter.create<mlir::LLVM::ZExtOp>(loc, llvmDstIntTy, llvmSrc); 187 return rewriter.create<mlir::LLVM::SExtOp>(loc, llvmDstIntTy, llvmSrc); 188 } 189 190 // Otherwise truncate 191 return rewriter.create<mlir::LLVM::TruncOp>(loc, llvmDstIntTy, llvmSrc); 192 } 193 194 class CIRAttrToValue { 195 public: 196 CIRAttrToValue(mlir::Operation *parentOp, 197 mlir::ConversionPatternRewriter &rewriter, 198 const mlir::TypeConverter *converter) 199 : parentOp(parentOp), rewriter(rewriter), converter(converter) {} 200 201 mlir::Value visit(mlir::Attribute attr) { 202 return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr) 203 .Case<cir::IntAttr, cir::FPAttr, cir::ConstComplexAttr, 204 cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr, 205 cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); }) 206 .Default([&](auto attrT) { return mlir::Value(); }); 207 } 208 209 mlir::Value visitCirAttr(cir::IntAttr intAttr); 210 mlir::Value visitCirAttr(cir::FPAttr fltAttr); 211 mlir::Value visitCirAttr(cir::ConstComplexAttr complexAttr); 212 mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr); 213 mlir::Value visitCirAttr(cir::ConstArrayAttr attr); 214 mlir::Value visitCirAttr(cir::ConstVectorAttr attr); 215 mlir::Value visitCirAttr(cir::ZeroAttr attr); 216 217 private: 218 mlir::Operation *parentOp; 219 mlir::ConversionPatternRewriter &rewriter; 220 const mlir::TypeConverter *converter; 221 }; 222 223 /// Switches on the type of attribute and calls the appropriate conversion. 224 mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, 225 const mlir::Attribute attr, 226 mlir::ConversionPatternRewriter &rewriter, 227 const mlir::TypeConverter *converter) { 228 CIRAttrToValue valueConverter(parentOp, rewriter, converter); 229 mlir::Value value = valueConverter.visit(attr); 230 if (!value) 231 llvm_unreachable("unhandled attribute type"); 232 return value; 233 } 234 235 void convertSideEffectForCall(mlir::Operation *callOp, bool isNothrow, 236 cir::SideEffect sideEffect, 237 mlir::LLVM::MemoryEffectsAttr &memoryEffect, 238 bool &noUnwind, bool &willReturn) { 239 using mlir::LLVM::ModRefInfo; 240 241 switch (sideEffect) { 242 case cir::SideEffect::All: 243 memoryEffect = {}; 244 noUnwind = isNothrow; 245 willReturn = false; 246 break; 247 248 case cir::SideEffect::Pure: 249 memoryEffect = mlir::LLVM::MemoryEffectsAttr::get( 250 callOp->getContext(), /*other=*/ModRefInfo::Ref, 251 /*argMem=*/ModRefInfo::Ref, 252 /*inaccessibleMem=*/ModRefInfo::Ref); 253 noUnwind = true; 254 willReturn = true; 255 break; 256 257 case cir::SideEffect::Const: 258 memoryEffect = mlir::LLVM::MemoryEffectsAttr::get( 259 callOp->getContext(), /*other=*/ModRefInfo::NoModRef, 260 /*argMem=*/ModRefInfo::NoModRef, 261 /*inaccessibleMem=*/ModRefInfo::NoModRef); 262 noUnwind = true; 263 willReturn = true; 264 break; 265 } 266 } 267 268 /// IntAttr visitor. 269 mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) { 270 mlir::Location loc = parentOp->getLoc(); 271 return rewriter.create<mlir::LLVM::ConstantOp>( 272 loc, converter->convertType(intAttr.getType()), intAttr.getValue()); 273 } 274 275 /// FPAttr visitor. 276 mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) { 277 mlir::Location loc = parentOp->getLoc(); 278 return rewriter.create<mlir::LLVM::ConstantOp>( 279 loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); 280 } 281 282 /// ConstComplexAttr visitor. 283 mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstComplexAttr complexAttr) { 284 auto complexType = mlir::cast<cir::ComplexType>(complexAttr.getType()); 285 mlir::Type complexElemTy = complexType.getElementType(); 286 mlir::Type complexElemLLVMTy = converter->convertType(complexElemTy); 287 288 mlir::Attribute components[2]; 289 if (const auto intType = mlir::dyn_cast<cir::IntType>(complexElemTy)) { 290 components[0] = rewriter.getIntegerAttr( 291 complexElemLLVMTy, 292 mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue()); 293 components[1] = rewriter.getIntegerAttr( 294 complexElemLLVMTy, 295 mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue()); 296 } else { 297 components[0] = rewriter.getFloatAttr( 298 complexElemLLVMTy, 299 mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue()); 300 components[1] = rewriter.getFloatAttr( 301 complexElemLLVMTy, 302 mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue()); 303 } 304 305 mlir::Location loc = parentOp->getLoc(); 306 return rewriter.create<mlir::LLVM::ConstantOp>( 307 loc, converter->convertType(complexAttr.getType()), 308 rewriter.getArrayAttr(components)); 309 } 310 311 /// ConstPtrAttr visitor. 312 mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) { 313 mlir::Location loc = parentOp->getLoc(); 314 if (ptrAttr.isNullValue()) { 315 return rewriter.create<mlir::LLVM::ZeroOp>( 316 loc, converter->convertType(ptrAttr.getType())); 317 } 318 mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>()); 319 mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>( 320 loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())), 321 ptrAttr.getValue().getInt()); 322 return rewriter.create<mlir::LLVM::IntToPtrOp>( 323 loc, converter->convertType(ptrAttr.getType()), ptrVal); 324 } 325 326 // ConstArrayAttr visitor 327 mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) { 328 mlir::Type llvmTy = converter->convertType(attr.getType()); 329 mlir::Location loc = parentOp->getLoc(); 330 mlir::Value result; 331 332 if (attr.hasTrailingZeros()) { 333 mlir::Type arrayTy = attr.getType(); 334 result = rewriter.create<mlir::LLVM::ZeroOp>( 335 loc, converter->convertType(arrayTy)); 336 } else { 337 result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy); 338 } 339 340 // Iteratively lower each constant element of the array. 341 if (auto arrayAttr = mlir::dyn_cast<mlir::ArrayAttr>(attr.getElts())) { 342 for (auto [idx, elt] : llvm::enumerate(arrayAttr)) { 343 mlir::DataLayout dataLayout(parentOp->getParentOfType<mlir::ModuleOp>()); 344 mlir::Value init = visit(elt); 345 result = 346 rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx); 347 } 348 } else if (auto strAttr = mlir::dyn_cast<mlir::StringAttr>(attr.getElts())) { 349 // TODO(cir): this diverges from traditional lowering. Normally the string 350 // would be a global constant that is memcopied. 351 auto arrayTy = mlir::dyn_cast<cir::ArrayType>(strAttr.getType()); 352 assert(arrayTy && "String attribute must have an array type"); 353 mlir::Type eltTy = arrayTy.getElementType(); 354 for (auto [idx, elt] : llvm::enumerate(strAttr)) { 355 auto init = rewriter.create<mlir::LLVM::ConstantOp>( 356 loc, converter->convertType(eltTy), elt); 357 result = 358 rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx); 359 } 360 } else { 361 llvm_unreachable("unexpected ConstArrayAttr elements"); 362 } 363 364 return result; 365 } 366 367 /// ConstVectorAttr visitor. 368 mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) { 369 const mlir::Type llvmTy = converter->convertType(attr.getType()); 370 const mlir::Location loc = parentOp->getLoc(); 371 372 SmallVector<mlir::Attribute> mlirValues; 373 for (const mlir::Attribute elementAttr : attr.getElts()) { 374 mlir::Attribute mlirAttr; 375 if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(elementAttr)) { 376 mlirAttr = rewriter.getIntegerAttr( 377 converter->convertType(intAttr.getType()), intAttr.getValue()); 378 } else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(elementAttr)) { 379 mlirAttr = rewriter.getFloatAttr( 380 converter->convertType(floatAttr.getType()), floatAttr.getValue()); 381 } else { 382 llvm_unreachable( 383 "vector constant with an element that is neither an int nor a float"); 384 } 385 mlirValues.push_back(mlirAttr); 386 } 387 388 return rewriter.create<mlir::LLVM::ConstantOp>( 389 loc, llvmTy, 390 mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy), 391 mlirValues)); 392 } 393 394 /// ZeroAttr visitor. 395 mlir::Value CIRAttrToValue::visitCirAttr(cir::ZeroAttr attr) { 396 mlir::Location loc = parentOp->getLoc(); 397 return rewriter.create<mlir::LLVM::ZeroOp>( 398 loc, converter->convertType(attr.getType())); 399 } 400 401 // This class handles rewriting initializer attributes for types that do not 402 // require region initialization. 403 class GlobalInitAttrRewriter { 404 public: 405 GlobalInitAttrRewriter(mlir::Type type, 406 mlir::ConversionPatternRewriter &rewriter) 407 : llvmType(type), rewriter(rewriter) {} 408 409 mlir::Attribute visit(mlir::Attribute attr) { 410 return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr) 411 .Case<cir::IntAttr, cir::FPAttr, cir::BoolAttr>( 412 [&](auto attrT) { return visitCirAttr(attrT); }) 413 .Default([&](auto attrT) { return mlir::Attribute(); }); 414 } 415 416 mlir::Attribute visitCirAttr(cir::IntAttr attr) { 417 return rewriter.getIntegerAttr(llvmType, attr.getValue()); 418 } 419 420 mlir::Attribute visitCirAttr(cir::FPAttr attr) { 421 return rewriter.getFloatAttr(llvmType, attr.getValue()); 422 } 423 424 mlir::Attribute visitCirAttr(cir::BoolAttr attr) { 425 return rewriter.getBoolAttr(attr.getValue()); 426 } 427 428 private: 429 mlir::Type llvmType; 430 mlir::ConversionPatternRewriter &rewriter; 431 }; 432 433 // This pass requires the CIR to be in a "flat" state. All blocks in each 434 // function must belong to the parent region. Once scopes and control flow 435 // are implemented in CIR, a pass will be run before this one to flatten 436 // the CIR and get it into the state that this pass requires. 437 struct ConvertCIRToLLVMPass 438 : public mlir::PassWrapper<ConvertCIRToLLVMPass, 439 mlir::OperationPass<mlir::ModuleOp>> { 440 void getDependentDialects(mlir::DialectRegistry ®istry) const override { 441 registry.insert<mlir::BuiltinDialect, mlir::DLTIDialect, 442 mlir::LLVM::LLVMDialect, mlir::func::FuncDialect>(); 443 } 444 void runOnOperation() final; 445 446 void processCIRAttrs(mlir::ModuleOp module); 447 448 StringRef getDescription() const override { 449 return "Convert the prepared CIR dialect module to LLVM dialect"; 450 } 451 452 StringRef getArgument() const override { return "cir-flat-to-llvm"; } 453 }; 454 455 mlir::LogicalResult CIRToLLVMAssumeOpLowering::matchAndRewrite( 456 cir::AssumeOp op, OpAdaptor adaptor, 457 mlir::ConversionPatternRewriter &rewriter) const { 458 auto cond = adaptor.getPredicate(); 459 rewriter.replaceOpWithNewOp<mlir::LLVM::AssumeOp>(op, cond); 460 return mlir::success(); 461 } 462 463 mlir::LogicalResult CIRToLLVMBitClrsbOpLowering::matchAndRewrite( 464 cir::BitClrsbOp op, OpAdaptor adaptor, 465 mlir::ConversionPatternRewriter &rewriter) const { 466 auto zero = rewriter.create<mlir::LLVM::ConstantOp>( 467 op.getLoc(), adaptor.getInput().getType(), 0); 468 auto isNeg = rewriter.create<mlir::LLVM::ICmpOp>( 469 op.getLoc(), 470 mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(), 471 mlir::LLVM::ICmpPredicate::slt), 472 adaptor.getInput(), zero); 473 474 auto negOne = rewriter.create<mlir::LLVM::ConstantOp>( 475 op.getLoc(), adaptor.getInput().getType(), -1); 476 auto flipped = rewriter.create<mlir::LLVM::XOrOp>(op.getLoc(), 477 adaptor.getInput(), negOne); 478 479 auto select = rewriter.create<mlir::LLVM::SelectOp>( 480 op.getLoc(), isNeg, flipped, adaptor.getInput()); 481 482 auto resTy = getTypeConverter()->convertType(op.getType()); 483 auto clz = rewriter.create<mlir::LLVM::CountLeadingZerosOp>( 484 op.getLoc(), resTy, select, /*is_zero_poison=*/false); 485 486 auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1); 487 auto res = rewriter.create<mlir::LLVM::SubOp>(op.getLoc(), clz, one); 488 rewriter.replaceOp(op, res); 489 490 return mlir::LogicalResult::success(); 491 } 492 493 mlir::LogicalResult CIRToLLVMBitClzOpLowering::matchAndRewrite( 494 cir::BitClzOp op, OpAdaptor adaptor, 495 mlir::ConversionPatternRewriter &rewriter) const { 496 auto resTy = getTypeConverter()->convertType(op.getType()); 497 auto llvmOp = rewriter.create<mlir::LLVM::CountLeadingZerosOp>( 498 op.getLoc(), resTy, adaptor.getInput(), op.getPoisonZero()); 499 rewriter.replaceOp(op, llvmOp); 500 return mlir::LogicalResult::success(); 501 } 502 503 mlir::LogicalResult CIRToLLVMBitCtzOpLowering::matchAndRewrite( 504 cir::BitCtzOp op, OpAdaptor adaptor, 505 mlir::ConversionPatternRewriter &rewriter) const { 506 auto resTy = getTypeConverter()->convertType(op.getType()); 507 auto llvmOp = rewriter.create<mlir::LLVM::CountTrailingZerosOp>( 508 op.getLoc(), resTy, adaptor.getInput(), op.getPoisonZero()); 509 rewriter.replaceOp(op, llvmOp); 510 return mlir::LogicalResult::success(); 511 } 512 513 mlir::LogicalResult CIRToLLVMBitParityOpLowering::matchAndRewrite( 514 cir::BitParityOp op, OpAdaptor adaptor, 515 mlir::ConversionPatternRewriter &rewriter) const { 516 auto resTy = getTypeConverter()->convertType(op.getType()); 517 auto popcnt = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy, 518 adaptor.getInput()); 519 520 auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1); 521 auto popcntMod2 = 522 rewriter.create<mlir::LLVM::AndOp>(op.getLoc(), popcnt, one); 523 rewriter.replaceOp(op, popcntMod2); 524 525 return mlir::LogicalResult::success(); 526 } 527 528 mlir::LogicalResult CIRToLLVMBitPopcountOpLowering::matchAndRewrite( 529 cir::BitPopcountOp op, OpAdaptor adaptor, 530 mlir::ConversionPatternRewriter &rewriter) const { 531 auto resTy = getTypeConverter()->convertType(op.getType()); 532 auto llvmOp = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy, 533 adaptor.getInput()); 534 rewriter.replaceOp(op, llvmOp); 535 return mlir::LogicalResult::success(); 536 } 537 538 mlir::LogicalResult CIRToLLVMBitReverseOpLowering::matchAndRewrite( 539 cir::BitReverseOp op, OpAdaptor adaptor, 540 mlir::ConversionPatternRewriter &rewriter) const { 541 rewriter.replaceOpWithNewOp<mlir::LLVM::BitReverseOp>(op, adaptor.getInput()); 542 return mlir::success(); 543 } 544 545 mlir::LogicalResult CIRToLLVMBrCondOpLowering::matchAndRewrite( 546 cir::BrCondOp brOp, OpAdaptor adaptor, 547 mlir::ConversionPatternRewriter &rewriter) const { 548 // When ZExtOp is implemented, we'll need to check if the condition is a 549 // ZExtOp and if so, delete it if it has a single use. 550 assert(!cir::MissingFeatures::zextOp()); 551 552 mlir::Value i1Condition = adaptor.getCond(); 553 554 rewriter.replaceOpWithNewOp<mlir::LLVM::CondBrOp>( 555 brOp, i1Condition, brOp.getDestTrue(), adaptor.getDestOperandsTrue(), 556 brOp.getDestFalse(), adaptor.getDestOperandsFalse()); 557 558 return mlir::success(); 559 } 560 561 mlir::LogicalResult CIRToLLVMByteSwapOpLowering::matchAndRewrite( 562 cir::ByteSwapOp op, OpAdaptor adaptor, 563 mlir::ConversionPatternRewriter &rewriter) const { 564 rewriter.replaceOpWithNewOp<mlir::LLVM::ByteSwapOp>(op, adaptor.getInput()); 565 return mlir::LogicalResult::success(); 566 } 567 568 mlir::Type CIRToLLVMCastOpLowering::convertTy(mlir::Type ty) const { 569 return getTypeConverter()->convertType(ty); 570 } 571 572 mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite( 573 cir::CastOp castOp, OpAdaptor adaptor, 574 mlir::ConversionPatternRewriter &rewriter) const { 575 // For arithmetic conversions, LLVM IR uses the same instruction to convert 576 // both individual scalars and entire vectors. This lowering pass handles 577 // both situations. 578 579 switch (castOp.getKind()) { 580 case cir::CastKind::array_to_ptrdecay: { 581 const auto ptrTy = mlir::cast<cir::PointerType>(castOp.getType()); 582 mlir::Value sourceValue = adaptor.getSrc(); 583 mlir::Type targetType = convertTy(ptrTy); 584 mlir::Type elementTy = convertTypeForMemory(*getTypeConverter(), dataLayout, 585 ptrTy.getPointee()); 586 llvm::SmallVector<mlir::LLVM::GEPArg> offset{0}; 587 rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( 588 castOp, targetType, elementTy, sourceValue, offset); 589 break; 590 } 591 case cir::CastKind::int_to_bool: { 592 mlir::Value llvmSrcVal = adaptor.getSrc(); 593 mlir::Value zeroInt = rewriter.create<mlir::LLVM::ConstantOp>( 594 castOp.getLoc(), llvmSrcVal.getType(), 0); 595 rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( 596 castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroInt); 597 break; 598 } 599 case cir::CastKind::integral: { 600 mlir::Type srcType = castOp.getSrc().getType(); 601 mlir::Type dstType = castOp.getType(); 602 mlir::Value llvmSrcVal = adaptor.getSrc(); 603 mlir::Type llvmDstType = getTypeConverter()->convertType(dstType); 604 cir::IntType srcIntType = 605 mlir::cast<cir::IntType>(elementTypeIfVector(srcType)); 606 cir::IntType dstIntType = 607 mlir::cast<cir::IntType>(elementTypeIfVector(dstType)); 608 rewriter.replaceOp(castOp, getLLVMIntCast(rewriter, llvmSrcVal, llvmDstType, 609 srcIntType.isUnsigned(), 610 srcIntType.getWidth(), 611 dstIntType.getWidth())); 612 break; 613 } 614 case cir::CastKind::floating: { 615 mlir::Value llvmSrcVal = adaptor.getSrc(); 616 mlir::Type llvmDstTy = getTypeConverter()->convertType(castOp.getType()); 617 618 mlir::Type srcTy = elementTypeIfVector(castOp.getSrc().getType()); 619 mlir::Type dstTy = elementTypeIfVector(castOp.getType()); 620 621 if (!mlir::isa<cir::FPTypeInterface>(dstTy) || 622 !mlir::isa<cir::FPTypeInterface>(srcTy)) 623 return castOp.emitError() << "NYI cast from " << srcTy << " to " << dstTy; 624 625 auto getFloatWidth = [](mlir::Type ty) -> unsigned { 626 return mlir::cast<cir::FPTypeInterface>(ty).getWidth(); 627 }; 628 629 if (getFloatWidth(srcTy) > getFloatWidth(dstTy)) 630 rewriter.replaceOpWithNewOp<mlir::LLVM::FPTruncOp>(castOp, llvmDstTy, 631 llvmSrcVal); 632 else 633 rewriter.replaceOpWithNewOp<mlir::LLVM::FPExtOp>(castOp, llvmDstTy, 634 llvmSrcVal); 635 return mlir::success(); 636 } 637 case cir::CastKind::int_to_ptr: { 638 auto dstTy = mlir::cast<cir::PointerType>(castOp.getType()); 639 mlir::Value llvmSrcVal = adaptor.getSrc(); 640 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); 641 rewriter.replaceOpWithNewOp<mlir::LLVM::IntToPtrOp>(castOp, llvmDstTy, 642 llvmSrcVal); 643 return mlir::success(); 644 } 645 case cir::CastKind::ptr_to_int: { 646 auto dstTy = mlir::cast<cir::IntType>(castOp.getType()); 647 mlir::Value llvmSrcVal = adaptor.getSrc(); 648 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); 649 rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(castOp, llvmDstTy, 650 llvmSrcVal); 651 return mlir::success(); 652 } 653 case cir::CastKind::float_to_bool: { 654 mlir::Value llvmSrcVal = adaptor.getSrc(); 655 auto kind = mlir::LLVM::FCmpPredicate::une; 656 657 // Check if float is not equal to zero. 658 auto zeroFloat = rewriter.create<mlir::LLVM::ConstantOp>( 659 castOp.getLoc(), llvmSrcVal.getType(), 660 mlir::FloatAttr::get(llvmSrcVal.getType(), 0.0)); 661 662 // Extend comparison result to either bool (C++) or int (C). 663 rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(castOp, kind, llvmSrcVal, 664 zeroFloat); 665 666 return mlir::success(); 667 } 668 case cir::CastKind::bool_to_int: { 669 auto dstTy = mlir::cast<cir::IntType>(castOp.getType()); 670 mlir::Value llvmSrcVal = adaptor.getSrc(); 671 auto llvmSrcTy = mlir::cast<mlir::IntegerType>(llvmSrcVal.getType()); 672 auto llvmDstTy = 673 mlir::cast<mlir::IntegerType>(getTypeConverter()->convertType(dstTy)); 674 if (llvmSrcTy.getWidth() == llvmDstTy.getWidth()) 675 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy, 676 llvmSrcVal); 677 else 678 rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(castOp, llvmDstTy, 679 llvmSrcVal); 680 return mlir::success(); 681 } 682 case cir::CastKind::bool_to_float: { 683 mlir::Type dstTy = castOp.getType(); 684 mlir::Value llvmSrcVal = adaptor.getSrc(); 685 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); 686 rewriter.replaceOpWithNewOp<mlir::LLVM::UIToFPOp>(castOp, llvmDstTy, 687 llvmSrcVal); 688 return mlir::success(); 689 } 690 case cir::CastKind::int_to_float: { 691 mlir::Type dstTy = castOp.getType(); 692 mlir::Value llvmSrcVal = adaptor.getSrc(); 693 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); 694 if (mlir::cast<cir::IntType>(elementTypeIfVector(castOp.getSrc().getType())) 695 .isSigned()) 696 rewriter.replaceOpWithNewOp<mlir::LLVM::SIToFPOp>(castOp, llvmDstTy, 697 llvmSrcVal); 698 else 699 rewriter.replaceOpWithNewOp<mlir::LLVM::UIToFPOp>(castOp, llvmDstTy, 700 llvmSrcVal); 701 return mlir::success(); 702 } 703 case cir::CastKind::float_to_int: { 704 mlir::Type dstTy = castOp.getType(); 705 mlir::Value llvmSrcVal = adaptor.getSrc(); 706 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); 707 if (mlir::cast<cir::IntType>(elementTypeIfVector(castOp.getType())) 708 .isSigned()) 709 rewriter.replaceOpWithNewOp<mlir::LLVM::FPToSIOp>(castOp, llvmDstTy, 710 llvmSrcVal); 711 else 712 rewriter.replaceOpWithNewOp<mlir::LLVM::FPToUIOp>(castOp, llvmDstTy, 713 llvmSrcVal); 714 return mlir::success(); 715 } 716 case cir::CastKind::bitcast: { 717 mlir::Type dstTy = castOp.getType(); 718 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); 719 720 assert(!MissingFeatures::cxxABI()); 721 assert(!MissingFeatures::dataMemberType()); 722 723 mlir::Value llvmSrcVal = adaptor.getSrc(); 724 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy, 725 llvmSrcVal); 726 return mlir::success(); 727 } 728 case cir::CastKind::ptr_to_bool: { 729 mlir::Value llvmSrcVal = adaptor.getSrc(); 730 mlir::Value zeroPtr = rewriter.create<mlir::LLVM::ZeroOp>( 731 castOp.getLoc(), llvmSrcVal.getType()); 732 rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( 733 castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroPtr); 734 break; 735 } 736 case cir::CastKind::address_space: { 737 mlir::Type dstTy = castOp.getType(); 738 mlir::Value llvmSrcVal = adaptor.getSrc(); 739 mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); 740 rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(castOp, llvmDstTy, 741 llvmSrcVal); 742 break; 743 } 744 case cir::CastKind::member_ptr_to_bool: 745 assert(!MissingFeatures::cxxABI()); 746 assert(!MissingFeatures::methodType()); 747 break; 748 default: { 749 return castOp.emitError("Unhandled cast kind: ") 750 << castOp.getKindAttrName(); 751 } 752 } 753 754 return mlir::success(); 755 } 756 757 mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite( 758 cir::PtrStrideOp ptrStrideOp, OpAdaptor adaptor, 759 mlir::ConversionPatternRewriter &rewriter) const { 760 761 const mlir::TypeConverter *tc = getTypeConverter(); 762 const mlir::Type resultTy = tc->convertType(ptrStrideOp.getType()); 763 764 mlir::Type elementTy = 765 convertTypeForMemory(*tc, dataLayout, ptrStrideOp.getElementTy()); 766 mlir::MLIRContext *ctx = elementTy.getContext(); 767 768 // void and function types doesn't really have a layout to use in GEPs, 769 // make it i8 instead. 770 if (mlir::isa<mlir::LLVM::LLVMVoidType>(elementTy) || 771 mlir::isa<mlir::LLVM::LLVMFunctionType>(elementTy)) 772 elementTy = mlir::IntegerType::get(elementTy.getContext(), 8, 773 mlir::IntegerType::Signless); 774 // Zero-extend, sign-extend or trunc the pointer value. 775 mlir::Value index = adaptor.getStride(); 776 const unsigned width = 777 mlir::cast<mlir::IntegerType>(index.getType()).getWidth(); 778 const std::optional<std::uint64_t> layoutWidth = 779 dataLayout.getTypeIndexBitwidth(adaptor.getBase().getType()); 780 781 mlir::Operation *indexOp = index.getDefiningOp(); 782 if (indexOp && layoutWidth && width != *layoutWidth) { 783 // If the index comes from a subtraction, make sure the extension happens 784 // before it. To achieve that, look at unary minus, which already got 785 // lowered to "sub 0, x". 786 const auto sub = dyn_cast<mlir::LLVM::SubOp>(indexOp); 787 auto unary = dyn_cast_if_present<cir::UnaryOp>( 788 ptrStrideOp.getStride().getDefiningOp()); 789 bool rewriteSub = 790 unary && unary.getKind() == cir::UnaryOpKind::Minus && sub; 791 if (rewriteSub) 792 index = indexOp->getOperand(1); 793 794 // Handle the cast 795 const auto llvmDstType = mlir::IntegerType::get(ctx, *layoutWidth); 796 index = getLLVMIntCast(rewriter, index, llvmDstType, 797 ptrStrideOp.getStride().getType().isUnsigned(), 798 width, *layoutWidth); 799 800 // Rewrite the sub in front of extensions/trunc 801 if (rewriteSub) { 802 index = rewriter.create<mlir::LLVM::SubOp>( 803 index.getLoc(), index.getType(), 804 rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(), 805 index.getType(), 0), 806 index); 807 rewriter.eraseOp(sub); 808 } 809 } 810 811 rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( 812 ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index); 813 return mlir::success(); 814 } 815 816 mlir::LogicalResult CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite( 817 cir::BaseClassAddrOp baseClassOp, OpAdaptor adaptor, 818 mlir::ConversionPatternRewriter &rewriter) const { 819 const mlir::Type resultType = 820 getTypeConverter()->convertType(baseClassOp.getType()); 821 mlir::Value derivedAddr = adaptor.getDerivedAddr(); 822 llvm::SmallVector<mlir::LLVM::GEPArg, 1> offset = { 823 adaptor.getOffset().getZExtValue()}; 824 mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8, 825 mlir::IntegerType::Signless); 826 if (adaptor.getOffset().getZExtValue() == 0) { 827 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>( 828 baseClassOp, resultType, adaptor.getDerivedAddr()); 829 return mlir::success(); 830 } 831 832 if (baseClassOp.getAssumeNotNull()) { 833 rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( 834 baseClassOp, resultType, byteType, derivedAddr, offset); 835 } else { 836 auto loc = baseClassOp.getLoc(); 837 mlir::Value isNull = rewriter.create<mlir::LLVM::ICmpOp>( 838 loc, mlir::LLVM::ICmpPredicate::eq, derivedAddr, 839 rewriter.create<mlir::LLVM::ZeroOp>(loc, derivedAddr.getType())); 840 mlir::Value adjusted = rewriter.create<mlir::LLVM::GEPOp>( 841 loc, resultType, byteType, derivedAddr, offset); 842 rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(baseClassOp, isNull, 843 derivedAddr, adjusted); 844 } 845 return mlir::success(); 846 } 847 848 mlir::LogicalResult CIRToLLVMAllocaOpLowering::matchAndRewrite( 849 cir::AllocaOp op, OpAdaptor adaptor, 850 mlir::ConversionPatternRewriter &rewriter) const { 851 assert(!cir::MissingFeatures::opAllocaDynAllocSize()); 852 mlir::Value size = rewriter.create<mlir::LLVM::ConstantOp>( 853 op.getLoc(), typeConverter->convertType(rewriter.getIndexType()), 1); 854 mlir::Type elementTy = 855 convertTypeForMemory(*getTypeConverter(), dataLayout, op.getAllocaType()); 856 mlir::Type resultTy = 857 convertTypeForMemory(*getTypeConverter(), dataLayout, op.getType()); 858 859 assert(!cir::MissingFeatures::addressSpace()); 860 assert(!cir::MissingFeatures::opAllocaAnnotations()); 861 862 rewriter.replaceOpWithNewOp<mlir::LLVM::AllocaOp>( 863 op, resultTy, elementTy, size, op.getAlignmentAttr().getInt()); 864 865 return mlir::success(); 866 } 867 868 mlir::LogicalResult CIRToLLVMReturnOpLowering::matchAndRewrite( 869 cir::ReturnOp op, OpAdaptor adaptor, 870 mlir::ConversionPatternRewriter &rewriter) const { 871 rewriter.replaceOpWithNewOp<mlir::LLVM::ReturnOp>(op, adaptor.getOperands()); 872 return mlir::LogicalResult::success(); 873 } 874 875 static mlir::LogicalResult 876 rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands, 877 mlir::ConversionPatternRewriter &rewriter, 878 const mlir::TypeConverter *converter, 879 mlir::FlatSymbolRefAttr calleeAttr) { 880 llvm::SmallVector<mlir::Type, 8> llvmResults; 881 mlir::ValueTypeRange<mlir::ResultRange> cirResults = op->getResultTypes(); 882 auto call = cast<cir::CIRCallOpInterface>(op); 883 884 if (converter->convertTypes(cirResults, llvmResults).failed()) 885 return mlir::failure(); 886 887 assert(!cir::MissingFeatures::opCallCallConv()); 888 889 mlir::LLVM::MemoryEffectsAttr memoryEffects; 890 bool noUnwind = false; 891 bool willReturn = false; 892 convertSideEffectForCall(op, call.getNothrow(), call.getSideEffect(), 893 memoryEffects, noUnwind, willReturn); 894 895 mlir::LLVM::LLVMFunctionType llvmFnTy; 896 if (calleeAttr) { // direct call 897 mlir::FunctionOpInterface fn = 898 mlir::SymbolTable::lookupNearestSymbolFrom<mlir::FunctionOpInterface>( 899 op, calleeAttr); 900 assert(fn && "Did not find function for call"); 901 llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>( 902 converter->convertType(fn.getFunctionType())); 903 } else { // indirect call 904 assert(!op->getOperands().empty() && 905 "operands list must no be empty for the indirect call"); 906 auto calleeTy = op->getOperands().front().getType(); 907 auto calleePtrTy = cast<cir::PointerType>(calleeTy); 908 auto calleeFuncTy = cast<cir::FuncType>(calleePtrTy.getPointee()); 909 calleeFuncTy.dump(); 910 converter->convertType(calleeFuncTy).dump(); 911 llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>( 912 converter->convertType(calleeFuncTy)); 913 } 914 915 assert(!cir::MissingFeatures::opCallLandingPad()); 916 assert(!cir::MissingFeatures::opCallContinueBlock()); 917 assert(!cir::MissingFeatures::opCallCallConv()); 918 919 auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( 920 op, llvmFnTy, calleeAttr, callOperands); 921 if (memoryEffects) 922 newOp.setMemoryEffectsAttr(memoryEffects); 923 newOp.setNoUnwind(noUnwind); 924 newOp.setWillReturn(willReturn); 925 926 return mlir::success(); 927 } 928 929 mlir::LogicalResult CIRToLLVMCallOpLowering::matchAndRewrite( 930 cir::CallOp op, OpAdaptor adaptor, 931 mlir::ConversionPatternRewriter &rewriter) const { 932 return rewriteCallOrInvoke(op.getOperation(), adaptor.getOperands(), rewriter, 933 getTypeConverter(), op.getCalleeAttr()); 934 } 935 936 mlir::LogicalResult CIRToLLVMLoadOpLowering::matchAndRewrite( 937 cir::LoadOp op, OpAdaptor adaptor, 938 mlir::ConversionPatternRewriter &rewriter) const { 939 const mlir::Type llvmTy = 940 convertTypeForMemory(*getTypeConverter(), dataLayout, op.getType()); 941 assert(!cir::MissingFeatures::opLoadStoreMemOrder()); 942 std::optional<size_t> opAlign = op.getAlignment(); 943 unsigned alignment = 944 (unsigned)opAlign.value_or(dataLayout.getTypeABIAlignment(llvmTy)); 945 946 assert(!cir::MissingFeatures::lowerModeOptLevel()); 947 948 // TODO: nontemporal, syncscope. 949 assert(!cir::MissingFeatures::opLoadStoreVolatile()); 950 mlir::LLVM::LoadOp newLoad = rewriter.create<mlir::LLVM::LoadOp>( 951 op->getLoc(), llvmTy, adaptor.getAddr(), alignment, 952 /*volatile=*/false, /*nontemporal=*/false, 953 /*invariant=*/false, /*invariantGroup=*/false, 954 mlir::LLVM::AtomicOrdering::not_atomic); 955 956 // Convert adapted result to its original type if needed. 957 mlir::Value result = 958 emitFromMemory(rewriter, dataLayout, op, newLoad.getResult()); 959 rewriter.replaceOp(op, result); 960 assert(!cir::MissingFeatures::opLoadStoreTbaa()); 961 return mlir::LogicalResult::success(); 962 } 963 964 mlir::LogicalResult CIRToLLVMStoreOpLowering::matchAndRewrite( 965 cir::StoreOp op, OpAdaptor adaptor, 966 mlir::ConversionPatternRewriter &rewriter) const { 967 assert(!cir::MissingFeatures::opLoadStoreMemOrder()); 968 const mlir::Type llvmTy = 969 getTypeConverter()->convertType(op.getValue().getType()); 970 std::optional<size_t> opAlign = op.getAlignment(); 971 unsigned alignment = 972 (unsigned)opAlign.value_or(dataLayout.getTypeABIAlignment(llvmTy)); 973 974 assert(!cir::MissingFeatures::lowerModeOptLevel()); 975 976 // Convert adapted value to its memory type if needed. 977 mlir::Value value = emitToMemory(rewriter, dataLayout, 978 op.getValue().getType(), adaptor.getValue()); 979 // TODO: nontemporal, syncscope. 980 assert(!cir::MissingFeatures::opLoadStoreVolatile()); 981 mlir::LLVM::StoreOp storeOp = rewriter.create<mlir::LLVM::StoreOp>( 982 op->getLoc(), value, adaptor.getAddr(), alignment, /*volatile=*/false, 983 /*nontemporal=*/false, /*invariantGroup=*/false, 984 mlir::LLVM::AtomicOrdering::not_atomic); 985 rewriter.replaceOp(op, storeOp); 986 assert(!cir::MissingFeatures::opLoadStoreTbaa()); 987 return mlir::LogicalResult::success(); 988 } 989 990 bool hasTrailingZeros(cir::ConstArrayAttr attr) { 991 auto array = mlir::dyn_cast<mlir::ArrayAttr>(attr.getElts()); 992 return attr.hasTrailingZeros() || 993 (array && std::count_if(array.begin(), array.end(), [](auto elt) { 994 auto ar = dyn_cast<cir::ConstArrayAttr>(elt); 995 return ar && hasTrailingZeros(ar); 996 })); 997 } 998 999 mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( 1000 cir::ConstantOp op, OpAdaptor adaptor, 1001 mlir::ConversionPatternRewriter &rewriter) const { 1002 mlir::Attribute attr = op.getValue(); 1003 1004 if (mlir::isa<mlir::IntegerType>(op.getType())) { 1005 // Verified cir.const operations cannot actually be of these types, but the 1006 // lowering pass may generate temporary cir.const operations with these 1007 // types. This is OK since MLIR allows unverified operations to be alive 1008 // during a pass as long as they don't live past the end of the pass. 1009 attr = op.getValue(); 1010 } else if (mlir::isa<cir::BoolType>(op.getType())) { 1011 int value = mlir::cast<cir::BoolAttr>(op.getValue()).getValue(); 1012 attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()), 1013 value); 1014 } else if (mlir::isa<cir::IntType>(op.getType())) { 1015 assert(!cir::MissingFeatures::opGlobalViewAttr()); 1016 1017 attr = rewriter.getIntegerAttr( 1018 typeConverter->convertType(op.getType()), 1019 mlir::cast<cir::IntAttr>(op.getValue()).getValue()); 1020 } else if (mlir::isa<cir::FPTypeInterface>(op.getType())) { 1021 attr = rewriter.getFloatAttr( 1022 typeConverter->convertType(op.getType()), 1023 mlir::cast<cir::FPAttr>(op.getValue()).getValue()); 1024 } else if (mlir::isa<cir::PointerType>(op.getType())) { 1025 // Optimize with dedicated LLVM op for null pointers. 1026 if (mlir::isa<cir::ConstPtrAttr>(op.getValue())) { 1027 if (mlir::cast<cir::ConstPtrAttr>(op.getValue()).isNullValue()) { 1028 rewriter.replaceOpWithNewOp<mlir::LLVM::ZeroOp>( 1029 op, typeConverter->convertType(op.getType())); 1030 return mlir::success(); 1031 } 1032 } 1033 assert(!cir::MissingFeatures::opGlobalViewAttr()); 1034 attr = op.getValue(); 1035 } else if (const auto arrTy = mlir::dyn_cast<cir::ArrayType>(op.getType())) { 1036 const auto constArr = mlir::dyn_cast<cir::ConstArrayAttr>(op.getValue()); 1037 if (!constArr && !isa<cir::ZeroAttr, cir::UndefAttr>(op.getValue())) 1038 return op.emitError() << "array does not have a constant initializer"; 1039 1040 std::optional<mlir::Attribute> denseAttr; 1041 if (constArr && hasTrailingZeros(constArr)) { 1042 const mlir::Value newOp = 1043 lowerCirAttrAsValue(op, constArr, rewriter, getTypeConverter()); 1044 rewriter.replaceOp(op, newOp); 1045 return mlir::success(); 1046 } else if (constArr && 1047 (denseAttr = lowerConstArrayAttr(constArr, typeConverter))) { 1048 attr = denseAttr.value(); 1049 } else { 1050 const mlir::Value initVal = 1051 lowerCirAttrAsValue(op, op.getValue(), rewriter, typeConverter); 1052 rewriter.replaceAllUsesWith(op, initVal); 1053 rewriter.eraseOp(op); 1054 return mlir::success(); 1055 } 1056 } else if (const auto vecTy = mlir::dyn_cast<cir::VectorType>(op.getType())) { 1057 rewriter.replaceOp(op, lowerCirAttrAsValue(op, op.getValue(), rewriter, 1058 getTypeConverter())); 1059 return mlir::success(); 1060 } else if (auto complexTy = mlir::dyn_cast<cir::ComplexType>(op.getType())) { 1061 mlir::Type complexElemTy = complexTy.getElementType(); 1062 mlir::Type complexElemLLVMTy = typeConverter->convertType(complexElemTy); 1063 1064 if (auto zeroInitAttr = mlir::dyn_cast<cir::ZeroAttr>(op.getValue())) { 1065 mlir::TypedAttr zeroAttr = rewriter.getZeroAttr(complexElemLLVMTy); 1066 mlir::ArrayAttr array = rewriter.getArrayAttr({zeroAttr, zeroAttr}); 1067 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 1068 op, getTypeConverter()->convertType(op.getType()), array); 1069 return mlir::success(); 1070 } 1071 1072 auto complexAttr = mlir::cast<cir::ConstComplexAttr>(op.getValue()); 1073 1074 mlir::Attribute components[2]; 1075 if (mlir::isa<cir::IntType>(complexElemTy)) { 1076 components[0] = rewriter.getIntegerAttr( 1077 complexElemLLVMTy, 1078 mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue()); 1079 components[1] = rewriter.getIntegerAttr( 1080 complexElemLLVMTy, 1081 mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue()); 1082 } else { 1083 components[0] = rewriter.getFloatAttr( 1084 complexElemLLVMTy, 1085 mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue()); 1086 components[1] = rewriter.getFloatAttr( 1087 complexElemLLVMTy, 1088 mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue()); 1089 } 1090 1091 attr = rewriter.getArrayAttr(components); 1092 } else { 1093 return op.emitError() << "unsupported constant type " << op.getType(); 1094 } 1095 1096 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 1097 op, getTypeConverter()->convertType(op.getType()), attr); 1098 1099 return mlir::success(); 1100 } 1101 1102 mlir::LogicalResult CIRToLLVMExpectOpLowering::matchAndRewrite( 1103 cir::ExpectOp op, OpAdaptor adaptor, 1104 mlir::ConversionPatternRewriter &rewriter) const { 1105 // TODO(cir): do not generate LLVM intrinsics under -O0 1106 assert(!cir::MissingFeatures::optInfoAttr()); 1107 1108 std::optional<llvm::APFloat> prob = op.getProb(); 1109 if (prob) 1110 rewriter.replaceOpWithNewOp<mlir::LLVM::ExpectWithProbabilityOp>( 1111 op, adaptor.getVal(), adaptor.getExpected(), prob.value()); 1112 else 1113 rewriter.replaceOpWithNewOp<mlir::LLVM::ExpectOp>(op, adaptor.getVal(), 1114 adaptor.getExpected()); 1115 return mlir::success(); 1116 } 1117 1118 /// Convert the `cir.func` attributes to `llvm.func` attributes. 1119 /// Only retain those attributes that are not constructed by 1120 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out 1121 /// argument attributes. 1122 void CIRToLLVMFuncOpLowering::lowerFuncAttributes( 1123 cir::FuncOp func, bool filterArgAndResAttrs, 1124 SmallVectorImpl<mlir::NamedAttribute> &result) const { 1125 assert(!cir::MissingFeatures::opFuncCallingConv()); 1126 for (mlir::NamedAttribute attr : func->getAttrs()) { 1127 assert(!cir::MissingFeatures::opFuncCallingConv()); 1128 if (attr.getName() == mlir::SymbolTable::getSymbolAttrName() || 1129 attr.getName() == func.getFunctionTypeAttrName() || 1130 attr.getName() == getLinkageAttrNameString() || 1131 attr.getName() == func.getGlobalVisibilityAttrName() || 1132 attr.getName() == func.getDsoLocalAttrName() || 1133 (filterArgAndResAttrs && 1134 (attr.getName() == func.getArgAttrsAttrName() || 1135 attr.getName() == func.getResAttrsAttrName()))) 1136 continue; 1137 1138 assert(!cir::MissingFeatures::opFuncExtraAttrs()); 1139 result.push_back(attr); 1140 } 1141 } 1142 1143 mlir::LogicalResult CIRToLLVMFuncOpLowering::matchAndRewrite( 1144 cir::FuncOp op, OpAdaptor adaptor, 1145 mlir::ConversionPatternRewriter &rewriter) const { 1146 1147 cir::FuncType fnType = op.getFunctionType(); 1148 bool isDsoLocal = op.getDsoLocal(); 1149 mlir::TypeConverter::SignatureConversion signatureConversion( 1150 fnType.getNumInputs()); 1151 1152 for (const auto &argType : llvm::enumerate(fnType.getInputs())) { 1153 mlir::Type convertedType = typeConverter->convertType(argType.value()); 1154 if (!convertedType) 1155 return mlir::failure(); 1156 signatureConversion.addInputs(argType.index(), convertedType); 1157 } 1158 1159 mlir::Type resultType = 1160 getTypeConverter()->convertType(fnType.getReturnType()); 1161 1162 // Create the LLVM function operation. 1163 mlir::Type llvmFnTy = mlir::LLVM::LLVMFunctionType::get( 1164 resultType ? resultType : mlir::LLVM::LLVMVoidType::get(getContext()), 1165 signatureConversion.getConvertedTypes(), 1166 /*isVarArg=*/fnType.isVarArg()); 1167 // LLVMFuncOp expects a single FileLine Location instead of a fused 1168 // location. 1169 mlir::Location loc = op.getLoc(); 1170 if (mlir::FusedLoc fusedLoc = mlir::dyn_cast<mlir::FusedLoc>(loc)) 1171 loc = fusedLoc.getLocations()[0]; 1172 assert((mlir::isa<mlir::FileLineColLoc>(loc) || 1173 mlir::isa<mlir::UnknownLoc>(loc)) && 1174 "expected single location or unknown location here"); 1175 1176 mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage()); 1177 assert(!cir::MissingFeatures::opFuncCallingConv()); 1178 mlir::LLVM::CConv cconv = mlir::LLVM::CConv::C; 1179 SmallVector<mlir::NamedAttribute, 4> attributes; 1180 lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes); 1181 1182 mlir::LLVM::LLVMFuncOp fn = rewriter.create<mlir::LLVM::LLVMFuncOp>( 1183 loc, op.getName(), llvmFnTy, linkage, isDsoLocal, cconv, 1184 mlir::SymbolRefAttr(), attributes); 1185 1186 assert(!cir::MissingFeatures::opFuncMultipleReturnVals()); 1187 1188 fn.setVisibility_Attr(mlir::LLVM::VisibilityAttr::get( 1189 getContext(), lowerCIRVisibilityToLLVMVisibility( 1190 op.getGlobalVisibilityAttr().getValue()))); 1191 1192 rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end()); 1193 if (failed(rewriter.convertRegionTypes(&fn.getBody(), *typeConverter, 1194 &signatureConversion))) 1195 return mlir::failure(); 1196 1197 rewriter.eraseOp(op); 1198 1199 return mlir::LogicalResult::success(); 1200 } 1201 1202 mlir::LogicalResult CIRToLLVMGetGlobalOpLowering::matchAndRewrite( 1203 cir::GetGlobalOp op, OpAdaptor adaptor, 1204 mlir::ConversionPatternRewriter &rewriter) const { 1205 // FIXME(cir): Premature DCE to avoid lowering stuff we're not using. 1206 // CIRGen should mitigate this and not emit the get_global. 1207 if (op->getUses().empty()) { 1208 rewriter.eraseOp(op); 1209 return mlir::success(); 1210 } 1211 1212 mlir::Type type = getTypeConverter()->convertType(op.getType()); 1213 mlir::Operation *newop = 1214 rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), type, op.getName()); 1215 1216 assert(!cir::MissingFeatures::opGlobalThreadLocal()); 1217 1218 rewriter.replaceOp(op, newop); 1219 return mlir::success(); 1220 } 1221 1222 /// Replace CIR global with a region initialized LLVM global and update 1223 /// insertion point to the end of the initializer block. 1224 void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp( 1225 cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const { 1226 const mlir::Type llvmType = 1227 convertTypeForMemory(*getTypeConverter(), dataLayout, op.getSymType()); 1228 1229 // FIXME: These default values are placeholders until the the equivalent 1230 // attributes are available on cir.global ops. This duplicates code 1231 // in CIRToLLVMGlobalOpLowering::matchAndRewrite() but that will go 1232 // away when the placeholders are no longer needed. 1233 assert(!cir::MissingFeatures::opGlobalConstant()); 1234 const bool isConst = false; 1235 assert(!cir::MissingFeatures::addressSpace()); 1236 const unsigned addrSpace = 0; 1237 const bool isDsoLocal = op.getDsoLocal(); 1238 assert(!cir::MissingFeatures::opGlobalThreadLocal()); 1239 const bool isThreadLocal = false; 1240 const uint64_t alignment = op.getAlignment().value_or(0); 1241 const mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage()); 1242 const StringRef symbol = op.getSymName(); 1243 mlir::SymbolRefAttr comdatAttr = getComdatAttr(op, rewriter); 1244 1245 SmallVector<mlir::NamedAttribute> attributes; 1246 mlir::LLVM::GlobalOp newGlobalOp = 1247 rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>( 1248 op, llvmType, isConst, linkage, symbol, nullptr, alignment, addrSpace, 1249 isDsoLocal, isThreadLocal, comdatAttr, attributes); 1250 newGlobalOp.getRegion().emplaceBlock(); 1251 rewriter.setInsertionPointToEnd(newGlobalOp.getInitializerBlock()); 1252 } 1253 1254 mlir::LogicalResult 1255 CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal( 1256 cir::GlobalOp op, mlir::Attribute init, 1257 mlir::ConversionPatternRewriter &rewriter) const { 1258 // TODO: Generalize this handling when more types are needed here. 1259 assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr, 1260 cir::ConstComplexAttr, cir::ZeroAttr>(init))); 1261 1262 // TODO(cir): once LLVM's dialect has proper equivalent attributes this 1263 // should be updated. For now, we use a custom op to initialize globals 1264 // to the appropriate value. 1265 const mlir::Location loc = op.getLoc(); 1266 setupRegionInitializedLLVMGlobalOp(op, rewriter); 1267 CIRAttrToValue valueConverter(op, rewriter, typeConverter); 1268 mlir::Value value = valueConverter.visit(init); 1269 rewriter.create<mlir::LLVM::ReturnOp>(loc, value); 1270 return mlir::success(); 1271 } 1272 1273 mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( 1274 cir::GlobalOp op, OpAdaptor adaptor, 1275 mlir::ConversionPatternRewriter &rewriter) const { 1276 1277 std::optional<mlir::Attribute> init = op.getInitialValue(); 1278 1279 // Fetch required values to create LLVM op. 1280 const mlir::Type cirSymType = op.getSymType(); 1281 1282 // This is the LLVM dialect type. 1283 const mlir::Type llvmType = 1284 convertTypeForMemory(*getTypeConverter(), dataLayout, cirSymType); 1285 // FIXME: These default values are placeholders until the the equivalent 1286 // attributes are available on cir.global ops. 1287 assert(!cir::MissingFeatures::opGlobalConstant()); 1288 const bool isConst = false; 1289 assert(!cir::MissingFeatures::addressSpace()); 1290 const unsigned addrSpace = 0; 1291 const bool isDsoLocal = op.getDsoLocal(); 1292 assert(!cir::MissingFeatures::opGlobalThreadLocal()); 1293 const bool isThreadLocal = false; 1294 const uint64_t alignment = op.getAlignment().value_or(0); 1295 const mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage()); 1296 const StringRef symbol = op.getSymName(); 1297 SmallVector<mlir::NamedAttribute> attributes; 1298 mlir::SymbolRefAttr comdatAttr = getComdatAttr(op, rewriter); 1299 1300 if (init.has_value()) { 1301 if (mlir::isa<cir::FPAttr, cir::IntAttr, cir::BoolAttr>(init.value())) { 1302 GlobalInitAttrRewriter initRewriter(llvmType, rewriter); 1303 init = initRewriter.visit(init.value()); 1304 // If initRewriter returned a null attribute, init will have a value but 1305 // the value will be null. If that happens, initRewriter didn't handle the 1306 // attribute type. It probably needs to be added to 1307 // GlobalInitAttrRewriter. 1308 if (!init.value()) { 1309 op.emitError() << "unsupported initializer '" << init.value() << "'"; 1310 return mlir::failure(); 1311 } 1312 } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr, 1313 cir::ConstPtrAttr, cir::ConstComplexAttr, 1314 cir::ZeroAttr>(init.value())) { 1315 // TODO(cir): once LLVM's dialect has proper equivalent attributes this 1316 // should be updated. For now, we use a custom op to initialize globals 1317 // to the appropriate value. 1318 return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter); 1319 } else { 1320 // We will only get here if new initializer types are added and this 1321 // code is not updated to handle them. 1322 op.emitError() << "unsupported initializer '" << init.value() << "'"; 1323 return mlir::failure(); 1324 } 1325 } 1326 1327 // Rewrite op. 1328 rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>( 1329 op, llvmType, isConst, linkage, symbol, init.value_or(mlir::Attribute()), 1330 alignment, addrSpace, isDsoLocal, isThreadLocal, comdatAttr, attributes); 1331 return mlir::success(); 1332 } 1333 1334 mlir::SymbolRefAttr 1335 CIRToLLVMGlobalOpLowering::getComdatAttr(cir::GlobalOp &op, 1336 mlir::OpBuilder &builder) const { 1337 if (!op.getComdat()) 1338 return mlir::SymbolRefAttr{}; 1339 1340 mlir::ModuleOp module = op->getParentOfType<mlir::ModuleOp>(); 1341 mlir::OpBuilder::InsertionGuard guard(builder); 1342 StringRef comdatName("__llvm_comdat_globals"); 1343 if (!comdatOp) { 1344 builder.setInsertionPointToStart(module.getBody()); 1345 comdatOp = 1346 builder.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName); 1347 } 1348 1349 builder.setInsertionPointToStart(&comdatOp.getBody().back()); 1350 auto selectorOp = builder.create<mlir::LLVM::ComdatSelectorOp>( 1351 comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any); 1352 return mlir::SymbolRefAttr::get( 1353 builder.getContext(), comdatName, 1354 mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr())); 1355 } 1356 1357 mlir::LogicalResult CIRToLLVMSwitchFlatOpLowering::matchAndRewrite( 1358 cir::SwitchFlatOp op, OpAdaptor adaptor, 1359 mlir::ConversionPatternRewriter &rewriter) const { 1360 1361 llvm::SmallVector<mlir::APInt, 8> caseValues; 1362 for (mlir::Attribute val : op.getCaseValues()) { 1363 auto intAttr = cast<cir::IntAttr>(val); 1364 caseValues.push_back(intAttr.getValue()); 1365 } 1366 1367 llvm::SmallVector<mlir::Block *, 8> caseDestinations; 1368 llvm::SmallVector<mlir::ValueRange, 8> caseOperands; 1369 1370 for (mlir::Block *x : op.getCaseDestinations()) 1371 caseDestinations.push_back(x); 1372 1373 for (mlir::OperandRange x : op.getCaseOperands()) 1374 caseOperands.push_back(x); 1375 1376 // Set switch op to branch to the newly created blocks. 1377 rewriter.setInsertionPoint(op); 1378 rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>( 1379 op, adaptor.getCondition(), op.getDefaultDestination(), 1380 op.getDefaultOperands(), caseValues, caseDestinations, caseOperands); 1381 return mlir::success(); 1382 } 1383 1384 mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite( 1385 cir::UnaryOp op, OpAdaptor adaptor, 1386 mlir::ConversionPatternRewriter &rewriter) const { 1387 assert(op.getType() == op.getInput().getType() && 1388 "Unary operation's operand type and result type are different"); 1389 mlir::Type type = op.getType(); 1390 mlir::Type elementType = elementTypeIfVector(type); 1391 bool isVector = mlir::isa<cir::VectorType>(type); 1392 mlir::Type llvmType = getTypeConverter()->convertType(type); 1393 mlir::Location loc = op.getLoc(); 1394 1395 // Integer unary operations: + - ~ ++ -- 1396 if (mlir::isa<cir::IntType>(elementType)) { 1397 mlir::LLVM::IntegerOverflowFlags maybeNSW = 1398 op.getNoSignedWrap() ? mlir::LLVM::IntegerOverflowFlags::nsw 1399 : mlir::LLVM::IntegerOverflowFlags::none; 1400 switch (op.getKind()) { 1401 case cir::UnaryOpKind::Inc: { 1402 assert(!isVector && "++ not allowed on vector types"); 1403 auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1); 1404 rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>( 1405 op, llvmType, adaptor.getInput(), one, maybeNSW); 1406 return mlir::success(); 1407 } 1408 case cir::UnaryOpKind::Dec: { 1409 assert(!isVector && "-- not allowed on vector types"); 1410 auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1); 1411 rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, adaptor.getInput(), 1412 one, maybeNSW); 1413 return mlir::success(); 1414 } 1415 case cir::UnaryOpKind::Plus: 1416 rewriter.replaceOp(op, adaptor.getInput()); 1417 return mlir::success(); 1418 case cir::UnaryOpKind::Minus: { 1419 mlir::Value zero; 1420 if (isVector) 1421 zero = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmType); 1422 else 1423 zero = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 0); 1424 rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>( 1425 op, zero, adaptor.getInput(), maybeNSW); 1426 return mlir::success(); 1427 } 1428 case cir::UnaryOpKind::Not: { 1429 // bit-wise compliment operator, implemented as an XOR with -1. 1430 mlir::Value minusOne; 1431 if (isVector) { 1432 const uint64_t numElements = 1433 mlir::dyn_cast<cir::VectorType>(type).getSize(); 1434 std::vector<int32_t> values(numElements, -1); 1435 mlir::DenseIntElementsAttr denseVec = rewriter.getI32VectorAttr(values); 1436 minusOne = 1437 rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, denseVec); 1438 } else { 1439 minusOne = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, -1); 1440 } 1441 rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(), 1442 minusOne); 1443 return mlir::success(); 1444 } 1445 } 1446 llvm_unreachable("Unexpected unary op for int"); 1447 } 1448 1449 // Floating point unary operations: + - ++ -- 1450 if (mlir::isa<cir::FPTypeInterface>(elementType)) { 1451 switch (op.getKind()) { 1452 case cir::UnaryOpKind::Inc: { 1453 assert(!isVector && "++ not allowed on vector types"); 1454 mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>( 1455 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 1456 rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmType, one, 1457 adaptor.getInput()); 1458 return mlir::success(); 1459 } 1460 case cir::UnaryOpKind::Dec: { 1461 assert(!isVector && "-- not allowed on vector types"); 1462 mlir::LLVM::ConstantOp minusOne = rewriter.create<mlir::LLVM::ConstantOp>( 1463 loc, llvmType, rewriter.getFloatAttr(llvmType, -1.0)); 1464 rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmType, minusOne, 1465 adaptor.getInput()); 1466 return mlir::success(); 1467 } 1468 case cir::UnaryOpKind::Plus: 1469 rewriter.replaceOp(op, adaptor.getInput()); 1470 return mlir::success(); 1471 case cir::UnaryOpKind::Minus: 1472 rewriter.replaceOpWithNewOp<mlir::LLVM::FNegOp>(op, llvmType, 1473 adaptor.getInput()); 1474 return mlir::success(); 1475 case cir::UnaryOpKind::Not: 1476 return op.emitError() << "Unary not is invalid for floating-point types"; 1477 } 1478 llvm_unreachable("Unexpected unary op for float"); 1479 } 1480 1481 // Boolean unary operations: ! only. (For all others, the operand has 1482 // already been promoted to int.) 1483 if (mlir::isa<cir::BoolType>(elementType)) { 1484 switch (op.getKind()) { 1485 case cir::UnaryOpKind::Inc: 1486 case cir::UnaryOpKind::Dec: 1487 case cir::UnaryOpKind::Plus: 1488 case cir::UnaryOpKind::Minus: 1489 // Some of these are allowed in source code, but we shouldn't get here 1490 // with a boolean type. 1491 return op.emitError() << "Unsupported unary operation on boolean type"; 1492 case cir::UnaryOpKind::Not: { 1493 assert(!isVector && "NYI: op! on vector mask"); 1494 auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1); 1495 rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(), 1496 one); 1497 return mlir::success(); 1498 } 1499 } 1500 llvm_unreachable("Unexpected unary op for bool"); 1501 } 1502 1503 // Pointer unary operations: + only. (++ and -- of pointers are implemented 1504 // with cir.ptr_stride, not cir.unary.) 1505 if (mlir::isa<cir::PointerType>(elementType)) { 1506 return op.emitError() 1507 << "Unary operation on pointer types is not yet implemented"; 1508 } 1509 1510 return op.emitError() << "Unary operation has unsupported type: " 1511 << elementType; 1512 } 1513 1514 mlir::LLVM::IntegerOverflowFlags 1515 CIRToLLVMBinOpLowering::getIntOverflowFlag(cir::BinOp op) const { 1516 if (op.getNoUnsignedWrap()) 1517 return mlir::LLVM::IntegerOverflowFlags::nuw; 1518 1519 if (op.getNoSignedWrap()) 1520 return mlir::LLVM::IntegerOverflowFlags::nsw; 1521 1522 return mlir::LLVM::IntegerOverflowFlags::none; 1523 } 1524 1525 static bool isIntTypeUnsigned(mlir::Type type) { 1526 // TODO: Ideally, we should only need to check cir::IntType here. 1527 return mlir::isa<cir::IntType>(type) 1528 ? mlir::cast<cir::IntType>(type).isUnsigned() 1529 : mlir::cast<mlir::IntegerType>(type).isUnsigned(); 1530 } 1531 1532 mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite( 1533 cir::BinOp op, OpAdaptor adaptor, 1534 mlir::ConversionPatternRewriter &rewriter) const { 1535 if (adaptor.getLhs().getType() != adaptor.getRhs().getType()) 1536 return op.emitError() << "inconsistent operands' types not supported yet"; 1537 1538 mlir::Type type = op.getRhs().getType(); 1539 if (!mlir::isa<cir::IntType, cir::BoolType, cir::FPTypeInterface, 1540 mlir::IntegerType, cir::VectorType>(type)) 1541 return op.emitError() << "operand type not supported yet"; 1542 1543 const mlir::Type llvmTy = getTypeConverter()->convertType(op.getType()); 1544 const mlir::Type llvmEltTy = elementTypeIfVector(llvmTy); 1545 1546 const mlir::Value rhs = adaptor.getRhs(); 1547 const mlir::Value lhs = adaptor.getLhs(); 1548 type = elementTypeIfVector(type); 1549 1550 switch (op.getKind()) { 1551 case cir::BinOpKind::Add: 1552 if (mlir::isa<mlir::IntegerType>(llvmEltTy)) { 1553 if (op.getSaturated()) { 1554 if (isIntTypeUnsigned(type)) { 1555 rewriter.replaceOpWithNewOp<mlir::LLVM::UAddSat>(op, lhs, rhs); 1556 break; 1557 } 1558 rewriter.replaceOpWithNewOp<mlir::LLVM::SAddSat>(op, lhs, rhs); 1559 break; 1560 } 1561 rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(op, llvmTy, lhs, rhs, 1562 getIntOverflowFlag(op)); 1563 } else { 1564 rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, lhs, rhs); 1565 } 1566 break; 1567 case cir::BinOpKind::Sub: 1568 if (mlir::isa<mlir::IntegerType>(llvmEltTy)) { 1569 if (op.getSaturated()) { 1570 if (isIntTypeUnsigned(type)) { 1571 rewriter.replaceOpWithNewOp<mlir::LLVM::USubSat>(op, lhs, rhs); 1572 break; 1573 } 1574 rewriter.replaceOpWithNewOp<mlir::LLVM::SSubSat>(op, lhs, rhs); 1575 break; 1576 } 1577 rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmTy, lhs, rhs, 1578 getIntOverflowFlag(op)); 1579 } else { 1580 rewriter.replaceOpWithNewOp<mlir::LLVM::FSubOp>(op, lhs, rhs); 1581 } 1582 break; 1583 case cir::BinOpKind::Mul: 1584 if (mlir::isa<mlir::IntegerType>(llvmEltTy)) 1585 rewriter.replaceOpWithNewOp<mlir::LLVM::MulOp>(op, llvmTy, lhs, rhs, 1586 getIntOverflowFlag(op)); 1587 else 1588 rewriter.replaceOpWithNewOp<mlir::LLVM::FMulOp>(op, lhs, rhs); 1589 break; 1590 case cir::BinOpKind::Div: 1591 if (mlir::isa<mlir::IntegerType>(llvmEltTy)) { 1592 auto isUnsigned = isIntTypeUnsigned(type); 1593 if (isUnsigned) 1594 rewriter.replaceOpWithNewOp<mlir::LLVM::UDivOp>(op, lhs, rhs); 1595 else 1596 rewriter.replaceOpWithNewOp<mlir::LLVM::SDivOp>(op, lhs, rhs); 1597 } else { 1598 rewriter.replaceOpWithNewOp<mlir::LLVM::FDivOp>(op, lhs, rhs); 1599 } 1600 break; 1601 case cir::BinOpKind::Rem: 1602 if (mlir::isa<mlir::IntegerType>(llvmEltTy)) { 1603 auto isUnsigned = isIntTypeUnsigned(type); 1604 if (isUnsigned) 1605 rewriter.replaceOpWithNewOp<mlir::LLVM::URemOp>(op, lhs, rhs); 1606 else 1607 rewriter.replaceOpWithNewOp<mlir::LLVM::SRemOp>(op, lhs, rhs); 1608 } else { 1609 rewriter.replaceOpWithNewOp<mlir::LLVM::FRemOp>(op, lhs, rhs); 1610 } 1611 break; 1612 case cir::BinOpKind::And: 1613 rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, lhs, rhs); 1614 break; 1615 case cir::BinOpKind::Or: 1616 rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, lhs, rhs); 1617 break; 1618 case cir::BinOpKind::Xor: 1619 rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, lhs, rhs); 1620 break; 1621 case cir::BinOpKind::Max: 1622 if (mlir::isa<mlir::IntegerType>(llvmEltTy)) { 1623 auto isUnsigned = isIntTypeUnsigned(type); 1624 if (isUnsigned) 1625 rewriter.replaceOpWithNewOp<mlir::LLVM::UMaxOp>(op, llvmTy, lhs, rhs); 1626 else 1627 rewriter.replaceOpWithNewOp<mlir::LLVM::SMaxOp>(op, llvmTy, lhs, rhs); 1628 } 1629 break; 1630 } 1631 return mlir::LogicalResult::success(); 1632 } 1633 1634 /// Convert from a CIR comparison kind to an LLVM IR integral comparison kind. 1635 static mlir::LLVM::ICmpPredicate 1636 convertCmpKindToICmpPredicate(cir::CmpOpKind kind, bool isSigned) { 1637 using CIR = cir::CmpOpKind; 1638 using LLVMICmp = mlir::LLVM::ICmpPredicate; 1639 switch (kind) { 1640 case CIR::eq: 1641 return LLVMICmp::eq; 1642 case CIR::ne: 1643 return LLVMICmp::ne; 1644 case CIR::lt: 1645 return (isSigned ? LLVMICmp::slt : LLVMICmp::ult); 1646 case CIR::le: 1647 return (isSigned ? LLVMICmp::sle : LLVMICmp::ule); 1648 case CIR::gt: 1649 return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt); 1650 case CIR::ge: 1651 return (isSigned ? LLVMICmp::sge : LLVMICmp::uge); 1652 } 1653 llvm_unreachable("Unknown CmpOpKind"); 1654 } 1655 1656 /// Convert from a CIR comparison kind to an LLVM IR floating-point comparison 1657 /// kind. 1658 static mlir::LLVM::FCmpPredicate 1659 convertCmpKindToFCmpPredicate(cir::CmpOpKind kind) { 1660 using CIR = cir::CmpOpKind; 1661 using LLVMFCmp = mlir::LLVM::FCmpPredicate; 1662 switch (kind) { 1663 case CIR::eq: 1664 return LLVMFCmp::oeq; 1665 case CIR::ne: 1666 return LLVMFCmp::une; 1667 case CIR::lt: 1668 return LLVMFCmp::olt; 1669 case CIR::le: 1670 return LLVMFCmp::ole; 1671 case CIR::gt: 1672 return LLVMFCmp::ogt; 1673 case CIR::ge: 1674 return LLVMFCmp::oge; 1675 } 1676 llvm_unreachable("Unknown CmpOpKind"); 1677 } 1678 1679 mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( 1680 cir::CmpOp cmpOp, OpAdaptor adaptor, 1681 mlir::ConversionPatternRewriter &rewriter) const { 1682 mlir::Type type = cmpOp.getLhs().getType(); 1683 1684 assert(!cir::MissingFeatures::dataMemberType()); 1685 assert(!cir::MissingFeatures::methodType()); 1686 1687 if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) { 1688 bool isSigned = mlir::isa<cir::IntType>(type) 1689 ? mlir::cast<cir::IntType>(type).isSigned() 1690 : mlir::cast<mlir::IntegerType>(type).isSigned(); 1691 mlir::LLVM::ICmpPredicate kind = 1692 convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned); 1693 rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( 1694 cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); 1695 return mlir::success(); 1696 } 1697 1698 if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) { 1699 mlir::LLVM::ICmpPredicate kind = 1700 convertCmpKindToICmpPredicate(cmpOp.getKind(), 1701 /* isSigned=*/false); 1702 rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( 1703 cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); 1704 return mlir::success(); 1705 } 1706 1707 if (mlir::isa<cir::FPTypeInterface>(type)) { 1708 mlir::LLVM::FCmpPredicate kind = 1709 convertCmpKindToFCmpPredicate(cmpOp.getKind()); 1710 rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>( 1711 cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); 1712 return mlir::success(); 1713 } 1714 1715 if (mlir::isa<cir::ComplexType>(type)) { 1716 mlir::Value lhs = adaptor.getLhs(); 1717 mlir::Value rhs = adaptor.getRhs(); 1718 mlir::Location loc = cmpOp.getLoc(); 1719 1720 auto complexType = mlir::cast<cir::ComplexType>(cmpOp.getLhs().getType()); 1721 mlir::Type complexElemTy = 1722 getTypeConverter()->convertType(complexType.getElementType()); 1723 1724 auto lhsReal = 1725 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0); 1726 auto lhsImag = 1727 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1); 1728 auto rhsReal = 1729 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0); 1730 auto rhsImag = 1731 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1); 1732 1733 if (cmpOp.getKind() == cir::CmpOpKind::eq) { 1734 if (complexElemTy.isInteger()) { 1735 auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>( 1736 loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal); 1737 auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>( 1738 loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag); 1739 rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmpOp, realCmp, imagCmp); 1740 return mlir::success(); 1741 } 1742 1743 auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>( 1744 loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal); 1745 auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>( 1746 loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag); 1747 rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmpOp, realCmp, imagCmp); 1748 return mlir::success(); 1749 } 1750 1751 if (cmpOp.getKind() == cir::CmpOpKind::ne) { 1752 if (complexElemTy.isInteger()) { 1753 auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>( 1754 loc, mlir::LLVM::ICmpPredicate::ne, lhsReal, rhsReal); 1755 auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>( 1756 loc, mlir::LLVM::ICmpPredicate::ne, lhsImag, rhsImag); 1757 rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmpOp, realCmp, imagCmp); 1758 return mlir::success(); 1759 } 1760 1761 auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>( 1762 loc, mlir::LLVM::FCmpPredicate::une, lhsReal, rhsReal); 1763 auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>( 1764 loc, mlir::LLVM::FCmpPredicate::une, lhsImag, rhsImag); 1765 rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmpOp, realCmp, imagCmp); 1766 return mlir::success(); 1767 } 1768 } 1769 1770 return cmpOp.emitError() << "unsupported type for CmpOp: " << type; 1771 } 1772 1773 mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite( 1774 cir::ShiftOp op, OpAdaptor adaptor, 1775 mlir::ConversionPatternRewriter &rewriter) const { 1776 assert((op.getValue().getType() == op.getType()) && 1777 "inconsistent operands' types NYI"); 1778 1779 const mlir::Type llvmTy = getTypeConverter()->convertType(op.getType()); 1780 mlir::Value amt = adaptor.getAmount(); 1781 mlir::Value val = adaptor.getValue(); 1782 1783 auto cirAmtTy = mlir::dyn_cast<cir::IntType>(op.getAmount().getType()); 1784 bool isUnsigned; 1785 if (cirAmtTy) { 1786 auto cirValTy = mlir::cast<cir::IntType>(op.getValue().getType()); 1787 isUnsigned = cirValTy.isUnsigned(); 1788 1789 // Ensure shift amount is the same type as the value. Some undefined 1790 // behavior might occur in the casts below as per [C99 6.5.7.3]. 1791 // Vector type shift amount needs no cast as type consistency is expected to 1792 // be already be enforced at CIRGen. 1793 if (cirAmtTy) 1794 amt = getLLVMIntCast(rewriter, amt, llvmTy, true, cirAmtTy.getWidth(), 1795 cirValTy.getWidth()); 1796 } else { 1797 auto cirValVTy = mlir::cast<cir::VectorType>(op.getValue().getType()); 1798 isUnsigned = 1799 mlir::cast<cir::IntType>(cirValVTy.getElementType()).isUnsigned(); 1800 } 1801 1802 // Lower to the proper LLVM shift operation. 1803 if (op.getIsShiftleft()) { 1804 rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt); 1805 return mlir::success(); 1806 } 1807 1808 if (isUnsigned) 1809 rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt); 1810 else 1811 rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt); 1812 return mlir::success(); 1813 } 1814 1815 mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite( 1816 cir::SelectOp op, OpAdaptor adaptor, 1817 mlir::ConversionPatternRewriter &rewriter) const { 1818 auto getConstantBool = [](mlir::Value value) -> cir::BoolAttr { 1819 auto definingOp = 1820 mlir::dyn_cast_if_present<cir::ConstantOp>(value.getDefiningOp()); 1821 if (!definingOp) 1822 return {}; 1823 1824 auto constValue = mlir::dyn_cast<cir::BoolAttr>(definingOp.getValue()); 1825 if (!constValue) 1826 return {}; 1827 1828 return constValue; 1829 }; 1830 1831 // Two special cases in the LLVMIR codegen of select op: 1832 // - select %0, %1, false => and %0, %1 1833 // - select %0, true, %1 => or %0, %1 1834 if (mlir::isa<cir::BoolType>(op.getTrueValue().getType())) { 1835 cir::BoolAttr trueValue = getConstantBool(op.getTrueValue()); 1836 cir::BoolAttr falseValue = getConstantBool(op.getFalseValue()); 1837 if (falseValue && !falseValue.getValue()) { 1838 // select %0, %1, false => and %0, %1 1839 rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, adaptor.getCondition(), 1840 adaptor.getTrueValue()); 1841 return mlir::success(); 1842 } 1843 if (trueValue && trueValue.getValue()) { 1844 // select %0, true, %1 => or %0, %1 1845 rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, adaptor.getCondition(), 1846 adaptor.getFalseValue()); 1847 return mlir::success(); 1848 } 1849 } 1850 1851 mlir::Value llvmCondition = adaptor.getCondition(); 1852 rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>( 1853 op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue()); 1854 1855 return mlir::success(); 1856 } 1857 1858 static void prepareTypeConverter(mlir::LLVMTypeConverter &converter, 1859 mlir::DataLayout &dataLayout) { 1860 converter.addConversion([&](cir::PointerType type) -> mlir::Type { 1861 // Drop pointee type since LLVM dialect only allows opaque pointers. 1862 assert(!cir::MissingFeatures::addressSpace()); 1863 unsigned targetAS = 0; 1864 1865 return mlir::LLVM::LLVMPointerType::get(type.getContext(), targetAS); 1866 }); 1867 converter.addConversion([&](cir::ArrayType type) -> mlir::Type { 1868 mlir::Type ty = 1869 convertTypeForMemory(converter, dataLayout, type.getElementType()); 1870 return mlir::LLVM::LLVMArrayType::get(ty, type.getSize()); 1871 }); 1872 converter.addConversion([&](cir::VectorType type) -> mlir::Type { 1873 const mlir::Type ty = converter.convertType(type.getElementType()); 1874 return mlir::VectorType::get(type.getSize(), ty); 1875 }); 1876 converter.addConversion([&](cir::BoolType type) -> mlir::Type { 1877 return mlir::IntegerType::get(type.getContext(), 1, 1878 mlir::IntegerType::Signless); 1879 }); 1880 converter.addConversion([&](cir::IntType type) -> mlir::Type { 1881 // LLVM doesn't work with signed types, so we drop the CIR signs here. 1882 return mlir::IntegerType::get(type.getContext(), type.getWidth()); 1883 }); 1884 converter.addConversion([&](cir::SingleType type) -> mlir::Type { 1885 return mlir::Float32Type::get(type.getContext()); 1886 }); 1887 converter.addConversion([&](cir::DoubleType type) -> mlir::Type { 1888 return mlir::Float64Type::get(type.getContext()); 1889 }); 1890 converter.addConversion([&](cir::FP80Type type) -> mlir::Type { 1891 return mlir::Float80Type::get(type.getContext()); 1892 }); 1893 converter.addConversion([&](cir::FP128Type type) -> mlir::Type { 1894 return mlir::Float128Type::get(type.getContext()); 1895 }); 1896 converter.addConversion([&](cir::LongDoubleType type) -> mlir::Type { 1897 return converter.convertType(type.getUnderlying()); 1898 }); 1899 converter.addConversion([&](cir::FP16Type type) -> mlir::Type { 1900 return mlir::Float16Type::get(type.getContext()); 1901 }); 1902 converter.addConversion([&](cir::BF16Type type) -> mlir::Type { 1903 return mlir::BFloat16Type::get(type.getContext()); 1904 }); 1905 converter.addConversion([&](cir::ComplexType type) -> mlir::Type { 1906 // A complex type is lowered to an LLVM struct that contains the real and 1907 // imaginary part as data fields. 1908 mlir::Type elementTy = converter.convertType(type.getElementType()); 1909 mlir::Type structFields[2] = {elementTy, elementTy}; 1910 return mlir::LLVM::LLVMStructType::getLiteral(type.getContext(), 1911 structFields); 1912 }); 1913 converter.addConversion([&](cir::FuncType type) -> std::optional<mlir::Type> { 1914 auto result = converter.convertType(type.getReturnType()); 1915 llvm::SmallVector<mlir::Type> arguments; 1916 arguments.reserve(type.getNumInputs()); 1917 if (converter.convertTypes(type.getInputs(), arguments).failed()) 1918 return std::nullopt; 1919 auto varArg = type.isVarArg(); 1920 return mlir::LLVM::LLVMFunctionType::get(result, arguments, varArg); 1921 }); 1922 converter.addConversion([&](cir::RecordType type) -> mlir::Type { 1923 // Convert struct members. 1924 llvm::SmallVector<mlir::Type> llvmMembers; 1925 switch (type.getKind()) { 1926 case cir::RecordType::Class: 1927 case cir::RecordType::Struct: 1928 for (mlir::Type ty : type.getMembers()) 1929 llvmMembers.push_back(convertTypeForMemory(converter, dataLayout, ty)); 1930 break; 1931 // Unions are lowered as only the largest member. 1932 case cir::RecordType::Union: 1933 if (auto largestMember = type.getLargestMember(dataLayout)) 1934 llvmMembers.push_back( 1935 convertTypeForMemory(converter, dataLayout, largestMember)); 1936 if (type.getPadded()) { 1937 auto last = *type.getMembers().rbegin(); 1938 llvmMembers.push_back( 1939 convertTypeForMemory(converter, dataLayout, last)); 1940 } 1941 break; 1942 } 1943 1944 // Record has a name: lower as an identified record. 1945 mlir::LLVM::LLVMStructType llvmStruct; 1946 if (type.getName()) { 1947 llvmStruct = mlir::LLVM::LLVMStructType::getIdentified( 1948 type.getContext(), type.getPrefixedName()); 1949 if (llvmStruct.setBody(llvmMembers, type.getPacked()).failed()) 1950 llvm_unreachable("Failed to set body of record"); 1951 } else { // Record has no name: lower as literal record. 1952 llvmStruct = mlir::LLVM::LLVMStructType::getLiteral( 1953 type.getContext(), llvmMembers, type.getPacked()); 1954 } 1955 1956 return llvmStruct; 1957 }); 1958 } 1959 1960 // The applyPartialConversion function traverses blocks in the dominance order, 1961 // so it does not lower and operations that are not reachachable from the 1962 // operations passed in as arguments. Since we do need to lower such code in 1963 // order to avoid verification errors occur, we cannot just pass the module op 1964 // to applyPartialConversion. We must build a set of unreachable ops and 1965 // explicitly add them, along with the module, to the vector we pass to 1966 // applyPartialConversion. 1967 // 1968 // For instance, this CIR code: 1969 // 1970 // cir.func @foo(%arg0: !s32i) -> !s32i { 1971 // %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool 1972 // cir.if %4 { 1973 // %5 = cir.const #cir.int<1> : !s32i 1974 // cir.return %5 : !s32i 1975 // } else { 1976 // %5 = cir.const #cir.int<0> : !s32i 1977 // cir.return %5 : !s32i 1978 // } 1979 // cir.return %arg0 : !s32i 1980 // } 1981 // 1982 // contains an unreachable return operation (the last one). After the flattening 1983 // pass it will be placed into the unreachable block. The possible error 1984 // after the lowering pass is: error: 'cir.return' op expects parent op to be 1985 // one of 'cir.func, cir.scope, cir.if ... The reason that this operation was 1986 // not lowered and the new parent is llvm.func. 1987 // 1988 // In the future we may want to get rid of this function and use a DCE pass or 1989 // something similar. But for now we need to guarantee the absence of the 1990 // dialect verification errors. 1991 static void collectUnreachable(mlir::Operation *parent, 1992 llvm::SmallVector<mlir::Operation *> &ops) { 1993 1994 llvm::SmallVector<mlir::Block *> unreachableBlocks; 1995 parent->walk([&](mlir::Block *blk) { // check 1996 if (blk->hasNoPredecessors() && !blk->isEntryBlock()) 1997 unreachableBlocks.push_back(blk); 1998 }); 1999 2000 std::set<mlir::Block *> visited; 2001 for (mlir::Block *root : unreachableBlocks) { 2002 // We create a work list for each unreachable block. 2003 // Thus we traverse operations in some order. 2004 std::deque<mlir::Block *> workList; 2005 workList.push_back(root); 2006 2007 while (!workList.empty()) { 2008 mlir::Block *blk = workList.back(); 2009 workList.pop_back(); 2010 if (visited.count(blk)) 2011 continue; 2012 visited.emplace(blk); 2013 2014 for (mlir::Operation &op : *blk) 2015 ops.push_back(&op); 2016 2017 for (mlir::Block *succ : blk->getSuccessors()) 2018 workList.push_back(succ); 2019 } 2020 } 2021 } 2022 2023 void ConvertCIRToLLVMPass::processCIRAttrs(mlir::ModuleOp module) { 2024 // Lower the module attributes to LLVM equivalents. 2025 if (mlir::Attribute tripleAttr = 2026 module->getAttr(cir::CIRDialect::getTripleAttrName())) 2027 module->setAttr(mlir::LLVM::LLVMDialect::getTargetTripleAttrName(), 2028 tripleAttr); 2029 } 2030 2031 void ConvertCIRToLLVMPass::runOnOperation() { 2032 llvm::TimeTraceScope scope("Convert CIR to LLVM Pass"); 2033 2034 mlir::ModuleOp module = getOperation(); 2035 mlir::DataLayout dl(module); 2036 mlir::LLVMTypeConverter converter(&getContext()); 2037 prepareTypeConverter(converter, dl); 2038 2039 mlir::RewritePatternSet patterns(&getContext()); 2040 2041 patterns.add<CIRToLLVMReturnOpLowering>(patterns.getContext()); 2042 // This could currently be merged with the group below, but it will get more 2043 // arguments later, so we'll keep it separate for now. 2044 patterns.add<CIRToLLVMAllocaOpLowering>(converter, patterns.getContext(), dl); 2045 patterns.add<CIRToLLVMLoadOpLowering>(converter, patterns.getContext(), dl); 2046 patterns.add<CIRToLLVMStoreOpLowering>(converter, patterns.getContext(), dl); 2047 patterns.add<CIRToLLVMGlobalOpLowering>(converter, patterns.getContext(), dl); 2048 patterns.add<CIRToLLVMCastOpLowering>(converter, patterns.getContext(), dl); 2049 patterns.add<CIRToLLVMPtrStrideOpLowering>(converter, patterns.getContext(), 2050 dl); 2051 patterns.add< 2052 // clang-format off 2053 CIRToLLVMAssumeOpLowering, 2054 CIRToLLVMBaseClassAddrOpLowering, 2055 CIRToLLVMBinOpLowering, 2056 CIRToLLVMBitClrsbOpLowering, 2057 CIRToLLVMBitClzOpLowering, 2058 CIRToLLVMBitCtzOpLowering, 2059 CIRToLLVMBitParityOpLowering, 2060 CIRToLLVMBitPopcountOpLowering, 2061 CIRToLLVMBitReverseOpLowering, 2062 CIRToLLVMBrCondOpLowering, 2063 CIRToLLVMBrOpLowering, 2064 CIRToLLVMByteSwapOpLowering, 2065 CIRToLLVMCallOpLowering, 2066 CIRToLLVMCmpOpLowering, 2067 CIRToLLVMComplexAddOpLowering, 2068 CIRToLLVMComplexCreateOpLowering, 2069 CIRToLLVMComplexImagOpLowering, 2070 CIRToLLVMComplexImagPtrOpLowering, 2071 CIRToLLVMComplexRealOpLowering, 2072 CIRToLLVMComplexRealPtrOpLowering, 2073 CIRToLLVMComplexSubOpLowering, 2074 CIRToLLVMConstantOpLowering, 2075 CIRToLLVMExpectOpLowering, 2076 CIRToLLVMFuncOpLowering, 2077 CIRToLLVMGetBitfieldOpLowering, 2078 CIRToLLVMGetGlobalOpLowering, 2079 CIRToLLVMGetMemberOpLowering, 2080 CIRToLLVMSelectOpLowering, 2081 CIRToLLVMSetBitfieldOpLowering, 2082 CIRToLLVMShiftOpLowering, 2083 CIRToLLVMStackRestoreOpLowering, 2084 CIRToLLVMStackSaveOpLowering, 2085 CIRToLLVMSwitchFlatOpLowering, 2086 CIRToLLVMTrapOpLowering, 2087 CIRToLLVMUnaryOpLowering, 2088 CIRToLLVMVecCmpOpLowering, 2089 CIRToLLVMVecCreateOpLowering, 2090 CIRToLLVMVecExtractOpLowering, 2091 CIRToLLVMVecInsertOpLowering, 2092 CIRToLLVMVecShuffleDynamicOpLowering, 2093 CIRToLLVMVecShuffleOpLowering, 2094 CIRToLLVMVecSplatOpLowering, 2095 CIRToLLVMVecTernaryOpLowering 2096 // clang-format on 2097 >(converter, patterns.getContext()); 2098 2099 processCIRAttrs(module); 2100 2101 mlir::ConversionTarget target(getContext()); 2102 target.addLegalOp<mlir::ModuleOp>(); 2103 target.addLegalDialect<mlir::LLVM::LLVMDialect>(); 2104 target.addIllegalDialect<mlir::BuiltinDialect, cir::CIRDialect, 2105 mlir::func::FuncDialect>(); 2106 2107 llvm::SmallVector<mlir::Operation *> ops; 2108 ops.push_back(module); 2109 collectUnreachable(module, ops); 2110 2111 if (failed(applyPartialConversion(ops, target, std::move(patterns)))) 2112 signalPassFailure(); 2113 } 2114 2115 mlir::LogicalResult CIRToLLVMBrOpLowering::matchAndRewrite( 2116 cir::BrOp op, OpAdaptor adaptor, 2117 mlir::ConversionPatternRewriter &rewriter) const { 2118 rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(op, adaptor.getOperands(), 2119 op.getDest()); 2120 return mlir::LogicalResult::success(); 2121 } 2122 2123 mlir::LogicalResult CIRToLLVMGetMemberOpLowering::matchAndRewrite( 2124 cir::GetMemberOp op, OpAdaptor adaptor, 2125 mlir::ConversionPatternRewriter &rewriter) const { 2126 mlir::Type llResTy = getTypeConverter()->convertType(op.getType()); 2127 const auto recordTy = 2128 mlir::cast<cir::RecordType>(op.getAddrTy().getPointee()); 2129 assert(recordTy && "expected record type"); 2130 2131 switch (recordTy.getKind()) { 2132 case cir::RecordType::Class: 2133 case cir::RecordType::Struct: { 2134 // Since the base address is a pointer to an aggregate, the first offset 2135 // is always zero. The second offset tell us which member it will access. 2136 llvm::SmallVector<mlir::LLVM::GEPArg, 2> offset{0, op.getIndex()}; 2137 const mlir::Type elementTy = getTypeConverter()->convertType(recordTy); 2138 rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, llResTy, elementTy, 2139 adaptor.getAddr(), offset); 2140 return mlir::success(); 2141 } 2142 case cir::RecordType::Union: 2143 // Union members share the address space, so we just need a bitcast to 2144 // conform to type-checking. 2145 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llResTy, 2146 adaptor.getAddr()); 2147 return mlir::success(); 2148 } 2149 } 2150 2151 mlir::LogicalResult CIRToLLVMTrapOpLowering::matchAndRewrite( 2152 cir::TrapOp op, OpAdaptor adaptor, 2153 mlir::ConversionPatternRewriter &rewriter) const { 2154 mlir::Location loc = op->getLoc(); 2155 rewriter.eraseOp(op); 2156 2157 rewriter.create<mlir::LLVM::Trap>(loc); 2158 2159 // Note that the call to llvm.trap is not a terminator in LLVM dialect. 2160 // So we must emit an additional llvm.unreachable to terminate the current 2161 // block. 2162 rewriter.create<mlir::LLVM::UnreachableOp>(loc); 2163 2164 return mlir::success(); 2165 } 2166 2167 mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite( 2168 cir::StackSaveOp op, OpAdaptor adaptor, 2169 mlir::ConversionPatternRewriter &rewriter) const { 2170 const mlir::Type ptrTy = getTypeConverter()->convertType(op.getType()); 2171 rewriter.replaceOpWithNewOp<mlir::LLVM::StackSaveOp>(op, ptrTy); 2172 return mlir::success(); 2173 } 2174 2175 mlir::LogicalResult CIRToLLVMStackRestoreOpLowering::matchAndRewrite( 2176 cir::StackRestoreOp op, OpAdaptor adaptor, 2177 mlir::ConversionPatternRewriter &rewriter) const { 2178 rewriter.replaceOpWithNewOp<mlir::LLVM::StackRestoreOp>(op, adaptor.getPtr()); 2179 return mlir::success(); 2180 } 2181 2182 mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite( 2183 cir::VecCreateOp op, OpAdaptor adaptor, 2184 mlir::ConversionPatternRewriter &rewriter) const { 2185 // Start with an 'undef' value for the vector. Then 'insertelement' for 2186 // each of the vector elements. 2187 const auto vecTy = mlir::cast<cir::VectorType>(op.getType()); 2188 const mlir::Type llvmTy = typeConverter->convertType(vecTy); 2189 const mlir::Location loc = op.getLoc(); 2190 mlir::Value result = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy); 2191 assert(vecTy.getSize() == op.getElements().size() && 2192 "cir.vec.create op count doesn't match vector type elements count"); 2193 2194 for (uint64_t i = 0; i < vecTy.getSize(); ++i) { 2195 const mlir::Value indexValue = 2196 rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i); 2197 result = rewriter.create<mlir::LLVM::InsertElementOp>( 2198 loc, result, adaptor.getElements()[i], indexValue); 2199 } 2200 2201 rewriter.replaceOp(op, result); 2202 return mlir::success(); 2203 } 2204 2205 mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite( 2206 cir::VecExtractOp op, OpAdaptor adaptor, 2207 mlir::ConversionPatternRewriter &rewriter) const { 2208 rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>( 2209 op, adaptor.getVec(), adaptor.getIndex()); 2210 return mlir::success(); 2211 } 2212 2213 mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite( 2214 cir::VecInsertOp op, OpAdaptor adaptor, 2215 mlir::ConversionPatternRewriter &rewriter) const { 2216 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertElementOp>( 2217 op, adaptor.getVec(), adaptor.getValue(), adaptor.getIndex()); 2218 return mlir::success(); 2219 } 2220 2221 mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite( 2222 cir::VecCmpOp op, OpAdaptor adaptor, 2223 mlir::ConversionPatternRewriter &rewriter) const { 2224 mlir::Type elementType = elementTypeIfVector(op.getLhs().getType()); 2225 mlir::Value bitResult; 2226 if (auto intType = mlir::dyn_cast<cir::IntType>(elementType)) { 2227 bitResult = rewriter.create<mlir::LLVM::ICmpOp>( 2228 op.getLoc(), 2229 convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()), 2230 adaptor.getLhs(), adaptor.getRhs()); 2231 } else if (mlir::isa<cir::FPTypeInterface>(elementType)) { 2232 bitResult = rewriter.create<mlir::LLVM::FCmpOp>( 2233 op.getLoc(), convertCmpKindToFCmpPredicate(op.getKind()), 2234 adaptor.getLhs(), adaptor.getRhs()); 2235 } else { 2236 return op.emitError() << "unsupported type for VecCmpOp: " << elementType; 2237 } 2238 2239 // LLVM IR vector comparison returns a vector of i1. This one-bit vector 2240 // must be sign-extended to the correct result type. 2241 rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>( 2242 op, typeConverter->convertType(op.getType()), bitResult); 2243 return mlir::success(); 2244 } 2245 2246 mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite( 2247 cir::VecSplatOp op, OpAdaptor adaptor, 2248 mlir::ConversionPatternRewriter &rewriter) const { 2249 // Vector splat can be implemented with an `insertelement` and a 2250 // `shufflevector`, which is better than an `insertelement` for each 2251 // element in the vector. Start with an undef vector. Insert the value into 2252 // the first element. Then use a `shufflevector` with a mask of all 0 to 2253 // fill out the entire vector with that value. 2254 cir::VectorType vecTy = op.getType(); 2255 mlir::Type llvmTy = typeConverter->convertType(vecTy); 2256 mlir::Location loc = op.getLoc(); 2257 mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy); 2258 2259 mlir::Value elementValue = adaptor.getValue(); 2260 if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) { 2261 // If the splat value is poison, then we can just use poison value 2262 // for the entire vector. 2263 rewriter.replaceOp(op, poison); 2264 return mlir::success(); 2265 } 2266 2267 if (auto constValue = 2268 dyn_cast<mlir::LLVM::ConstantOp>(elementValue.getDefiningOp())) { 2269 if (auto intAttr = dyn_cast<mlir::IntegerAttr>(constValue.getValue())) { 2270 mlir::DenseIntElementsAttr denseVec = mlir::DenseIntElementsAttr::get( 2271 mlir::cast<mlir::ShapedType>(llvmTy), intAttr.getValue()); 2272 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 2273 op, denseVec.getType(), denseVec); 2274 return mlir::success(); 2275 } 2276 2277 if (auto fpAttr = dyn_cast<mlir::FloatAttr>(constValue.getValue())) { 2278 mlir::DenseFPElementsAttr denseVec = mlir::DenseFPElementsAttr::get( 2279 mlir::cast<mlir::ShapedType>(llvmTy), fpAttr.getValue()); 2280 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( 2281 op, denseVec.getType(), denseVec); 2282 return mlir::success(); 2283 } 2284 } 2285 2286 mlir::Value indexValue = 2287 rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0); 2288 mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>( 2289 loc, poison, elementValue, indexValue); 2290 SmallVector<int32_t> zeroValues(vecTy.getSize(), 0); 2291 rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(op, oneElement, 2292 poison, zeroValues); 2293 return mlir::success(); 2294 } 2295 2296 mlir::LogicalResult CIRToLLVMVecShuffleOpLowering::matchAndRewrite( 2297 cir::VecShuffleOp op, OpAdaptor adaptor, 2298 mlir::ConversionPatternRewriter &rewriter) const { 2299 // LLVM::ShuffleVectorOp takes an ArrayRef of int for the list of indices. 2300 // Convert the ClangIR ArrayAttr of IntAttr constants into a 2301 // SmallVector<int>. 2302 SmallVector<int, 8> indices; 2303 std::transform( 2304 op.getIndices().begin(), op.getIndices().end(), 2305 std::back_inserter(indices), [](mlir::Attribute intAttr) { 2306 return mlir::cast<cir::IntAttr>(intAttr).getValue().getSExtValue(); 2307 }); 2308 rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>( 2309 op, adaptor.getVec1(), adaptor.getVec2(), indices); 2310 return mlir::success(); 2311 } 2312 2313 mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite( 2314 cir::VecShuffleDynamicOp op, OpAdaptor adaptor, 2315 mlir::ConversionPatternRewriter &rewriter) const { 2316 // LLVM IR does not have an operation that corresponds to this form of 2317 // the built-in. 2318 // __builtin_shufflevector(V, I) 2319 // is implemented as this pseudocode, where the for loop is unrolled 2320 // and N is the number of elements: 2321 // 2322 // result = undef 2323 // maskbits = NextPowerOf2(N - 1) 2324 // masked = I & maskbits 2325 // for (i in 0 <= i < N) 2326 // result[i] = V[masked[i]] 2327 mlir::Location loc = op.getLoc(); 2328 mlir::Value input = adaptor.getVec(); 2329 mlir::Type llvmIndexVecType = 2330 getTypeConverter()->convertType(op.getIndices().getType()); 2331 mlir::Type llvmIndexType = getTypeConverter()->convertType( 2332 elementTypeIfVector(op.getIndices().getType())); 2333 uint64_t numElements = 2334 mlir::cast<cir::VectorType>(op.getVec().getType()).getSize(); 2335 2336 uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1; 2337 mlir::Value maskValue = rewriter.create<mlir::LLVM::ConstantOp>( 2338 loc, llvmIndexType, rewriter.getIntegerAttr(llvmIndexType, maskBits)); 2339 mlir::Value maskVector = 2340 rewriter.create<mlir::LLVM::UndefOp>(loc, llvmIndexVecType); 2341 2342 for (uint64_t i = 0; i < numElements; ++i) { 2343 mlir::Value idxValue = 2344 rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i); 2345 maskVector = rewriter.create<mlir::LLVM::InsertElementOp>( 2346 loc, maskVector, maskValue, idxValue); 2347 } 2348 2349 mlir::Value maskedIndices = rewriter.create<mlir::LLVM::AndOp>( 2350 loc, llvmIndexVecType, adaptor.getIndices(), maskVector); 2351 mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>( 2352 loc, getTypeConverter()->convertType(op.getVec().getType())); 2353 for (uint64_t i = 0; i < numElements; ++i) { 2354 mlir::Value iValue = 2355 rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i); 2356 mlir::Value indexValue = rewriter.create<mlir::LLVM::ExtractElementOp>( 2357 loc, maskedIndices, iValue); 2358 mlir::Value valueAtIndex = 2359 rewriter.create<mlir::LLVM::ExtractElementOp>(loc, input, indexValue); 2360 result = rewriter.create<mlir::LLVM::InsertElementOp>(loc, result, 2361 valueAtIndex, iValue); 2362 } 2363 rewriter.replaceOp(op, result); 2364 return mlir::success(); 2365 } 2366 2367 mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite( 2368 cir::VecTernaryOp op, OpAdaptor adaptor, 2369 mlir::ConversionPatternRewriter &rewriter) const { 2370 // Convert `cond` into a vector of i1, then use that in a `select` op. 2371 mlir::Value bitVec = rewriter.create<mlir::LLVM::ICmpOp>( 2372 op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(), 2373 rewriter.create<mlir::LLVM::ZeroOp>( 2374 op.getCond().getLoc(), 2375 typeConverter->convertType(op.getCond().getType()))); 2376 rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>( 2377 op, bitVec, adaptor.getLhs(), adaptor.getRhs()); 2378 return mlir::success(); 2379 } 2380 2381 mlir::LogicalResult CIRToLLVMComplexAddOpLowering::matchAndRewrite( 2382 cir::ComplexAddOp op, OpAdaptor adaptor, 2383 mlir::ConversionPatternRewriter &rewriter) const { 2384 mlir::Value lhs = adaptor.getLhs(); 2385 mlir::Value rhs = adaptor.getRhs(); 2386 mlir::Location loc = op.getLoc(); 2387 2388 auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType()); 2389 mlir::Type complexElemTy = 2390 getTypeConverter()->convertType(complexType.getElementType()); 2391 auto lhsReal = 2392 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0); 2393 auto lhsImag = 2394 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1); 2395 auto rhsReal = 2396 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0); 2397 auto rhsImag = 2398 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1); 2399 2400 mlir::Value newReal; 2401 mlir::Value newImag; 2402 if (complexElemTy.isInteger()) { 2403 newReal = rewriter.create<mlir::LLVM::AddOp>(loc, complexElemTy, lhsReal, 2404 rhsReal); 2405 newImag = rewriter.create<mlir::LLVM::AddOp>(loc, complexElemTy, lhsImag, 2406 rhsImag); 2407 } else { 2408 assert(!cir::MissingFeatures::fastMathFlags()); 2409 assert(!cir::MissingFeatures::fpConstraints()); 2410 newReal = rewriter.create<mlir::LLVM::FAddOp>(loc, complexElemTy, lhsReal, 2411 rhsReal); 2412 newImag = rewriter.create<mlir::LLVM::FAddOp>(loc, complexElemTy, lhsImag, 2413 rhsImag); 2414 } 2415 2416 mlir::Type complexLLVMTy = 2417 getTypeConverter()->convertType(op.getResult().getType()); 2418 auto initialComplex = 2419 rewriter.create<mlir::LLVM::PoisonOp>(op->getLoc(), complexLLVMTy); 2420 2421 auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>( 2422 op->getLoc(), initialComplex, newReal, 0); 2423 2424 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(op, realComplex, 2425 newImag, 1); 2426 2427 return mlir::success(); 2428 } 2429 2430 mlir::LogicalResult CIRToLLVMComplexCreateOpLowering::matchAndRewrite( 2431 cir::ComplexCreateOp op, OpAdaptor adaptor, 2432 mlir::ConversionPatternRewriter &rewriter) const { 2433 mlir::Type complexLLVMTy = 2434 getTypeConverter()->convertType(op.getResult().getType()); 2435 auto initialComplex = 2436 rewriter.create<mlir::LLVM::UndefOp>(op->getLoc(), complexLLVMTy); 2437 2438 auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>( 2439 op->getLoc(), initialComplex, adaptor.getReal(), 0); 2440 2441 auto complex = rewriter.create<mlir::LLVM::InsertValueOp>( 2442 op->getLoc(), realComplex, adaptor.getImag(), 1); 2443 2444 rewriter.replaceOp(op, complex); 2445 return mlir::success(); 2446 } 2447 2448 mlir::LogicalResult CIRToLLVMComplexRealOpLowering::matchAndRewrite( 2449 cir::ComplexRealOp op, OpAdaptor adaptor, 2450 mlir::ConversionPatternRewriter &rewriter) const { 2451 mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType()); 2452 rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>( 2453 op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{0}); 2454 return mlir::success(); 2455 } 2456 2457 mlir::LogicalResult CIRToLLVMComplexSubOpLowering::matchAndRewrite( 2458 cir::ComplexSubOp op, OpAdaptor adaptor, 2459 mlir::ConversionPatternRewriter &rewriter) const { 2460 mlir::Value lhs = adaptor.getLhs(); 2461 mlir::Value rhs = adaptor.getRhs(); 2462 mlir::Location loc = op.getLoc(); 2463 2464 auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType()); 2465 mlir::Type complexElemTy = 2466 getTypeConverter()->convertType(complexType.getElementType()); 2467 auto lhsReal = 2468 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0); 2469 auto lhsImag = 2470 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1); 2471 auto rhsReal = 2472 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0); 2473 auto rhsImag = 2474 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1); 2475 2476 mlir::Value newReal; 2477 mlir::Value newImag; 2478 if (complexElemTy.isInteger()) { 2479 newReal = rewriter.create<mlir::LLVM::SubOp>(loc, complexElemTy, lhsReal, 2480 rhsReal); 2481 newImag = rewriter.create<mlir::LLVM::SubOp>(loc, complexElemTy, lhsImag, 2482 rhsImag); 2483 } else { 2484 assert(!cir::MissingFeatures::fastMathFlags()); 2485 assert(!cir::MissingFeatures::fpConstraints()); 2486 newReal = rewriter.create<mlir::LLVM::FSubOp>(loc, complexElemTy, lhsReal, 2487 rhsReal); 2488 newImag = rewriter.create<mlir::LLVM::FSubOp>(loc, complexElemTy, lhsImag, 2489 rhsImag); 2490 } 2491 2492 mlir::Type complexLLVMTy = 2493 getTypeConverter()->convertType(op.getResult().getType()); 2494 auto initialComplex = 2495 rewriter.create<mlir::LLVM::PoisonOp>(op->getLoc(), complexLLVMTy); 2496 2497 auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>( 2498 op->getLoc(), initialComplex, newReal, 0); 2499 2500 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(op, realComplex, 2501 newImag, 1); 2502 2503 return mlir::success(); 2504 } 2505 2506 mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite( 2507 cir::ComplexImagOp op, OpAdaptor adaptor, 2508 mlir::ConversionPatternRewriter &rewriter) const { 2509 mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType()); 2510 rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>( 2511 op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{1}); 2512 return mlir::success(); 2513 } 2514 2515 mlir::IntegerType computeBitfieldIntType(mlir::Type storageType, 2516 mlir::MLIRContext *context, 2517 unsigned &storageSize) { 2518 return TypeSwitch<mlir::Type, mlir::IntegerType>(storageType) 2519 .Case<cir::ArrayType>([&](cir::ArrayType atTy) { 2520 storageSize = atTy.getSize() * 8; 2521 return mlir::IntegerType::get(context, storageSize); 2522 }) 2523 .Case<cir::IntType>([&](cir::IntType intTy) { 2524 storageSize = intTy.getWidth(); 2525 return mlir::IntegerType::get(context, storageSize); 2526 }) 2527 .Default([](mlir::Type) -> mlir::IntegerType { 2528 llvm_unreachable( 2529 "Either ArrayType or IntType expected for bitfields storage"); 2530 }); 2531 } 2532 2533 mlir::LogicalResult CIRToLLVMSetBitfieldOpLowering::matchAndRewrite( 2534 cir::SetBitfieldOp op, OpAdaptor adaptor, 2535 mlir::ConversionPatternRewriter &rewriter) const { 2536 mlir::OpBuilder::InsertionGuard guard(rewriter); 2537 rewriter.setInsertionPoint(op); 2538 2539 cir::BitfieldInfoAttr info = op.getBitfieldInfo(); 2540 uint64_t size = info.getSize(); 2541 uint64_t offset = info.getOffset(); 2542 mlir::Type storageType = info.getStorageType(); 2543 mlir::MLIRContext *context = storageType.getContext(); 2544 2545 unsigned storageSize = 0; 2546 2547 mlir::IntegerType intType = 2548 computeBitfieldIntType(storageType, context, storageSize); 2549 2550 mlir::Value srcVal = createIntCast(rewriter, adaptor.getSrc(), intType); 2551 unsigned srcWidth = storageSize; 2552 mlir::Value resultVal = srcVal; 2553 2554 if (storageSize != size) { 2555 assert(storageSize > size && "Invalid bitfield size."); 2556 2557 mlir::Value val = rewriter.create<mlir::LLVM::LoadOp>( 2558 op.getLoc(), intType, adaptor.getAddr(), /* alignment */ 0, 2559 op.getIsVolatile()); 2560 2561 srcVal = 2562 createAnd(rewriter, srcVal, llvm::APInt::getLowBitsSet(srcWidth, size)); 2563 resultVal = srcVal; 2564 srcVal = createShL(rewriter, srcVal, offset); 2565 2566 // Mask out the original value. 2567 val = createAnd(rewriter, val, 2568 ~llvm::APInt::getBitsSet(srcWidth, offset, offset + size)); 2569 2570 // Or together the unchanged values and the source value. 2571 srcVal = rewriter.create<mlir::LLVM::OrOp>(op.getLoc(), val, srcVal); 2572 } 2573 2574 rewriter.create<mlir::LLVM::StoreOp>(op.getLoc(), srcVal, adaptor.getAddr(), 2575 /* alignment */ 0, op.getIsVolatile()); 2576 2577 mlir::Type resultTy = getTypeConverter()->convertType(op.getType()); 2578 2579 if (info.getIsSigned()) { 2580 assert(size <= storageSize); 2581 unsigned highBits = storageSize - size; 2582 2583 if (highBits) { 2584 resultVal = createShL(rewriter, resultVal, highBits); 2585 resultVal = createAShR(rewriter, resultVal, highBits); 2586 } 2587 } 2588 2589 resultVal = createIntCast(rewriter, resultVal, 2590 mlir::cast<mlir::IntegerType>(resultTy), 2591 info.getIsSigned()); 2592 2593 rewriter.replaceOp(op, resultVal); 2594 return mlir::success(); 2595 } 2596 2597 mlir::LogicalResult CIRToLLVMComplexImagPtrOpLowering::matchAndRewrite( 2598 cir::ComplexImagPtrOp op, OpAdaptor adaptor, 2599 mlir::ConversionPatternRewriter &rewriter) const { 2600 cir::PointerType operandTy = op.getOperand().getType(); 2601 mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType()); 2602 mlir::Type elementLLVMTy = 2603 getTypeConverter()->convertType(operandTy.getPointee()); 2604 2605 mlir::LLVM::GEPArg gepIndices[2] = {{0}, {1}}; 2606 mlir::LLVM::GEPNoWrapFlags inboundsNuw = 2607 mlir::LLVM::GEPNoWrapFlags::inbounds | mlir::LLVM::GEPNoWrapFlags::nuw; 2608 rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( 2609 op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices, 2610 inboundsNuw); 2611 return mlir::success(); 2612 } 2613 2614 mlir::LogicalResult CIRToLLVMComplexRealPtrOpLowering::matchAndRewrite( 2615 cir::ComplexRealPtrOp op, OpAdaptor adaptor, 2616 mlir::ConversionPatternRewriter &rewriter) const { 2617 cir::PointerType operandTy = op.getOperand().getType(); 2618 mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType()); 2619 mlir::Type elementLLVMTy = 2620 getTypeConverter()->convertType(operandTy.getPointee()); 2621 2622 mlir::LLVM::GEPArg gepIndices[2] = {0, 0}; 2623 mlir::LLVM::GEPNoWrapFlags inboundsNuw = 2624 mlir::LLVM::GEPNoWrapFlags::inbounds | mlir::LLVM::GEPNoWrapFlags::nuw; 2625 rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( 2626 op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices, 2627 inboundsNuw); 2628 return mlir::success(); 2629 } 2630 2631 mlir::LogicalResult CIRToLLVMGetBitfieldOpLowering::matchAndRewrite( 2632 cir::GetBitfieldOp op, OpAdaptor adaptor, 2633 mlir::ConversionPatternRewriter &rewriter) const { 2634 2635 mlir::OpBuilder::InsertionGuard guard(rewriter); 2636 rewriter.setInsertionPoint(op); 2637 2638 cir::BitfieldInfoAttr info = op.getBitfieldInfo(); 2639 uint64_t size = info.getSize(); 2640 uint64_t offset = info.getOffset(); 2641 mlir::Type storageType = info.getStorageType(); 2642 mlir::MLIRContext *context = storageType.getContext(); 2643 unsigned storageSize = 0; 2644 2645 mlir::IntegerType intType = 2646 computeBitfieldIntType(storageType, context, storageSize); 2647 2648 mlir::Value val = rewriter.create<mlir::LLVM::LoadOp>( 2649 op.getLoc(), intType, adaptor.getAddr(), 0, op.getIsVolatile()); 2650 val = rewriter.create<mlir::LLVM::BitcastOp>(op.getLoc(), intType, val); 2651 2652 if (info.getIsSigned()) { 2653 assert(static_cast<unsigned>(offset + size) <= storageSize); 2654 unsigned highBits = storageSize - offset - size; 2655 val = createShL(rewriter, val, highBits); 2656 val = createAShR(rewriter, val, offset + highBits); 2657 } else { 2658 val = createLShR(rewriter, val, offset); 2659 2660 if (static_cast<unsigned>(offset) + size < storageSize) 2661 val = createAnd(rewriter, val, 2662 llvm::APInt::getLowBitsSet(storageSize, size)); 2663 } 2664 2665 mlir::Type resTy = getTypeConverter()->convertType(op.getType()); 2666 mlir::Value newOp = createIntCast( 2667 rewriter, val, mlir::cast<mlir::IntegerType>(resTy), info.getIsSigned()); 2668 rewriter.replaceOp(op, newOp); 2669 return mlir::success(); 2670 } 2671 2672 std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() { 2673 return std::make_unique<ConvertCIRToLLVMPass>(); 2674 } 2675 2676 void populateCIRToLLVMPasses(mlir::OpPassManager &pm) { 2677 mlir::populateCIRPreLoweringPasses(pm); 2678 pm.addPass(createConvertCIRToLLVMPass()); 2679 } 2680 2681 std::unique_ptr<llvm::Module> 2682 lowerDirectlyFromCIRToLLVMIR(mlir::ModuleOp mlirModule, LLVMContext &llvmCtx) { 2683 llvm::TimeTraceScope scope("lower from CIR to LLVM directly"); 2684 2685 mlir::MLIRContext *mlirCtx = mlirModule.getContext(); 2686 2687 mlir::PassManager pm(mlirCtx); 2688 populateCIRToLLVMPasses(pm); 2689 2690 (void)mlir::applyPassManagerCLOptions(pm); 2691 2692 if (mlir::failed(pm.run(mlirModule))) { 2693 // FIXME: Handle any errors where they occurs and return a nullptr here. 2694 report_fatal_error( 2695 "The pass manager failed to lower CIR to LLVMIR dialect!"); 2696 } 2697 2698 mlir::registerBuiltinDialectTranslation(*mlirCtx); 2699 mlir::registerLLVMDialectTranslation(*mlirCtx); 2700 mlir::registerCIRDialectTranslation(*mlirCtx); 2701 2702 llvm::TimeTraceScope translateScope("translateModuleToLLVMIR"); 2703 2704 StringRef moduleName = mlirModule.getName().value_or("CIRToLLVMModule"); 2705 std::unique_ptr<llvm::Module> llvmModule = 2706 mlir::translateModuleToLLVMIR(mlirModule, llvmCtx, moduleName); 2707 2708 if (!llvmModule) { 2709 // FIXME: Handle any errors where they occurs and return a nullptr here. 2710 report_fatal_error("Lowering from LLVMIR dialect to llvm IR failed!"); 2711 } 2712 2713 return llvmModule; 2714 } 2715 } // namespace direct 2716 } // namespace cir 2717