MLIR: lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

15

27

28 namespace mlir {

29 #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS

30 #include "mlir/Conversion/Passes.h.inc"

31 }

32

33 using namespace mlir;

34

35

36

37

38

39

40

41 template <typename... OpTy>

43 if (block.empty() || llvm::hasSingleElement(block) ||

44 std::next(block.begin(), 2) != block.end())

45 return false;

46

48 return false;

49

52 0, combinerOps);

53

54 if (!reducedVal || !isa(reducedVal) || combinerOps.size() != 1)

55 return false;

56

57 return isa<OpTy...>(combinerOps[0]) &&

58 isascf::ReduceReturnOp(block.back()) &&

60 }

61

62

63

64

65

66

67

68

69

70

71

72

73 template <

74 typename CompareOpTy, typename SelectOpTy,

75 typename Predicate = decltype(std::declval().getPredicate())>

76 static bool

79 static_assert(

80 llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,

81 "only arithmetic and llvm select ops are supported");

82

83

84 if (block.empty() || llvm::hasSingleElement(block) ||

85 std::next(block.begin(), 2) == block.end() ||

86 std::next(block.begin(), 3) != block.end())

87 return false;

88

89

90 auto compare = dyn_cast(block.front());

91 auto select = dyn_cast(block.front().getNextNode());

92 auto terminator = dyn_castscf::ReduceReturnOp(block.back());

93 if (compare || !select || !terminator)

94 return false;

95

96

98 return false;

99

100

101 bool isLess;

102 if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {

103 isLess = true;

104 } else if (llvm::is_contained(greaterThanPredicates,

105 compare.getPredicate())) {

106 isLess = false;

107 } else {

108 return false;

109 }

110

111 if (select.getCondition() != compare.getResult())

112 return false;

113

114

115

116

117

118

119 constexpr unsigned kTrueValue = 1;

120 constexpr unsigned kFalseValue = 2;

121 bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&

122 select.getOperand(kFalseValue) == compare.getRhs();

123 bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&

124 select.getOperand(kFalseValue) == compare.getLhs();

125 if (!sameOperands && !swappedOperands)

126 return false;

127

128 if (select.getResult() != terminator.getResult())

129 return false;

130

131

132

133 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);

134 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);

135 }

136

137

139 if (type.isF16())

140 return llvm::APFloat::IEEEhalf();

141 if (type.isF32())

142 return llvm::APFloat::IEEEsingle();

143 if (type.isF64())

144 return llvm::APFloat::IEEEdouble();

145 if (type.isF128())

146 return llvm::APFloat::IEEEquad();

147 if (type.isBF16())

148 return llvm::APFloat::BFloat();

149 if (type.isF80())

150 return llvm::APFloat::x87DoubleExtended();

151 llvm_unreachable("unknown float type");

152 }

153

154

155

157 auto fltType = cast(type);

160 }

161

162

163

164

166 auto intType = cast(type);

167 unsigned bitwidth = intType.getWidth();

168 return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)

169 : llvm::APInt::getSignedMaxValue(bitwidth));

170 }

171

172

173

174

176 auto intType = cast(type);

177 unsigned bitwidth = intType.getWidth();

179 : llvm::APInt::getAllOnes(bitwidth));

180 }

181

182

183

184

185

186 static omp::DeclareReductionOp

188 scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {

192 "__scf_reduction", type);

193 symbolTable.insert(decl);

194

195 builder.createBlock(&decl.getInitializerRegion(),

196 decl.getInitializerRegion().end(), {type},

197 {reduce.getOperands()[reductionIndex].getLoc()});

202

204 &reduce.getReductions()[reductionIndex].front().back();

205 assert(isascf::ReduceReturnOp(terminator) &&

206 "expected reduce op to be terminated by redure return");

209 terminator->getOperands());

211 decl.getReductionRegion(),

212 decl.getReductionRegion().end());

213 return decl;

214 }

215

216

217

219 LLVM::AtomicBinOp atomicKind,

220 omp::DeclareReductionOp decl,

221 scf::ReduceOp reduce,

222 int64_t reductionIndex) {

226 builder.createBlock(&decl.getAtomicReductionRegion(),

227 decl.getAtomicReductionRegion().end(), {ptrType, ptrType},

228 {reduceOperandLoc, reduceOperandLoc});

229 Block *atomicBlock = &decl.getAtomicReductionRegion().back();

235 LLVM::AtomicOrdering::monotonic);

237 return decl;

238 }

239

240

241

242

243

245 scf::ReduceOp reduce,

246 int64_t reductionIndex) {

249

250

251

253 while (insertionPoint->getParentOp() != container)

254 insertionPoint = insertionPoint->getParentOp();

257

258 assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) &&

259 "expected reduction region to have a single element");

260

261

263 Block &reduction = reduce.getReductions()[reductionIndex].front();

