MLIR: lib/Dialect/Vector/Transforms/VectorUnroll.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

17 #include "llvm/ADT/MapVector.h"

18 #include "llvm/ADT/STLExtras.h"

19 #include "llvm/Support/Debug.h"

20 #include "llvm/Support/InterleavedRange.h"

21 #include

22

23 #define DEBUG_TYPE "vector-unroll"

24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

25 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

26

27 using namespace mlir;

29

30

37 auto isBroadcast = [](AffineExpr expr) {

38 if (auto constExpr = dyn_cast(expr))

39 return constExpr.getValue() == 0;

40 return false;

41 };

42

45 if (isBroadcast(dim.value()))

46 continue;

47 unsigned pos = cast(dim.value()).getPosition();

50 auto map = AffineMap::get(1, 0, expr);

51 slicedIndices[pos] =

52 builder.createaffine::AffineApplyOp(loc, map, indices[pos]);

53 }

54 return slicedIndices;

55 }

56

57

58

65 }

66

67

68

69 static std::optional<SmallVector<int64_t>>

73 if (options.filterConstraint && failed(options.filterConstraint(op))) {

74 LDBG("--no filter constraint -> BAIL");

75 return std::nullopt;

76 }

77 assert(options.nativeShape &&

78 "vector unrolling expects the native shape or native"

79 "shape call back function to be set");

80 auto unrollableVectorOp = dyn_cast(op);

81 if (!unrollableVectorOp) {

82 LDBG("--not an unrollable op -> BAIL");

83 return std::nullopt;

84 }

85 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();

86 if (!maybeUnrollShape) {

87 LDBG("--could not get shape of op " << *op << " -> BAIL");

88 return std::nullopt;

89 }

90 LDBG("--vector op shape: " << llvm::interleaved(*maybeUnrollShape));

91

92 std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);

93 if (!targetShape) {

94 LDBG("--no unrolling target shape defined " << *op << "-> SKIP");

95 return std::nullopt;

96 }

97 LDBG("--target shape: " << llvm::interleaved(*targetShape));

98

99 auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);

100 if (!maybeShapeRatio) {

101 LDBG("--could not compute integral shape ratio -> BAIL");

102 return std::nullopt;

103 }

104 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {

105 LDBG("--no unrolling needed -> SKIP");

106 return std::nullopt;

107 }

108 LDBG("--found an integral shape ratio to unroll to -> SUCCESS");

109 return targetShape;

110 }

111

116 llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));

117 if (options.traversalOrderCallback != nullptr) {

118 std::optional<SmallVector<int64_t>> order =

119 options.traversalOrderCallback(op);

120 if (order) {

121 loopOrder = std::move(*order);

122 }

123 }

124 return loopOrder;

125 }

126

127 namespace {

128

129 struct UnrollTransferReadPattern

131 UnrollTransferReadPattern(MLIRContext *context,

134 : OpRewritePatternvector::TransferReadOp(context, benefit),

136

137 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,

139

140 if (readOp.getTransferRank() == 0)

141 return failure();

142 if (readOp.getMask())

143 return failure();

145 if (!targetShape)

146 return failure();

147 auto sourceVectorType = readOp.getVectorType();

149 Location loc = readOp.getLoc();

150 ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();

151

152

153 Value result = rewriter.createarith::ConstantOp(

154 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));

155 auto targetType =

156 VectorType::get(*targetShape, sourceVectorType.getElementType());

158 readOp.getIndices().end());

165 readOp.getPermutationMap(), loc, rewriter);

166 auto slicedRead = rewriter.createvector::TransferReadOp(

167 loc, targetType, readOp.getBase(), indices,

168 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),

169 readOp.getInBoundsAttr());

170

171 result = rewriter.createOrFoldvector::InsertStridedSliceOp(

172 loc, slicedRead, result, elementOffsets, strides);

173 }

174 rewriter.replaceOp(readOp, result);

175 return success();

176 }

177

178 private:

180 };

181

182 struct UnrollTransferWritePattern

184 UnrollTransferWritePattern(MLIRContext *context,

187 : OpRewritePatternvector::TransferWriteOp(context, benefit),

189

190 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,

192

193 if (writeOp.getTransferRank() == 0)

194 return failure();

195

196 if (writeOp.getMask())

197 return failure();

199 if (!targetShape)

200 return failure();

201 auto sourceVectorType = writeOp.getVectorType();

203 Location loc = writeOp.getLoc();

206 writeOp.getIndices().end());

