1*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
2*700637cbSDimitry Andric //
3*700637cbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*700637cbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*700637cbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*700637cbSDimitry Andric //
7*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
8*700637cbSDimitry Andric //
9*700637cbSDimitry Andric // This file implements pass that canonicalizes CIR operations, eliminating
10*700637cbSDimitry Andric // redundant branches, empty scopes, and other unnecessary operations.
11*700637cbSDimitry Andric //
12*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
13*700637cbSDimitry Andric
14*700637cbSDimitry Andric #include "PassDetail.h"
15*700637cbSDimitry Andric #include "mlir/Dialect/Func/IR/FuncOps.h"
16*700637cbSDimitry Andric #include "mlir/IR/Block.h"
17*700637cbSDimitry Andric #include "mlir/IR/Operation.h"
18*700637cbSDimitry Andric #include "mlir/IR/PatternMatch.h"
19*700637cbSDimitry Andric #include "mlir/IR/Region.h"
20*700637cbSDimitry Andric #include "mlir/Support/LogicalResult.h"
21*700637cbSDimitry Andric #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22*700637cbSDimitry Andric #include "clang/CIR/Dialect/IR/CIRDialect.h"
23*700637cbSDimitry Andric #include "clang/CIR/Dialect/Passes.h"
24*700637cbSDimitry Andric #include "clang/CIR/MissingFeatures.h"
25*700637cbSDimitry Andric
26*700637cbSDimitry Andric using namespace mlir;
27*700637cbSDimitry Andric using namespace cir;
28*700637cbSDimitry Andric
29*700637cbSDimitry Andric namespace {
30*700637cbSDimitry Andric
31*700637cbSDimitry Andric /// Removes branches between two blocks if it is the only branch.
32*700637cbSDimitry Andric ///
33*700637cbSDimitry Andric /// From:
34*700637cbSDimitry Andric /// ^bb0:
35*700637cbSDimitry Andric /// cir.br ^bb1
36*700637cbSDimitry Andric /// ^bb1: // pred: ^bb0
37*700637cbSDimitry Andric /// cir.return
38*700637cbSDimitry Andric ///
39*700637cbSDimitry Andric /// To:
40*700637cbSDimitry Andric /// ^bb0:
41*700637cbSDimitry Andric /// cir.return
42*700637cbSDimitry Andric struct RemoveRedundantBranches : public OpRewritePattern<BrOp> {
43*700637cbSDimitry Andric using OpRewritePattern<BrOp>::OpRewritePattern;
44*700637cbSDimitry Andric
matchAndRewrite__anon7e637cf80111::RemoveRedundantBranches45*700637cbSDimitry Andric LogicalResult matchAndRewrite(BrOp op,
46*700637cbSDimitry Andric PatternRewriter &rewriter) const final {
47*700637cbSDimitry Andric Block *block = op.getOperation()->getBlock();
48*700637cbSDimitry Andric Block *dest = op.getDest();
49*700637cbSDimitry Andric
50*700637cbSDimitry Andric assert(!cir::MissingFeatures::labelOp());
51*700637cbSDimitry Andric
52*700637cbSDimitry Andric // Single edge between blocks: merge it.
53*700637cbSDimitry Andric if (block->getNumSuccessors() == 1 &&
54*700637cbSDimitry Andric dest->getSinglePredecessor() == block) {
55*700637cbSDimitry Andric rewriter.eraseOp(op);
56*700637cbSDimitry Andric rewriter.mergeBlocks(dest, block);
57*700637cbSDimitry Andric return success();
58*700637cbSDimitry Andric }
59*700637cbSDimitry Andric
60*700637cbSDimitry Andric return failure();
61*700637cbSDimitry Andric }
62*700637cbSDimitry Andric };
63*700637cbSDimitry Andric
64*700637cbSDimitry Andric struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
65*700637cbSDimitry Andric using OpRewritePattern<ScopeOp>::OpRewritePattern;
66*700637cbSDimitry Andric
matchAndRewrite__anon7e637cf80111::RemoveEmptyScope67*700637cbSDimitry Andric LogicalResult matchAndRewrite(ScopeOp op,
68*700637cbSDimitry Andric PatternRewriter &rewriter) const final {
69*700637cbSDimitry Andric // TODO: Remove this logic once CIR uses MLIR infrastructure to remove
70*700637cbSDimitry Andric // trivially dead operations
71*700637cbSDimitry Andric if (op.isEmpty()) {
72*700637cbSDimitry Andric rewriter.eraseOp(op);
73*700637cbSDimitry Andric return success();
74*700637cbSDimitry Andric }
75*700637cbSDimitry Andric
76*700637cbSDimitry Andric Region ®ion = op.getScopeRegion();
77*700637cbSDimitry Andric if (region.getBlocks().front().getOperations().size() == 1 &&
78*700637cbSDimitry Andric isa<YieldOp>(region.getBlocks().front().front())) {
79*700637cbSDimitry Andric rewriter.eraseOp(op);
80*700637cbSDimitry Andric return success();
81*700637cbSDimitry Andric }
82*700637cbSDimitry Andric
83*700637cbSDimitry Andric return failure();
84*700637cbSDimitry Andric }
85*700637cbSDimitry Andric };
86*700637cbSDimitry Andric
87*700637cbSDimitry Andric struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
88*700637cbSDimitry Andric using OpRewritePattern<SwitchOp>::OpRewritePattern;
89*700637cbSDimitry Andric
matchAndRewrite__anon7e637cf80111::RemoveEmptySwitch90*700637cbSDimitry Andric LogicalResult matchAndRewrite(SwitchOp op,
91*700637cbSDimitry Andric PatternRewriter &rewriter) const final {
92*700637cbSDimitry Andric if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front())))
93*700637cbSDimitry Andric return failure();
94*700637cbSDimitry Andric
95*700637cbSDimitry Andric rewriter.eraseOp(op);
96*700637cbSDimitry Andric return success();
97*700637cbSDimitry Andric }
98*700637cbSDimitry Andric };
99*700637cbSDimitry Andric
100*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
101*700637cbSDimitry Andric // CIRCanonicalizePass
102*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
103*700637cbSDimitry Andric
104*700637cbSDimitry Andric struct CIRCanonicalizePass : public CIRCanonicalizeBase<CIRCanonicalizePass> {
105*700637cbSDimitry Andric using CIRCanonicalizeBase::CIRCanonicalizeBase;
106*700637cbSDimitry Andric
107*700637cbSDimitry Andric // The same operation rewriting done here could have been performed
108*700637cbSDimitry Andric // by CanonicalizerPass (adding hasCanonicalizer for target Ops and
109*700637cbSDimitry Andric // implementing the same from above in CIRDialects.cpp). However, it's
110*700637cbSDimitry Andric // currently too aggressive for static analysis purposes, since it might
111*700637cbSDimitry Andric // remove things where a diagnostic can be generated.
112*700637cbSDimitry Andric //
113*700637cbSDimitry Andric // FIXME: perhaps we can add one more mode to GreedyRewriteConfig to
114*700637cbSDimitry Andric // disable this behavior.
115*700637cbSDimitry Andric void runOnOperation() override;
116*700637cbSDimitry Andric };
117*700637cbSDimitry Andric
populateCIRCanonicalizePatterns(RewritePatternSet & patterns)118*700637cbSDimitry Andric void populateCIRCanonicalizePatterns(RewritePatternSet &patterns) {
119*700637cbSDimitry Andric // clang-format off
120*700637cbSDimitry Andric patterns.add<
121*700637cbSDimitry Andric RemoveRedundantBranches,
122*700637cbSDimitry Andric RemoveEmptyScope
123*700637cbSDimitry Andric >(patterns.getContext());
124*700637cbSDimitry Andric // clang-format on
125*700637cbSDimitry Andric }
126*700637cbSDimitry Andric
runOnOperation()127*700637cbSDimitry Andric void CIRCanonicalizePass::runOnOperation() {
128*700637cbSDimitry Andric // Collect rewrite patterns.
129*700637cbSDimitry Andric RewritePatternSet patterns(&getContext());
130*700637cbSDimitry Andric populateCIRCanonicalizePatterns(patterns);
131*700637cbSDimitry Andric
132*700637cbSDimitry Andric // Collect operations to apply patterns.
133*700637cbSDimitry Andric llvm::SmallVector<Operation *, 16> ops;
134*700637cbSDimitry Andric getOperation()->walk([&](Operation *op) {
135*700637cbSDimitry Andric assert(!cir::MissingFeatures::switchOp());
136*700637cbSDimitry Andric assert(!cir::MissingFeatures::tryOp());
137*700637cbSDimitry Andric assert(!cir::MissingFeatures::complexRealOp());
138*700637cbSDimitry Andric assert(!cir::MissingFeatures::complexImagOp());
139*700637cbSDimitry Andric assert(!cir::MissingFeatures::callOp());
140*700637cbSDimitry Andric
141*700637cbSDimitry Andric // Many operations are here to perform a manual `fold` in
142*700637cbSDimitry Andric // applyOpPatternsGreedily.
143*700637cbSDimitry Andric if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
144*700637cbSDimitry Andric ComplexCreateOp, ComplexImagOp, ComplexRealOp, VecCmpOp,
145*700637cbSDimitry Andric VecCreateOp, VecExtractOp, VecShuffleOp, VecShuffleDynamicOp,
146*700637cbSDimitry Andric VecTernaryOp>(op))
147*700637cbSDimitry Andric ops.push_back(op);
148*700637cbSDimitry Andric });
149*700637cbSDimitry Andric
150*700637cbSDimitry Andric // Apply patterns.
151*700637cbSDimitry Andric if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
152*700637cbSDimitry Andric signalPassFailure();
153*700637cbSDimitry Andric }
154*700637cbSDimitry Andric
155*700637cbSDimitry Andric } // namespace
156*700637cbSDimitry Andric
createCIRCanonicalizePass()157*700637cbSDimitry Andric std::unique_ptr<Pass> mlir::createCIRCanonicalizePass() {
158*700637cbSDimitry Andric return std::make_unique<CIRCanonicalizePass>();
159*700637cbSDimitry Andric }
160