xref: /freebsd/contrib/llvm-project/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp (revision e64bea71c21eb42e97aa615188ba91f6cce0d36d)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/Region.h"
15 #include "mlir/Support/LogicalResult.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 #include "clang/CIR/Dialect/IR/CIRDialect.h"
18 #include "clang/CIR/Dialect/Passes.h"
19 #include "llvm/ADT/SmallVector.h"
20 
21 using namespace mlir;
22 using namespace cir;
23 
24 //===----------------------------------------------------------------------===//
25 // Rewrite patterns
26 //===----------------------------------------------------------------------===//
27 
28 namespace {
29 
30 /// Simplify suitable ternary operations into select operations.
31 ///
32 /// For now we only simplify those ternary operations whose true and false
33 /// branches directly yield a value or a constant. That is, both of the true and
34 /// the false branch must either contain a cir.yield operation as the only
35 /// operation in the branch, or contain a cir.const operation followed by a
36 /// cir.yield operation that yields the constant value.
37 ///
38 /// For example, we will simplify the following ternary operation:
39 ///
40 ///   %0 = ...
41 ///   %1 = cir.ternary (%condition, true {
42 ///     %2 = cir.const ...
43 ///     cir.yield %2
44 ///   } false {
45 ///     cir.yield %0
46 ///
47 /// into the following sequence of operations:
48 ///
49 ///   %1 = cir.const ...
50 ///   %0 = cir.select if %condition then %1 else %2
51 struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
52   using OpRewritePattern<TernaryOp>::OpRewritePattern;
53 
54   LogicalResult matchAndRewrite(TernaryOp op,
55                                 PatternRewriter &rewriter) const override {
56     if (op->getNumResults() != 1)
57       return mlir::failure();
58 
59     if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
60         !isSimpleTernaryBranch(op.getFalseRegion()))
61       return mlir::failure();
62 
63     cir::YieldOp trueBranchYieldOp =
64         mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
65     cir::YieldOp falseBranchYieldOp =
66         mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
67     mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
68     mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
69 
70     rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
71     rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
72     rewriter.eraseOp(trueBranchYieldOp);
73     rewriter.eraseOp(falseBranchYieldOp);
74     rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
75                                                falseValue);
76 
77     return mlir::success();
78   }
79 
80 private:
81   bool isSimpleTernaryBranch(mlir::Region &region) const {
82     if (!region.hasOneBlock())
83       return false;
84 
85     mlir::Block &onlyBlock = region.front();
86     mlir::Block::OpListType &ops = onlyBlock.getOperations();
87 
88     // The region/block could only contain at most 2 operations.
89     if (ops.size() > 2)
90       return false;
91 
92     if (ops.size() == 1) {
93       // The region/block only contain a cir.yield operation.
94       return true;
95     }
96 
97     // Check whether the region/block contains a cir.const followed by a
98     // cir.yield that yields the value.
99     auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
100     auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
101         yieldOp.getArgs()[0].getDefiningOp());
102     return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
103   }
104 };
105 
106 /// Simplify select operations with boolean constants into simpler forms.
107 ///
108 /// This pattern simplifies select operations where both true and false values
109 /// are boolean constants. Two specific cases are handled:
110 ///
111 /// 1. When selecting between true and false based on a condition,
112 ///    the operation simplifies to just the condition itself:
113 ///
114 ///    %0 = cir.select if %condition then true else false
115 ///    ->
116 ///    (replaced with %condition directly)
117 ///
118 /// 2. When selecting between false and true based on a condition,
119 ///    the operation simplifies to the logical negation of the condition:
120 ///
121 ///    %0 = cir.select if %condition then false else true
122 ///    ->
123 ///    %0 = cir.unary not %condition
124 struct SimplifySelect : public OpRewritePattern<SelectOp> {
125   using OpRewritePattern<SelectOp>::OpRewritePattern;
126 
127   LogicalResult matchAndRewrite(SelectOp op,
128                                 PatternRewriter &rewriter) const final {
129     mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
130     mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
131     auto trueValueConstOp =
132         mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
133     auto falseValueConstOp =
134         mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
135     if (!trueValueConstOp || !falseValueConstOp)
136       return mlir::failure();
137 
138     auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
139     auto falseValue =
140         mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
141     if (!trueValue || !falseValue)
142       return mlir::failure();
143 
144     // cir.select if %0 then #true else #false -> %0
145     if (trueValue.getValue() && !falseValue.getValue()) {
146       rewriter.replaceAllUsesWith(op, op.getCondition());
147       rewriter.eraseOp(op);
148       return mlir::success();
149     }
150 
151     // cir.select if %0 then #false else #true -> cir.unary not %0
152     if (!trueValue.getValue() && falseValue.getValue()) {
153       rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
154                                                 op.getCondition());
155       return mlir::success();
156     }
157 
158     return mlir::failure();
159   }
160 };
161 
162 /// Simplify `cir.switch` operations by folding cascading cases
163 /// into a single `cir.case` with the `anyof` kind.
164 ///
165 /// This pattern identifies cascading cases within a `cir.switch` operation.
166 /// Cascading cases are defined as consecutive `cir.case` operations of kind
167 /// `equal`, each containing a single `cir.yield` operation in their body.
168 ///
169 /// The pattern merges these cascading cases into a single `cir.case` operation
170 /// with kind `anyof`, aggregating all the case values.
171 ///
172 /// The merging process continues until a `cir.case` with a different body
173 /// (e.g., containing `cir.break` or compound stmt) is encountered, which
174 /// breaks the chain.
175 ///
176 /// Example:
177 ///
178 /// Before:
179 ///   cir.case equal, [#cir.int<0> : !s32i] {
180 ///     cir.yield
181 ///   }
182 ///   cir.case equal, [#cir.int<1> : !s32i] {
183 ///     cir.yield
184 ///   }
185 ///   cir.case equal, [#cir.int<2> : !s32i] {
186 ///     cir.break
187 ///   }
188 ///
189 /// After applying SimplifySwitch:
190 ///   cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
191 ///   !s32i] {
192 ///     cir.break
193 ///   }
194 struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
195   using OpRewritePattern<SwitchOp>::OpRewritePattern;
196   LogicalResult matchAndRewrite(SwitchOp op,
197                                 PatternRewriter &rewriter) const override {
198 
199     LogicalResult changed = mlir::failure();
200     SmallVector<CaseOp, 8> cases;
201     SmallVector<CaseOp, 4> cascadingCases;
202     SmallVector<mlir::Attribute, 4> cascadingCaseValues;
203 
204     op.collectCases(cases);
205     if (cases.empty())
206       return mlir::failure();
207 
208     auto flushMergedOps = [&]() {
209       for (CaseOp &c : cascadingCases)
210         rewriter.eraseOp(c);
211       cascadingCases.clear();
212       cascadingCaseValues.clear();
213     };
214 
215     auto mergeCascadingInto = [&](CaseOp &target) {
216       rewriter.modifyOpInPlace(target, [&]() {
217         target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
218         target.setKind(CaseOpKind::Anyof);
219       });
220       changed = mlir::success();
221     };
222 
223     for (CaseOp c : cases) {
224       cir::CaseOpKind kind = c.getKind();
225       if (kind == cir::CaseOpKind::Equal &&
226           isa<YieldOp>(c.getCaseRegion().front().front())) {
227         // If the case contains only a YieldOp, collect it for cascading merge
228         cascadingCases.push_back(c);
229         cascadingCaseValues.push_back(c.getValue()[0]);
230       } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
231         // merge previously collected cascading cases
232         cascadingCaseValues.push_back(c.getValue()[0]);
233         mergeCascadingInto(c);
234         flushMergedOps();
235       } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
236         // If a Default, Anyof or Range case is found and there are previous
237         // cascading cases, merge all of them into the last cascading case.
238         // We don't currently fold case range statements with other case
239         // statements.
240         assert(!cir::MissingFeatures::foldRangeCase());
241         CaseOp lastCascadingCase = cascadingCases.back();
242         mergeCascadingInto(lastCascadingCase);
243         cascadingCases.pop_back();
244         flushMergedOps();
245       } else {
246         cascadingCases.clear();
247         cascadingCaseValues.clear();
248       }
249     }
250 
251     // Edge case: all cases are simple cascading cases
252     if (cascadingCases.size() == cases.size()) {
253       CaseOp lastCascadingCase = cascadingCases.back();
254       mergeCascadingInto(lastCascadingCase);
255       cascadingCases.pop_back();
256       flushMergedOps();
257     }
258 
259     return changed;
260   }
261 };
262 
263 struct SimplifyVecSplat : public OpRewritePattern<VecSplatOp> {
264   using OpRewritePattern<VecSplatOp>::OpRewritePattern;
265   LogicalResult matchAndRewrite(VecSplatOp op,
266                                 PatternRewriter &rewriter) const override {
267     mlir::Value splatValue = op.getValue();
268     auto constant =
269         mlir::dyn_cast_if_present<cir::ConstantOp>(splatValue.getDefiningOp());
270     if (!constant)
271       return mlir::failure();
272 
273     auto value = constant.getValue();
274     if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
275         !mlir::isa_and_nonnull<cir::FPAttr>(value))
276       return mlir::failure();
277 
278     cir::VectorType resultType = op.getResult().getType();
279     SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value);
280     auto constVecAttr = cir::ConstVectorAttr::get(
281         resultType, mlir::ArrayAttr::get(getContext(), elements));
282 
283     rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr);
284     return mlir::success();
285   }
286 };
287 
288 //===----------------------------------------------------------------------===//
289 // CIRSimplifyPass
290 //===----------------------------------------------------------------------===//
291 
292 struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
293   using CIRSimplifyBase::CIRSimplifyBase;
294 
295   void runOnOperation() override;
296 };
297 
298 void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
299   // clang-format off
300   patterns.add<
301     SimplifyTernary,
302     SimplifySelect,
303     SimplifySwitch,
304     SimplifyVecSplat
305   >(patterns.getContext());
306   // clang-format on
307 }
308 
309 void CIRSimplifyPass::runOnOperation() {
310   // Collect rewrite patterns.
311   RewritePatternSet patterns(&getContext());
312   populateMergeCleanupPatterns(patterns);
313 
314   // Collect operations to apply patterns.
315   llvm::SmallVector<Operation *, 16> ops;
316   getOperation()->walk([&](Operation *op) {
317     if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp>(op))
318       ops.push_back(op);
319   });
320 
321   // Apply patterns.
322   if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
323     signalPassFailure();
324 }
325 
326 } // namespace
327 
328 std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
329   return std::make_unique<CIRSimplifyPass>();
330 }
331