264 if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {

265 omp::DeclareReductionOp decl =

269 reductionIndex);

270 }

271 if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {

272 omp::DeclareReductionOp decl =

276 reductionIndex);

277 }

278 if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {

279 omp::DeclareReductionOp decl =

283 reductionIndex);

284 }

285 if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {

286 omp::DeclareReductionOp decl =

290 reductionIndex);

291 }

292 if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {

293 omp::DeclareReductionOp decl = createDecl(

294 builder, symbolTable, reduce, reductionIndex,

298 reductionIndex);

299 }

300

301

302

303

304 if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {

305 return createDecl(builder, symbolTable, reduce, reductionIndex,

307 }

308 if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {

309 return createDecl(builder, symbolTable, reduce, reductionIndex,

311 }

312

313

314 bool isMin;

315 if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(

316 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},

317 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||

318 matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(

319 reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},

320 {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {

321 return createDecl(builder, symbolTable, reduce, reductionIndex,

323 }

324 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(

325 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},

326 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||

327 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(

328 reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},

329 {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {

330 omp::DeclareReductionOp decl =

335 decl, reduce, reductionIndex);

336 }

337 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(

338 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},

339 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||

340 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(

341 reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},

342 {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {

343 omp::DeclareReductionOp decl =

347 builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,

348 decl, reduce, reductionIndex);

349 }

350

351 return nullptr;

352 }

353

354 namespace {

355

356 struct ParallelOpLowering : public OpRewritePatternscf::ParallelOp {

357 static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;

358 unsigned numThreads;

359

361 unsigned numThreads = kUseOpenMPDefaultNumThreads)

362 : OpRewritePatternscf::ParallelOp(context), numThreads(numThreads) {}

363

364 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,

366

367

368

371 auto reduce = castscf::ReduceOp(parallelOp.getBody()->getTerminator());

372 for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {

374 ompReductionDecls.push_back(decl);

375 if (!decl)

376 return failure();

377 reductionSyms.push_back(

379 }

380

381

382

383 Location loc = parallelOp.getLoc();

384 Value one = rewriter.createLLVM::ConstantOp(

387 reductionVariables.reserve(parallelOp.getNumReductions());

389 for (Value init : parallelOp.getInitVals()) {

391 isaLLVM::PointerElementTypeInterface(init.getType())) &&

392 "cannot create a reduction variable if the type is not an LLVM "

393 "pointer element");

395 rewriter.createLLVM::AllocaOp(loc, ptrType, init.getType(), one, 0);

396 rewriter.createLLVM::StoreOp(loc, init, storage);

397 reductionVariables.push_back(storage);

398 }

399

400

401

402

403 for (auto [x, y, rD] : llvm::zip_equal(

404 reductionVariables, reduce.getOperands(), ompReductionDecls)) {

407 Region &redRegion = rD.getReductionRegion();

408

409

410

411

413 "expect reduction region to have one block");

414 Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);

416 rD.getType(), pvtRedVar);

417

419 builder.setInsertionPoint(reduce);

422 "expect reduction region to have two arguments");

425 for (auto &op : redRegion.getOps()) {

427 if (auto yieldOp = dyn_castomp::YieldOp(*cloneOp)) {

428 assert(yieldOp && yieldOp.getResults().size() == 1 &&

429 "expect YieldOp in reduction region to return one result");

430 Value redVal = yieldOp.getResults()[0];

431 rewriter.createLLVM::StoreOp(loc, redVal, pvtRedVar);

432 rewriter.eraseOp(yieldOp);

433 break;

434 }

435 }

436 }

438

439 Value numThreadsVar;

440 if (numThreads > 0) {

441 numThreadsVar = rewriter.createLLVM::ConstantOp(

443 }

444

445 auto ompParallel = rewriter.createomp::ParallelOp(

446 loc,

449 Value{},

450 numThreadsVar,

452 nullptr,

453 nullptr,

454 omp::ClauseProcBindKindAttr{},

455 nullptr,

458 ArrayAttr{});

459 {

460

462 rewriter.createBlock(&ompParallel.getRegion());

463

464

465 {

467

468 auto wsloopOp = rewriter.createomp::WsloopOp(parallelOp.getLoc());

469 if (!reductionVariables.empty()) {

470 wsloopOp.setReductionSymsAttr(

472 wsloopOp.getReductionVarsMutable().append(reductionVariables);

474

475

476 reductionByRef.resize(reductionVariables.size(), false);

477 wsloopOp.setReductionByref(

479 }

480 rewriter.createomp::TerminatorOp(loc);

481

482

483

485 reductionTypes.reserve(reductionVariables.size());

486 llvm::transform(reductionVariables, std::back_inserter(reductionTypes),

489 &wsloopOp.getRegion(), {}, reductionTypes,

491 parallelOp.getLoc()));

492

493

494 auto loopOp = rewriter.createomp::LoopNestOp(

495 parallelOp.getLoc(), parallelOp.getLowerBound(),

496 parallelOp.getUpperBound(), parallelOp.getStep());

497

498 rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),

499 loopOp.getRegion().begin());

500

501

502

504 unsigned numLoops = parallelOp.getNumLoops();

506 loopOpEntryBlock.getArguments().drop_front(numLoops),

507 wsloopOp.getRegion().getArguments());

509 numLoops, loopOpEntryBlock.getNumArguments() - numLoops);