209 Value resultTensor;

212 Value slicedVector = rewriter.createOrFoldvector::ExtractStridedSliceOp(

213 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);

216 writeOp.getPermutationMap(), loc, rewriter);

217 Operation *slicedWrite = rewriter.createvector::TransferWriteOp(

218 loc, slicedVector, resultTensor ? resultTensor : writeOp.getBase(),

219 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());

220

221 if (!slicedWrite->getResults().empty())

222 resultTensor = slicedWrite->getResult(0);

223 }

224 if (resultTensor)

225 rewriter.replaceOp(writeOp, resultTensor);

226 else

227 rewriter.eraseOp(writeOp);

228 return success();

229 }

230

231 private:

233 };

234

235 struct OffsetMapInfo {

237

239

241 return static_cast<unsigned>(llvm::hash_combine_range(v));

242 }

243

246 return lhs == rhs;

247 }

248 };

249

250 struct UnrollContractionPattern

252 UnrollContractionPattern(MLIRContext *context,

257

258 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,

261 if (!targetShape)

262 return failure();

263 auto dstVecType = cast(contractOp.getResultType());

265

266 Location loc = contractOp.getLoc();

267 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();

268 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];

269 llvm::MapVector<

272 accCache;

273

275 contractOp.getIteratorTypes().size(), contractOp, options);

276

280

281

282 auto extractOperand = [&](unsigned index, Value operand,

288 slicesOperands[index] =

289 rewriter.createOrFoldvector::ExtractStridedSliceOp(

290 loc, operand, operandOffets, operandShape, operandStrides);

291 };

292

293

294 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];

297 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);

298

299

300 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];

303 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);

304

305 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];

308

309

310 auto *accIt = accCache.find(accOffets);

311 if (accIt != accCache.end())

312 slicesOperands[2] = accIt->second;

313 else

314 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);

315

318 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());

320 rewriter, loc, contractOp, slicesOperands, targetType);

321

324

325

326 accCache[dstOffets] = newOp->getResult(0);

327 }

328

329 Value result = rewriter.createarith::ConstantOp(

330 loc, dstVecType, rewriter.getZeroAttr(dstVecType));

331 for (const auto &it : accCache) {

333 result = rewriter.createOrFoldvector::InsertStridedSliceOp(

334 loc, it.second, result, it.first, dstStrides);

335 }

336 rewriter.replaceOp(contractOp, result);

337 return success();

338 }

339

340 private:

342 };

343

344 struct UnrollMultiReductionPattern

346 UnrollMultiReductionPattern(MLIRContext *context,

349 : OpRewritePatternvector::MultiDimReductionOp(context, benefit),

351

352 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,

354 auto resultType = reductionOp->getResult(0).getType();

355 if (resultType.isIntOrFloat()) {

357 "Unrolling scalars is not supported");

358 }

359 std::optional<SmallVector<int64_t>> targetShape =

361 if (!targetShape)

362 return failure();

364 llvm::MapVector<

367 accCache;

368 Location loc = reductionOp.getLoc();

369

370

371

376 Value slicedOperand =

377 rewriter.createOrFoldvector::ExtractStridedSliceOp(

378 loc, reductionOp.getSource(), offsets, *targetShape,

379 operandStrides);

380 operands.push_back(slicedOperand);

383 for (size_t i : llvm::seq(size_t(0), targetShape->size())) {

384 if (!reductionOp.isReducedDim(i)) {

385 destOffset.push_back(offsets[i]);

386 dstShape.push_back((*targetShape)[i]);

387 }

388 }

391

392

393 auto *accIt = accCache.find(destOffset);

394 if (accIt != accCache.end())

395 acc = accIt->second;

396 else

397 acc = rewriter.createOrFoldvector::ExtractStridedSliceOp(

398 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);

399 operands.push_back(acc);

401 dstShape, reductionOp.getSourceVectorType().getElementType());

403 operands, targetType);

405 accCache[destOffset] = result;

406 }

407

