1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/Region.h"
15 #include "mlir/Support/LogicalResult.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 #include "clang/CIR/Dialect/IR/CIRDialect.h"
18 #include "clang/CIR/Dialect/Passes.h"
19 #include "llvm/ADT/SmallVector.h"
20
21 using namespace mlir;
22 using namespace cir;
23
24 //===----------------------------------------------------------------------===//
25 // Rewrite patterns
26 //===----------------------------------------------------------------------===//
27
28 namespace {
29
30 /// Simplify suitable ternary operations into select operations.
31 ///
32 /// For now we only simplify those ternary operations whose true and false
33 /// branches directly yield a value or a constant. That is, both of the true and
34 /// the false branch must either contain a cir.yield operation as the only
35 /// operation in the branch, or contain a cir.const operation followed by a
36 /// cir.yield operation that yields the constant value.
37 ///
38 /// For example, we will simplify the following ternary operation:
39 ///
40 /// %0 = ...
41 /// %1 = cir.ternary (%condition, true {
42 /// %2 = cir.const ...
43 /// cir.yield %2
44 /// } false {
45 /// cir.yield %0
46 ///
47 /// into the following sequence of operations:
48 ///
49 /// %1 = cir.const ...
50 /// %0 = cir.select if %condition then %1 else %2
51 struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
52 using OpRewritePattern<TernaryOp>::OpRewritePattern;
53
matchAndRewrite__anond71145f50111::SimplifyTernary54 LogicalResult matchAndRewrite(TernaryOp op,
55 PatternRewriter &rewriter) const override {
56 if (op->getNumResults() != 1)
57 return mlir::failure();
58
59 if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
60 !isSimpleTernaryBranch(op.getFalseRegion()))
61 return mlir::failure();
62
63 cir::YieldOp trueBranchYieldOp =
64 mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
65 cir::YieldOp falseBranchYieldOp =
66 mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
67 mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
68 mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
69
70 rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
71 rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
72 rewriter.eraseOp(trueBranchYieldOp);
73 rewriter.eraseOp(falseBranchYieldOp);
74 rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
75 falseValue);
76
77 return mlir::success();
78 }
79
80 private:
isSimpleTernaryBranch__anond71145f50111::SimplifyTernary81 bool isSimpleTernaryBranch(mlir::Region ®ion) const {
82 if (!region.hasOneBlock())
83 return false;
84
85 mlir::Block &onlyBlock = region.front();
86 mlir::Block::OpListType &ops = onlyBlock.getOperations();
87
88 // The region/block could only contain at most 2 operations.
89 if (ops.size() > 2)
90 return false;
91
92 if (ops.size() == 1) {
93 // The region/block only contain a cir.yield operation.
94 return true;
95 }
96
97 // Check whether the region/block contains a cir.const followed by a
98 // cir.yield that yields the value.
99 auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
100 auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
101 yieldOp.getArgs()[0].getDefiningOp());
102 return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
103 }
104 };
105
106 /// Simplify select operations with boolean constants into simpler forms.
107 ///
108 /// This pattern simplifies select operations where both true and false values
109 /// are boolean constants. Two specific cases are handled:
110 ///
111 /// 1. When selecting between true and false based on a condition,
112 /// the operation simplifies to just the condition itself:
113 ///
114 /// %0 = cir.select if %condition then true else false
115 /// ->
116 /// (replaced with %condition directly)
117 ///
118 /// 2. When selecting between false and true based on a condition,
119 /// the operation simplifies to the logical negation of the condition:
120 ///
121 /// %0 = cir.select if %condition then false else true
122 /// ->
123 /// %0 = cir.unary not %condition
124 struct SimplifySelect : public OpRewritePattern<SelectOp> {
125 using OpRewritePattern<SelectOp>::OpRewritePattern;
126
matchAndRewrite__anond71145f50111::SimplifySelect127 LogicalResult matchAndRewrite(SelectOp op,
128 PatternRewriter &rewriter) const final {
129 mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
130 mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
131 auto trueValueConstOp =
132 mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
133 auto falseValueConstOp =
134 mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
135 if (!trueValueConstOp || !falseValueConstOp)
136 return mlir::failure();
137
138 auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
139 auto falseValue =
140 mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
141 if (!trueValue || !falseValue)
142 return mlir::failure();
143
144 // cir.select if %0 then #true else #false -> %0
145 if (trueValue.getValue() && !falseValue.getValue()) {
146 rewriter.replaceAllUsesWith(op, op.getCondition());
147 rewriter.eraseOp(op);
148 return mlir::success();
149 }
150
151 // cir.select if %0 then #false else #true -> cir.unary not %0
152 if (!trueValue.getValue() && falseValue.getValue()) {
153 rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
154 op.getCondition());
155 return mlir::success();
156 }
157
158 return mlir::failure();
159 }
160 };
161
162 /// Simplify `cir.switch` operations by folding cascading cases
163 /// into a single `cir.case` with the `anyof` kind.
164 ///
165 /// This pattern identifies cascading cases within a `cir.switch` operation.
166 /// Cascading cases are defined as consecutive `cir.case` operations of kind
167 /// `equal`, each containing a single `cir.yield` operation in their body.
168 ///
169 /// The pattern merges these cascading cases into a single `cir.case` operation
170 /// with kind `anyof`, aggregating all the case values.
171 ///
172 /// The merging process continues until a `cir.case` with a different body
173 /// (e.g., containing `cir.break` or compound stmt) is encountered, which
174 /// breaks the chain.
175 ///
176 /// Example:
177 ///
178 /// Before:
179 /// cir.case equal, [#cir.int<0> : !s32i] {
180 /// cir.yield
181 /// }
182 /// cir.case equal, [#cir.int<1> : !s32i] {
183 /// cir.yield
184 /// }
185 /// cir.case equal, [#cir.int<2> : !s32i] {
186 /// cir.break
187 /// }
188 ///
189 /// After applying SimplifySwitch:
190 /// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
191 /// !s32i] {
192 /// cir.break
193 /// }
194 struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
195 using OpRewritePattern<SwitchOp>::OpRewritePattern;
matchAndRewrite__anond71145f50111::SimplifySwitch196 LogicalResult matchAndRewrite(SwitchOp op,
197 PatternRewriter &rewriter) const override {
198
199 LogicalResult changed = mlir::failure();
200 SmallVector<CaseOp, 8> cases;
201 SmallVector<CaseOp, 4> cascadingCases;
202 SmallVector<mlir::Attribute, 4> cascadingCaseValues;
203
204 op.collectCases(cases);
205 if (cases.empty())
206 return mlir::failure();
207
208 auto flushMergedOps = [&]() {
209 for (CaseOp &c : cascadingCases)
210 rewriter.eraseOp(c);
211 cascadingCases.clear();
212 cascadingCaseValues.clear();
213 };
214
215 auto mergeCascadingInto = [&](CaseOp &target) {
216 rewriter.modifyOpInPlace(target, [&]() {
217 target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
218 target.setKind(CaseOpKind::Anyof);
219 });
220 changed = mlir::success();
221 };
222
223 for (CaseOp c : cases) {
224 cir::CaseOpKind kind = c.getKind();
225 if (kind == cir::CaseOpKind::Equal &&
226 isa<YieldOp>(c.getCaseRegion().front().front())) {
227 // If the case contains only a YieldOp, collect it for cascading merge
228 cascadingCases.push_back(c);
229 cascadingCaseValues.push_back(c.getValue()[0]);
230 } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
231 // merge previously collected cascading cases
232 cascadingCaseValues.push_back(c.getValue()[0]);
233 mergeCascadingInto(c);
234 flushMergedOps();
235 } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
236 // If a Default, Anyof or Range case is found and there are previous
237 // cascading cases, merge all of them into the last cascading case.
238 // We don't currently fold case range statements with other case
239 // statements.
240 assert(!cir::MissingFeatures::foldRangeCase());
241 CaseOp lastCascadingCase = cascadingCases.back();
242 mergeCascadingInto(lastCascadingCase);
243 cascadingCases.pop_back();
244 flushMergedOps();
245 } else {
246 cascadingCases.clear();
247 cascadingCaseValues.clear();
248 }
249 }
250
251 // Edge case: all cases are simple cascading cases
252 if (cascadingCases.size() == cases.size()) {
253 CaseOp lastCascadingCase = cascadingCases.back();
254 mergeCascadingInto(lastCascadingCase);
255 cascadingCases.pop_back();
256 flushMergedOps();
257 }
258
259 return changed;
260 }
261 };
262
263 struct SimplifyVecSplat : public OpRewritePattern<VecSplatOp> {
264 using OpRewritePattern<VecSplatOp>::OpRewritePattern;
matchAndRewrite__anond71145f50111::SimplifyVecSplat265 LogicalResult matchAndRewrite(VecSplatOp op,
266 PatternRewriter &rewriter) const override {
267 mlir::Value splatValue = op.getValue();
268 auto constant =
269 mlir::dyn_cast_if_present<cir::ConstantOp>(splatValue.getDefiningOp());
270 if (!constant)
271 return mlir::failure();
272
273 auto value = constant.getValue();
274 if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
275 !mlir::isa_and_nonnull<cir::FPAttr>(value))
276 return mlir::failure();
277
278 cir::VectorType resultType = op.getResult().getType();
279 SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value);
280 auto constVecAttr = cir::ConstVectorAttr::get(
281 resultType, mlir::ArrayAttr::get(getContext(), elements));
282
283 rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr);
284 return mlir::success();
285 }
286 };
287
288 //===----------------------------------------------------------------------===//
289 // CIRSimplifyPass
290 //===----------------------------------------------------------------------===//
291
292 struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
293 using CIRSimplifyBase::CIRSimplifyBase;
294
295 void runOnOperation() override;
296 };
297
populateMergeCleanupPatterns(RewritePatternSet & patterns)298 void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
299 // clang-format off
300 patterns.add<
301 SimplifyTernary,
302 SimplifySelect,
303 SimplifySwitch,
304 SimplifyVecSplat
305 >(patterns.getContext());
306 // clang-format on
307 }
308
runOnOperation()309 void CIRSimplifyPass::runOnOperation() {
310 // Collect rewrite patterns.
311 RewritePatternSet patterns(&getContext());
312 populateMergeCleanupPatterns(patterns);
313
314 // Collect operations to apply patterns.
315 llvm::SmallVector<Operation *, 16> ops;
316 getOperation()->walk([&](Operation *op) {
317 if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp>(op))
318 ops.push_back(op);
319 });
320
321 // Apply patterns.
322 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
323 signalPassFailure();
324 }
325
326 } // namespace
327
createCIRSimplifyPass()328 std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
329 return std::make_unique<CIRSimplifyPass>();
330 }
331