510

512 rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin());

514

515 auto scope = rewriter.creatememref::AllocaScopeOp(parallelOp.getLoc(),

518 Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());

521 rewriter.creatememref::AllocaScopeReturnOp(loc, ValueRange());

522 }

523 }

524

525

527 results.reserve(reductionVariables.size());

528 for (auto [variable, type] :

529 llvm::zip(reductionVariables, parallelOp.getResultTypes())) {

530 Value res = rewriter.createLLVM::LoadOp(loc, type, variable);

531 results.push_back(res);

532 }

533 rewriter.replaceOp(parallelOp, results);

534

535 return success();

536 }

537 };

538

539

540 static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {

542 target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();

543 target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,

544 memref::MemRefDialect>();

545

547 patterns.add(module.getContext(), numThreads);

550 }

551

552

553 struct SCFToOpenMPPass

554 : public impl::ConvertSCFToOpenMPPassBase {

555

556 using Base::Base;

557

558

559 void runOnOperation() override {

560 if (failed(applyPatterns(getOperation(), numThreads)))

561 signalPassFailure();

562 }

563 };

564

565 }

static Value getZero(OpBuilder &b, Location loc, Type elementType)

Get zero value for an element type.

static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)

static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)

We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...

static Attribute minMaxValueForFloat(Type type, bool min)

Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...

static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::DeclareReductionOp decl, scf::ReduceOp reduce, int64_t reductionIndex)

Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....

static bool matchSimpleReduction(Block &block)

Matches a block containing a "simple" reduction.

static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, int64_t reductionIndex)

Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.

static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)

Matches a block containing a select-based min/max reduction.

static const llvm::fltSemantics & fltSemanticsForType(FloatType type)

Returns the float semantics for the given float type.

static omp::DeclareReductionOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue)

Creates an OpenMP reduction declaration and inserts it into the provided symbol table.

static Attribute minMaxValueForSignedInt(Type type, bool min)

Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...

static Attribute minMaxValueForUnsignedInt(Type type, bool min)

Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

BlockArgument getArgument(unsigned i)

unsigned getNumArguments()

void eraseArguments(unsigned start, unsigned num)

Erases 'num' arguments from the index 'start'.

BlockArgListType getArguments()

IntegerAttr getI32IntegerAttr(int32_t value)

IntegerAttr getIntegerAttr(Type type, int64_t value)

FloatAttr getFloatAttr(Type type, double value)

IntegerAttr getI64IntegerAttr(int64_t value)

IntegerType getIntegerType(unsigned width)

MLIRContext * getContext() const

This class describes a specific conversion target.

This class represents a frozen set of patterns that can be processed by a pattern applicator.

This is a utility class for mapping one set of IR entities to another.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

Operation is the basic unit of execution within MLIR.

Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())

Create a deep copy of this operation, remapping any operands that use values outside of the operation...

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

operand_range getOperands()

Returns an iterator on the underlying Value's.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class contains a list of basic blocks and a link to the parent operation it is attached to.

iterator_range< OpIterator > getOps()

unsigned getNumArguments()

BlockArgument getArgument(unsigned i)

bool hasOneBlock()

Return true if this region has exactly one block.

Block * splitBlock(Block *block, Block::iterator before)

Split the operations starting at "before" (inclusive) out of the given block into a new block,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void replaceAllUsesWith(Value from, Value to)

Find uses of from and replace them with to.

void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into the end of block 'dest'.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)

Move the blocks that belong to "region" before the given position in another region "parent".

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class allows for representing and managing the symbol table used by operations with the 'SymbolT...

StringAttr insert(Operation *symbol, Block::iterator insertPt={})

Insert a new symbol into the table, and rename it as necessary to avoid collisions.

static Operation * getNearestSymbolTable(Operation *from)

Returns the nearest symbol table from a given operation from.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Location getLoc() const

Return the location of this value.

Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...

static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)

Builder from ArrayRef.

bool isCompatibleType(Type type)

Returns true if the given type is compatible with the LLVM dialect.

int compare(const Fraction &x, const Fraction &y)

Three-way comparison between two fractions.

Include the generated interface declarations.

Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)

Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...

const FrozenRewritePatternSet & patterns

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())

Below we define several entry points for operation conversion.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...