408 Value result = rewriter.createarith::ConstantOp(

409 loc, reductionOp.getDestType(),

410 rewriter.getZeroAttr(reductionOp.getDestType()));

411 for (const auto &it : accCache) {

413 result = rewriter.createOrFoldvector::InsertStridedSliceOp(

414 loc, it.second, result, it.first, dstStrides);

415 }

416 rewriter.replaceOp(reductionOp, result);

417 return success();

418 }

419

420 private:

422 };

423

424 struct UnrollElementwisePattern : public RewritePattern {

425 UnrollElementwisePattern(MLIRContext *context,

428 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),

430

431 LogicalResult matchAndRewrite(Operation *op,

434 return failure();

436 if (!targetShape)

437 return failure();

438 auto dstVecType = cast(op->getResult(0).getType());

440 *cast(op).getShapeForUnroll();

441

442

443

444 if (originalSize.size() != targetShape->size())

446 op, "expected input vector rank to match target shape rank");

448

449 Value result = rewriter.createarith::ConstantOp(

450 loc, dstVecType, rewriter.getZeroAttr(dstVecType));

452 VectorType newVecType =

453 VectorType::get(*targetShape, dstVecType.getElementType());

454

455

460 auto vecType = dyn_cast(operand.get().getType());

461 if (!vecType) {

462 extractOperands.push_back(operand.get());

463 continue;

464 }

465 extractOperands.push_back(

466 rewriter.createOrFoldvector::ExtractStridedSliceOp(

467 loc, operand.get(), offsets, *targetShape, strides));

468 }

470 rewriter, loc, op, extractOperands, newVecType);

471 result = rewriter.createOrFoldvector::InsertStridedSliceOp(

472 loc, newOp->getResult(0), result, offsets, strides);

473 }

475 return success();

476 }

477

478 private:

480 };

481

482 struct UnrollReductionPattern : public OpRewritePatternvector::ReductionOp {

483 UnrollReductionPattern(MLIRContext *context,

488

489 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,

491 std::optional<SmallVector<int64_t>> targetShape =

493 if (!targetShape)

494 return failure();

496

497

498 Location loc = reductionOp.getLoc();

499 Value accumulator = nullptr;

503 Value slicedOperand =

504 rewriter.createOrFoldvector::ExtractStridedSliceOp(

505 loc, reductionOp.getVector(), offsets, *targetShape, strides);

507 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());

508 Value result = newOp->getResult(0);

509

510 if (!accumulator) {

511

512 accumulator = result;

513 } else {

514

515 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),

516 accumulator, result);

517 }

518 }

519

520 rewriter.replaceOp(reductionOp, accumulator);

521 return success();

522 }

523

524 private:

526 };

527

528 struct UnrollTransposePattern : public OpRewritePatternvector::TransposeOp {

529 UnrollTransposePattern(MLIRContext *context,

534

535 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,

537 if (transposeOp.getResultVectorType().getRank() == 0)

538 return failure();

540 if (!targetShape)

541 return failure();

542 auto originalVectorType = transposeOp.getResultVectorType();

544 Location loc = transposeOp.getLoc();

546

547

548 Value result = rewriter.createarith::ConstantOp(

549 loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));

551

552

557

559 permutedOffsets[indices.value()] = elementOffsets[indices.index()];

560 permutedShape[indices.value()] = (*targetShape)[indices.index()];

561 }

562 Value slicedOperand =

563 rewriter.createOrFoldvector::ExtractStridedSliceOp(

564 loc, transposeOp.getVector(), permutedOffsets, permutedShape,

565 strides);

566 Value transposedSlice = rewriter.createOrFoldvector::TransposeOp(

567 loc, slicedOperand, permutation);

568 result = rewriter.createOrFoldvector::InsertStridedSliceOp(

569 loc, transposedSlice, result, elementOffsets, strides);

570 }

571 rewriter.replaceOp(transposeOp, result);

572 return success();

573 }

574

575 private:

577 };

578

579 struct UnrollGatherPattern : public OpRewritePatternvector::GatherOp {

580 UnrollGatherPattern(MLIRContext *context,

584 }

585

586 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,

588 VectorType sourceVectorType = gatherOp.getVectorType();

589 if (sourceVectorType.getRank() == 0)

590 return failure();

592 if (!targetShape)

593 return failure();

595 Location loc = gatherOp.getLoc();

596 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();

597

598

599 Value result = rewriter.createarith::ConstantOp(

600 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));

