1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements pass that canonicalizes CIR operations, eliminating 10 // redundant branches, empty scopes, and other unnecessary operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/IR/Block.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/Region.h" 20 #include "mlir/Support/LogicalResult.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 #include "clang/CIR/Dialect/IR/CIRDialect.h" 23 #include "clang/CIR/Dialect/Passes.h" 24 #include "clang/CIR/MissingFeatures.h" 25 26 using namespace mlir; 27 using namespace cir; 28 29 namespace { 30 31 /// Removes branches between two blocks if it is the only branch. 32 /// 33 /// From: 34 /// ^bb0: 35 /// cir.br ^bb1 36 /// ^bb1: // pred: ^bb0 37 /// cir.return 38 /// 39 /// To: 40 /// ^bb0: 41 /// cir.return 42 struct RemoveRedundantBranches : public OpRewritePattern<BrOp> { 43 using OpRewritePattern<BrOp>::OpRewritePattern; 44 45 LogicalResult matchAndRewrite(BrOp op, 46 PatternRewriter &rewriter) const final { 47 Block *block = op.getOperation()->getBlock(); 48 Block *dest = op.getDest(); 49 50 assert(!cir::MissingFeatures::labelOp()); 51 52 // Single edge between blocks: merge it. 53 if (block->getNumSuccessors() == 1 && 54 dest->getSinglePredecessor() == block) { 55 rewriter.eraseOp(op); 56 rewriter.mergeBlocks(dest, block); 57 return success(); 58 } 59 60 return failure(); 61 } 62 }; 63 64 struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> { 65 using OpRewritePattern<ScopeOp>::OpRewritePattern; 66 67 LogicalResult matchAndRewrite(ScopeOp op, 68 PatternRewriter &rewriter) const final { 69 // TODO: Remove this logic once CIR uses MLIR infrastructure to remove 70 // trivially dead operations 71 if (op.isEmpty()) { 72 rewriter.eraseOp(op); 73 return success(); 74 } 75 76 Region ®ion = op.getScopeRegion(); 77 if (region.getBlocks().front().getOperations().size() == 1 && 78 isa<YieldOp>(region.getBlocks().front().front())) { 79 rewriter.eraseOp(op); 80 return success(); 81 } 82 83 return failure(); 84 } 85 }; 86 87 struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> { 88 using OpRewritePattern<SwitchOp>::OpRewritePattern; 89 90 LogicalResult matchAndRewrite(SwitchOp op, 91 PatternRewriter &rewriter) const final { 92 if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front()))) 93 return failure(); 94 95 rewriter.eraseOp(op); 96 return success(); 97 } 98 }; 99 100 //===----------------------------------------------------------------------===// 101 // CIRCanonicalizePass 102 //===----------------------------------------------------------------------===// 103 104 struct CIRCanonicalizePass : public CIRCanonicalizeBase<CIRCanonicalizePass> { 105 using CIRCanonicalizeBase::CIRCanonicalizeBase; 106 107 // The same operation rewriting done here could have been performed 108 // by CanonicalizerPass (adding hasCanonicalizer for target Ops and 109 // implementing the same from above in CIRDialects.cpp). However, it's 110 // currently too aggressive for static analysis purposes, since it might 111 // remove things where a diagnostic can be generated. 112 // 113 // FIXME: perhaps we can add one more mode to GreedyRewriteConfig to 114 // disable this behavior. 115 void runOnOperation() override; 116 }; 117 118 void populateCIRCanonicalizePatterns(RewritePatternSet &patterns) { 119 // clang-format off 120 patterns.add< 121 RemoveRedundantBranches, 122 RemoveEmptyScope 123 >(patterns.getContext()); 124 // clang-format on 125 } 126 127 void CIRCanonicalizePass::runOnOperation() { 128 // Collect rewrite patterns. 129 RewritePatternSet patterns(&getContext()); 130 populateCIRCanonicalizePatterns(patterns); 131 132 // Collect operations to apply patterns. 133 llvm::SmallVector<Operation *, 16> ops; 134 getOperation()->walk([&](Operation *op) { 135 assert(!cir::MissingFeatures::switchOp()); 136 assert(!cir::MissingFeatures::tryOp()); 137 assert(!cir::MissingFeatures::complexRealOp()); 138 assert(!cir::MissingFeatures::complexImagOp()); 139 assert(!cir::MissingFeatures::callOp()); 140 141 // Many operations are here to perform a manual `fold` in 142 // applyOpPatternsGreedily. 143 if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp, 144 ComplexCreateOp, ComplexImagOp, ComplexRealOp, VecCmpOp, 145 VecCreateOp, VecExtractOp, VecShuffleOp, VecShuffleDynamicOp, 146 VecTernaryOp>(op)) 147 ops.push_back(op); 148 }); 149 150 // Apply patterns. 151 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed()) 152 signalPassFailure(); 153 } 154 155 } // namespace 156 157 std::unique_ptr<Pass> mlir::createCIRCanonicalizePass() { 158 return std::make_unique<CIRCanonicalizePass>(); 159 } 160