MLIR: lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
15
27
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
30 #include "mlir/Conversion/Passes.h.inc"
31 }
32
33 using namespace mlir;
34
35
36
37
38
39
40
41 template <typename... OpTy>
43 if (block.empty() || llvm::hasSingleElement(block) ||
44 std::next(block.begin(), 2) != block.end())
45 return false;
46
48 return false;
49
52 0, combinerOps);
53
54 if (!reducedVal || !isa(reducedVal) || combinerOps.size() != 1)
55 return false;
56
57 return isa<OpTy...>(combinerOps[0]) &&
58 isascf::ReduceReturnOp(block.back()) &&
60 }
61
62
63
64
65
66
67
68
69
70
71
72
73 template <
74 typename CompareOpTy, typename SelectOpTy,
75 typename Predicate = decltype(std::declval().getPredicate())>
76 static bool
79 static_assert(
80 llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
81 "only arithmetic and llvm select ops are supported");
82
83
84 if (block.empty() || llvm::hasSingleElement(block) ||
85 std::next(block.begin(), 2) == block.end() ||
86 std::next(block.begin(), 3) != block.end())
87 return false;
88
89
90 auto compare = dyn_cast(block.front());
91 auto select = dyn_cast(block.front().getNextNode());
92 auto terminator = dyn_castscf::ReduceReturnOp(block.back());
93 if ( || !select || !terminator)
94 return false;
95
96
98 return false;
99
100
101 bool isLess;
102 if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
103 isLess = true;
104 } else if (llvm::is_contained(greaterThanPredicates,
105 compare.getPredicate())) {
106 isLess = false;
107 } else {
108 return false;
109 }
110
111 if (select.getCondition() != compare.getResult())
112 return false;
113
114
115
116
117
118
119 constexpr unsigned kTrueValue = 1;
120 constexpr unsigned kFalseValue = 2;
121 bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
122 select.getOperand(kFalseValue) == compare.getRhs();
123 bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
124 select.getOperand(kFalseValue) == compare.getLhs();
125 if (!sameOperands && !swappedOperands)
126 return false;
127
128 if (select.getResult() != terminator.getResult())
129 return false;
130
131
132
133 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
134 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
135 }
136
137
139 if (type.isF16())
140 return llvm::APFloat::IEEEhalf();
141 if (type.isF32())
142 return llvm::APFloat::IEEEsingle();
143 if (type.isF64())
144 return llvm::APFloat::IEEEdouble();
145 if (type.isF128())
146 return llvm::APFloat::IEEEquad();
147 if (type.isBF16())
148 return llvm::APFloat::BFloat();
149 if (type.isF80())
150 return llvm::APFloat::x87DoubleExtended();
151 llvm_unreachable("unknown float type");
152 }
153
154
155
157 auto fltType = cast(type);
160 }
161
162
163
164
166 auto intType = cast(type);
167 unsigned bitwidth = intType.getWidth();
168 return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
169 : llvm::APInt::getSignedMaxValue(bitwidth));
170 }
171
172
173
174
176 auto intType = cast(type);
177 unsigned bitwidth = intType.getWidth();
179 : llvm::APInt::getAllOnes(bitwidth));
180 }
181
182
183
184
185
186 static omp::DeclareReductionOp
188 scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {
192 "__scf_reduction", type);
193 symbolTable.insert(decl);
194
195 builder.createBlock(&decl.getInitializerRegion(),
196 decl.getInitializerRegion().end(), {type},
197 {reduce.getOperands()[reductionIndex].getLoc()});
202
204 &reduce.getReductions()[reductionIndex].front().back();
205 assert(isascf::ReduceReturnOp(terminator) &&
206 "expected reduce op to be terminated by redure return");
209 terminator->getOperands());
211 decl.getReductionRegion(),
212 decl.getReductionRegion().end());
213 return decl;
214 }
215
216
217
219 LLVM::AtomicBinOp atomicKind,
220 omp::DeclareReductionOp decl,
221 scf::ReduceOp reduce,
222 int64_t reductionIndex) {
226 builder.createBlock(&decl.getAtomicReductionRegion(),
227 decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
228 {reduceOperandLoc, reduceOperandLoc});
229 Block *atomicBlock = &decl.getAtomicReductionRegion().back();
235 LLVM::AtomicOrdering::monotonic);
237 return decl;
238 }
239
240
241
242
243
245 scf::ReduceOp reduce,
246 int64_t reductionIndex) {
249
250
251
253 while (insertionPoint->getParentOp() != container)
254 insertionPoint = insertionPoint->getParentOp();
257
258 assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) &&
259 "expected reduction region to have a single element");
260
261
263 Block &reduction = reduce.getReductions()[reductionIndex].front();
264 if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
265 omp::DeclareReductionOp decl =
269 reductionIndex);
270 }
271 if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
272 omp::DeclareReductionOp decl =
276 reductionIndex);
277 }
278 if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
279 omp::DeclareReductionOp decl =
283 reductionIndex);
284 }
285 if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
286 omp::DeclareReductionOp decl =
290 reductionIndex);
291 }
292 if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
293 omp::DeclareReductionOp decl = createDecl(
294 builder, symbolTable, reduce, reductionIndex,
298 reductionIndex);
299 }
300
301
302
303
304 if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
305 return createDecl(builder, symbolTable, reduce, reductionIndex,
307 }
308 if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
309 return createDecl(builder, symbolTable, reduce, reductionIndex,
311 }
312
313
314 bool isMin;
315 if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
316 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
317 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
318 matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
319 reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
320 {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
321 return createDecl(builder, symbolTable, reduce, reductionIndex,
323 }
324 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
325 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
326 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
327 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
328 reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
329 {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
330 omp::DeclareReductionOp decl =
335 decl, reduce, reductionIndex);
336 }
337 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
338 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
339 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
340 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
341 reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
342 {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
343 omp::DeclareReductionOp decl =
347 builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
348 decl, reduce, reductionIndex);
349 }
350
351 return nullptr;
352 }
353
354 namespace {
355
356 struct ParallelOpLowering : public OpRewritePatternscf::ParallelOp {
357 static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
358 unsigned numThreads;
359
361 unsigned numThreads = kUseOpenMPDefaultNumThreads)
362 : OpRewritePatternscf::ParallelOp(context), numThreads(numThreads) {}
363
364 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
366
367
368
371 auto reduce = castscf::ReduceOp(parallelOp.getBody()->getTerminator());
372 for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
374 ompReductionDecls.push_back(decl);
375 if (!decl)
376 return failure();
377 reductionSyms.push_back(
379 }
380
381
382
383 Location loc = parallelOp.getLoc();
384 Value one = rewriter.createLLVM::ConstantOp(
387 reductionVariables.reserve(parallelOp.getNumReductions());
389 for (Value init : parallelOp.getInitVals()) {
391 isaLLVM::PointerElementTypeInterface(init.getType())) &&
392 "cannot create a reduction variable if the type is not an LLVM "
393 "pointer element");
395 rewriter.createLLVM::AllocaOp(loc, ptrType, init.getType(), one, 0);
396 rewriter.createLLVM::StoreOp(loc, init, storage);
397 reductionVariables.push_back(storage);
398 }
399
400
401
402
403 for (auto [x, y, rD] : llvm::zip_equal(
404 reductionVariables, reduce.getOperands(), ompReductionDecls)) {
407 Region &redRegion = rD.getReductionRegion();
408
409
410
411
413 "expect reduction region to have one block");
414 Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
416 rD.getType(), pvtRedVar);
417
419 builder.setInsertionPoint(reduce);
422 "expect reduction region to have two arguments");
425 for (auto &op : redRegion.getOps()) {
427 if (auto yieldOp = dyn_castomp::YieldOp(*cloneOp)) {
428 assert(yieldOp && yieldOp.getResults().size() == 1 &&
429 "expect YieldOp in reduction region to return one result");
430 Value redVal = yieldOp.getResults()[0];
431 rewriter.createLLVM::StoreOp(loc, redVal, pvtRedVar);
432 rewriter.eraseOp(yieldOp);
433 break;
434 }
435 }
436 }
438
439 Value numThreadsVar;
440 if (numThreads > 0) {
441 numThreadsVar = rewriter.createLLVM::ConstantOp(
443 }
444
445 auto ompParallel = rewriter.createomp::ParallelOp(
446 loc,
449 Value{},
450 numThreadsVar,
452 nullptr,
453 nullptr,
454 omp::ClauseProcBindKindAttr{},
455 nullptr,
458 ArrayAttr{});
459 {
460
462 rewriter.createBlock(&ompParallel.getRegion());
463
464
465 {
467
468 auto wsloopOp = rewriter.createomp::WsloopOp(parallelOp.getLoc());
469 if (!reductionVariables.empty()) {
470 wsloopOp.setReductionSymsAttr(
472 wsloopOp.getReductionVarsMutable().append(reductionVariables);
474
475
476 reductionByRef.resize(reductionVariables.size(), false);
477 wsloopOp.setReductionByref(
479 }
480 rewriter.createomp::TerminatorOp(loc);
481
482
483
485 reductionTypes.reserve(reductionVariables.size());
486 llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
489 &wsloopOp.getRegion(), {}, reductionTypes,
491 parallelOp.getLoc()));
492
493
494 auto loopOp = rewriter.createomp::LoopNestOp(
495 parallelOp.getLoc(), parallelOp.getLowerBound(),
496 parallelOp.getUpperBound(), parallelOp.getStep());
497
498 rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
499 loopOp.getRegion().begin());
500
501
502
504 unsigned numLoops = parallelOp.getNumLoops();
506 loopOpEntryBlock.getArguments().drop_front(numLoops),
507 wsloopOp.getRegion().getArguments());
509 numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
510
512 rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin());
514
515 auto scope = rewriter.creatememref::AllocaScopeOp(parallelOp.getLoc(),
518 Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
521 rewriter.creatememref::AllocaScopeReturnOp(loc, ValueRange());
522 }
523 }
524
525
527 results.reserve(reductionVariables.size());
528 for (auto [variable, type] :
529 llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
530 Value res = rewriter.createLLVM::LoadOp(loc, type, variable);
531 results.push_back(res);
532 }
533 rewriter.replaceOp(parallelOp, results);
534
535 return success();
536 }
537 };
538
539
540 static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
542 target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
543 target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
544 memref::MemRefDialect>();
545
547 patterns.add(module.getContext(), numThreads);
550 }
551
552
553 struct SCFToOpenMPPass
554 : public impl::ConvertSCFToOpenMPPassBase {
555
556 using Base::Base;
557
558
559 void runOnOperation() override {
560 if (failed(applyPatterns(getOperation(), numThreads)))
561 signalPassFailure();
562 }
563 };
564
565 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::DeclareReductionOp decl, scf::ReduceOp reduce, int64_t reductionIndex)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, int64_t reductionIndex)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static omp::DeclareReductionOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class describes a specific conversion target.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...