1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Func/IR/FuncOps.h" 11 #include "mlir/IR/Block.h" 12 #include "mlir/IR/Operation.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/IR/Region.h" 15 #include "mlir/Support/LogicalResult.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 #include "clang/CIR/Dialect/IR/CIRDialect.h" 18 #include "clang/CIR/Dialect/Passes.h" 19 #include "llvm/ADT/SmallVector.h" 20 21 using namespace mlir; 22 using namespace cir; 23 24 //===----------------------------------------------------------------------===// 25 // Rewrite patterns 26 //===----------------------------------------------------------------------===// 27 28 namespace { 29 30 /// Simplify suitable ternary operations into select operations. 31 /// 32 /// For now we only simplify those ternary operations whose true and false 33 /// branches directly yield a value or a constant. That is, both of the true and 34 /// the false branch must either contain a cir.yield operation as the only 35 /// operation in the branch, or contain a cir.const operation followed by a 36 /// cir.yield operation that yields the constant value. 37 /// 38 /// For example, we will simplify the following ternary operation: 39 /// 40 /// %0 = ... 41 /// %1 = cir.ternary (%condition, true { 42 /// %2 = cir.const ... 43 /// cir.yield %2 44 /// } false { 45 /// cir.yield %0 46 /// 47 /// into the following sequence of operations: 48 /// 49 /// %1 = cir.const ... 50 /// %0 = cir.select if %condition then %1 else %2 51 struct SimplifyTernary final : public OpRewritePattern<TernaryOp> { 52 using OpRewritePattern<TernaryOp>::OpRewritePattern; 53 54 LogicalResult matchAndRewrite(TernaryOp op, 55 PatternRewriter &rewriter) const override { 56 if (op->getNumResults() != 1) 57 return mlir::failure(); 58 59 if (!isSimpleTernaryBranch(op.getTrueRegion()) || 60 !isSimpleTernaryBranch(op.getFalseRegion())) 61 return mlir::failure(); 62 63 cir::YieldOp trueBranchYieldOp = 64 mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator()); 65 cir::YieldOp falseBranchYieldOp = 66 mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator()); 67 mlir::Value trueValue = trueBranchYieldOp.getArgs()[0]; 68 mlir::Value falseValue = falseBranchYieldOp.getArgs()[0]; 69 70 rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op); 71 rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op); 72 rewriter.eraseOp(trueBranchYieldOp); 73 rewriter.eraseOp(falseBranchYieldOp); 74 rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue, 75 falseValue); 76 77 return mlir::success(); 78 } 79 80 private: 81 bool isSimpleTernaryBranch(mlir::Region ®ion) const { 82 if (!region.hasOneBlock()) 83 return false; 84 85 mlir::Block &onlyBlock = region.front(); 86 mlir::Block::OpListType &ops = onlyBlock.getOperations(); 87 88 // The region/block could only contain at most 2 operations. 89 if (ops.size() > 2) 90 return false; 91 92 if (ops.size() == 1) { 93 // The region/block only contain a cir.yield operation. 94 return true; 95 } 96 97 // Check whether the region/block contains a cir.const followed by a 98 // cir.yield that yields the value. 99 auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator()); 100 auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>( 101 yieldOp.getArgs()[0].getDefiningOp()); 102 return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock; 103 } 104 }; 105 106 /// Simplify select operations with boolean constants into simpler forms. 107 /// 108 /// This pattern simplifies select operations where both true and false values 109 /// are boolean constants. Two specific cases are handled: 110 /// 111 /// 1. When selecting between true and false based on a condition, 112 /// the operation simplifies to just the condition itself: 113 /// 114 /// %0 = cir.select if %condition then true else false 115 /// -> 116 /// (replaced with %condition directly) 117 /// 118 /// 2. When selecting between false and true based on a condition, 119 /// the operation simplifies to the logical negation of the condition: 120 /// 121 /// %0 = cir.select if %condition then false else true 122 /// -> 123 /// %0 = cir.unary not %condition 124 struct SimplifySelect : public OpRewritePattern<SelectOp> { 125 using OpRewritePattern<SelectOp>::OpRewritePattern; 126 127 LogicalResult matchAndRewrite(SelectOp op, 128 PatternRewriter &rewriter) const final { 129 mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp(); 130 mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp(); 131 auto trueValueConstOp = 132 mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp); 133 auto falseValueConstOp = 134 mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp); 135 if (!trueValueConstOp || !falseValueConstOp) 136 return mlir::failure(); 137 138 auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue()); 139 auto falseValue = 140 mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue()); 141 if (!trueValue || !falseValue) 142 return mlir::failure(); 143 144 // cir.select if %0 then #true else #false -> %0 145 if (trueValue.getValue() && !falseValue.getValue()) { 146 rewriter.replaceAllUsesWith(op, op.getCondition()); 147 rewriter.eraseOp(op); 148 return mlir::success(); 149 } 150 151 // cir.select if %0 then #false else #true -> cir.unary not %0 152 if (!trueValue.getValue() && falseValue.getValue()) { 153 rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not, 154 op.getCondition()); 155 return mlir::success(); 156 } 157 158 return mlir::failure(); 159 } 160 }; 161 162 /// Simplify `cir.switch` operations by folding cascading cases 163 /// into a single `cir.case` with the `anyof` kind. 164 /// 165 /// This pattern identifies cascading cases within a `cir.switch` operation. 166 /// Cascading cases are defined as consecutive `cir.case` operations of kind 167 /// `equal`, each containing a single `cir.yield` operation in their body. 168 /// 169 /// The pattern merges these cascading cases into a single `cir.case` operation 170 /// with kind `anyof`, aggregating all the case values. 171 /// 172 /// The merging process continues until a `cir.case` with a different body 173 /// (e.g., containing `cir.break` or compound stmt) is encountered, which 174 /// breaks the chain. 175 /// 176 /// Example: 177 /// 178 /// Before: 179 /// cir.case equal, [#cir.int<0> : !s32i] { 180 /// cir.yield 181 /// } 182 /// cir.case equal, [#cir.int<1> : !s32i] { 183 /// cir.yield 184 /// } 185 /// cir.case equal, [#cir.int<2> : !s32i] { 186 /// cir.break 187 /// } 188 /// 189 /// After applying SimplifySwitch: 190 /// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : 191 /// !s32i] { 192 /// cir.break 193 /// } 194 struct SimplifySwitch : public OpRewritePattern<SwitchOp> { 195 using OpRewritePattern<SwitchOp>::OpRewritePattern; 196 LogicalResult matchAndRewrite(SwitchOp op, 197 PatternRewriter &rewriter) const override { 198 199 LogicalResult changed = mlir::failure(); 200 SmallVector<CaseOp, 8> cases; 201 SmallVector<CaseOp, 4> cascadingCases; 202 SmallVector<mlir::Attribute, 4> cascadingCaseValues; 203 204 op.collectCases(cases); 205 if (cases.empty()) 206 return mlir::failure(); 207 208 auto flushMergedOps = [&]() { 209 for (CaseOp &c : cascadingCases) 210 rewriter.eraseOp(c); 211 cascadingCases.clear(); 212 cascadingCaseValues.clear(); 213 }; 214 215 auto mergeCascadingInto = [&](CaseOp &target) { 216 rewriter.modifyOpInPlace(target, [&]() { 217 target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues)); 218 target.setKind(CaseOpKind::Anyof); 219 }); 220 changed = mlir::success(); 221 }; 222 223 for (CaseOp c : cases) { 224 cir::CaseOpKind kind = c.getKind(); 225 if (kind == cir::CaseOpKind::Equal && 226 isa<YieldOp>(c.getCaseRegion().front().front())) { 227 // If the case contains only a YieldOp, collect it for cascading merge 228 cascadingCases.push_back(c); 229 cascadingCaseValues.push_back(c.getValue()[0]); 230 } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) { 231 // merge previously collected cascading cases 232 cascadingCaseValues.push_back(c.getValue()[0]); 233 mergeCascadingInto(c); 234 flushMergedOps(); 235 } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) { 236 // If a Default, Anyof or Range case is found and there are previous 237 // cascading cases, merge all of them into the last cascading case. 238 // We don't currently fold case range statements with other case 239 // statements. 240 assert(!cir::MissingFeatures::foldRangeCase()); 241 CaseOp lastCascadingCase = cascadingCases.back(); 242 mergeCascadingInto(lastCascadingCase); 243 cascadingCases.pop_back(); 244 flushMergedOps(); 245 } else { 246 cascadingCases.clear(); 247 cascadingCaseValues.clear(); 248 } 249 } 250 251 // Edge case: all cases are simple cascading cases 252 if (cascadingCases.size() == cases.size()) { 253 CaseOp lastCascadingCase = cascadingCases.back(); 254 mergeCascadingInto(lastCascadingCase); 255 cascadingCases.pop_back(); 256 flushMergedOps(); 257 } 258 259 return changed; 260 } 261 }; 262 263 struct SimplifyVecSplat : public OpRewritePattern<VecSplatOp> { 264 using OpRewritePattern<VecSplatOp>::OpRewritePattern; 265 LogicalResult matchAndRewrite(VecSplatOp op, 266 PatternRewriter &rewriter) const override { 267 mlir::Value splatValue = op.getValue(); 268 auto constant = 269 mlir::dyn_cast_if_present<cir::ConstantOp>(splatValue.getDefiningOp()); 270 if (!constant) 271 return mlir::failure(); 272 273 auto value = constant.getValue(); 274 if (!mlir::isa_and_nonnull<cir::IntAttr>(value) && 275 !mlir::isa_and_nonnull<cir::FPAttr>(value)) 276 return mlir::failure(); 277 278 cir::VectorType resultType = op.getResult().getType(); 279 SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value); 280 auto constVecAttr = cir::ConstVectorAttr::get( 281 resultType, mlir::ArrayAttr::get(getContext(), elements)); 282 283 rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr); 284 return mlir::success(); 285 } 286 }; 287 288 //===----------------------------------------------------------------------===// 289 // CIRSimplifyPass 290 //===----------------------------------------------------------------------===// 291 292 struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> { 293 using CIRSimplifyBase::CIRSimplifyBase; 294 295 void runOnOperation() override; 296 }; 297 298 void populateMergeCleanupPatterns(RewritePatternSet &patterns) { 299 // clang-format off 300 patterns.add< 301 SimplifyTernary, 302 SimplifySelect, 303 SimplifySwitch, 304 SimplifyVecSplat 305 >(patterns.getContext()); 306 // clang-format on 307 } 308 309 void CIRSimplifyPass::runOnOperation() { 310 // Collect rewrite patterns. 311 RewritePatternSet patterns(&getContext()); 312 populateMergeCleanupPatterns(patterns); 313 314 // Collect operations to apply patterns. 315 llvm::SmallVector<Operation *, 16> ops; 316 getOperation()->walk([&](Operation *op) { 317 if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp>(op)) 318 ops.push_back(op); 319 }); 320 321 // Apply patterns. 322 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed()) 323 signalPassFailure(); 324 } 325 326 } // namespace 327 328 std::unique_ptr<Pass> mlir::createCIRSimplifyPass() { 329 return std::make_unique<CIRSimplifyPass>(); 330 } 331