xref: /freebsd/contrib/llvm-project/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //====- LowerToLLVM.cpp - Lowering from CIR to LLVMIR ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of CIR operations to LLVMIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "LowerToLLVM.h"
14 
15 #include <deque>
16 #include <optional>
17 
18 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
19 #include "mlir/Dialect/DLTI/DLTI.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/IR/BuiltinDialect.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/Types.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Pass/PassManager.h"
29 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
30 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
31 #include "mlir/Target/LLVMIR/Export.h"
32 #include "mlir/Transforms/DialectConversion.h"
33 #include "clang/CIR/Dialect/IR/CIRAttrs.h"
34 #include "clang/CIR/Dialect/IR/CIRDialect.h"
35 #include "clang/CIR/Dialect/Passes.h"
36 #include "clang/CIR/LoweringHelpers.h"
37 #include "clang/CIR/MissingFeatures.h"
38 #include "clang/CIR/Passes.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/IR/Module.h"
41 #include "llvm/Support/ErrorHandling.h"
42 #include "llvm/Support/TimeProfiler.h"
43 
44 using namespace cir;
45 using namespace llvm;
46 
47 namespace cir {
48 namespace direct {
49 
50 //===----------------------------------------------------------------------===//
51 // Helper Methods
52 //===----------------------------------------------------------------------===//
53 
54 namespace {
55 /// If the given type is a vector type, return the vector's element type.
56 /// Otherwise return the given type unchanged.
elementTypeIfVector(mlir::Type type)57 mlir::Type elementTypeIfVector(mlir::Type type) {
58   return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
59       .Case<cir::VectorType, mlir::VectorType>(
60           [](auto p) { return p.getElementType(); })
61       .Default([](mlir::Type p) { return p; });
62 }
63 } // namespace
64 
65 /// Given a type convertor and a data layout, convert the given type to a type
66 /// that is suitable for memory operations. For example, this can be used to
67 /// lower cir.bool accesses to i8.
convertTypeForMemory(const mlir::TypeConverter & converter,mlir::DataLayout const & dataLayout,mlir::Type type)68 static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter,
69                                        mlir::DataLayout const &dataLayout,
70                                        mlir::Type type) {
71   // TODO(cir): Handle other types similarly to clang's codegen
72   // convertTypeForMemory
73   if (isa<cir::BoolType>(type)) {
74     return mlir::IntegerType::get(type.getContext(),
75                                   dataLayout.getTypeSizeInBits(type));
76   }
77 
78   return converter.convertType(type);
79 }
80 
createIntCast(mlir::OpBuilder & bld,mlir::Value src,mlir::IntegerType dstTy,bool isSigned=false)81 static mlir::Value createIntCast(mlir::OpBuilder &bld, mlir::Value src,
82                                  mlir::IntegerType dstTy,
83                                  bool isSigned = false) {
84   mlir::Type srcTy = src.getType();
85   assert(mlir::isa<mlir::IntegerType>(srcTy));
86 
87   unsigned srcWidth = mlir::cast<mlir::IntegerType>(srcTy).getWidth();
88   unsigned dstWidth = mlir::cast<mlir::IntegerType>(dstTy).getWidth();
89   mlir::Location loc = src.getLoc();
90 
91   if (dstWidth > srcWidth && isSigned)
92     return bld.create<mlir::LLVM::SExtOp>(loc, dstTy, src);
93   if (dstWidth > srcWidth)
94     return bld.create<mlir::LLVM::ZExtOp>(loc, dstTy, src);
95   if (dstWidth < srcWidth)
96     return bld.create<mlir::LLVM::TruncOp>(loc, dstTy, src);
97   return bld.create<mlir::LLVM::BitcastOp>(loc, dstTy, src);
98 }
99 
100 static mlir::LLVM::Visibility
lowerCIRVisibilityToLLVMVisibility(cir::VisibilityKind visibilityKind)101 lowerCIRVisibilityToLLVMVisibility(cir::VisibilityKind visibilityKind) {
102   switch (visibilityKind) {
103   case cir::VisibilityKind::Default:
104     return ::mlir::LLVM::Visibility::Default;
105   case cir::VisibilityKind::Hidden:
106     return ::mlir::LLVM::Visibility::Hidden;
107   case cir::VisibilityKind::Protected:
108     return ::mlir::LLVM::Visibility::Protected;
109   }
110 }
111 
112 /// Emits the value from memory as expected by its users. Should be called when
113 /// the memory represetnation of a CIR type is not equal to its scalar
114 /// representation.
emitFromMemory(mlir::ConversionPatternRewriter & rewriter,mlir::DataLayout const & dataLayout,cir::LoadOp op,mlir::Value value)115 static mlir::Value emitFromMemory(mlir::ConversionPatternRewriter &rewriter,
116                                   mlir::DataLayout const &dataLayout,
117                                   cir::LoadOp op, mlir::Value value) {
118 
119   // TODO(cir): Handle other types similarly to clang's codegen EmitFromMemory
120   if (auto boolTy = mlir::dyn_cast<cir::BoolType>(op.getType())) {
121     // Create a cast value from specified size in datalayout to i1
122     assert(value.getType().isInteger(dataLayout.getTypeSizeInBits(boolTy)));
123     return createIntCast(rewriter, value, rewriter.getI1Type());
124   }
125 
126   return value;
127 }
128 
129 /// Emits a value to memory with the expected scalar type. Should be called when
130 /// the memory represetnation of a CIR type is not equal to its scalar
131 /// representation.
emitToMemory(mlir::ConversionPatternRewriter & rewriter,mlir::DataLayout const & dataLayout,mlir::Type origType,mlir::Value value)132 static mlir::Value emitToMemory(mlir::ConversionPatternRewriter &rewriter,
133                                 mlir::DataLayout const &dataLayout,
134                                 mlir::Type origType, mlir::Value value) {
135 
136   // TODO(cir): Handle other types similarly to clang's codegen EmitToMemory
137   if (auto boolTy = mlir::dyn_cast<cir::BoolType>(origType)) {
138     // Create zext of value from i1 to i8
139     mlir::IntegerType memType =
140         rewriter.getIntegerType(dataLayout.getTypeSizeInBits(boolTy));
141     return createIntCast(rewriter, value, memType);
142   }
143 
144   return value;
145 }
146 
convertLinkage(cir::GlobalLinkageKind linkage)147 mlir::LLVM::Linkage convertLinkage(cir::GlobalLinkageKind linkage) {
148   using CIR = cir::GlobalLinkageKind;
149   using LLVM = mlir::LLVM::Linkage;
150 
151   switch (linkage) {
152   case CIR::AvailableExternallyLinkage:
153     return LLVM::AvailableExternally;
154   case CIR::CommonLinkage:
155     return LLVM::Common;
156   case CIR::ExternalLinkage:
157     return LLVM::External;
158   case CIR::ExternalWeakLinkage:
159     return LLVM::ExternWeak;
160   case CIR::InternalLinkage:
161     return LLVM::Internal;
162   case CIR::LinkOnceAnyLinkage:
163     return LLVM::Linkonce;
164   case CIR::LinkOnceODRLinkage:
165     return LLVM::LinkonceODR;
166   case CIR::PrivateLinkage:
167     return LLVM::Private;
168   case CIR::WeakAnyLinkage:
169     return LLVM::Weak;
170   case CIR::WeakODRLinkage:
171     return LLVM::WeakODR;
172   };
173   llvm_unreachable("Unknown CIR linkage type");
174 }
175 
getLLVMIntCast(mlir::ConversionPatternRewriter & rewriter,mlir::Value llvmSrc,mlir::Type llvmDstIntTy,bool isUnsigned,uint64_t cirSrcWidth,uint64_t cirDstIntWidth)176 static mlir::Value getLLVMIntCast(mlir::ConversionPatternRewriter &rewriter,
177                                   mlir::Value llvmSrc, mlir::Type llvmDstIntTy,
178                                   bool isUnsigned, uint64_t cirSrcWidth,
179                                   uint64_t cirDstIntWidth) {
180   if (cirSrcWidth == cirDstIntWidth)
181     return llvmSrc;
182 
183   auto loc = llvmSrc.getLoc();
184   if (cirSrcWidth < cirDstIntWidth) {
185     if (isUnsigned)
186       return rewriter.create<mlir::LLVM::ZExtOp>(loc, llvmDstIntTy, llvmSrc);
187     return rewriter.create<mlir::LLVM::SExtOp>(loc, llvmDstIntTy, llvmSrc);
188   }
189 
190   // Otherwise truncate
191   return rewriter.create<mlir::LLVM::TruncOp>(loc, llvmDstIntTy, llvmSrc);
192 }
193 
194 class CIRAttrToValue {
195 public:
CIRAttrToValue(mlir::Operation * parentOp,mlir::ConversionPatternRewriter & rewriter,const mlir::TypeConverter * converter)196   CIRAttrToValue(mlir::Operation *parentOp,
197                  mlir::ConversionPatternRewriter &rewriter,
198                  const mlir::TypeConverter *converter)
199       : parentOp(parentOp), rewriter(rewriter), converter(converter) {}
200 
visit(mlir::Attribute attr)201   mlir::Value visit(mlir::Attribute attr) {
202     return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
203         .Case<cir::IntAttr, cir::FPAttr, cir::ConstComplexAttr,
204               cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
205               cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); })
206         .Default([&](auto attrT) { return mlir::Value(); });
207   }
208 
209   mlir::Value visitCirAttr(cir::IntAttr intAttr);
210   mlir::Value visitCirAttr(cir::FPAttr fltAttr);
211   mlir::Value visitCirAttr(cir::ConstComplexAttr complexAttr);
212   mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
213   mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
214   mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
215   mlir::Value visitCirAttr(cir::ZeroAttr attr);
216 
217 private:
218   mlir::Operation *parentOp;
219   mlir::ConversionPatternRewriter &rewriter;
220   const mlir::TypeConverter *converter;
221 };
222 
223 /// Switches on the type of attribute and calls the appropriate conversion.
lowerCirAttrAsValue(mlir::Operation * parentOp,const mlir::Attribute attr,mlir::ConversionPatternRewriter & rewriter,const mlir::TypeConverter * converter)224 mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
225                                 const mlir::Attribute attr,
226                                 mlir::ConversionPatternRewriter &rewriter,
227                                 const mlir::TypeConverter *converter) {
228   CIRAttrToValue valueConverter(parentOp, rewriter, converter);
229   mlir::Value value = valueConverter.visit(attr);
230   if (!value)
231     llvm_unreachable("unhandled attribute type");
232   return value;
233 }
234 
convertSideEffectForCall(mlir::Operation * callOp,bool isNothrow,cir::SideEffect sideEffect,mlir::LLVM::MemoryEffectsAttr & memoryEffect,bool & noUnwind,bool & willReturn)235 void convertSideEffectForCall(mlir::Operation *callOp, bool isNothrow,
236                               cir::SideEffect sideEffect,
237                               mlir::LLVM::MemoryEffectsAttr &memoryEffect,
238                               bool &noUnwind, bool &willReturn) {
239   using mlir::LLVM::ModRefInfo;
240 
241   switch (sideEffect) {
242   case cir::SideEffect::All:
243     memoryEffect = {};
244     noUnwind = isNothrow;
245     willReturn = false;
246     break;
247 
248   case cir::SideEffect::Pure:
249     memoryEffect = mlir::LLVM::MemoryEffectsAttr::get(
250         callOp->getContext(), /*other=*/ModRefInfo::Ref,
251         /*argMem=*/ModRefInfo::Ref,
252         /*inaccessibleMem=*/ModRefInfo::Ref);
253     noUnwind = true;
254     willReturn = true;
255     break;
256 
257   case cir::SideEffect::Const:
258     memoryEffect = mlir::LLVM::MemoryEffectsAttr::get(
259         callOp->getContext(), /*other=*/ModRefInfo::NoModRef,
260         /*argMem=*/ModRefInfo::NoModRef,
261         /*inaccessibleMem=*/ModRefInfo::NoModRef);
262     noUnwind = true;
263     willReturn = true;
264     break;
265   }
266 }
267 
268 /// IntAttr visitor.
visitCirAttr(cir::IntAttr intAttr)269 mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) {
270   mlir::Location loc = parentOp->getLoc();
271   return rewriter.create<mlir::LLVM::ConstantOp>(
272       loc, converter->convertType(intAttr.getType()), intAttr.getValue());
273 }
274 
275 /// FPAttr visitor.
visitCirAttr(cir::FPAttr fltAttr)276 mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
277   mlir::Location loc = parentOp->getLoc();
278   return rewriter.create<mlir::LLVM::ConstantOp>(
279       loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
280 }
281 
282 /// ConstComplexAttr visitor.
visitCirAttr(cir::ConstComplexAttr complexAttr)283 mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstComplexAttr complexAttr) {
284   auto complexType = mlir::cast<cir::ComplexType>(complexAttr.getType());
285   mlir::Type complexElemTy = complexType.getElementType();
286   mlir::Type complexElemLLVMTy = converter->convertType(complexElemTy);
287 
288   mlir::Attribute components[2];
289   if (const auto intType = mlir::dyn_cast<cir::IntType>(complexElemTy)) {
290     components[0] = rewriter.getIntegerAttr(
291         complexElemLLVMTy,
292         mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue());
293     components[1] = rewriter.getIntegerAttr(
294         complexElemLLVMTy,
295         mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue());
296   } else {
297     components[0] = rewriter.getFloatAttr(
298         complexElemLLVMTy,
299         mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue());
300     components[1] = rewriter.getFloatAttr(
301         complexElemLLVMTy,
302         mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue());
303   }
304 
305   mlir::Location loc = parentOp->getLoc();
306   return rewriter.create<mlir::LLVM::ConstantOp>(
307       loc, converter->convertType(complexAttr.getType()),
308       rewriter.getArrayAttr(components));
309 }
310 
311 /// ConstPtrAttr visitor.
visitCirAttr(cir::ConstPtrAttr ptrAttr)312 mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
313   mlir::Location loc = parentOp->getLoc();
314   if (ptrAttr.isNullValue()) {
315     return rewriter.create<mlir::LLVM::ZeroOp>(
316         loc, converter->convertType(ptrAttr.getType()));
317   }
318   mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
319   mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
320       loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
321       ptrAttr.getValue().getInt());
322   return rewriter.create<mlir::LLVM::IntToPtrOp>(
323       loc, converter->convertType(ptrAttr.getType()), ptrVal);
324 }
325 
326 // ConstArrayAttr visitor
visitCirAttr(cir::ConstArrayAttr attr)327 mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
328   mlir::Type llvmTy = converter->convertType(attr.getType());
329   mlir::Location loc = parentOp->getLoc();
330   mlir::Value result;
331 
332   if (attr.hasTrailingZeros()) {
333     mlir::Type arrayTy = attr.getType();
334     result = rewriter.create<mlir::LLVM::ZeroOp>(
335         loc, converter->convertType(arrayTy));
336   } else {
337     result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
338   }
339 
340   // Iteratively lower each constant element of the array.
341   if (auto arrayAttr = mlir::dyn_cast<mlir::ArrayAttr>(attr.getElts())) {
342     for (auto [idx, elt] : llvm::enumerate(arrayAttr)) {
343       mlir::DataLayout dataLayout(parentOp->getParentOfType<mlir::ModuleOp>());
344       mlir::Value init = visit(elt);
345       result =
346           rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
347     }
348   } else if (auto strAttr = mlir::dyn_cast<mlir::StringAttr>(attr.getElts())) {
349     // TODO(cir): this diverges from traditional lowering. Normally the string
350     // would be a global constant that is memcopied.
351     auto arrayTy = mlir::dyn_cast<cir::ArrayType>(strAttr.getType());
352     assert(arrayTy && "String attribute must have an array type");
353     mlir::Type eltTy = arrayTy.getElementType();
354     for (auto [idx, elt] : llvm::enumerate(strAttr)) {
355       auto init = rewriter.create<mlir::LLVM::ConstantOp>(
356           loc, converter->convertType(eltTy), elt);
357       result =
358           rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
359     }
360   } else {
361     llvm_unreachable("unexpected ConstArrayAttr elements");
362   }
363 
364   return result;
365 }
366 
367 /// ConstVectorAttr visitor.
visitCirAttr(cir::ConstVectorAttr attr)368 mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) {
369   const mlir::Type llvmTy = converter->convertType(attr.getType());
370   const mlir::Location loc = parentOp->getLoc();
371 
372   SmallVector<mlir::Attribute> mlirValues;
373   for (const mlir::Attribute elementAttr : attr.getElts()) {
374     mlir::Attribute mlirAttr;
375     if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(elementAttr)) {
376       mlirAttr = rewriter.getIntegerAttr(
377           converter->convertType(intAttr.getType()), intAttr.getValue());
378     } else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(elementAttr)) {
379       mlirAttr = rewriter.getFloatAttr(
380           converter->convertType(floatAttr.getType()), floatAttr.getValue());
381     } else {
382       llvm_unreachable(
383           "vector constant with an element that is neither an int nor a float");
384     }
385     mlirValues.push_back(mlirAttr);
386   }
387 
388   return rewriter.create<mlir::LLVM::ConstantOp>(
389       loc, llvmTy,
390       mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy),
391                                    mlirValues));
392 }
393 
394 /// ZeroAttr visitor.
visitCirAttr(cir::ZeroAttr attr)395 mlir::Value CIRAttrToValue::visitCirAttr(cir::ZeroAttr attr) {
396   mlir::Location loc = parentOp->getLoc();
397   return rewriter.create<mlir::LLVM::ZeroOp>(
398       loc, converter->convertType(attr.getType()));
399 }
400 
401 // This class handles rewriting initializer attributes for types that do not
402 // require region initialization.
403 class GlobalInitAttrRewriter {
404 public:
GlobalInitAttrRewriter(mlir::Type type,mlir::ConversionPatternRewriter & rewriter)405   GlobalInitAttrRewriter(mlir::Type type,
406                          mlir::ConversionPatternRewriter &rewriter)
407       : llvmType(type), rewriter(rewriter) {}
408 
visit(mlir::Attribute attr)409   mlir::Attribute visit(mlir::Attribute attr) {
410     return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
411         .Case<cir::IntAttr, cir::FPAttr, cir::BoolAttr>(
412             [&](auto attrT) { return visitCirAttr(attrT); })
413         .Default([&](auto attrT) { return mlir::Attribute(); });
414   }
415 
visitCirAttr(cir::IntAttr attr)416   mlir::Attribute visitCirAttr(cir::IntAttr attr) {
417     return rewriter.getIntegerAttr(llvmType, attr.getValue());
418   }
419 
visitCirAttr(cir::FPAttr attr)420   mlir::Attribute visitCirAttr(cir::FPAttr attr) {
421     return rewriter.getFloatAttr(llvmType, attr.getValue());
422   }
423 
visitCirAttr(cir::BoolAttr attr)424   mlir::Attribute visitCirAttr(cir::BoolAttr attr) {
425     return rewriter.getBoolAttr(attr.getValue());
426   }
427 
428 private:
429   mlir::Type llvmType;
430   mlir::ConversionPatternRewriter &rewriter;
431 };
432 
433 // This pass requires the CIR to be in a "flat" state. All blocks in each
434 // function must belong to the parent region. Once scopes and control flow
435 // are implemented in CIR, a pass will be run before this one to flatten
436 // the CIR and get it into the state that this pass requires.
437 struct ConvertCIRToLLVMPass
438     : public mlir::PassWrapper<ConvertCIRToLLVMPass,
439                                mlir::OperationPass<mlir::ModuleOp>> {
getDependentDialectscir::direct::ConvertCIRToLLVMPass440   void getDependentDialects(mlir::DialectRegistry &registry) const override {
441     registry.insert<mlir::BuiltinDialect, mlir::DLTIDialect,
442                     mlir::LLVM::LLVMDialect, mlir::func::FuncDialect>();
443   }
444   void runOnOperation() final;
445 
446   void processCIRAttrs(mlir::ModuleOp module);
447 
getDescriptioncir::direct::ConvertCIRToLLVMPass448   StringRef getDescription() const override {
449     return "Convert the prepared CIR dialect module to LLVM dialect";
450   }
451 
getArgumentcir::direct::ConvertCIRToLLVMPass452   StringRef getArgument() const override { return "cir-flat-to-llvm"; }
453 };
454 
matchAndRewrite(cir::AssumeOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const455 mlir::LogicalResult CIRToLLVMAssumeOpLowering::matchAndRewrite(
456     cir::AssumeOp op, OpAdaptor adaptor,
457     mlir::ConversionPatternRewriter &rewriter) const {
458   auto cond = adaptor.getPredicate();
459   rewriter.replaceOpWithNewOp<mlir::LLVM::AssumeOp>(op, cond);
460   return mlir::success();
461 }
462 
matchAndRewrite(cir::BitClrsbOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const463 mlir::LogicalResult CIRToLLVMBitClrsbOpLowering::matchAndRewrite(
464     cir::BitClrsbOp op, OpAdaptor adaptor,
465     mlir::ConversionPatternRewriter &rewriter) const {
466   auto zero = rewriter.create<mlir::LLVM::ConstantOp>(
467       op.getLoc(), adaptor.getInput().getType(), 0);
468   auto isNeg = rewriter.create<mlir::LLVM::ICmpOp>(
469       op.getLoc(),
470       mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(),
471                                          mlir::LLVM::ICmpPredicate::slt),
472       adaptor.getInput(), zero);
473 
474   auto negOne = rewriter.create<mlir::LLVM::ConstantOp>(
475       op.getLoc(), adaptor.getInput().getType(), -1);
476   auto flipped = rewriter.create<mlir::LLVM::XOrOp>(op.getLoc(),
477                                                     adaptor.getInput(), negOne);
478 
479   auto select = rewriter.create<mlir::LLVM::SelectOp>(
480       op.getLoc(), isNeg, flipped, adaptor.getInput());
481 
482   auto resTy = getTypeConverter()->convertType(op.getType());
483   auto clz = rewriter.create<mlir::LLVM::CountLeadingZerosOp>(
484       op.getLoc(), resTy, select, /*is_zero_poison=*/false);
485 
486   auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
487   auto res = rewriter.create<mlir::LLVM::SubOp>(op.getLoc(), clz, one);
488   rewriter.replaceOp(op, res);
489 
490   return mlir::LogicalResult::success();
491 }
492 
matchAndRewrite(cir::BitClzOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const493 mlir::LogicalResult CIRToLLVMBitClzOpLowering::matchAndRewrite(
494     cir::BitClzOp op, OpAdaptor adaptor,
495     mlir::ConversionPatternRewriter &rewriter) const {
496   auto resTy = getTypeConverter()->convertType(op.getType());
497   auto llvmOp = rewriter.create<mlir::LLVM::CountLeadingZerosOp>(
498       op.getLoc(), resTy, adaptor.getInput(), op.getPoisonZero());
499   rewriter.replaceOp(op, llvmOp);
500   return mlir::LogicalResult::success();
501 }
502 
matchAndRewrite(cir::BitCtzOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const503 mlir::LogicalResult CIRToLLVMBitCtzOpLowering::matchAndRewrite(
504     cir::BitCtzOp op, OpAdaptor adaptor,
505     mlir::ConversionPatternRewriter &rewriter) const {
506   auto resTy = getTypeConverter()->convertType(op.getType());
507   auto llvmOp = rewriter.create<mlir::LLVM::CountTrailingZerosOp>(
508       op.getLoc(), resTy, adaptor.getInput(), op.getPoisonZero());
509   rewriter.replaceOp(op, llvmOp);
510   return mlir::LogicalResult::success();
511 }
512 
matchAndRewrite(cir::BitParityOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const513 mlir::LogicalResult CIRToLLVMBitParityOpLowering::matchAndRewrite(
514     cir::BitParityOp op, OpAdaptor adaptor,
515     mlir::ConversionPatternRewriter &rewriter) const {
516   auto resTy = getTypeConverter()->convertType(op.getType());
517   auto popcnt = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy,
518                                                      adaptor.getInput());
519 
520   auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
521   auto popcntMod2 =
522       rewriter.create<mlir::LLVM::AndOp>(op.getLoc(), popcnt, one);
523   rewriter.replaceOp(op, popcntMod2);
524 
525   return mlir::LogicalResult::success();
526 }
527 
matchAndRewrite(cir::BitPopcountOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const528 mlir::LogicalResult CIRToLLVMBitPopcountOpLowering::matchAndRewrite(
529     cir::BitPopcountOp op, OpAdaptor adaptor,
530     mlir::ConversionPatternRewriter &rewriter) const {
531   auto resTy = getTypeConverter()->convertType(op.getType());
532   auto llvmOp = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy,
533                                                      adaptor.getInput());
534   rewriter.replaceOp(op, llvmOp);
535   return mlir::LogicalResult::success();
536 }
537 
matchAndRewrite(cir::BitReverseOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const538 mlir::LogicalResult CIRToLLVMBitReverseOpLowering::matchAndRewrite(
539     cir::BitReverseOp op, OpAdaptor adaptor,
540     mlir::ConversionPatternRewriter &rewriter) const {
541   rewriter.replaceOpWithNewOp<mlir::LLVM::BitReverseOp>(op, adaptor.getInput());
542   return mlir::success();
543 }
544 
matchAndRewrite(cir::BrCondOp brOp,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const545 mlir::LogicalResult CIRToLLVMBrCondOpLowering::matchAndRewrite(
546     cir::BrCondOp brOp, OpAdaptor adaptor,
547     mlir::ConversionPatternRewriter &rewriter) const {
548   // When ZExtOp is implemented, we'll need to check if the condition is a
549   // ZExtOp and if so, delete it if it has a single use.
550   assert(!cir::MissingFeatures::zextOp());
551 
552   mlir::Value i1Condition = adaptor.getCond();
553 
554   rewriter.replaceOpWithNewOp<mlir::LLVM::CondBrOp>(
555       brOp, i1Condition, brOp.getDestTrue(), adaptor.getDestOperandsTrue(),
556       brOp.getDestFalse(), adaptor.getDestOperandsFalse());
557 
558   return mlir::success();
559 }
560 
matchAndRewrite(cir::ByteSwapOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const561 mlir::LogicalResult CIRToLLVMByteSwapOpLowering::matchAndRewrite(
562     cir::ByteSwapOp op, OpAdaptor adaptor,
563     mlir::ConversionPatternRewriter &rewriter) const {
564   rewriter.replaceOpWithNewOp<mlir::LLVM::ByteSwapOp>(op, adaptor.getInput());
565   return mlir::LogicalResult::success();
566 }
567 
convertTy(mlir::Type ty) const568 mlir::Type CIRToLLVMCastOpLowering::convertTy(mlir::Type ty) const {
569   return getTypeConverter()->convertType(ty);
570 }
571 
matchAndRewrite(cir::CastOp castOp,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const572 mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
573     cir::CastOp castOp, OpAdaptor adaptor,
574     mlir::ConversionPatternRewriter &rewriter) const {
575   // For arithmetic conversions, LLVM IR uses the same instruction to convert
576   // both individual scalars and entire vectors. This lowering pass handles
577   // both situations.
578 
579   switch (castOp.getKind()) {
580   case cir::CastKind::array_to_ptrdecay: {
581     const auto ptrTy = mlir::cast<cir::PointerType>(castOp.getType());
582     mlir::Value sourceValue = adaptor.getSrc();
583     mlir::Type targetType = convertTy(ptrTy);
584     mlir::Type elementTy = convertTypeForMemory(*getTypeConverter(), dataLayout,
585                                                 ptrTy.getPointee());
586     llvm::SmallVector<mlir::LLVM::GEPArg> offset{0};
587     rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
588         castOp, targetType, elementTy, sourceValue, offset);
589     break;
590   }
591   case cir::CastKind::int_to_bool: {
592     mlir::Value llvmSrcVal = adaptor.getSrc();
593     mlir::Value zeroInt = rewriter.create<mlir::LLVM::ConstantOp>(
594         castOp.getLoc(), llvmSrcVal.getType(), 0);
595     rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
596         castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroInt);
597     break;
598   }
599   case cir::CastKind::integral: {
600     mlir::Type srcType = castOp.getSrc().getType();
601     mlir::Type dstType = castOp.getType();
602     mlir::Value llvmSrcVal = adaptor.getSrc();
603     mlir::Type llvmDstType = getTypeConverter()->convertType(dstType);
604     cir::IntType srcIntType =
605         mlir::cast<cir::IntType>(elementTypeIfVector(srcType));
606     cir::IntType dstIntType =
607         mlir::cast<cir::IntType>(elementTypeIfVector(dstType));
608     rewriter.replaceOp(castOp, getLLVMIntCast(rewriter, llvmSrcVal, llvmDstType,
609                                               srcIntType.isUnsigned(),
610                                               srcIntType.getWidth(),
611                                               dstIntType.getWidth()));
612     break;
613   }
614   case cir::CastKind::floating: {
615     mlir::Value llvmSrcVal = adaptor.getSrc();
616     mlir::Type llvmDstTy = getTypeConverter()->convertType(castOp.getType());
617 
618     mlir::Type srcTy = elementTypeIfVector(castOp.getSrc().getType());
619     mlir::Type dstTy = elementTypeIfVector(castOp.getType());
620 
621     if (!mlir::isa<cir::FPTypeInterface>(dstTy) ||
622         !mlir::isa<cir::FPTypeInterface>(srcTy))
623       return castOp.emitError() << "NYI cast from " << srcTy << " to " << dstTy;
624 
625     auto getFloatWidth = [](mlir::Type ty) -> unsigned {
626       return mlir::cast<cir::FPTypeInterface>(ty).getWidth();
627     };
628 
629     if (getFloatWidth(srcTy) > getFloatWidth(dstTy))
630       rewriter.replaceOpWithNewOp<mlir::LLVM::FPTruncOp>(castOp, llvmDstTy,
631                                                          llvmSrcVal);
632     else
633       rewriter.replaceOpWithNewOp<mlir::LLVM::FPExtOp>(castOp, llvmDstTy,
634                                                        llvmSrcVal);
635     return mlir::success();
636   }
637   case cir::CastKind::int_to_ptr: {
638     auto dstTy = mlir::cast<cir::PointerType>(castOp.getType());
639     mlir::Value llvmSrcVal = adaptor.getSrc();
640     mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
641     rewriter.replaceOpWithNewOp<mlir::LLVM::IntToPtrOp>(castOp, llvmDstTy,
642                                                         llvmSrcVal);
643     return mlir::success();
644   }
645   case cir::CastKind::ptr_to_int: {
646     auto dstTy = mlir::cast<cir::IntType>(castOp.getType());
647     mlir::Value llvmSrcVal = adaptor.getSrc();
648     mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
649     rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(castOp, llvmDstTy,
650                                                         llvmSrcVal);
651     return mlir::success();
652   }
653   case cir::CastKind::float_to_bool: {
654     mlir::Value llvmSrcVal = adaptor.getSrc();
655     auto kind = mlir::LLVM::FCmpPredicate::une;
656 
657     // Check if float is not equal to zero.
658     auto zeroFloat = rewriter.create<mlir::LLVM::ConstantOp>(
659         castOp.getLoc(), llvmSrcVal.getType(),
660         mlir::FloatAttr::get(llvmSrcVal.getType(), 0.0));
661 
662     // Extend comparison result to either bool (C++) or int (C).
663     rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(castOp, kind, llvmSrcVal,
664                                                     zeroFloat);
665 
666     return mlir::success();
667   }
668   case cir::CastKind::bool_to_int: {
669     auto dstTy = mlir::cast<cir::IntType>(castOp.getType());
670     mlir::Value llvmSrcVal = adaptor.getSrc();
671     auto llvmSrcTy = mlir::cast<mlir::IntegerType>(llvmSrcVal.getType());
672     auto llvmDstTy =
673         mlir::cast<mlir::IntegerType>(getTypeConverter()->convertType(dstTy));
674     if (llvmSrcTy.getWidth() == llvmDstTy.getWidth())
675       rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy,
676                                                          llvmSrcVal);
677     else
678       rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(castOp, llvmDstTy,
679                                                       llvmSrcVal);
680     return mlir::success();
681   }
682   case cir::CastKind::bool_to_float: {
683     mlir::Type dstTy = castOp.getType();
684     mlir::Value llvmSrcVal = adaptor.getSrc();
685     mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
686     rewriter.replaceOpWithNewOp<mlir::LLVM::UIToFPOp>(castOp, llvmDstTy,
687                                                       llvmSrcVal);
688     return mlir::success();
689   }
690   case cir::CastKind::int_to_float: {
691     mlir::Type dstTy = castOp.getType();
692     mlir::Value llvmSrcVal = adaptor.getSrc();
693     mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
694     if (mlir::cast<cir::IntType>(elementTypeIfVector(castOp.getSrc().getType()))
695             .isSigned())
696       rewriter.replaceOpWithNewOp<mlir::LLVM::SIToFPOp>(castOp, llvmDstTy,
697                                                         llvmSrcVal);
698     else
699       rewriter.replaceOpWithNewOp<mlir::LLVM::UIToFPOp>(castOp, llvmDstTy,
700                                                         llvmSrcVal);
701     return mlir::success();
702   }
703   case cir::CastKind::float_to_int: {
704     mlir::Type dstTy = castOp.getType();
705     mlir::Value llvmSrcVal = adaptor.getSrc();
706     mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
707     if (mlir::cast<cir::IntType>(elementTypeIfVector(castOp.getType()))
708             .isSigned())
709       rewriter.replaceOpWithNewOp<mlir::LLVM::FPToSIOp>(castOp, llvmDstTy,
710                                                         llvmSrcVal);
711     else
712       rewriter.replaceOpWithNewOp<mlir::LLVM::FPToUIOp>(castOp, llvmDstTy,
713                                                         llvmSrcVal);
714     return mlir::success();
715   }
716   case cir::CastKind::bitcast: {
717     mlir::Type dstTy = castOp.getType();
718     mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
719 
720     assert(!MissingFeatures::cxxABI());
721     assert(!MissingFeatures::dataMemberType());
722 
723     mlir::Value llvmSrcVal = adaptor.getSrc();
724     rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy,
725                                                        llvmSrcVal);
726     return mlir::success();
727   }
728   case cir::CastKind::ptr_to_bool: {
729     mlir::Value llvmSrcVal = adaptor.getSrc();
730     mlir::Value zeroPtr = rewriter.create<mlir::LLVM::ZeroOp>(
731         castOp.getLoc(), llvmSrcVal.getType());
732     rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
733         castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroPtr);
734     break;
735   }
736   case cir::CastKind::address_space: {
737     mlir::Type dstTy = castOp.getType();
738     mlir::Value llvmSrcVal = adaptor.getSrc();
739     mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
740     rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(castOp, llvmDstTy,
741                                                              llvmSrcVal);
742     break;
743   }
744   case cir::CastKind::member_ptr_to_bool:
745     assert(!MissingFeatures::cxxABI());
746     assert(!MissingFeatures::methodType());
747     break;
748   default: {
749     return castOp.emitError("Unhandled cast kind: ")
750            << castOp.getKindAttrName();
751   }
752   }
753 
754   return mlir::success();
755 }
756 
matchAndRewrite(cir::PtrStrideOp ptrStrideOp,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const757 mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
758     cir::PtrStrideOp ptrStrideOp, OpAdaptor adaptor,
759     mlir::ConversionPatternRewriter &rewriter) const {
760 
761   const mlir::TypeConverter *tc = getTypeConverter();
762   const mlir::Type resultTy = tc->convertType(ptrStrideOp.getType());
763 
764   mlir::Type elementTy =
765       convertTypeForMemory(*tc, dataLayout, ptrStrideOp.getElementTy());
766   mlir::MLIRContext *ctx = elementTy.getContext();
767 
768   // void and function types doesn't really have a layout to use in GEPs,
769   // make it i8 instead.
770   if (mlir::isa<mlir::LLVM::LLVMVoidType>(elementTy) ||
771       mlir::isa<mlir::LLVM::LLVMFunctionType>(elementTy))
772     elementTy = mlir::IntegerType::get(elementTy.getContext(), 8,
773                                        mlir::IntegerType::Signless);
774   // Zero-extend, sign-extend or trunc the pointer value.
775   mlir::Value index = adaptor.getStride();
776   const unsigned width =
777       mlir::cast<mlir::IntegerType>(index.getType()).getWidth();
778   const std::optional<std::uint64_t> layoutWidth =
779       dataLayout.getTypeIndexBitwidth(adaptor.getBase().getType());
780 
781   mlir::Operation *indexOp = index.getDefiningOp();
782   if (indexOp && layoutWidth && width != *layoutWidth) {
783     // If the index comes from a subtraction, make sure the extension happens
784     // before it. To achieve that, look at unary minus, which already got
785     // lowered to "sub 0, x".
786     const auto sub = dyn_cast<mlir::LLVM::SubOp>(indexOp);
787     auto unary = dyn_cast_if_present<cir::UnaryOp>(
788         ptrStrideOp.getStride().getDefiningOp());
789     bool rewriteSub =
790         unary && unary.getKind() == cir::UnaryOpKind::Minus && sub;
791     if (rewriteSub)
792       index = indexOp->getOperand(1);
793 
794     // Handle the cast
795     const auto llvmDstType = mlir::IntegerType::get(ctx, *layoutWidth);
796     index = getLLVMIntCast(rewriter, index, llvmDstType,
797                            ptrStrideOp.getStride().getType().isUnsigned(),
798                            width, *layoutWidth);
799 
800     // Rewrite the sub in front of extensions/trunc
801     if (rewriteSub) {
802       index = rewriter.create<mlir::LLVM::SubOp>(
803           index.getLoc(), index.getType(),
804           rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(),
805                                                   index.getType(), 0),
806           index);
807       rewriter.eraseOp(sub);
808     }
809   }
810 
811   rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
812       ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index);
813   return mlir::success();
814 }
815 
matchAndRewrite(cir::BaseClassAddrOp baseClassOp,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const816 mlir::LogicalResult CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite(
817     cir::BaseClassAddrOp baseClassOp, OpAdaptor adaptor,
818     mlir::ConversionPatternRewriter &rewriter) const {
819   const mlir::Type resultType =
820       getTypeConverter()->convertType(baseClassOp.getType());
821   mlir::Value derivedAddr = adaptor.getDerivedAddr();
822   llvm::SmallVector<mlir::LLVM::GEPArg, 1> offset = {
823       adaptor.getOffset().getZExtValue()};
824   mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8,
825                                                mlir::IntegerType::Signless);
826   if (adaptor.getOffset().getZExtValue() == 0) {
827     rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(
828         baseClassOp, resultType, adaptor.getDerivedAddr());
829     return mlir::success();
830   }
831 
832   if (baseClassOp.getAssumeNotNull()) {
833     rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
834         baseClassOp, resultType, byteType, derivedAddr, offset);
835   } else {
836     auto loc = baseClassOp.getLoc();
837     mlir::Value isNull = rewriter.create<mlir::LLVM::ICmpOp>(
838         loc, mlir::LLVM::ICmpPredicate::eq, derivedAddr,
839         rewriter.create<mlir::LLVM::ZeroOp>(loc, derivedAddr.getType()));
840     mlir::Value adjusted = rewriter.create<mlir::LLVM::GEPOp>(
841         loc, resultType, byteType, derivedAddr, offset);
842     rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(baseClassOp, isNull,
843                                                       derivedAddr, adjusted);
844   }
845   return mlir::success();
846 }
847 
matchAndRewrite(cir::AllocaOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const848 mlir::LogicalResult CIRToLLVMAllocaOpLowering::matchAndRewrite(
849     cir::AllocaOp op, OpAdaptor adaptor,
850     mlir::ConversionPatternRewriter &rewriter) const {
851   assert(!cir::MissingFeatures::opAllocaDynAllocSize());
852   mlir::Value size = rewriter.create<mlir::LLVM::ConstantOp>(
853       op.getLoc(), typeConverter->convertType(rewriter.getIndexType()), 1);
854   mlir::Type elementTy =
855       convertTypeForMemory(*getTypeConverter(), dataLayout, op.getAllocaType());
856   mlir::Type resultTy =
857       convertTypeForMemory(*getTypeConverter(), dataLayout, op.getType());
858 
859   assert(!cir::MissingFeatures::addressSpace());
860   assert(!cir::MissingFeatures::opAllocaAnnotations());
861 
862   rewriter.replaceOpWithNewOp<mlir::LLVM::AllocaOp>(
863       op, resultTy, elementTy, size, op.getAlignmentAttr().getInt());
864 
865   return mlir::success();
866 }
867 
matchAndRewrite(cir::ReturnOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const868 mlir::LogicalResult CIRToLLVMReturnOpLowering::matchAndRewrite(
869     cir::ReturnOp op, OpAdaptor adaptor,
870     mlir::ConversionPatternRewriter &rewriter) const {
871   rewriter.replaceOpWithNewOp<mlir::LLVM::ReturnOp>(op, adaptor.getOperands());
872   return mlir::LogicalResult::success();
873 }
874 
875 static mlir::LogicalResult
rewriteCallOrInvoke(mlir::Operation * op,mlir::ValueRange callOperands,mlir::ConversionPatternRewriter & rewriter,const mlir::TypeConverter * converter,mlir::FlatSymbolRefAttr calleeAttr)876 rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
877                     mlir::ConversionPatternRewriter &rewriter,
878                     const mlir::TypeConverter *converter,
879                     mlir::FlatSymbolRefAttr calleeAttr) {
880   llvm::SmallVector<mlir::Type, 8> llvmResults;
881   mlir::ValueTypeRange<mlir::ResultRange> cirResults = op->getResultTypes();
882   auto call = cast<cir::CIRCallOpInterface>(op);
883 
884   if (converter->convertTypes(cirResults, llvmResults).failed())
885     return mlir::failure();
886 
887   assert(!cir::MissingFeatures::opCallCallConv());
888 
889   mlir::LLVM::MemoryEffectsAttr memoryEffects;
890   bool noUnwind = false;
891   bool willReturn = false;
892   convertSideEffectForCall(op, call.getNothrow(), call.getSideEffect(),
893                            memoryEffects, noUnwind, willReturn);
894 
895   mlir::LLVM::LLVMFunctionType llvmFnTy;
896   if (calleeAttr) { // direct call
897     mlir::FunctionOpInterface fn =
898         mlir::SymbolTable::lookupNearestSymbolFrom<mlir::FunctionOpInterface>(
899             op, calleeAttr);
900     assert(fn && "Did not find function for call");
901     llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(
902         converter->convertType(fn.getFunctionType()));
903   } else { // indirect call
904     assert(!op->getOperands().empty() &&
905            "operands list must no be empty for the indirect call");
906     auto calleeTy = op->getOperands().front().getType();
907     auto calleePtrTy = cast<cir::PointerType>(calleeTy);
908     auto calleeFuncTy = cast<cir::FuncType>(calleePtrTy.getPointee());
909     calleeFuncTy.dump();
910     converter->convertType(calleeFuncTy).dump();
911     llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(
912         converter->convertType(calleeFuncTy));
913   }
914 
915   assert(!cir::MissingFeatures::opCallLandingPad());
916   assert(!cir::MissingFeatures::opCallContinueBlock());
917   assert(!cir::MissingFeatures::opCallCallConv());
918 
919   auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
920       op, llvmFnTy, calleeAttr, callOperands);
921   if (memoryEffects)
922     newOp.setMemoryEffectsAttr(memoryEffects);
923   newOp.setNoUnwind(noUnwind);
924   newOp.setWillReturn(willReturn);
925 
926   return mlir::success();
927 }
928 
matchAndRewrite(cir::CallOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const929 mlir::LogicalResult CIRToLLVMCallOpLowering::matchAndRewrite(
930     cir::CallOp op, OpAdaptor adaptor,
931     mlir::ConversionPatternRewriter &rewriter) const {
932   return rewriteCallOrInvoke(op.getOperation(), adaptor.getOperands(), rewriter,
933                              getTypeConverter(), op.getCalleeAttr());
934 }
935 
matchAndRewrite(cir::LoadOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const936 mlir::LogicalResult CIRToLLVMLoadOpLowering::matchAndRewrite(
937     cir::LoadOp op, OpAdaptor adaptor,
938     mlir::ConversionPatternRewriter &rewriter) const {
939   const mlir::Type llvmTy =
940       convertTypeForMemory(*getTypeConverter(), dataLayout, op.getType());
941   assert(!cir::MissingFeatures::opLoadStoreMemOrder());
942   std::optional<size_t> opAlign = op.getAlignment();
943   unsigned alignment =
944       (unsigned)opAlign.value_or(dataLayout.getTypeABIAlignment(llvmTy));
945 
946   assert(!cir::MissingFeatures::lowerModeOptLevel());
947 
948   // TODO: nontemporal, syncscope.
949   assert(!cir::MissingFeatures::opLoadStoreVolatile());
950   mlir::LLVM::LoadOp newLoad = rewriter.create<mlir::LLVM::LoadOp>(
951       op->getLoc(), llvmTy, adaptor.getAddr(), alignment,
952       /*volatile=*/false, /*nontemporal=*/false,
953       /*invariant=*/false, /*invariantGroup=*/false,
954       mlir::LLVM::AtomicOrdering::not_atomic);
955 
956   // Convert adapted result to its original type if needed.
957   mlir::Value result =
958       emitFromMemory(rewriter, dataLayout, op, newLoad.getResult());
959   rewriter.replaceOp(op, result);
960   assert(!cir::MissingFeatures::opLoadStoreTbaa());
961   return mlir::LogicalResult::success();
962 }
963 
matchAndRewrite(cir::StoreOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const964 mlir::LogicalResult CIRToLLVMStoreOpLowering::matchAndRewrite(
965     cir::StoreOp op, OpAdaptor adaptor,
966     mlir::ConversionPatternRewriter &rewriter) const {
967   assert(!cir::MissingFeatures::opLoadStoreMemOrder());
968   const mlir::Type llvmTy =
969       getTypeConverter()->convertType(op.getValue().getType());
970   std::optional<size_t> opAlign = op.getAlignment();
971   unsigned alignment =
972       (unsigned)opAlign.value_or(dataLayout.getTypeABIAlignment(llvmTy));
973 
974   assert(!cir::MissingFeatures::lowerModeOptLevel());
975 
976   // Convert adapted value to its memory type if needed.
977   mlir::Value value = emitToMemory(rewriter, dataLayout,
978                                    op.getValue().getType(), adaptor.getValue());
979   // TODO: nontemporal, syncscope.
980   assert(!cir::MissingFeatures::opLoadStoreVolatile());
981   mlir::LLVM::StoreOp storeOp = rewriter.create<mlir::LLVM::StoreOp>(
982       op->getLoc(), value, adaptor.getAddr(), alignment, /*volatile=*/false,
983       /*nontemporal=*/false, /*invariantGroup=*/false,
984       mlir::LLVM::AtomicOrdering::not_atomic);
985   rewriter.replaceOp(op, storeOp);
986   assert(!cir::MissingFeatures::opLoadStoreTbaa());
987   return mlir::LogicalResult::success();
988 }
989 
hasTrailingZeros(cir::ConstArrayAttr attr)990 bool hasTrailingZeros(cir::ConstArrayAttr attr) {
991   auto array = mlir::dyn_cast<mlir::ArrayAttr>(attr.getElts());
992   return attr.hasTrailingZeros() ||
993          (array && std::count_if(array.begin(), array.end(), [](auto elt) {
994             auto ar = dyn_cast<cir::ConstArrayAttr>(elt);
995             return ar && hasTrailingZeros(ar);
996           }));
997 }
998 
matchAndRewrite(cir::ConstantOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const999 mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
1000     cir::ConstantOp op, OpAdaptor adaptor,
1001     mlir::ConversionPatternRewriter &rewriter) const {
1002   mlir::Attribute attr = op.getValue();
1003 
1004   if (mlir::isa<mlir::IntegerType>(op.getType())) {
1005     // Verified cir.const operations cannot actually be of these types, but the
1006     // lowering pass may generate temporary cir.const operations with these
1007     // types. This is OK since MLIR allows unverified operations to be alive
1008     // during a pass as long as they don't live past the end of the pass.
1009     attr = op.getValue();
1010   } else if (mlir::isa<cir::BoolType>(op.getType())) {
1011     int value = mlir::cast<cir::BoolAttr>(op.getValue()).getValue();
1012     attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()),
1013                                    value);
1014   } else if (mlir::isa<cir::IntType>(op.getType())) {
1015     assert(!cir::MissingFeatures::opGlobalViewAttr());
1016 
1017     attr = rewriter.getIntegerAttr(
1018         typeConverter->convertType(op.getType()),
1019         mlir::cast<cir::IntAttr>(op.getValue()).getValue());
1020   } else if (mlir::isa<cir::FPTypeInterface>(op.getType())) {
1021     attr = rewriter.getFloatAttr(
1022         typeConverter->convertType(op.getType()),
1023         mlir::cast<cir::FPAttr>(op.getValue()).getValue());
1024   } else if (mlir::isa<cir::PointerType>(op.getType())) {
1025     // Optimize with dedicated LLVM op for null pointers.
1026     if (mlir::isa<cir::ConstPtrAttr>(op.getValue())) {
1027       if (mlir::cast<cir::ConstPtrAttr>(op.getValue()).isNullValue()) {
1028         rewriter.replaceOpWithNewOp<mlir::LLVM::ZeroOp>(
1029             op, typeConverter->convertType(op.getType()));
1030         return mlir::success();
1031       }
1032     }
1033     assert(!cir::MissingFeatures::opGlobalViewAttr());
1034     attr = op.getValue();
1035   } else if (const auto arrTy = mlir::dyn_cast<cir::ArrayType>(op.getType())) {
1036     const auto constArr = mlir::dyn_cast<cir::ConstArrayAttr>(op.getValue());
1037     if (!constArr && !isa<cir::ZeroAttr, cir::UndefAttr>(op.getValue()))
1038       return op.emitError() << "array does not have a constant initializer";
1039 
1040     std::optional<mlir::Attribute> denseAttr;
1041     if (constArr && hasTrailingZeros(constArr)) {
1042       const mlir::Value newOp =
1043           lowerCirAttrAsValue(op, constArr, rewriter, getTypeConverter());
1044       rewriter.replaceOp(op, newOp);
1045       return mlir::success();
1046     } else if (constArr &&
1047                (denseAttr = lowerConstArrayAttr(constArr, typeConverter))) {
1048       attr = denseAttr.value();
1049     } else {
1050       const mlir::Value initVal =
1051           lowerCirAttrAsValue(op, op.getValue(), rewriter, typeConverter);
1052       rewriter.replaceAllUsesWith(op, initVal);
1053       rewriter.eraseOp(op);
1054       return mlir::success();
1055     }
1056   } else if (const auto vecTy = mlir::dyn_cast<cir::VectorType>(op.getType())) {
1057     rewriter.replaceOp(op, lowerCirAttrAsValue(op, op.getValue(), rewriter,
1058                                                getTypeConverter()));
1059     return mlir::success();
1060   } else if (auto complexTy = mlir::dyn_cast<cir::ComplexType>(op.getType())) {
1061     mlir::Type complexElemTy = complexTy.getElementType();
1062     mlir::Type complexElemLLVMTy = typeConverter->convertType(complexElemTy);
1063 
1064     if (auto zeroInitAttr = mlir::dyn_cast<cir::ZeroAttr>(op.getValue())) {
1065       mlir::TypedAttr zeroAttr = rewriter.getZeroAttr(complexElemLLVMTy);
1066       mlir::ArrayAttr array = rewriter.getArrayAttr({zeroAttr, zeroAttr});
1067       rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
1068           op, getTypeConverter()->convertType(op.getType()), array);
1069       return mlir::success();
1070     }
1071 
1072     auto complexAttr = mlir::cast<cir::ConstComplexAttr>(op.getValue());
1073 
1074     mlir::Attribute components[2];
1075     if (mlir::isa<cir::IntType>(complexElemTy)) {
1076       components[0] = rewriter.getIntegerAttr(
1077           complexElemLLVMTy,
1078           mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue());
1079       components[1] = rewriter.getIntegerAttr(
1080           complexElemLLVMTy,
1081           mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue());
1082     } else {
1083       components[0] = rewriter.getFloatAttr(
1084           complexElemLLVMTy,
1085           mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue());
1086       components[1] = rewriter.getFloatAttr(
1087           complexElemLLVMTy,
1088           mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue());
1089     }
1090 
1091     attr = rewriter.getArrayAttr(components);
1092   } else {
1093     return op.emitError() << "unsupported constant type " << op.getType();
1094   }
1095 
1096   rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
1097       op, getTypeConverter()->convertType(op.getType()), attr);
1098 
1099   return mlir::success();
1100 }
1101 
matchAndRewrite(cir::ExpectOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1102 mlir::LogicalResult CIRToLLVMExpectOpLowering::matchAndRewrite(
1103     cir::ExpectOp op, OpAdaptor adaptor,
1104     mlir::ConversionPatternRewriter &rewriter) const {
1105   // TODO(cir): do not generate LLVM intrinsics under -O0
1106   assert(!cir::MissingFeatures::optInfoAttr());
1107 
1108   std::optional<llvm::APFloat> prob = op.getProb();
1109   if (prob)
1110     rewriter.replaceOpWithNewOp<mlir::LLVM::ExpectWithProbabilityOp>(
1111         op, adaptor.getVal(), adaptor.getExpected(), prob.value());
1112   else
1113     rewriter.replaceOpWithNewOp<mlir::LLVM::ExpectOp>(op, adaptor.getVal(),
1114                                                       adaptor.getExpected());
1115   return mlir::success();
1116 }
1117 
1118 /// Convert the `cir.func` attributes to `llvm.func` attributes.
1119 /// Only retain those attributes that are not constructed by
1120 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out
1121 /// argument attributes.
lowerFuncAttributes(cir::FuncOp func,bool filterArgAndResAttrs,SmallVectorImpl<mlir::NamedAttribute> & result) const1122 void CIRToLLVMFuncOpLowering::lowerFuncAttributes(
1123     cir::FuncOp func, bool filterArgAndResAttrs,
1124     SmallVectorImpl<mlir::NamedAttribute> &result) const {
1125   assert(!cir::MissingFeatures::opFuncCallingConv());
1126   for (mlir::NamedAttribute attr : func->getAttrs()) {
1127     assert(!cir::MissingFeatures::opFuncCallingConv());
1128     if (attr.getName() == mlir::SymbolTable::getSymbolAttrName() ||
1129         attr.getName() == func.getFunctionTypeAttrName() ||
1130         attr.getName() == getLinkageAttrNameString() ||
1131         attr.getName() == func.getGlobalVisibilityAttrName() ||
1132         attr.getName() == func.getDsoLocalAttrName() ||
1133         (filterArgAndResAttrs &&
1134          (attr.getName() == func.getArgAttrsAttrName() ||
1135           attr.getName() == func.getResAttrsAttrName())))
1136       continue;
1137 
1138     assert(!cir::MissingFeatures::opFuncExtraAttrs());
1139     result.push_back(attr);
1140   }
1141 }
1142 
matchAndRewrite(cir::FuncOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1143 mlir::LogicalResult CIRToLLVMFuncOpLowering::matchAndRewrite(
1144     cir::FuncOp op, OpAdaptor adaptor,
1145     mlir::ConversionPatternRewriter &rewriter) const {
1146 
1147   cir::FuncType fnType = op.getFunctionType();
1148   bool isDsoLocal = op.getDsoLocal();
1149   mlir::TypeConverter::SignatureConversion signatureConversion(
1150       fnType.getNumInputs());
1151 
1152   for (const auto &argType : llvm::enumerate(fnType.getInputs())) {
1153     mlir::Type convertedType = typeConverter->convertType(argType.value());
1154     if (!convertedType)
1155       return mlir::failure();
1156     signatureConversion.addInputs(argType.index(), convertedType);
1157   }
1158 
1159   mlir::Type resultType =
1160       getTypeConverter()->convertType(fnType.getReturnType());
1161 
1162   // Create the LLVM function operation.
1163   mlir::Type llvmFnTy = mlir::LLVM::LLVMFunctionType::get(
1164       resultType ? resultType : mlir::LLVM::LLVMVoidType::get(getContext()),
1165       signatureConversion.getConvertedTypes(),
1166       /*isVarArg=*/fnType.isVarArg());
1167   // LLVMFuncOp expects a single FileLine Location instead of a fused
1168   // location.
1169   mlir::Location loc = op.getLoc();
1170   if (mlir::FusedLoc fusedLoc = mlir::dyn_cast<mlir::FusedLoc>(loc))
1171     loc = fusedLoc.getLocations()[0];
1172   assert((mlir::isa<mlir::FileLineColLoc>(loc) ||
1173           mlir::isa<mlir::UnknownLoc>(loc)) &&
1174          "expected single location or unknown location here");
1175 
1176   mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage());
1177   assert(!cir::MissingFeatures::opFuncCallingConv());
1178   mlir::LLVM::CConv cconv = mlir::LLVM::CConv::C;
1179   SmallVector<mlir::NamedAttribute, 4> attributes;
1180   lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes);
1181 
1182   mlir::LLVM::LLVMFuncOp fn = rewriter.create<mlir::LLVM::LLVMFuncOp>(
1183       loc, op.getName(), llvmFnTy, linkage, isDsoLocal, cconv,
1184       mlir::SymbolRefAttr(), attributes);
1185 
1186   assert(!cir::MissingFeatures::opFuncMultipleReturnVals());
1187 
1188   fn.setVisibility_Attr(mlir::LLVM::VisibilityAttr::get(
1189       getContext(), lowerCIRVisibilityToLLVMVisibility(
1190                         op.getGlobalVisibilityAttr().getValue())));
1191 
1192   rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end());
1193   if (failed(rewriter.convertRegionTypes(&fn.getBody(), *typeConverter,
1194                                          &signatureConversion)))
1195     return mlir::failure();
1196 
1197   rewriter.eraseOp(op);
1198 
1199   return mlir::LogicalResult::success();
1200 }
1201 
matchAndRewrite(cir::GetGlobalOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1202 mlir::LogicalResult CIRToLLVMGetGlobalOpLowering::matchAndRewrite(
1203     cir::GetGlobalOp op, OpAdaptor adaptor,
1204     mlir::ConversionPatternRewriter &rewriter) const {
1205   // FIXME(cir): Premature DCE to avoid lowering stuff we're not using.
1206   // CIRGen should mitigate this and not emit the get_global.
1207   if (op->getUses().empty()) {
1208     rewriter.eraseOp(op);
1209     return mlir::success();
1210   }
1211 
1212   mlir::Type type = getTypeConverter()->convertType(op.getType());
1213   mlir::Operation *newop =
1214       rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), type, op.getName());
1215 
1216   assert(!cir::MissingFeatures::opGlobalThreadLocal());
1217 
1218   rewriter.replaceOp(op, newop);
1219   return mlir::success();
1220 }
1221 
1222 /// Replace CIR global with a region initialized LLVM global and update
1223 /// insertion point to the end of the initializer block.
setupRegionInitializedLLVMGlobalOp(cir::GlobalOp op,mlir::ConversionPatternRewriter & rewriter) const1224 void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
1225     cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const {
1226   const mlir::Type llvmType =
1227       convertTypeForMemory(*getTypeConverter(), dataLayout, op.getSymType());
1228 
1229   // FIXME: These default values are placeholders until the the equivalent
1230   //        attributes are available on cir.global ops. This duplicates code
1231   //        in CIRToLLVMGlobalOpLowering::matchAndRewrite() but that will go
1232   //        away when the placeholders are no longer needed.
1233   assert(!cir::MissingFeatures::opGlobalConstant());
1234   const bool isConst = false;
1235   assert(!cir::MissingFeatures::addressSpace());
1236   const unsigned addrSpace = 0;
1237   const bool isDsoLocal = op.getDsoLocal();
1238   assert(!cir::MissingFeatures::opGlobalThreadLocal());
1239   const bool isThreadLocal = false;
1240   const uint64_t alignment = op.getAlignment().value_or(0);
1241   const mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage());
1242   const StringRef symbol = op.getSymName();
1243   mlir::SymbolRefAttr comdatAttr = getComdatAttr(op, rewriter);
1244 
1245   SmallVector<mlir::NamedAttribute> attributes;
1246   mlir::LLVM::GlobalOp newGlobalOp =
1247       rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>(
1248           op, llvmType, isConst, linkage, symbol, nullptr, alignment, addrSpace,
1249           isDsoLocal, isThreadLocal, comdatAttr, attributes);
1250   newGlobalOp.getRegion().emplaceBlock();
1251   rewriter.setInsertionPointToEnd(newGlobalOp.getInitializerBlock());
1252 }
1253 
1254 mlir::LogicalResult
matchAndRewriteRegionInitializedGlobal(cir::GlobalOp op,mlir::Attribute init,mlir::ConversionPatternRewriter & rewriter) const1255 CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
1256     cir::GlobalOp op, mlir::Attribute init,
1257     mlir::ConversionPatternRewriter &rewriter) const {
1258   // TODO: Generalize this handling when more types are needed here.
1259   assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
1260               cir::ConstComplexAttr, cir::ZeroAttr>(init)));
1261 
1262   // TODO(cir): once LLVM's dialect has proper equivalent attributes this
1263   // should be updated. For now, we use a custom op to initialize globals
1264   // to the appropriate value.
1265   const mlir::Location loc = op.getLoc();
1266   setupRegionInitializedLLVMGlobalOp(op, rewriter);
1267   CIRAttrToValue valueConverter(op, rewriter, typeConverter);
1268   mlir::Value value = valueConverter.visit(init);
1269   rewriter.create<mlir::LLVM::ReturnOp>(loc, value);
1270   return mlir::success();
1271 }
1272 
matchAndRewrite(cir::GlobalOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1273 mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
1274     cir::GlobalOp op, OpAdaptor adaptor,
1275     mlir::ConversionPatternRewriter &rewriter) const {
1276 
1277   std::optional<mlir::Attribute> init = op.getInitialValue();
1278 
1279   // Fetch required values to create LLVM op.
1280   const mlir::Type cirSymType = op.getSymType();
1281 
1282   // This is the LLVM dialect type.
1283   const mlir::Type llvmType =
1284       convertTypeForMemory(*getTypeConverter(), dataLayout, cirSymType);
1285   // FIXME: These default values are placeholders until the the equivalent
1286   //        attributes are available on cir.global ops.
1287   assert(!cir::MissingFeatures::opGlobalConstant());
1288   const bool isConst = false;
1289   assert(!cir::MissingFeatures::addressSpace());
1290   const unsigned addrSpace = 0;
1291   const bool isDsoLocal = op.getDsoLocal();
1292   assert(!cir::MissingFeatures::opGlobalThreadLocal());
1293   const bool isThreadLocal = false;
1294   const uint64_t alignment = op.getAlignment().value_or(0);
1295   const mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage());
1296   const StringRef symbol = op.getSymName();
1297   SmallVector<mlir::NamedAttribute> attributes;
1298   mlir::SymbolRefAttr comdatAttr = getComdatAttr(op, rewriter);
1299 
1300   if (init.has_value()) {
1301     if (mlir::isa<cir::FPAttr, cir::IntAttr, cir::BoolAttr>(init.value())) {
1302       GlobalInitAttrRewriter initRewriter(llvmType, rewriter);
1303       init = initRewriter.visit(init.value());
1304       // If initRewriter returned a null attribute, init will have a value but
1305       // the value will be null. If that happens, initRewriter didn't handle the
1306       // attribute type. It probably needs to be added to
1307       // GlobalInitAttrRewriter.
1308       if (!init.value()) {
1309         op.emitError() << "unsupported initializer '" << init.value() << "'";
1310         return mlir::failure();
1311       }
1312     } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
1313                          cir::ConstPtrAttr, cir::ConstComplexAttr,
1314                          cir::ZeroAttr>(init.value())) {
1315       // TODO(cir): once LLVM's dialect has proper equivalent attributes this
1316       // should be updated. For now, we use a custom op to initialize globals
1317       // to the appropriate value.
1318       return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
1319     } else {
1320       // We will only get here if new initializer types are added and this
1321       // code is not updated to handle them.
1322       op.emitError() << "unsupported initializer '" << init.value() << "'";
1323       return mlir::failure();
1324     }
1325   }
1326 
1327   // Rewrite op.
1328   rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>(
1329       op, llvmType, isConst, linkage, symbol, init.value_or(mlir::Attribute()),
1330       alignment, addrSpace, isDsoLocal, isThreadLocal, comdatAttr, attributes);
1331   return mlir::success();
1332 }
1333 
1334 mlir::SymbolRefAttr
getComdatAttr(cir::GlobalOp & op,mlir::OpBuilder & builder) const1335 CIRToLLVMGlobalOpLowering::getComdatAttr(cir::GlobalOp &op,
1336                                          mlir::OpBuilder &builder) const {
1337   if (!op.getComdat())
1338     return mlir::SymbolRefAttr{};
1339 
1340   mlir::ModuleOp module = op->getParentOfType<mlir::ModuleOp>();
1341   mlir::OpBuilder::InsertionGuard guard(builder);
1342   StringRef comdatName("__llvm_comdat_globals");
1343   if (!comdatOp) {
1344     builder.setInsertionPointToStart(module.getBody());
1345     comdatOp =
1346         builder.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName);
1347   }
1348 
1349   builder.setInsertionPointToStart(&comdatOp.getBody().back());
1350   auto selectorOp = builder.create<mlir::LLVM::ComdatSelectorOp>(
1351       comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any);
1352   return mlir::SymbolRefAttr::get(
1353       builder.getContext(), comdatName,
1354       mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr()));
1355 }
1356 
matchAndRewrite(cir::SwitchFlatOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1357 mlir::LogicalResult CIRToLLVMSwitchFlatOpLowering::matchAndRewrite(
1358     cir::SwitchFlatOp op, OpAdaptor adaptor,
1359     mlir::ConversionPatternRewriter &rewriter) const {
1360 
1361   llvm::SmallVector<mlir::APInt, 8> caseValues;
1362   for (mlir::Attribute val : op.getCaseValues()) {
1363     auto intAttr = cast<cir::IntAttr>(val);
1364     caseValues.push_back(intAttr.getValue());
1365   }
1366 
1367   llvm::SmallVector<mlir::Block *, 8> caseDestinations;
1368   llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
1369 
1370   for (mlir::Block *x : op.getCaseDestinations())
1371     caseDestinations.push_back(x);
1372 
1373   for (mlir::OperandRange x : op.getCaseOperands())
1374     caseOperands.push_back(x);
1375 
1376   // Set switch op to branch to the newly created blocks.
1377   rewriter.setInsertionPoint(op);
1378   rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
1379       op, adaptor.getCondition(), op.getDefaultDestination(),
1380       op.getDefaultOperands(), caseValues, caseDestinations, caseOperands);
1381   return mlir::success();
1382 }
1383 
matchAndRewrite(cir::UnaryOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1384 mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
1385     cir::UnaryOp op, OpAdaptor adaptor,
1386     mlir::ConversionPatternRewriter &rewriter) const {
1387   assert(op.getType() == op.getInput().getType() &&
1388          "Unary operation's operand type and result type are different");
1389   mlir::Type type = op.getType();
1390   mlir::Type elementType = elementTypeIfVector(type);
1391   bool isVector = mlir::isa<cir::VectorType>(type);
1392   mlir::Type llvmType = getTypeConverter()->convertType(type);
1393   mlir::Location loc = op.getLoc();
1394 
1395   // Integer unary operations: + - ~ ++ --
1396   if (mlir::isa<cir::IntType>(elementType)) {
1397     mlir::LLVM::IntegerOverflowFlags maybeNSW =
1398         op.getNoSignedWrap() ? mlir::LLVM::IntegerOverflowFlags::nsw
1399                              : mlir::LLVM::IntegerOverflowFlags::none;
1400     switch (op.getKind()) {
1401     case cir::UnaryOpKind::Inc: {
1402       assert(!isVector && "++ not allowed on vector types");
1403       auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
1404       rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(
1405           op, llvmType, adaptor.getInput(), one, maybeNSW);
1406       return mlir::success();
1407     }
1408     case cir::UnaryOpKind::Dec: {
1409       assert(!isVector && "-- not allowed on vector types");
1410       auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
1411       rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, adaptor.getInput(),
1412                                                      one, maybeNSW);
1413       return mlir::success();
1414     }
1415     case cir::UnaryOpKind::Plus:
1416       rewriter.replaceOp(op, adaptor.getInput());
1417       return mlir::success();
1418     case cir::UnaryOpKind::Minus: {
1419       mlir::Value zero;
1420       if (isVector)
1421         zero = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmType);
1422       else
1423         zero = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 0);
1424       rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
1425           op, zero, adaptor.getInput(), maybeNSW);
1426       return mlir::success();
1427     }
1428     case cir::UnaryOpKind::Not: {
1429       // bit-wise compliment operator, implemented as an XOR with -1.
1430       mlir::Value minusOne;
1431       if (isVector) {
1432         const uint64_t numElements =
1433             mlir::dyn_cast<cir::VectorType>(type).getSize();
1434         std::vector<int32_t> values(numElements, -1);
1435         mlir::DenseIntElementsAttr denseVec = rewriter.getI32VectorAttr(values);
1436         minusOne =
1437             rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, denseVec);
1438       } else {
1439         minusOne = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, -1);
1440       }
1441       rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(),
1442                                                      minusOne);
1443       return mlir::success();
1444     }
1445     }
1446     llvm_unreachable("Unexpected unary op for int");
1447   }
1448 
1449   // Floating point unary operations: + - ++ --
1450   if (mlir::isa<cir::FPTypeInterface>(elementType)) {
1451     switch (op.getKind()) {
1452     case cir::UnaryOpKind::Inc: {
1453       assert(!isVector && "++ not allowed on vector types");
1454       mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>(
1455           loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
1456       rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmType, one,
1457                                                       adaptor.getInput());
1458       return mlir::success();
1459     }
1460     case cir::UnaryOpKind::Dec: {
1461       assert(!isVector && "-- not allowed on vector types");
1462       mlir::LLVM::ConstantOp minusOne = rewriter.create<mlir::LLVM::ConstantOp>(
1463           loc, llvmType, rewriter.getFloatAttr(llvmType, -1.0));
1464       rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmType, minusOne,
1465                                                       adaptor.getInput());
1466       return mlir::success();
1467     }
1468     case cir::UnaryOpKind::Plus:
1469       rewriter.replaceOp(op, adaptor.getInput());
1470       return mlir::success();
1471     case cir::UnaryOpKind::Minus:
1472       rewriter.replaceOpWithNewOp<mlir::LLVM::FNegOp>(op, llvmType,
1473                                                       adaptor.getInput());
1474       return mlir::success();
1475     case cir::UnaryOpKind::Not:
1476       return op.emitError() << "Unary not is invalid for floating-point types";
1477     }
1478     llvm_unreachable("Unexpected unary op for float");
1479   }
1480 
1481   // Boolean unary operations: ! only. (For all others, the operand has
1482   // already been promoted to int.)
1483   if (mlir::isa<cir::BoolType>(elementType)) {
1484     switch (op.getKind()) {
1485     case cir::UnaryOpKind::Inc:
1486     case cir::UnaryOpKind::Dec:
1487     case cir::UnaryOpKind::Plus:
1488     case cir::UnaryOpKind::Minus:
1489       // Some of these are allowed in source code, but we shouldn't get here
1490       // with a boolean type.
1491       return op.emitError() << "Unsupported unary operation on boolean type";
1492     case cir::UnaryOpKind::Not: {
1493       assert(!isVector && "NYI: op! on vector mask");
1494       auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
1495       rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(),
1496                                                      one);
1497       return mlir::success();
1498     }
1499     }
1500     llvm_unreachable("Unexpected unary op for bool");
1501   }
1502 
1503   // Pointer unary operations: + only.  (++ and -- of pointers are implemented
1504   // with cir.ptr_stride, not cir.unary.)
1505   if (mlir::isa<cir::PointerType>(elementType)) {
1506     return op.emitError()
1507            << "Unary operation on pointer types is not yet implemented";
1508   }
1509 
1510   return op.emitError() << "Unary operation has unsupported type: "
1511                         << elementType;
1512 }
1513 
1514 mlir::LLVM::IntegerOverflowFlags
getIntOverflowFlag(cir::BinOp op) const1515 CIRToLLVMBinOpLowering::getIntOverflowFlag(cir::BinOp op) const {
1516   if (op.getNoUnsignedWrap())
1517     return mlir::LLVM::IntegerOverflowFlags::nuw;
1518 
1519   if (op.getNoSignedWrap())
1520     return mlir::LLVM::IntegerOverflowFlags::nsw;
1521 
1522   return mlir::LLVM::IntegerOverflowFlags::none;
1523 }
1524 
isIntTypeUnsigned(mlir::Type type)1525 static bool isIntTypeUnsigned(mlir::Type type) {
1526   // TODO: Ideally, we should only need to check cir::IntType here.
1527   return mlir::isa<cir::IntType>(type)
1528              ? mlir::cast<cir::IntType>(type).isUnsigned()
1529              : mlir::cast<mlir::IntegerType>(type).isUnsigned();
1530 }
1531 
matchAndRewrite(cir::BinOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1532 mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
1533     cir::BinOp op, OpAdaptor adaptor,
1534     mlir::ConversionPatternRewriter &rewriter) const {
1535   if (adaptor.getLhs().getType() != adaptor.getRhs().getType())
1536     return op.emitError() << "inconsistent operands' types not supported yet";
1537 
1538   mlir::Type type = op.getRhs().getType();
1539   if (!mlir::isa<cir::IntType, cir::BoolType, cir::FPTypeInterface,
1540                  mlir::IntegerType, cir::VectorType>(type))
1541     return op.emitError() << "operand type not supported yet";
1542 
1543   const mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
1544   const mlir::Type llvmEltTy = elementTypeIfVector(llvmTy);
1545 
1546   const mlir::Value rhs = adaptor.getRhs();
1547   const mlir::Value lhs = adaptor.getLhs();
1548   type = elementTypeIfVector(type);
1549 
1550   switch (op.getKind()) {
1551   case cir::BinOpKind::Add:
1552     if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
1553       if (op.getSaturated()) {
1554         if (isIntTypeUnsigned(type)) {
1555           rewriter.replaceOpWithNewOp<mlir::LLVM::UAddSat>(op, lhs, rhs);
1556           break;
1557         }
1558         rewriter.replaceOpWithNewOp<mlir::LLVM::SAddSat>(op, lhs, rhs);
1559         break;
1560       }
1561       rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(op, llvmTy, lhs, rhs,
1562                                                      getIntOverflowFlag(op));
1563     } else {
1564       rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, lhs, rhs);
1565     }
1566     break;
1567   case cir::BinOpKind::Sub:
1568     if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
1569       if (op.getSaturated()) {
1570         if (isIntTypeUnsigned(type)) {
1571           rewriter.replaceOpWithNewOp<mlir::LLVM::USubSat>(op, lhs, rhs);
1572           break;
1573         }
1574         rewriter.replaceOpWithNewOp<mlir::LLVM::SSubSat>(op, lhs, rhs);
1575         break;
1576       }
1577       rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmTy, lhs, rhs,
1578                                                      getIntOverflowFlag(op));
1579     } else {
1580       rewriter.replaceOpWithNewOp<mlir::LLVM::FSubOp>(op, lhs, rhs);
1581     }
1582     break;
1583   case cir::BinOpKind::Mul:
1584     if (mlir::isa<mlir::IntegerType>(llvmEltTy))
1585       rewriter.replaceOpWithNewOp<mlir::LLVM::MulOp>(op, llvmTy, lhs, rhs,
1586                                                      getIntOverflowFlag(op));
1587     else
1588       rewriter.replaceOpWithNewOp<mlir::LLVM::FMulOp>(op, lhs, rhs);
1589     break;
1590   case cir::BinOpKind::Div:
1591     if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
1592       auto isUnsigned = isIntTypeUnsigned(type);
1593       if (isUnsigned)
1594         rewriter.replaceOpWithNewOp<mlir::LLVM::UDivOp>(op, lhs, rhs);
1595       else
1596         rewriter.replaceOpWithNewOp<mlir::LLVM::SDivOp>(op, lhs, rhs);
1597     } else {
1598       rewriter.replaceOpWithNewOp<mlir::LLVM::FDivOp>(op, lhs, rhs);
1599     }
1600     break;
1601   case cir::BinOpKind::Rem:
1602     if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
1603       auto isUnsigned = isIntTypeUnsigned(type);
1604       if (isUnsigned)
1605         rewriter.replaceOpWithNewOp<mlir::LLVM::URemOp>(op, lhs, rhs);
1606       else
1607         rewriter.replaceOpWithNewOp<mlir::LLVM::SRemOp>(op, lhs, rhs);
1608     } else {
1609       rewriter.replaceOpWithNewOp<mlir::LLVM::FRemOp>(op, lhs, rhs);
1610     }
1611     break;
1612   case cir::BinOpKind::And:
1613     rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, lhs, rhs);
1614     break;
1615   case cir::BinOpKind::Or:
1616     rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, lhs, rhs);
1617     break;
1618   case cir::BinOpKind::Xor:
1619     rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, lhs, rhs);
1620     break;
1621   case cir::BinOpKind::Max:
1622     if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
1623       auto isUnsigned = isIntTypeUnsigned(type);
1624       if (isUnsigned)
1625         rewriter.replaceOpWithNewOp<mlir::LLVM::UMaxOp>(op, llvmTy, lhs, rhs);
1626       else
1627         rewriter.replaceOpWithNewOp<mlir::LLVM::SMaxOp>(op, llvmTy, lhs, rhs);
1628     }
1629     break;
1630   }
1631   return mlir::LogicalResult::success();
1632 }
1633 
1634 /// Convert from a CIR comparison kind to an LLVM IR integral comparison kind.
1635 static mlir::LLVM::ICmpPredicate
convertCmpKindToICmpPredicate(cir::CmpOpKind kind,bool isSigned)1636 convertCmpKindToICmpPredicate(cir::CmpOpKind kind, bool isSigned) {
1637   using CIR = cir::CmpOpKind;
1638   using LLVMICmp = mlir::LLVM::ICmpPredicate;
1639   switch (kind) {
1640   case CIR::eq:
1641     return LLVMICmp::eq;
1642   case CIR::ne:
1643     return LLVMICmp::ne;
1644   case CIR::lt:
1645     return (isSigned ? LLVMICmp::slt : LLVMICmp::ult);
1646   case CIR::le:
1647     return (isSigned ? LLVMICmp::sle : LLVMICmp::ule);
1648   case CIR::gt:
1649     return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt);
1650   case CIR::ge:
1651     return (isSigned ? LLVMICmp::sge : LLVMICmp::uge);
1652   }
1653   llvm_unreachable("Unknown CmpOpKind");
1654 }
1655 
1656 /// Convert from a CIR comparison kind to an LLVM IR floating-point comparison
1657 /// kind.
1658 static mlir::LLVM::FCmpPredicate
convertCmpKindToFCmpPredicate(cir::CmpOpKind kind)1659 convertCmpKindToFCmpPredicate(cir::CmpOpKind kind) {
1660   using CIR = cir::CmpOpKind;
1661   using LLVMFCmp = mlir::LLVM::FCmpPredicate;
1662   switch (kind) {
1663   case CIR::eq:
1664     return LLVMFCmp::oeq;
1665   case CIR::ne:
1666     return LLVMFCmp::une;
1667   case CIR::lt:
1668     return LLVMFCmp::olt;
1669   case CIR::le:
1670     return LLVMFCmp::ole;
1671   case CIR::gt:
1672     return LLVMFCmp::ogt;
1673   case CIR::ge:
1674     return LLVMFCmp::oge;
1675   }
1676   llvm_unreachable("Unknown CmpOpKind");
1677 }
1678 
matchAndRewrite(cir::CmpOp cmpOp,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1679 mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
1680     cir::CmpOp cmpOp, OpAdaptor adaptor,
1681     mlir::ConversionPatternRewriter &rewriter) const {
1682   mlir::Type type = cmpOp.getLhs().getType();
1683 
1684   assert(!cir::MissingFeatures::dataMemberType());
1685   assert(!cir::MissingFeatures::methodType());
1686 
1687   if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
1688     bool isSigned = mlir::isa<cir::IntType>(type)
1689                         ? mlir::cast<cir::IntType>(type).isSigned()
1690                         : mlir::cast<mlir::IntegerType>(type).isSigned();
1691     mlir::LLVM::ICmpPredicate kind =
1692         convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned);
1693     rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
1694         cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1695     return mlir::success();
1696   }
1697 
1698   if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) {
1699     mlir::LLVM::ICmpPredicate kind =
1700         convertCmpKindToICmpPredicate(cmpOp.getKind(),
1701                                       /* isSigned=*/false);
1702     rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
1703         cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1704     return mlir::success();
1705   }
1706 
1707   if (mlir::isa<cir::FPTypeInterface>(type)) {
1708     mlir::LLVM::FCmpPredicate kind =
1709         convertCmpKindToFCmpPredicate(cmpOp.getKind());
1710     rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(
1711         cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1712     return mlir::success();
1713   }
1714 
1715   if (mlir::isa<cir::ComplexType>(type)) {
1716     mlir::Value lhs = adaptor.getLhs();
1717     mlir::Value rhs = adaptor.getRhs();
1718     mlir::Location loc = cmpOp.getLoc();
1719 
1720     auto complexType = mlir::cast<cir::ComplexType>(cmpOp.getLhs().getType());
1721     mlir::Type complexElemTy =
1722         getTypeConverter()->convertType(complexType.getElementType());
1723 
1724     auto lhsReal =
1725         rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
1726     auto lhsImag =
1727         rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
1728     auto rhsReal =
1729         rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
1730     auto rhsImag =
1731         rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
1732 
1733     if (cmpOp.getKind() == cir::CmpOpKind::eq) {
1734       if (complexElemTy.isInteger()) {
1735         auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1736             loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal);
1737         auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1738             loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag);
1739         rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmpOp, realCmp, imagCmp);
1740         return mlir::success();
1741       }
1742 
1743       auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1744           loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal);
1745       auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1746           loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag);
1747       rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmpOp, realCmp, imagCmp);
1748       return mlir::success();
1749     }
1750 
1751     if (cmpOp.getKind() == cir::CmpOpKind::ne) {
1752       if (complexElemTy.isInteger()) {
1753         auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1754             loc, mlir::LLVM::ICmpPredicate::ne, lhsReal, rhsReal);
1755         auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1756             loc, mlir::LLVM::ICmpPredicate::ne, lhsImag, rhsImag);
1757         rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmpOp, realCmp, imagCmp);
1758         return mlir::success();
1759       }
1760 
1761       auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1762           loc, mlir::LLVM::FCmpPredicate::une, lhsReal, rhsReal);
1763       auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1764           loc, mlir::LLVM::FCmpPredicate::une, lhsImag, rhsImag);
1765       rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmpOp, realCmp, imagCmp);
1766       return mlir::success();
1767     }
1768   }
1769 
1770   return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
1771 }
1772 
matchAndRewrite(cir::ShiftOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1773 mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
1774     cir::ShiftOp op, OpAdaptor adaptor,
1775     mlir::ConversionPatternRewriter &rewriter) const {
1776   assert((op.getValue().getType() == op.getType()) &&
1777          "inconsistent operands' types NYI");
1778 
1779   const mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
1780   mlir::Value amt = adaptor.getAmount();
1781   mlir::Value val = adaptor.getValue();
1782 
1783   auto cirAmtTy = mlir::dyn_cast<cir::IntType>(op.getAmount().getType());
1784   bool isUnsigned;
1785   if (cirAmtTy) {
1786     auto cirValTy = mlir::cast<cir::IntType>(op.getValue().getType());
1787     isUnsigned = cirValTy.isUnsigned();
1788 
1789     // Ensure shift amount is the same type as the value. Some undefined
1790     // behavior might occur in the casts below as per [C99 6.5.7.3].
1791     // Vector type shift amount needs no cast as type consistency is expected to
1792     // be already be enforced at CIRGen.
1793     if (cirAmtTy)
1794       amt = getLLVMIntCast(rewriter, amt, llvmTy, true, cirAmtTy.getWidth(),
1795                            cirValTy.getWidth());
1796   } else {
1797     auto cirValVTy = mlir::cast<cir::VectorType>(op.getValue().getType());
1798     isUnsigned =
1799         mlir::cast<cir::IntType>(cirValVTy.getElementType()).isUnsigned();
1800   }
1801 
1802   // Lower to the proper LLVM shift operation.
1803   if (op.getIsShiftleft()) {
1804     rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
1805     return mlir::success();
1806   }
1807 
1808   if (isUnsigned)
1809     rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
1810   else
1811     rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt);
1812   return mlir::success();
1813 }
1814 
matchAndRewrite(cir::SelectOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1815 mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite(
1816     cir::SelectOp op, OpAdaptor adaptor,
1817     mlir::ConversionPatternRewriter &rewriter) const {
1818   auto getConstantBool = [](mlir::Value value) -> cir::BoolAttr {
1819     auto definingOp =
1820         mlir::dyn_cast_if_present<cir::ConstantOp>(value.getDefiningOp());
1821     if (!definingOp)
1822       return {};
1823 
1824     auto constValue = mlir::dyn_cast<cir::BoolAttr>(definingOp.getValue());
1825     if (!constValue)
1826       return {};
1827 
1828     return constValue;
1829   };
1830 
1831   // Two special cases in the LLVMIR codegen of select op:
1832   // - select %0, %1, false => and %0, %1
1833   // - select %0, true, %1 => or %0, %1
1834   if (mlir::isa<cir::BoolType>(op.getTrueValue().getType())) {
1835     cir::BoolAttr trueValue = getConstantBool(op.getTrueValue());
1836     cir::BoolAttr falseValue = getConstantBool(op.getFalseValue());
1837     if (falseValue && !falseValue.getValue()) {
1838       // select %0, %1, false => and %0, %1
1839       rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, adaptor.getCondition(),
1840                                                      adaptor.getTrueValue());
1841       return mlir::success();
1842     }
1843     if (trueValue && trueValue.getValue()) {
1844       // select %0, true, %1 => or %0, %1
1845       rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, adaptor.getCondition(),
1846                                                     adaptor.getFalseValue());
1847       return mlir::success();
1848     }
1849   }
1850 
1851   mlir::Value llvmCondition = adaptor.getCondition();
1852   rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
1853       op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue());
1854 
1855   return mlir::success();
1856 }
1857 
prepareTypeConverter(mlir::LLVMTypeConverter & converter,mlir::DataLayout & dataLayout)1858 static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
1859                                  mlir::DataLayout &dataLayout) {
1860   converter.addConversion([&](cir::PointerType type) -> mlir::Type {
1861     // Drop pointee type since LLVM dialect only allows opaque pointers.
1862     assert(!cir::MissingFeatures::addressSpace());
1863     unsigned targetAS = 0;
1864 
1865     return mlir::LLVM::LLVMPointerType::get(type.getContext(), targetAS);
1866   });
1867   converter.addConversion([&](cir::ArrayType type) -> mlir::Type {
1868     mlir::Type ty =
1869         convertTypeForMemory(converter, dataLayout, type.getElementType());
1870     return mlir::LLVM::LLVMArrayType::get(ty, type.getSize());
1871   });
1872   converter.addConversion([&](cir::VectorType type) -> mlir::Type {
1873     const mlir::Type ty = converter.convertType(type.getElementType());
1874     return mlir::VectorType::get(type.getSize(), ty);
1875   });
1876   converter.addConversion([&](cir::BoolType type) -> mlir::Type {
1877     return mlir::IntegerType::get(type.getContext(), 1,
1878                                   mlir::IntegerType::Signless);
1879   });
1880   converter.addConversion([&](cir::IntType type) -> mlir::Type {
1881     // LLVM doesn't work with signed types, so we drop the CIR signs here.
1882     return mlir::IntegerType::get(type.getContext(), type.getWidth());
1883   });
1884   converter.addConversion([&](cir::SingleType type) -> mlir::Type {
1885     return mlir::Float32Type::get(type.getContext());
1886   });
1887   converter.addConversion([&](cir::DoubleType type) -> mlir::Type {
1888     return mlir::Float64Type::get(type.getContext());
1889   });
1890   converter.addConversion([&](cir::FP80Type type) -> mlir::Type {
1891     return mlir::Float80Type::get(type.getContext());
1892   });
1893   converter.addConversion([&](cir::FP128Type type) -> mlir::Type {
1894     return mlir::Float128Type::get(type.getContext());
1895   });
1896   converter.addConversion([&](cir::LongDoubleType type) -> mlir::Type {
1897     return converter.convertType(type.getUnderlying());
1898   });
1899   converter.addConversion([&](cir::FP16Type type) -> mlir::Type {
1900     return mlir::Float16Type::get(type.getContext());
1901   });
1902   converter.addConversion([&](cir::BF16Type type) -> mlir::Type {
1903     return mlir::BFloat16Type::get(type.getContext());
1904   });
1905   converter.addConversion([&](cir::ComplexType type) -> mlir::Type {
1906     // A complex type is lowered to an LLVM struct that contains the real and
1907     // imaginary part as data fields.
1908     mlir::Type elementTy = converter.convertType(type.getElementType());
1909     mlir::Type structFields[2] = {elementTy, elementTy};
1910     return mlir::LLVM::LLVMStructType::getLiteral(type.getContext(),
1911                                                   structFields);
1912   });
1913   converter.addConversion([&](cir::FuncType type) -> std::optional<mlir::Type> {
1914     auto result = converter.convertType(type.getReturnType());
1915     llvm::SmallVector<mlir::Type> arguments;
1916     arguments.reserve(type.getNumInputs());
1917     if (converter.convertTypes(type.getInputs(), arguments).failed())
1918       return std::nullopt;
1919     auto varArg = type.isVarArg();
1920     return mlir::LLVM::LLVMFunctionType::get(result, arguments, varArg);
1921   });
1922   converter.addConversion([&](cir::RecordType type) -> mlir::Type {
1923     // Convert struct members.
1924     llvm::SmallVector<mlir::Type> llvmMembers;
1925     switch (type.getKind()) {
1926     case cir::RecordType::Class:
1927     case cir::RecordType::Struct:
1928       for (mlir::Type ty : type.getMembers())
1929         llvmMembers.push_back(convertTypeForMemory(converter, dataLayout, ty));
1930       break;
1931     // Unions are lowered as only the largest member.
1932     case cir::RecordType::Union:
1933       if (auto largestMember = type.getLargestMember(dataLayout))
1934         llvmMembers.push_back(
1935             convertTypeForMemory(converter, dataLayout, largestMember));
1936       if (type.getPadded()) {
1937         auto last = *type.getMembers().rbegin();
1938         llvmMembers.push_back(
1939             convertTypeForMemory(converter, dataLayout, last));
1940       }
1941       break;
1942     }
1943 
1944     // Record has a name: lower as an identified record.
1945     mlir::LLVM::LLVMStructType llvmStruct;
1946     if (type.getName()) {
1947       llvmStruct = mlir::LLVM::LLVMStructType::getIdentified(
1948           type.getContext(), type.getPrefixedName());
1949       if (llvmStruct.setBody(llvmMembers, type.getPacked()).failed())
1950         llvm_unreachable("Failed to set body of record");
1951     } else { // Record has no name: lower as literal record.
1952       llvmStruct = mlir::LLVM::LLVMStructType::getLiteral(
1953           type.getContext(), llvmMembers, type.getPacked());
1954     }
1955 
1956     return llvmStruct;
1957   });
1958 }
1959 
1960 // The applyPartialConversion function traverses blocks in the dominance order,
1961 // so it does not lower and operations that are not reachachable from the
1962 // operations passed in as arguments. Since we do need to lower such code in
1963 // order to avoid verification errors occur, we cannot just pass the module op
1964 // to applyPartialConversion. We must build a set of unreachable ops and
1965 // explicitly add them, along with the module, to the vector we pass to
1966 // applyPartialConversion.
1967 //
1968 // For instance, this CIR code:
1969 //
1970 //    cir.func @foo(%arg0: !s32i) -> !s32i {
1971 //      %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool
1972 //      cir.if %4 {
1973 //        %5 = cir.const #cir.int<1> : !s32i
1974 //        cir.return %5 : !s32i
1975 //      } else {
1976 //        %5 = cir.const #cir.int<0> : !s32i
1977 //       cir.return %5 : !s32i
1978 //      }
1979 //      cir.return %arg0 : !s32i
1980 //    }
1981 //
1982 // contains an unreachable return operation (the last one). After the flattening
1983 // pass it will be placed into the unreachable block. The possible error
1984 // after the lowering pass is: error: 'cir.return' op expects parent op to be
1985 // one of 'cir.func, cir.scope, cir.if ... The reason that this operation was
1986 // not lowered and the new parent is llvm.func.
1987 //
1988 // In the future we may want to get rid of this function and use a DCE pass or
1989 // something similar. But for now we need to guarantee the absence of the
1990 // dialect verification errors.
collectUnreachable(mlir::Operation * parent,llvm::SmallVector<mlir::Operation * > & ops)1991 static void collectUnreachable(mlir::Operation *parent,
1992                                llvm::SmallVector<mlir::Operation *> &ops) {
1993 
1994   llvm::SmallVector<mlir::Block *> unreachableBlocks;
1995   parent->walk([&](mlir::Block *blk) { // check
1996     if (blk->hasNoPredecessors() && !blk->isEntryBlock())
1997       unreachableBlocks.push_back(blk);
1998   });
1999 
2000   std::set<mlir::Block *> visited;
2001   for (mlir::Block *root : unreachableBlocks) {
2002     // We create a work list for each unreachable block.
2003     // Thus we traverse operations in some order.
2004     std::deque<mlir::Block *> workList;
2005     workList.push_back(root);
2006 
2007     while (!workList.empty()) {
2008       mlir::Block *blk = workList.back();
2009       workList.pop_back();
2010       if (visited.count(blk))
2011         continue;
2012       visited.emplace(blk);
2013 
2014       for (mlir::Operation &op : *blk)
2015         ops.push_back(&op);
2016 
2017       for (mlir::Block *succ : blk->getSuccessors())
2018         workList.push_back(succ);
2019     }
2020   }
2021 }
2022 
processCIRAttrs(mlir::ModuleOp module)2023 void ConvertCIRToLLVMPass::processCIRAttrs(mlir::ModuleOp module) {
2024   // Lower the module attributes to LLVM equivalents.
2025   if (mlir::Attribute tripleAttr =
2026           module->getAttr(cir::CIRDialect::getTripleAttrName()))
2027     module->setAttr(mlir::LLVM::LLVMDialect::getTargetTripleAttrName(),
2028                     tripleAttr);
2029 }
2030 
runOnOperation()2031 void ConvertCIRToLLVMPass::runOnOperation() {
2032   llvm::TimeTraceScope scope("Convert CIR to LLVM Pass");
2033 
2034   mlir::ModuleOp module = getOperation();
2035   mlir::DataLayout dl(module);
2036   mlir::LLVMTypeConverter converter(&getContext());
2037   prepareTypeConverter(converter, dl);
2038 
2039   mlir::RewritePatternSet patterns(&getContext());
2040 
2041   patterns.add<CIRToLLVMReturnOpLowering>(patterns.getContext());
2042   // This could currently be merged with the group below, but it will get more
2043   // arguments later, so we'll keep it separate for now.
2044   patterns.add<CIRToLLVMAllocaOpLowering>(converter, patterns.getContext(), dl);
2045   patterns.add<CIRToLLVMLoadOpLowering>(converter, patterns.getContext(), dl);
2046   patterns.add<CIRToLLVMStoreOpLowering>(converter, patterns.getContext(), dl);
2047   patterns.add<CIRToLLVMGlobalOpLowering>(converter, patterns.getContext(), dl);
2048   patterns.add<CIRToLLVMCastOpLowering>(converter, patterns.getContext(), dl);
2049   patterns.add<CIRToLLVMPtrStrideOpLowering>(converter, patterns.getContext(),
2050                                              dl);
2051   patterns.add<
2052       // clang-format off
2053                CIRToLLVMAssumeOpLowering,
2054                CIRToLLVMBaseClassAddrOpLowering,
2055                CIRToLLVMBinOpLowering,
2056                CIRToLLVMBitClrsbOpLowering,
2057                CIRToLLVMBitClzOpLowering,
2058                CIRToLLVMBitCtzOpLowering,
2059                CIRToLLVMBitParityOpLowering,
2060                CIRToLLVMBitPopcountOpLowering,
2061                CIRToLLVMBitReverseOpLowering,
2062                CIRToLLVMBrCondOpLowering,
2063                CIRToLLVMBrOpLowering,
2064                CIRToLLVMByteSwapOpLowering,
2065                CIRToLLVMCallOpLowering,
2066                CIRToLLVMCmpOpLowering,
2067                CIRToLLVMComplexAddOpLowering,
2068                CIRToLLVMComplexCreateOpLowering,
2069                CIRToLLVMComplexImagOpLowering,
2070                CIRToLLVMComplexImagPtrOpLowering,
2071                CIRToLLVMComplexRealOpLowering,
2072                CIRToLLVMComplexRealPtrOpLowering,
2073                CIRToLLVMComplexSubOpLowering,
2074                CIRToLLVMConstantOpLowering,
2075                CIRToLLVMExpectOpLowering,
2076                CIRToLLVMFuncOpLowering,
2077                CIRToLLVMGetBitfieldOpLowering,
2078                CIRToLLVMGetGlobalOpLowering,
2079                CIRToLLVMGetMemberOpLowering,
2080                CIRToLLVMSelectOpLowering,
2081                CIRToLLVMSetBitfieldOpLowering,
2082                CIRToLLVMShiftOpLowering,
2083                CIRToLLVMStackRestoreOpLowering,
2084                CIRToLLVMStackSaveOpLowering,
2085                CIRToLLVMSwitchFlatOpLowering,
2086                CIRToLLVMTrapOpLowering,
2087                CIRToLLVMUnaryOpLowering,
2088                CIRToLLVMVecCmpOpLowering,
2089                CIRToLLVMVecCreateOpLowering,
2090                CIRToLLVMVecExtractOpLowering,
2091                CIRToLLVMVecInsertOpLowering,
2092                CIRToLLVMVecShuffleDynamicOpLowering,
2093                CIRToLLVMVecShuffleOpLowering,
2094                CIRToLLVMVecSplatOpLowering,
2095                CIRToLLVMVecTernaryOpLowering
2096       // clang-format on
2097       >(converter, patterns.getContext());
2098 
2099   processCIRAttrs(module);
2100 
2101   mlir::ConversionTarget target(getContext());
2102   target.addLegalOp<mlir::ModuleOp>();
2103   target.addLegalDialect<mlir::LLVM::LLVMDialect>();
2104   target.addIllegalDialect<mlir::BuiltinDialect, cir::CIRDialect,
2105                            mlir::func::FuncDialect>();
2106 
2107   llvm::SmallVector<mlir::Operation *> ops;
2108   ops.push_back(module);
2109   collectUnreachable(module, ops);
2110 
2111   if (failed(applyPartialConversion(ops, target, std::move(patterns))))
2112     signalPassFailure();
2113 }
2114 
matchAndRewrite(cir::BrOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2115 mlir::LogicalResult CIRToLLVMBrOpLowering::matchAndRewrite(
2116     cir::BrOp op, OpAdaptor adaptor,
2117     mlir::ConversionPatternRewriter &rewriter) const {
2118   rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(op, adaptor.getOperands(),
2119                                                 op.getDest());
2120   return mlir::LogicalResult::success();
2121 }
2122 
matchAndRewrite(cir::GetMemberOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2123 mlir::LogicalResult CIRToLLVMGetMemberOpLowering::matchAndRewrite(
2124     cir::GetMemberOp op, OpAdaptor adaptor,
2125     mlir::ConversionPatternRewriter &rewriter) const {
2126   mlir::Type llResTy = getTypeConverter()->convertType(op.getType());
2127   const auto recordTy =
2128       mlir::cast<cir::RecordType>(op.getAddrTy().getPointee());
2129   assert(recordTy && "expected record type");
2130 
2131   switch (recordTy.getKind()) {
2132   case cir::RecordType::Class:
2133   case cir::RecordType::Struct: {
2134     // Since the base address is a pointer to an aggregate, the first offset
2135     // is always zero. The second offset tell us which member it will access.
2136     llvm::SmallVector<mlir::LLVM::GEPArg, 2> offset{0, op.getIndex()};
2137     const mlir::Type elementTy = getTypeConverter()->convertType(recordTy);
2138     rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, llResTy, elementTy,
2139                                                    adaptor.getAddr(), offset);
2140     return mlir::success();
2141   }
2142   case cir::RecordType::Union:
2143     // Union members share the address space, so we just need a bitcast to
2144     // conform to type-checking.
2145     rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llResTy,
2146                                                        adaptor.getAddr());
2147     return mlir::success();
2148   }
2149 }
2150 
matchAndRewrite(cir::TrapOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2151 mlir::LogicalResult CIRToLLVMTrapOpLowering::matchAndRewrite(
2152     cir::TrapOp op, OpAdaptor adaptor,
2153     mlir::ConversionPatternRewriter &rewriter) const {
2154   mlir::Location loc = op->getLoc();
2155   rewriter.eraseOp(op);
2156 
2157   rewriter.create<mlir::LLVM::Trap>(loc);
2158 
2159   // Note that the call to llvm.trap is not a terminator in LLVM dialect.
2160   // So we must emit an additional llvm.unreachable to terminate the current
2161   // block.
2162   rewriter.create<mlir::LLVM::UnreachableOp>(loc);
2163 
2164   return mlir::success();
2165 }
2166 
matchAndRewrite(cir::StackSaveOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2167 mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite(
2168     cir::StackSaveOp op, OpAdaptor adaptor,
2169     mlir::ConversionPatternRewriter &rewriter) const {
2170   const mlir::Type ptrTy = getTypeConverter()->convertType(op.getType());
2171   rewriter.replaceOpWithNewOp<mlir::LLVM::StackSaveOp>(op, ptrTy);
2172   return mlir::success();
2173 }
2174 
matchAndRewrite(cir::StackRestoreOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2175 mlir::LogicalResult CIRToLLVMStackRestoreOpLowering::matchAndRewrite(
2176     cir::StackRestoreOp op, OpAdaptor adaptor,
2177     mlir::ConversionPatternRewriter &rewriter) const {
2178   rewriter.replaceOpWithNewOp<mlir::LLVM::StackRestoreOp>(op, adaptor.getPtr());
2179   return mlir::success();
2180 }
2181 
matchAndRewrite(cir::VecCreateOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2182 mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
2183     cir::VecCreateOp op, OpAdaptor adaptor,
2184     mlir::ConversionPatternRewriter &rewriter) const {
2185   // Start with an 'undef' value for the vector.  Then 'insertelement' for
2186   // each of the vector elements.
2187   const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
2188   const mlir::Type llvmTy = typeConverter->convertType(vecTy);
2189   const mlir::Location loc = op.getLoc();
2190   mlir::Value result = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
2191   assert(vecTy.getSize() == op.getElements().size() &&
2192          "cir.vec.create op count doesn't match vector type elements count");
2193 
2194   for (uint64_t i = 0; i < vecTy.getSize(); ++i) {
2195     const mlir::Value indexValue =
2196         rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
2197     result = rewriter.create<mlir::LLVM::InsertElementOp>(
2198         loc, result, adaptor.getElements()[i], indexValue);
2199   }
2200 
2201   rewriter.replaceOp(op, result);
2202   return mlir::success();
2203 }
2204 
matchAndRewrite(cir::VecExtractOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2205 mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite(
2206     cir::VecExtractOp op, OpAdaptor adaptor,
2207     mlir::ConversionPatternRewriter &rewriter) const {
2208   rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>(
2209       op, adaptor.getVec(), adaptor.getIndex());
2210   return mlir::success();
2211 }
2212 
matchAndRewrite(cir::VecInsertOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2213 mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite(
2214     cir::VecInsertOp op, OpAdaptor adaptor,
2215     mlir::ConversionPatternRewriter &rewriter) const {
2216   rewriter.replaceOpWithNewOp<mlir::LLVM::InsertElementOp>(
2217       op, adaptor.getVec(), adaptor.getValue(), adaptor.getIndex());
2218   return mlir::success();
2219 }
2220 
matchAndRewrite(cir::VecCmpOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2221 mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite(
2222     cir::VecCmpOp op, OpAdaptor adaptor,
2223     mlir::ConversionPatternRewriter &rewriter) const {
2224   mlir::Type elementType = elementTypeIfVector(op.getLhs().getType());
2225   mlir::Value bitResult;
2226   if (auto intType = mlir::dyn_cast<cir::IntType>(elementType)) {
2227     bitResult = rewriter.create<mlir::LLVM::ICmpOp>(
2228         op.getLoc(),
2229         convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()),
2230         adaptor.getLhs(), adaptor.getRhs());
2231   } else if (mlir::isa<cir::FPTypeInterface>(elementType)) {
2232     bitResult = rewriter.create<mlir::LLVM::FCmpOp>(
2233         op.getLoc(), convertCmpKindToFCmpPredicate(op.getKind()),
2234         adaptor.getLhs(), adaptor.getRhs());
2235   } else {
2236     return op.emitError() << "unsupported type for VecCmpOp: " << elementType;
2237   }
2238 
2239   // LLVM IR vector comparison returns a vector of i1. This one-bit vector
2240   // must be sign-extended to the correct result type.
2241   rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(
2242       op, typeConverter->convertType(op.getType()), bitResult);
2243   return mlir::success();
2244 }
2245 
matchAndRewrite(cir::VecSplatOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2246 mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
2247     cir::VecSplatOp op, OpAdaptor adaptor,
2248     mlir::ConversionPatternRewriter &rewriter) const {
2249   // Vector splat can be implemented with an `insertelement` and a
2250   // `shufflevector`, which is better than an `insertelement` for each
2251   // element in the vector. Start with an undef vector. Insert the value into
2252   // the first element. Then use a `shufflevector` with a mask of all 0 to
2253   // fill out the entire vector with that value.
2254   cir::VectorType vecTy = op.getType();
2255   mlir::Type llvmTy = typeConverter->convertType(vecTy);
2256   mlir::Location loc = op.getLoc();
2257   mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
2258 
2259   mlir::Value elementValue = adaptor.getValue();
2260   if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
2261     // If the splat value is poison, then we can just use poison value
2262     // for the entire vector.
2263     rewriter.replaceOp(op, poison);
2264     return mlir::success();
2265   }
2266 
2267   if (auto constValue =
2268           dyn_cast<mlir::LLVM::ConstantOp>(elementValue.getDefiningOp())) {
2269     if (auto intAttr = dyn_cast<mlir::IntegerAttr>(constValue.getValue())) {
2270       mlir::DenseIntElementsAttr denseVec = mlir::DenseIntElementsAttr::get(
2271           mlir::cast<mlir::ShapedType>(llvmTy), intAttr.getValue());
2272       rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
2273           op, denseVec.getType(), denseVec);
2274       return mlir::success();
2275     }
2276 
2277     if (auto fpAttr = dyn_cast<mlir::FloatAttr>(constValue.getValue())) {
2278       mlir::DenseFPElementsAttr denseVec = mlir::DenseFPElementsAttr::get(
2279           mlir::cast<mlir::ShapedType>(llvmTy), fpAttr.getValue());
2280       rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
2281           op, denseVec.getType(), denseVec);
2282       return mlir::success();
2283     }
2284   }
2285 
2286   mlir::Value indexValue =
2287       rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
2288   mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
2289       loc, poison, elementValue, indexValue);
2290   SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
2291   rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(op, oneElement,
2292                                                            poison, zeroValues);
2293   return mlir::success();
2294 }
2295 
matchAndRewrite(cir::VecShuffleOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2296 mlir::LogicalResult CIRToLLVMVecShuffleOpLowering::matchAndRewrite(
2297     cir::VecShuffleOp op, OpAdaptor adaptor,
2298     mlir::ConversionPatternRewriter &rewriter) const {
2299   // LLVM::ShuffleVectorOp takes an ArrayRef of int for the list of indices.
2300   // Convert the ClangIR ArrayAttr of IntAttr constants into a
2301   // SmallVector<int>.
2302   SmallVector<int, 8> indices;
2303   std::transform(
2304       op.getIndices().begin(), op.getIndices().end(),
2305       std::back_inserter(indices), [](mlir::Attribute intAttr) {
2306         return mlir::cast<cir::IntAttr>(intAttr).getValue().getSExtValue();
2307       });
2308   rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(
2309       op, adaptor.getVec1(), adaptor.getVec2(), indices);
2310   return mlir::success();
2311 }
2312 
matchAndRewrite(cir::VecShuffleDynamicOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2313 mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
2314     cir::VecShuffleDynamicOp op, OpAdaptor adaptor,
2315     mlir::ConversionPatternRewriter &rewriter) const {
2316   // LLVM IR does not have an operation that corresponds to this form of
2317   // the built-in.
2318   //     __builtin_shufflevector(V, I)
2319   // is implemented as this pseudocode, where the for loop is unrolled
2320   // and N is the number of elements:
2321   //
2322   // result = undef
2323   // maskbits = NextPowerOf2(N - 1)
2324   // masked = I & maskbits
2325   // for (i in 0 <= i < N)
2326   //    result[i] = V[masked[i]]
2327   mlir::Location loc = op.getLoc();
2328   mlir::Value input = adaptor.getVec();
2329   mlir::Type llvmIndexVecType =
2330       getTypeConverter()->convertType(op.getIndices().getType());
2331   mlir::Type llvmIndexType = getTypeConverter()->convertType(
2332       elementTypeIfVector(op.getIndices().getType()));
2333   uint64_t numElements =
2334       mlir::cast<cir::VectorType>(op.getVec().getType()).getSize();
2335 
2336   uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
2337   mlir::Value maskValue = rewriter.create<mlir::LLVM::ConstantOp>(
2338       loc, llvmIndexType, rewriter.getIntegerAttr(llvmIndexType, maskBits));
2339   mlir::Value maskVector =
2340       rewriter.create<mlir::LLVM::UndefOp>(loc, llvmIndexVecType);
2341 
2342   for (uint64_t i = 0; i < numElements; ++i) {
2343     mlir::Value idxValue =
2344         rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
2345     maskVector = rewriter.create<mlir::LLVM::InsertElementOp>(
2346         loc, maskVector, maskValue, idxValue);
2347   }
2348 
2349   mlir::Value maskedIndices = rewriter.create<mlir::LLVM::AndOp>(
2350       loc, llvmIndexVecType, adaptor.getIndices(), maskVector);
2351   mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(
2352       loc, getTypeConverter()->convertType(op.getVec().getType()));
2353   for (uint64_t i = 0; i < numElements; ++i) {
2354     mlir::Value iValue =
2355         rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
2356     mlir::Value indexValue = rewriter.create<mlir::LLVM::ExtractElementOp>(
2357         loc, maskedIndices, iValue);
2358     mlir::Value valueAtIndex =
2359         rewriter.create<mlir::LLVM::ExtractElementOp>(loc, input, indexValue);
2360     result = rewriter.create<mlir::LLVM::InsertElementOp>(loc, result,
2361                                                           valueAtIndex, iValue);
2362   }
2363   rewriter.replaceOp(op, result);
2364   return mlir::success();
2365 }
2366 
matchAndRewrite(cir::VecTernaryOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2367 mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite(
2368     cir::VecTernaryOp op, OpAdaptor adaptor,
2369     mlir::ConversionPatternRewriter &rewriter) const {
2370   // Convert `cond` into a vector of i1, then use that in a `select` op.
2371   mlir::Value bitVec = rewriter.create<mlir::LLVM::ICmpOp>(
2372       op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(),
2373       rewriter.create<mlir::LLVM::ZeroOp>(
2374           op.getCond().getLoc(),
2375           typeConverter->convertType(op.getCond().getType())));
2376   rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
2377       op, bitVec, adaptor.getLhs(), adaptor.getRhs());
2378   return mlir::success();
2379 }
2380 
matchAndRewrite(cir::ComplexAddOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2381 mlir::LogicalResult CIRToLLVMComplexAddOpLowering::matchAndRewrite(
2382     cir::ComplexAddOp op, OpAdaptor adaptor,
2383     mlir::ConversionPatternRewriter &rewriter) const {
2384   mlir::Value lhs = adaptor.getLhs();
2385   mlir::Value rhs = adaptor.getRhs();
2386   mlir::Location loc = op.getLoc();
2387 
2388   auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType());
2389   mlir::Type complexElemTy =
2390       getTypeConverter()->convertType(complexType.getElementType());
2391   auto lhsReal =
2392       rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
2393   auto lhsImag =
2394       rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
2395   auto rhsReal =
2396       rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
2397   auto rhsImag =
2398       rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
2399 
2400   mlir::Value newReal;
2401   mlir::Value newImag;
2402   if (complexElemTy.isInteger()) {
2403     newReal = rewriter.create<mlir::LLVM::AddOp>(loc, complexElemTy, lhsReal,
2404                                                  rhsReal);
2405     newImag = rewriter.create<mlir::LLVM::AddOp>(loc, complexElemTy, lhsImag,
2406                                                  rhsImag);
2407   } else {
2408     assert(!cir::MissingFeatures::fastMathFlags());
2409     assert(!cir::MissingFeatures::fpConstraints());
2410     newReal = rewriter.create<mlir::LLVM::FAddOp>(loc, complexElemTy, lhsReal,
2411                                                   rhsReal);
2412     newImag = rewriter.create<mlir::LLVM::FAddOp>(loc, complexElemTy, lhsImag,
2413                                                   rhsImag);
2414   }
2415 
2416   mlir::Type complexLLVMTy =
2417       getTypeConverter()->convertType(op.getResult().getType());
2418   auto initialComplex =
2419       rewriter.create<mlir::LLVM::PoisonOp>(op->getLoc(), complexLLVMTy);
2420 
2421   auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>(
2422       op->getLoc(), initialComplex, newReal, 0);
2423 
2424   rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(op, realComplex,
2425                                                          newImag, 1);
2426 
2427   return mlir::success();
2428 }
2429 
matchAndRewrite(cir::ComplexCreateOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2430 mlir::LogicalResult CIRToLLVMComplexCreateOpLowering::matchAndRewrite(
2431     cir::ComplexCreateOp op, OpAdaptor adaptor,
2432     mlir::ConversionPatternRewriter &rewriter) const {
2433   mlir::Type complexLLVMTy =
2434       getTypeConverter()->convertType(op.getResult().getType());
2435   auto initialComplex =
2436       rewriter.create<mlir::LLVM::UndefOp>(op->getLoc(), complexLLVMTy);
2437 
2438   auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>(
2439       op->getLoc(), initialComplex, adaptor.getReal(), 0);
2440 
2441   auto complex = rewriter.create<mlir::LLVM::InsertValueOp>(
2442       op->getLoc(), realComplex, adaptor.getImag(), 1);
2443 
2444   rewriter.replaceOp(op, complex);
2445   return mlir::success();
2446 }
2447 
matchAndRewrite(cir::ComplexRealOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2448 mlir::LogicalResult CIRToLLVMComplexRealOpLowering::matchAndRewrite(
2449     cir::ComplexRealOp op, OpAdaptor adaptor,
2450     mlir::ConversionPatternRewriter &rewriter) const {
2451   mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
2452   rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
2453       op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{0});
2454   return mlir::success();
2455 }
2456 
matchAndRewrite(cir::ComplexSubOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2457 mlir::LogicalResult CIRToLLVMComplexSubOpLowering::matchAndRewrite(
2458     cir::ComplexSubOp op, OpAdaptor adaptor,
2459     mlir::ConversionPatternRewriter &rewriter) const {
2460   mlir::Value lhs = adaptor.getLhs();
2461   mlir::Value rhs = adaptor.getRhs();
2462   mlir::Location loc = op.getLoc();
2463 
2464   auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType());
2465   mlir::Type complexElemTy =
2466       getTypeConverter()->convertType(complexType.getElementType());
2467   auto lhsReal =
2468       rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
2469   auto lhsImag =
2470       rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
2471   auto rhsReal =
2472       rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
2473   auto rhsImag =
2474       rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
2475 
2476   mlir::Value newReal;
2477   mlir::Value newImag;
2478   if (complexElemTy.isInteger()) {
2479     newReal = rewriter.create<mlir::LLVM::SubOp>(loc, complexElemTy, lhsReal,
2480                                                  rhsReal);
2481     newImag = rewriter.create<mlir::LLVM::SubOp>(loc, complexElemTy, lhsImag,
2482                                                  rhsImag);
2483   } else {
2484     assert(!cir::MissingFeatures::fastMathFlags());
2485     assert(!cir::MissingFeatures::fpConstraints());
2486     newReal = rewriter.create<mlir::LLVM::FSubOp>(loc, complexElemTy, lhsReal,
2487                                                   rhsReal);
2488     newImag = rewriter.create<mlir::LLVM::FSubOp>(loc, complexElemTy, lhsImag,
2489                                                   rhsImag);
2490   }
2491 
2492   mlir::Type complexLLVMTy =
2493       getTypeConverter()->convertType(op.getResult().getType());
2494   auto initialComplex =
2495       rewriter.create<mlir::LLVM::PoisonOp>(op->getLoc(), complexLLVMTy);
2496 
2497   auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>(
2498       op->getLoc(), initialComplex, newReal, 0);
2499 
2500   rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(op, realComplex,
2501                                                          newImag, 1);
2502 
2503   return mlir::success();
2504 }
2505 
matchAndRewrite(cir::ComplexImagOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2506 mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite(
2507     cir::ComplexImagOp op, OpAdaptor adaptor,
2508     mlir::ConversionPatternRewriter &rewriter) const {
2509   mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
2510   rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
2511       op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{1});
2512   return mlir::success();
2513 }
2514 
computeBitfieldIntType(mlir::Type storageType,mlir::MLIRContext * context,unsigned & storageSize)2515 mlir::IntegerType computeBitfieldIntType(mlir::Type storageType,
2516                                          mlir::MLIRContext *context,
2517                                          unsigned &storageSize) {
2518   return TypeSwitch<mlir::Type, mlir::IntegerType>(storageType)
2519       .Case<cir::ArrayType>([&](cir::ArrayType atTy) {
2520         storageSize = atTy.getSize() * 8;
2521         return mlir::IntegerType::get(context, storageSize);
2522       })
2523       .Case<cir::IntType>([&](cir::IntType intTy) {
2524         storageSize = intTy.getWidth();
2525         return mlir::IntegerType::get(context, storageSize);
2526       })
2527       .Default([](mlir::Type) -> mlir::IntegerType {
2528         llvm_unreachable(
2529             "Either ArrayType or IntType expected for bitfields storage");
2530       });
2531 }
2532 
matchAndRewrite(cir::SetBitfieldOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2533 mlir::LogicalResult CIRToLLVMSetBitfieldOpLowering::matchAndRewrite(
2534     cir::SetBitfieldOp op, OpAdaptor adaptor,
2535     mlir::ConversionPatternRewriter &rewriter) const {
2536   mlir::OpBuilder::InsertionGuard guard(rewriter);
2537   rewriter.setInsertionPoint(op);
2538 
2539   cir::BitfieldInfoAttr info = op.getBitfieldInfo();
2540   uint64_t size = info.getSize();
2541   uint64_t offset = info.getOffset();
2542   mlir::Type storageType = info.getStorageType();
2543   mlir::MLIRContext *context = storageType.getContext();
2544 
2545   unsigned storageSize = 0;
2546 
2547   mlir::IntegerType intType =
2548       computeBitfieldIntType(storageType, context, storageSize);
2549 
2550   mlir::Value srcVal = createIntCast(rewriter, adaptor.getSrc(), intType);
2551   unsigned srcWidth = storageSize;
2552   mlir::Value resultVal = srcVal;
2553 
2554   if (storageSize != size) {
2555     assert(storageSize > size && "Invalid bitfield size.");
2556 
2557     mlir::Value val = rewriter.create<mlir::LLVM::LoadOp>(
2558         op.getLoc(), intType, adaptor.getAddr(), /* alignment */ 0,
2559         op.getIsVolatile());
2560 
2561     srcVal =
2562         createAnd(rewriter, srcVal, llvm::APInt::getLowBitsSet(srcWidth, size));
2563     resultVal = srcVal;
2564     srcVal = createShL(rewriter, srcVal, offset);
2565 
2566     // Mask out the original value.
2567     val = createAnd(rewriter, val,
2568                     ~llvm::APInt::getBitsSet(srcWidth, offset, offset + size));
2569 
2570     // Or together the unchanged values and the source value.
2571     srcVal = rewriter.create<mlir::LLVM::OrOp>(op.getLoc(), val, srcVal);
2572   }
2573 
2574   rewriter.create<mlir::LLVM::StoreOp>(op.getLoc(), srcVal, adaptor.getAddr(),
2575                                        /* alignment */ 0, op.getIsVolatile());
2576 
2577   mlir::Type resultTy = getTypeConverter()->convertType(op.getType());
2578 
2579   if (info.getIsSigned()) {
2580     assert(size <= storageSize);
2581     unsigned highBits = storageSize - size;
2582 
2583     if (highBits) {
2584       resultVal = createShL(rewriter, resultVal, highBits);
2585       resultVal = createAShR(rewriter, resultVal, highBits);
2586     }
2587   }
2588 
2589   resultVal = createIntCast(rewriter, resultVal,
2590                             mlir::cast<mlir::IntegerType>(resultTy),
2591                             info.getIsSigned());
2592 
2593   rewriter.replaceOp(op, resultVal);
2594   return mlir::success();
2595 }
2596 
matchAndRewrite(cir::ComplexImagPtrOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2597 mlir::LogicalResult CIRToLLVMComplexImagPtrOpLowering::matchAndRewrite(
2598     cir::ComplexImagPtrOp op, OpAdaptor adaptor,
2599     mlir::ConversionPatternRewriter &rewriter) const {
2600   cir::PointerType operandTy = op.getOperand().getType();
2601   mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
2602   mlir::Type elementLLVMTy =
2603       getTypeConverter()->convertType(operandTy.getPointee());
2604 
2605   mlir::LLVM::GEPArg gepIndices[2] = {{0}, {1}};
2606   mlir::LLVM::GEPNoWrapFlags inboundsNuw =
2607       mlir::LLVM::GEPNoWrapFlags::inbounds | mlir::LLVM::GEPNoWrapFlags::nuw;
2608   rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
2609       op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices,
2610       inboundsNuw);
2611   return mlir::success();
2612 }
2613 
matchAndRewrite(cir::ComplexRealPtrOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2614 mlir::LogicalResult CIRToLLVMComplexRealPtrOpLowering::matchAndRewrite(
2615     cir::ComplexRealPtrOp op, OpAdaptor adaptor,
2616     mlir::ConversionPatternRewriter &rewriter) const {
2617   cir::PointerType operandTy = op.getOperand().getType();
2618   mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
2619   mlir::Type elementLLVMTy =
2620       getTypeConverter()->convertType(operandTy.getPointee());
2621 
2622   mlir::LLVM::GEPArg gepIndices[2] = {0, 0};
2623   mlir::LLVM::GEPNoWrapFlags inboundsNuw =
2624       mlir::LLVM::GEPNoWrapFlags::inbounds | mlir::LLVM::GEPNoWrapFlags::nuw;
2625   rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
2626       op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices,
2627       inboundsNuw);
2628   return mlir::success();
2629 }
2630 
matchAndRewrite(cir::GetBitfieldOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const2631 mlir::LogicalResult CIRToLLVMGetBitfieldOpLowering::matchAndRewrite(
2632     cir::GetBitfieldOp op, OpAdaptor adaptor,
2633     mlir::ConversionPatternRewriter &rewriter) const {
2634 
2635   mlir::OpBuilder::InsertionGuard guard(rewriter);
2636   rewriter.setInsertionPoint(op);
2637 
2638   cir::BitfieldInfoAttr info = op.getBitfieldInfo();
2639   uint64_t size = info.getSize();
2640   uint64_t offset = info.getOffset();
2641   mlir::Type storageType = info.getStorageType();
2642   mlir::MLIRContext *context = storageType.getContext();
2643   unsigned storageSize = 0;
2644 
2645   mlir::IntegerType intType =
2646       computeBitfieldIntType(storageType, context, storageSize);
2647 
2648   mlir::Value val = rewriter.create<mlir::LLVM::LoadOp>(
2649       op.getLoc(), intType, adaptor.getAddr(), 0, op.getIsVolatile());
2650   val = rewriter.create<mlir::LLVM::BitcastOp>(op.getLoc(), intType, val);
2651 
2652   if (info.getIsSigned()) {
2653     assert(static_cast<unsigned>(offset + size) <= storageSize);
2654     unsigned highBits = storageSize - offset - size;
2655     val = createShL(rewriter, val, highBits);
2656     val = createAShR(rewriter, val, offset + highBits);
2657   } else {
2658     val = createLShR(rewriter, val, offset);
2659 
2660     if (static_cast<unsigned>(offset) + size < storageSize)
2661       val = createAnd(rewriter, val,
2662                       llvm::APInt::getLowBitsSet(storageSize, size));
2663   }
2664 
2665   mlir::Type resTy = getTypeConverter()->convertType(op.getType());
2666   mlir::Value newOp = createIntCast(
2667       rewriter, val, mlir::cast<mlir::IntegerType>(resTy), info.getIsSigned());
2668   rewriter.replaceOp(op, newOp);
2669   return mlir::success();
2670 }
2671 
createConvertCIRToLLVMPass()2672 std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
2673   return std::make_unique<ConvertCIRToLLVMPass>();
2674 }
2675 
populateCIRToLLVMPasses(mlir::OpPassManager & pm)2676 void populateCIRToLLVMPasses(mlir::OpPassManager &pm) {
2677   mlir::populateCIRPreLoweringPasses(pm);
2678   pm.addPass(createConvertCIRToLLVMPass());
2679 }
2680 
2681 std::unique_ptr<llvm::Module>
lowerDirectlyFromCIRToLLVMIR(mlir::ModuleOp mlirModule,LLVMContext & llvmCtx)2682 lowerDirectlyFromCIRToLLVMIR(mlir::ModuleOp mlirModule, LLVMContext &llvmCtx) {
2683   llvm::TimeTraceScope scope("lower from CIR to LLVM directly");
2684 
2685   mlir::MLIRContext *mlirCtx = mlirModule.getContext();
2686 
2687   mlir::PassManager pm(mlirCtx);
2688   populateCIRToLLVMPasses(pm);
2689 
2690   (void)mlir::applyPassManagerCLOptions(pm);
2691 
2692   if (mlir::failed(pm.run(mlirModule))) {
2693     // FIXME: Handle any errors where they occurs and return a nullptr here.
2694     report_fatal_error(
2695         "The pass manager failed to lower CIR to LLVMIR dialect!");
2696   }
2697 
2698   mlir::registerBuiltinDialectTranslation(*mlirCtx);
2699   mlir::registerLLVMDialectTranslation(*mlirCtx);
2700   mlir::registerCIRDialectTranslation(*mlirCtx);
2701 
2702   llvm::TimeTraceScope translateScope("translateModuleToLLVMIR");
2703 
2704   StringRef moduleName = mlirModule.getName().value_or("CIRToLLVMModule");
2705   std::unique_ptr<llvm::Module> llvmModule =
2706       mlir::translateModuleToLLVMIR(mlirModule, llvmCtx, moduleName);
2707 
2708   if (!llvmModule) {
2709     // FIXME: Handle any errors where they occurs and return a nullptr here.
2710     report_fatal_error("Lowering from LLVMIR dialect to llvm IR failed!");
2711   }
2712 
2713   return llvmModule;
2714 }
2715 } // namespace direct
2716 } // namespace cir
2717