1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Emit OpenACC Stmt nodes as CIR code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CIRGenBuilder.h" 14 #include "CIRGenFunction.h" 15 #include "mlir/Dialect/OpenACC/OpenACC.h" 16 #include "clang/AST/OpenACCClause.h" 17 #include "clang/AST/StmtOpenACC.h" 18 19 using namespace clang; 20 using namespace clang::CIRGen; 21 using namespace cir; 22 using namespace mlir::acc; 23 24 template <typename Op, typename TermOp> 25 mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt( 26 mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind, 27 SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses, 28 const Stmt *associatedStmt) { 29 mlir::LogicalResult res = mlir::success(); 30 31 llvm::SmallVector<mlir::Type> retTy; 32 llvm::SmallVector<mlir::Value> operands; 33 auto op = builder.create<Op>(start, retTy, operands); 34 35 emitOpenACCClauses(op, dirKind, dirLoc, clauses); 36 37 { 38 mlir::Block &block = op.getRegion().emplaceBlock(); 39 mlir::OpBuilder::InsertionGuard guardCase(builder); 40 builder.setInsertionPointToEnd(&block); 41 42 LexicalScope ls{*this, start, builder.getInsertionBlock()}; 43 res = emitStmt(associatedStmt, /*useCurrentScope=*/true); 44 45 builder.create<TermOp>(end); 46 } 47 return res; 48 } 49 50 namespace { 51 template <typename Op> struct CombinedType; 52 template <> struct CombinedType<ParallelOp> { 53 static constexpr mlir::acc::CombinedConstructsType value = 54 mlir::acc::CombinedConstructsType::ParallelLoop; 55 }; 56 template <> struct CombinedType<SerialOp> { 57 static constexpr mlir::acc::CombinedConstructsType value = 58 mlir::acc::CombinedConstructsType::SerialLoop; 59 }; 60 template <> struct CombinedType<KernelsOp> { 61 static constexpr mlir::acc::CombinedConstructsType value = 62 mlir::acc::CombinedConstructsType::KernelsLoop; 63 }; 64 } // namespace 65 66 template <typename Op, typename TermOp> 67 mlir::LogicalResult CIRGenFunction::emitOpenACCOpCombinedConstruct( 68 mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind, 69 SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses, 70 const Stmt *loopStmt) { 71 mlir::LogicalResult res = mlir::success(); 72 73 llvm::SmallVector<mlir::Type> retTy; 74 llvm::SmallVector<mlir::Value> operands; 75 76 auto computeOp = builder.create<Op>(start, retTy, operands); 77 computeOp.setCombinedAttr(builder.getUnitAttr()); 78 mlir::acc::LoopOp loopOp; 79 80 // First, emit the bodies of both operations, with the loop inside the body of 81 // the combined construct. 82 { 83 mlir::Block &block = computeOp.getRegion().emplaceBlock(); 84 mlir::OpBuilder::InsertionGuard guardCase(builder); 85 builder.setInsertionPointToEnd(&block); 86 87 LexicalScope ls{*this, start, builder.getInsertionBlock()}; 88 auto loopOp = builder.create<LoopOp>(start, retTy, operands); 89 loopOp.setCombinedAttr(mlir::acc::CombinedConstructsTypeAttr::get( 90 builder.getContext(), CombinedType<Op>::value)); 91 92 { 93 mlir::Block &innerBlock = loopOp.getRegion().emplaceBlock(); 94 mlir::OpBuilder::InsertionGuard guardCase(builder); 95 builder.setInsertionPointToEnd(&innerBlock); 96 97 LexicalScope ls{*this, start, builder.getInsertionBlock()}; 98 ActiveOpenACCLoopRAII activeLoop{*this, &loopOp}; 99 100 res = emitStmt(loopStmt, /*useCurrentScope=*/true); 101 102 builder.create<mlir::acc::YieldOp>(end); 103 } 104 105 emitOpenACCClauses(computeOp, loopOp, dirKind, dirLoc, clauses); 106 107 updateLoopOpParallelism(loopOp, /*isOrphan=*/false, dirKind); 108 109 builder.create<TermOp>(end); 110 } 111 112 return res; 113 } 114 115 template <typename Op> 116 Op CIRGenFunction::emitOpenACCOp( 117 mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc, 118 llvm::ArrayRef<const OpenACCClause *> clauses) { 119 llvm::SmallVector<mlir::Type> retTy; 120 llvm::SmallVector<mlir::Value> operands; 121 auto op = builder.create<Op>(start, retTy, operands); 122 123 emitOpenACCClauses(op, dirKind, dirLoc, clauses); 124 return op; 125 } 126 127 mlir::LogicalResult 128 CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) { 129 mlir::Location start = getLoc(s.getSourceRange().getBegin()); 130 mlir::Location end = getLoc(s.getSourceRange().getEnd()); 131 132 switch (s.getDirectiveKind()) { 133 case OpenACCDirectiveKind::Parallel: 134 return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>( 135 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), 136 s.getStructuredBlock()); 137 case OpenACCDirectiveKind::Serial: 138 return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>( 139 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), 140 s.getStructuredBlock()); 141 case OpenACCDirectiveKind::Kernels: 142 return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>( 143 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), 144 s.getStructuredBlock()); 145 default: 146 llvm_unreachable("invalid compute construct kind"); 147 } 148 } 149 150 mlir::LogicalResult 151 CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) { 152 mlir::Location start = getLoc(s.getSourceRange().getBegin()); 153 mlir::Location end = getLoc(s.getSourceRange().getEnd()); 154 155 return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>( 156 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), 157 s.getStructuredBlock()); 158 } 159 160 mlir::LogicalResult 161 CIRGenFunction::emitOpenACCInitConstruct(const OpenACCInitConstruct &s) { 162 mlir::Location start = getLoc(s.getSourceRange().getBegin()); 163 emitOpenACCOp<InitOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(), 164 s.clauses()); 165 return mlir::success(); 166 } 167 168 mlir::LogicalResult 169 CIRGenFunction::emitOpenACCSetConstruct(const OpenACCSetConstruct &s) { 170 mlir::Location start = getLoc(s.getSourceRange().getBegin()); 171 emitOpenACCOp<SetOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(), 172 s.clauses()); 173 return mlir::success(); 174 } 175 176 mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct( 177 const OpenACCShutdownConstruct &s) { 178 mlir::Location start = getLoc(s.getSourceRange().getBegin()); 179 emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind(), 180 s.getDirectiveLoc(), s.clauses()); 181 return mlir::success(); 182 } 183 184 mlir::LogicalResult 185 CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) { 186 mlir::Location start = getLoc(s.getSourceRange().getBegin()); 187 auto waitOp = emitOpenACCOp<WaitOp>(start, s.getDirectiveKind(), 188 s.getDirectiveLoc(), s.clauses()); 189 190 auto createIntExpr = [this](const Expr *intExpr) { 191 mlir::Value expr = emitScalarExpr(intExpr); 192 mlir::Location exprLoc = cgm.getLoc(intExpr->getBeginLoc()); 193 194 mlir::IntegerType targetType = mlir::IntegerType::get( 195 &getMLIRContext(), getContext().getIntWidth(intExpr->getType()), 196 intExpr->getType()->isSignedIntegerOrEnumerationType() 197 ? mlir::IntegerType::SignednessSemantics::Signed 198 : mlir::IntegerType::SignednessSemantics::Unsigned); 199 200 auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>( 201 exprLoc, targetType, expr); 202 return conversionOp.getResult(0); 203 }; 204 205 // Emit the correct 'wait' clauses. 206 { 207 mlir::OpBuilder::InsertionGuard guardCase(builder); 208 builder.setInsertionPoint(waitOp); 209 210 if (s.hasDevNumExpr()) 211 waitOp.getWaitDevnumMutable().append(createIntExpr(s.getDevNumExpr())); 212 213 for (Expr *QueueExpr : s.getQueueIdExprs()) 214 waitOp.getWaitOperandsMutable().append(createIntExpr(QueueExpr)); 215 } 216 217 return mlir::success(); 218 } 219 220 mlir::LogicalResult CIRGenFunction::emitOpenACCCombinedConstruct( 221 const OpenACCCombinedConstruct &s) { 222 mlir::Location start = getLoc(s.getSourceRange().getBegin()); 223 mlir::Location end = getLoc(s.getSourceRange().getEnd()); 224 225 switch (s.getDirectiveKind()) { 226 case OpenACCDirectiveKind::ParallelLoop: 227 return emitOpenACCOpCombinedConstruct<ParallelOp, mlir::acc::YieldOp>( 228 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), 229 s.getLoop()); 230 case OpenACCDirectiveKind::SerialLoop: 231 return emitOpenACCOpCombinedConstruct<SerialOp, mlir::acc::YieldOp>( 232 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), 233 s.getLoop()); 234 case OpenACCDirectiveKind::KernelsLoop: 235 return emitOpenACCOpCombinedConstruct<KernelsOp, mlir::acc::TerminatorOp>( 236 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), 237 s.getLoop()); 238 default: 239 llvm_unreachable("invalid compute construct kind"); 240 } 241 } 242 243 mlir::LogicalResult CIRGenFunction::emitOpenACCHostDataConstruct( 244 const OpenACCHostDataConstruct &s) { 245 mlir::Location start = getLoc(s.getSourceRange().getBegin()); 246 mlir::Location end = getLoc(s.getSourceRange().getEnd()); 247 248 return emitOpenACCOpAssociatedStmt<HostDataOp, mlir::acc::TerminatorOp>( 249 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), 250 s.getStructuredBlock()); 251 } 252 253 mlir::LogicalResult CIRGenFunction::emitOpenACCEnterDataConstruct( 254 const OpenACCEnterDataConstruct &s) { 255 mlir::Location start = getLoc(s.getSourceRange().getBegin()); 256 emitOpenACCOp<EnterDataOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(), 257 s.clauses()); 258 return mlir::success(); 259 } 260 261 mlir::LogicalResult CIRGenFunction::emitOpenACCExitDataConstruct( 262 const OpenACCExitDataConstruct &s) { 263 mlir::Location start = getLoc(s.getSourceRange().getBegin()); 264 emitOpenACCOp<ExitDataOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(), 265 s.clauses()); 266 return mlir::success(); 267 } 268 269 mlir::LogicalResult 270 CIRGenFunction::emitOpenACCUpdateConstruct(const OpenACCUpdateConstruct &s) { 271 mlir::Location start = getLoc(s.getSourceRange().getBegin()); 272 emitOpenACCOp<UpdateOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(), 273 s.clauses()); 274 return mlir::success(); 275 } 276 277 mlir::LogicalResult 278 CIRGenFunction::emitOpenACCCacheConstruct(const OpenACCCacheConstruct &s) { 279 // The 'cache' directive 'may' be at the top of a loop by standard, but 280 // doesn't have to be. Additionally, there is nothing that requires this be a 281 // loop affected by an OpenACC pragma. Sema doesn't do any level of 282 // enforcement here, since it isn't particularly valuable to do so thanks to 283 // that. Instead, we treat cache as a 'noop' if there is no acc.loop to apply 284 // it to. 285 if (!activeLoopOp) 286 return mlir::success(); 287 288 mlir::acc::LoopOp loopOp = *activeLoopOp; 289 290 mlir::OpBuilder::InsertionGuard guard(builder); 291 builder.setInsertionPoint(loopOp); 292 293 for (const Expr *var : s.getVarList()) { 294 CIRGenFunction::OpenACCDataOperandInfo opInfo = 295 getOpenACCDataOperandInfo(var); 296 297 auto cacheOp = builder.create<CacheOp>( 298 opInfo.beginLoc, opInfo.varValue, 299 /*structured=*/false, /*implicit=*/false, opInfo.name, opInfo.bounds); 300 301 loopOp.getCacheOperandsMutable().append(cacheOp.getResult()); 302 } 303 304 return mlir::success(); 305 } 306 307 mlir::LogicalResult 308 CIRGenFunction::emitOpenACCAtomicConstruct(const OpenACCAtomicConstruct &s) { 309 cgm.errorNYI(s.getSourceRange(), "OpenACC Atomic Construct"); 310 return mlir::failure(); 311 } 312