xref: /freebsd/contrib/llvm-project/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements pass that inlines CIR operations regions into the parent
10 // function region.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/IR/Block.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Support/LogicalResult.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "clang/CIR/Dialect/IR/CIRDialect.h"
23 #include "clang/CIR/Dialect/Passes.h"
24 #include "clang/CIR/MissingFeatures.h"
25 
26 using namespace mlir;
27 using namespace cir;
28 
29 namespace {
30 
31 /// Lowers operations with the terminator trait that have a single successor.
lowerTerminator(mlir::Operation * op,mlir::Block * dest,mlir::PatternRewriter & rewriter)32 void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
33                      mlir::PatternRewriter &rewriter) {
34   assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator");
35   mlir::OpBuilder::InsertionGuard guard(rewriter);
36   rewriter.setInsertionPoint(op);
37   rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest);
38 }
39 
40 /// Walks a region while skipping operations of type `Ops`. This ensures the
41 /// callback is not applied to said operations and its children.
42 template <typename... Ops>
walkRegionSkipping(mlir::Region & region,mlir::function_ref<mlir::WalkResult (mlir::Operation *)> callback)43 void walkRegionSkipping(
44     mlir::Region &region,
45     mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {
46   region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
47     if (isa<Ops...>(op))
48       return mlir::WalkResult::skip();
49     return callback(op);
50   });
51 }
52 
53 struct CIRFlattenCFGPass : public CIRFlattenCFGBase<CIRFlattenCFGPass> {
54 
55   CIRFlattenCFGPass() = default;
56   void runOnOperation() override;
57 };
58 
59 struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> {
60   using OpRewritePattern<IfOp>::OpRewritePattern;
61 
62   mlir::LogicalResult
matchAndRewrite__anonc87c85480111::CIRIfFlattening63   matchAndRewrite(cir::IfOp ifOp,
64                   mlir::PatternRewriter &rewriter) const override {
65     mlir::OpBuilder::InsertionGuard guard(rewriter);
66     mlir::Location loc = ifOp.getLoc();
67     bool emptyElse = ifOp.getElseRegion().empty();
68     mlir::Block *currentBlock = rewriter.getInsertionBlock();
69     mlir::Block *remainingOpsBlock =
70         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
71     mlir::Block *continueBlock;
72     if (ifOp->getResults().empty())
73       continueBlock = remainingOpsBlock;
74     else
75       llvm_unreachable("NYI");
76 
77     // Inline the region
78     mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front();
79     mlir::Block *thenAfterBody = &ifOp.getThenRegion().back();
80     rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
81 
82     rewriter.setInsertionPointToEnd(thenAfterBody);
83     if (auto thenYieldOp =
84             dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) {
85       rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(),
86                                              continueBlock);
87     }
88 
89     rewriter.setInsertionPointToEnd(continueBlock);
90 
91     // Has else region: inline it.
92     mlir::Block *elseBeforeBody = nullptr;
93     mlir::Block *elseAfterBody = nullptr;
94     if (!emptyElse) {
95       elseBeforeBody = &ifOp.getElseRegion().front();
96       elseAfterBody = &ifOp.getElseRegion().back();
97       rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock);
98     } else {
99       elseBeforeBody = elseAfterBody = continueBlock;
100     }
101 
102     rewriter.setInsertionPointToEnd(currentBlock);
103     rewriter.create<cir::BrCondOp>(loc, ifOp.getCondition(), thenBeforeBody,
104                                    elseBeforeBody);
105 
106     if (!emptyElse) {
107       rewriter.setInsertionPointToEnd(elseAfterBody);
108       if (auto elseYieldOP =
109               dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) {
110         rewriter.replaceOpWithNewOp<cir::BrOp>(
111             elseYieldOP, elseYieldOP.getArgs(), continueBlock);
112       }
113     }
114 
115     rewriter.replaceOp(ifOp, continueBlock->getArguments());
116     return mlir::success();
117   }
118 };
119 
120 class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
121 public:
122   using OpRewritePattern<cir::ScopeOp>::OpRewritePattern;
123 
124   mlir::LogicalResult
matchAndRewrite(cir::ScopeOp scopeOp,mlir::PatternRewriter & rewriter) const125   matchAndRewrite(cir::ScopeOp scopeOp,
126                   mlir::PatternRewriter &rewriter) const override {
127     mlir::OpBuilder::InsertionGuard guard(rewriter);
128     mlir::Location loc = scopeOp.getLoc();
129 
130     // Empty scope: just remove it.
131     // TODO: Remove this logic once CIR uses MLIR infrastructure to remove
132     // trivially dead operations. MLIR canonicalizer is too aggressive and we
133     // need to either (a) make sure all our ops model all side-effects and/or
134     // (b) have more options in the canonicalizer in MLIR to temper
135     // aggressiveness level.
136     if (scopeOp.isEmpty()) {
137       rewriter.eraseOp(scopeOp);
138       return mlir::success();
139     }
140 
141     // Split the current block before the ScopeOp to create the inlining
142     // point.
143     mlir::Block *currentBlock = rewriter.getInsertionBlock();
144     mlir::Block *continueBlock =
145         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
146     if (scopeOp.getNumResults() > 0)
147       continueBlock->addArguments(scopeOp.getResultTypes(), loc);
148 
149     // Inline body region.
150     mlir::Block *beforeBody = &scopeOp.getScopeRegion().front();
151     mlir::Block *afterBody = &scopeOp.getScopeRegion().back();
152     rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock);
153 
154     // Save stack and then branch into the body of the region.
155     rewriter.setInsertionPointToEnd(currentBlock);
156     assert(!cir::MissingFeatures::stackSaveOp());
157     rewriter.create<cir::BrOp>(loc, mlir::ValueRange(), beforeBody);
158 
159     // Replace the scopeop return with a branch that jumps out of the body.
160     // Stack restore before leaving the body region.
161     rewriter.setInsertionPointToEnd(afterBody);
162     if (auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) {
163       rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
164                                              continueBlock);
165     }
166 
167     // Replace the op with values return from the body region.
168     rewriter.replaceOp(scopeOp, continueBlock->getArguments());
169 
170     return mlir::success();
171   }
172 };
173 
174 class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
175 public:
176   using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
177 
rewriteYieldOp(mlir::PatternRewriter & rewriter,cir::YieldOp yieldOp,mlir::Block * destination) const178   inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
179                              cir::YieldOp yieldOp,
180                              mlir::Block *destination) const {
181     rewriter.setInsertionPoint(yieldOp);
182     rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
183                                            destination);
184   }
185 
186   // Return the new defaultDestination block.
condBrToRangeDestination(cir::SwitchOp op,mlir::PatternRewriter & rewriter,mlir::Block * rangeDestination,mlir::Block * defaultDestination,const APInt & lowerBound,const APInt & upperBound) const187   Block *condBrToRangeDestination(cir::SwitchOp op,
188                                   mlir::PatternRewriter &rewriter,
189                                   mlir::Block *rangeDestination,
190                                   mlir::Block *defaultDestination,
191                                   const APInt &lowerBound,
192                                   const APInt &upperBound) const {
193     assert(lowerBound.sle(upperBound) && "Invalid range");
194     mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
195     cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
196     cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
197 
198     cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>(
199         op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound));
200 
201     cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>(
202         op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
203     cir::BinOp diffValue =
204         rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub,
205                                     op.getCondition(), lowerBoundValue);
206 
207     // Use unsigned comparison to check if the condition is in the range.
208     cir::CastOp uDiffValue = rewriter.create<cir::CastOp>(
209         op.getLoc(), uIntType, CastKind::integral, diffValue);
210     cir::CastOp uRangeLength = rewriter.create<cir::CastOp>(
211         op.getLoc(), uIntType, CastKind::integral, rangeLength);
212 
213     cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>(
214         op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le,
215         uDiffValue, uRangeLength);
216     rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination,
217                                    defaultDestination);
218     return resBlock;
219   }
220 
221   mlir::LogicalResult
matchAndRewrite(cir::SwitchOp op,mlir::PatternRewriter & rewriter) const222   matchAndRewrite(cir::SwitchOp op,
223                   mlir::PatternRewriter &rewriter) const override {
224     llvm::SmallVector<CaseOp> cases;
225     op.collectCases(cases);
226 
227     // Empty switch statement: just erase it.
228     if (cases.empty()) {
229       rewriter.eraseOp(op);
230       return mlir::success();
231     }
232 
233     // Create exit block from the next node of cir.switch op.
234     mlir::Block *exitBlock = rewriter.splitBlock(
235         rewriter.getBlock(), op->getNextNode()->getIterator());
236 
237     // We lower cir.switch op in the following process:
238     // 1. Inline the region from the switch op after switch op.
239     // 2. Traverse each cir.case op:
240     //    a. Record the entry block, block arguments and condition for every
241     //    case. b. Inline the case region after the case op.
242     // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
243     //    recorded block and conditions.
244 
245     // inline everything from switch body between the switch op and the exit
246     // block.
247     {
248       cir::YieldOp switchYield = nullptr;
249       // Clear switch operation.
250       for (mlir::Block &block :
251            llvm::make_early_inc_range(op.getBody().getBlocks()))
252         if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
253           switchYield = yieldOp;
254 
255       assert(!op.getBody().empty());
256       mlir::Block *originalBlock = op->getBlock();
257       mlir::Block *swopBlock =
258           rewriter.splitBlock(originalBlock, op->getIterator());
259       rewriter.inlineRegionBefore(op.getBody(), exitBlock);
260 
261       if (switchYield)
262         rewriteYieldOp(rewriter, switchYield, exitBlock);
263 
264       rewriter.setInsertionPointToEnd(originalBlock);
265       rewriter.create<cir::BrOp>(op.getLoc(), swopBlock);
266     }
267 
268     // Allocate required data structures (disconsider default case in
269     // vectors).
270     llvm::SmallVector<mlir::APInt, 8> caseValues;
271     llvm::SmallVector<mlir::Block *, 8> caseDestinations;
272     llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
273 
274     llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
275     llvm::SmallVector<mlir::Block *> rangeDestinations;
276     llvm::SmallVector<mlir::ValueRange> rangeOperands;
277 
278     // Initialize default case as optional.
279     mlir::Block *defaultDestination = exitBlock;
280     mlir::ValueRange defaultOperands = exitBlock->getArguments();
281 
282     // Digest the case statements values and bodies.
283     for (cir::CaseOp caseOp : cases) {
284       mlir::Region &region = caseOp.getCaseRegion();
285 
286       // Found default case: save destination and operands.
287       switch (caseOp.getKind()) {
288       case cir::CaseOpKind::Default:
289         defaultDestination = &region.front();
290         defaultOperands = defaultDestination->getArguments();
291         break;
292       case cir::CaseOpKind::Range:
293         assert(caseOp.getValue().size() == 2 &&
294                "Case range should have 2 case value");
295         rangeValues.push_back(
296             {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
297              cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
298         rangeDestinations.push_back(&region.front());
299         rangeOperands.push_back(rangeDestinations.back()->getArguments());
300         break;
301       case cir::CaseOpKind::Anyof:
302       case cir::CaseOpKind::Equal:
303         // AnyOf cases kind can have multiple values, hence the loop below.
304         for (const mlir::Attribute &value : caseOp.getValue()) {
305           caseValues.push_back(cast<cir::IntAttr>(value).getValue());
306           caseDestinations.push_back(&region.front());
307           caseOperands.push_back(caseDestinations.back()->getArguments());
308         }
309         break;
310       }
311 
312       // Handle break statements.
313       walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
314           region, [&](mlir::Operation *op) {
315             if (!isa<cir::BreakOp>(op))
316               return mlir::WalkResult::advance();
317 
318             lowerTerminator(op, exitBlock, rewriter);
319             return mlir::WalkResult::skip();
320           });
321 
322       // Track fallthrough in cases.
323       for (mlir::Block &blk : region.getBlocks()) {
324         if (blk.getNumSuccessors())
325           continue;
326 
327         if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
328           mlir::Operation *nextOp = caseOp->getNextNode();
329           assert(nextOp && "caseOp is not expected to be the last op");
330           mlir::Block *oldBlock = nextOp->getBlock();
331           mlir::Block *newBlock =
332               rewriter.splitBlock(oldBlock, nextOp->getIterator());
333           rewriter.setInsertionPointToEnd(oldBlock);
334           rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(),
335                                      newBlock);
336           rewriteYieldOp(rewriter, yieldOp, newBlock);
337         }
338       }
339 
340       mlir::Block *oldBlock = caseOp->getBlock();
341       mlir::Block *newBlock =
342           rewriter.splitBlock(oldBlock, caseOp->getIterator());
343 
344       mlir::Block &entryBlock = caseOp.getCaseRegion().front();
345       rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
346 
347       // Create a branch to the entry of the inlined region.
348       rewriter.setInsertionPointToEnd(oldBlock);
349       rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock);
350     }
351 
352     // Remove all cases since we've inlined the regions.
353     for (cir::CaseOp caseOp : cases) {
354       mlir::Block *caseBlock = caseOp->getBlock();
355       // Erase the block with no predecessors here to make the generated code
356       // simpler a little bit.
357       if (caseBlock->hasNoPredecessors())
358         rewriter.eraseBlock(caseBlock);
359       else
360         rewriter.eraseOp(caseOp);
361     }
362 
363     for (auto [rangeVal, operand, destination] :
364          llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {
365       APInt lowerBound = rangeVal.first;
366       APInt upperBound = rangeVal.second;
367 
368       // The case range is unreachable, skip it.
369       if (lowerBound.sgt(upperBound))
370         continue;
371 
372       // If range is small, add multiple switch instruction cases.
373       // This magical number is from the original CGStmt code.
374       constexpr int kSmallRangeThreshold = 64;
375       if ((upperBound - lowerBound)
376               .ult(llvm::APInt(32, kSmallRangeThreshold))) {
377         for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {
378           caseValues.push_back(iValue);
379           caseOperands.push_back(operand);
380           caseDestinations.push_back(destination);
381         }
382         continue;
383       }
384 
385       defaultDestination =
386           condBrToRangeDestination(op, rewriter, destination,
387                                    defaultDestination, lowerBound, upperBound);
388       defaultOperands = operand;
389     }
390 
391     // Set switch op to branch to the newly created blocks.
392     rewriter.setInsertionPoint(op);
393     rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
394         op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
395         caseDestinations, caseOperands);
396 
397     return mlir::success();
398   }
399 };
400 
401 class CIRLoopOpInterfaceFlattening
402     : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
403 public:
404   using mlir::OpInterfaceRewritePattern<
405       cir::LoopOpInterface>::OpInterfaceRewritePattern;
406 
lowerConditionOp(cir::ConditionOp op,mlir::Block * body,mlir::Block * exit,mlir::PatternRewriter & rewriter) const407   inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body,
408                                mlir::Block *exit,
409                                mlir::PatternRewriter &rewriter) const {
410     mlir::OpBuilder::InsertionGuard guard(rewriter);
411     rewriter.setInsertionPoint(op);
412     rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body,
413                                                exit);
414   }
415 
416   mlir::LogicalResult
matchAndRewrite(cir::LoopOpInterface op,mlir::PatternRewriter & rewriter) const417   matchAndRewrite(cir::LoopOpInterface op,
418                   mlir::PatternRewriter &rewriter) const final {
419     // Setup CFG blocks.
420     mlir::Block *entry = rewriter.getInsertionBlock();
421     mlir::Block *exit =
422         rewriter.splitBlock(entry, rewriter.getInsertionPoint());
423     mlir::Block *cond = &op.getCond().front();
424     mlir::Block *body = &op.getBody().front();
425     mlir::Block *step =
426         (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr);
427 
428     // Setup loop entry branch.
429     rewriter.setInsertionPointToEnd(entry);
430     rewriter.create<cir::BrOp>(op.getLoc(), &op.getEntry().front());
431 
432     // Branch from condition region to body or exit.
433     auto conditionOp = cast<cir::ConditionOp>(cond->getTerminator());
434     lowerConditionOp(conditionOp, body, exit, rewriter);
435 
436     // TODO(cir): Remove the walks below. It visits operations unnecessarily.
437     // However, to solve this we would likely need a custom DialectConversion
438     // driver to customize the order that operations are visited.
439 
440     // Lower continue statements.
441     mlir::Block *dest = (step ? step : cond);
442     op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
443       if (!isa<cir::ContinueOp>(op))
444         return mlir::WalkResult::advance();
445 
446       lowerTerminator(op, dest, rewriter);
447       return mlir::WalkResult::skip();
448     });
449 
450     // Lower break statements.
451     assert(!cir::MissingFeatures::switchOp());
452     walkRegionSkipping<cir::LoopOpInterface>(
453         op.getBody(), [&](mlir::Operation *op) {
454           if (!isa<cir::BreakOp>(op))
455             return mlir::WalkResult::advance();
456 
457           lowerTerminator(op, exit, rewriter);
458           return mlir::WalkResult::skip();
459         });
460 
461     // Lower optional body region yield.
462     for (mlir::Block &blk : op.getBody().getBlocks()) {
463       auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator());
464       if (bodyYield)
465         lowerTerminator(bodyYield, (step ? step : cond), rewriter);
466     }
467 
468     // Lower mandatory step region yield.
469     if (step)
470       lowerTerminator(cast<cir::YieldOp>(step->getTerminator()), cond,
471                       rewriter);
472 
473     // Move region contents out of the loop op.
474     rewriter.inlineRegionBefore(op.getCond(), exit);
475     rewriter.inlineRegionBefore(op.getBody(), exit);
476     if (step)
477       rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);
478 
479     rewriter.eraseOp(op);
480     return mlir::success();
481   }
482 };
483 
484 class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
485 public:
486   using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
487 
488   mlir::LogicalResult
matchAndRewrite(cir::TernaryOp op,mlir::PatternRewriter & rewriter) const489   matchAndRewrite(cir::TernaryOp op,
490                   mlir::PatternRewriter &rewriter) const override {
491     Location loc = op->getLoc();
492     Block *condBlock = rewriter.getInsertionBlock();
493     Block::iterator opPosition = rewriter.getInsertionPoint();
494     Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
495     llvm::SmallVector<mlir::Location, 2> locs;
496     // Ternary result is optional, make sure to populate the location only
497     // when relevant.
498     if (op->getResultTypes().size())
499       locs.push_back(loc);
500     Block *continueBlock =
501         rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
502     rewriter.create<cir::BrOp>(loc, remainingOpsBlock);
503 
504     Region &trueRegion = op.getTrueRegion();
505     Block *trueBlock = &trueRegion.front();
506     mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
507     rewriter.setInsertionPointToEnd(&trueRegion.back());
508     auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
509 
510     rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
511                                            continueBlock);
512     rewriter.inlineRegionBefore(trueRegion, continueBlock);
513 
514     Block *falseBlock = continueBlock;
515     Region &falseRegion = op.getFalseRegion();
516 
517     falseBlock = &falseRegion.front();
518     mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
519     rewriter.setInsertionPointToEnd(&falseRegion.back());
520     auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
521     rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(),
522                                            continueBlock);
523     rewriter.inlineRegionBefore(falseRegion, continueBlock);
524 
525     rewriter.setInsertionPointToEnd(condBlock);
526     rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock);
527 
528     rewriter.replaceOp(op, continueBlock->getArguments());
529 
530     // Ok, we're done!
531     return mlir::success();
532   }
533 };
534 
populateFlattenCFGPatterns(RewritePatternSet & patterns)535 void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
536   patterns
537       .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
538            CIRSwitchOpFlattening, CIRTernaryOpFlattening>(
539           patterns.getContext());
540 }
541 
runOnOperation()542 void CIRFlattenCFGPass::runOnOperation() {
543   RewritePatternSet patterns(&getContext());
544   populateFlattenCFGPatterns(patterns);
545 
546   // Collect operations to apply patterns.
547   llvm::SmallVector<Operation *, 16> ops;
548   getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
549     assert(!cir::MissingFeatures::ifOp());
550     assert(!cir::MissingFeatures::switchOp());
551     assert(!cir::MissingFeatures::tryOp());
552     if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp>(op))
553       ops.push_back(op);
554   });
555 
556   // Apply patterns.
557   if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
558     signalPassFailure();
559 }
560 
561 } // namespace
562 
563 namespace mlir {
564 
createCIRFlattenCFGPass()565 std::unique_ptr<Pass> createCIRFlattenCFGPass() {
566   return std::make_unique<CIRFlattenCFGPass>();
567 }
568 
569 } // namespace mlir
570