601 auto targetType =

602 VectorType::get(*targetShape, sourceVectorType.getElementType());

603

608

609

610

611 Value indexSubVec = rewriter.createOrFoldvector::ExtractStridedSliceOp(

612 loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);

613 Value maskSubVec = rewriter.createOrFoldvector::ExtractStridedSliceOp(

614 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);

615 Value passThruSubVec =

616 rewriter.createOrFoldvector::ExtractStridedSliceOp(

617 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,

618 strides);

619 auto slicedGather = rewriter.createvector::GatherOp(

620 loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),

621 indexSubVec, maskSubVec, passThruSubVec);

622

623 result = rewriter.createOrFoldvector::InsertStridedSliceOp(

624 loc, slicedGather, result, elementOffsets, strides);

625 }

626 rewriter.replaceOp(gatherOp, result);

627 return success();

628 }

629

630 private:

632 };

633

634 struct UnrollBroadcastPattern : public OpRewritePatternvector::BroadcastOp {

635 UnrollBroadcastPattern(MLIRContext *context,

640

641 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,

644 if (!targetShape)

645 return failure();

646

647 Location loc = broadcastOp.getLoc();

648 VectorType srcType = dyn_cast(broadcastOp.getSourceType());

649 VectorType resType = broadcastOp.getResultVectorType();

650 VectorType targetType =

651 resType.cloneWith(*targetShape, resType.getElementType());

652 Value result = rewriter.createarith::ConstantOp(

653 loc, resType, rewriter.getZeroAttr(resType));

654

657

661 if (!srcType) {

662

663 newSrc = broadcastOp.getSource();

664 } else {

665

666 int64_t rank = srcType.getRank();

669 targetShape->end());

671

672 for (int64_t i = 0; i < rank; ++i) {

673 if (srcType.getDimSize(i) == 1) {

674 srcOffsets[i] = 0;

675 srcShape[i] = 1;

676 }

677 }

678 newSrc = rewriter.createOrFoldvector::ExtractStridedSliceOp(

679 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);

680 }

681

683 newSrc, targetType);

684

685 result = rewriter.createOrFoldvector::InsertStridedSliceOp(

686 loc, newOp->getResult(0), result, offsets, strides);

687 }

688

689 rewriter.replaceOp(broadcastOp, result);

690 return success();

691 }

692

693 private:

695 };

696

697 }

698

699 void mlir::vector::populateVectorUnrollPatterns(

703 .add<UnrollTransferReadPattern, UnrollTransferWritePattern,

704 UnrollContractionPattern, UnrollElementwisePattern,

705 UnrollReductionPattern, UnrollMultiReductionPattern,

706 UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(

708 }

static llvm::ManagedStatic< PassManagerOptions > options

static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)

Compute the indices of the slice index for a transfer op.

static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)

Return the target shape for unrolling for the given op.

static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)

static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)

Base type for affine expression.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

ArrayRef< AffineExpr > getResults() const

TypedAttr getZeroAttr(Type type)

MLIRContext * getContext() const

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

This class helps build Operations.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents an operand of an operation.

StringRef getStringRef() const

Return the name of this operation. This always succeeds.

StringAttr getIdentifier() const

Return the name of this operation as a StringAttr.

Operation is the basic unit of execution within MLIR.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

Location getLoc()

The source location the operation was defined or derived from.

ArrayRef< NamedAttribute > getAttrs()

Return all of the attributes on this operation.

OperationName getName()

The name of an operation is the key identifier for it.

MutableArrayRef< OpOperand > getOpOperands()

result_range getResults()

unsigned getNumResults()

Return the number of results held by this operation.

This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

RewritePattern is the common base class for all DAG to DAG replacements.

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

bool hasElementwiseMappableTraits(Operation *op)

Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)

Returns the result value of reducing two scalar/vector values with the corresponding arith operation.

Include the generated interface declarations.

SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)

Apply a permutation from map to source and return the result.

const FrozenRewritePatternSet & patterns

AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)

Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.

AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)

These free functions allow clients of the API to not use classes in detail.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

Options that control the vector unrolling.