MLIR: lib/Dialect/Vector/Transforms/VectorUnroll.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/InterleavedRange.h"
21 #include
22
23 #define DEBUG_TYPE "vector-unroll"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
26
27 using namespace mlir;
29
30
37 auto isBroadcast = [](AffineExpr expr) {
38 if (auto constExpr = dyn_cast(expr))
39 return constExpr.getValue() == 0;
40 return false;
41 };
42
45 if (isBroadcast(dim.value()))
46 continue;
47 unsigned pos = cast(dim.value()).getPosition();
50 auto map = AffineMap::get(1, 0, expr);
51 slicedIndices[pos] =
52 builder.createaffine::AffineApplyOp(loc, map, indices[pos]);
53 }
54 return slicedIndices;
55 }
56
57
58
65 }
66
67
68
69 static std::optional<SmallVector<int64_t>>
73 if (options.filterConstraint && failed(options.filterConstraint(op))) {
74 LDBG("--no filter constraint -> BAIL");
75 return std::nullopt;
76 }
77 assert(options.nativeShape &&
78 "vector unrolling expects the native shape or native"
79 "shape call back function to be set");
80 auto unrollableVectorOp = dyn_cast(op);
81 if (!unrollableVectorOp) {
82 LDBG("--not an unrollable op -> BAIL");
83 return std::nullopt;
84 }
85 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
86 if (!maybeUnrollShape) {
87 LDBG("--could not get shape of op " << *op << " -> BAIL");
88 return std::nullopt;
89 }
90 LDBG("--vector op shape: " << llvm::interleaved(*maybeUnrollShape));
91
92 std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
93 if (!targetShape) {
94 LDBG("--no unrolling target shape defined " << *op << "-> SKIP");
95 return std::nullopt;
96 }
97 LDBG("--target shape: " << llvm::interleaved(*targetShape));
98
99 auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
100 if (!maybeShapeRatio) {
101 LDBG("--could not compute integral shape ratio -> BAIL");
102 return std::nullopt;
103 }
104 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
105 LDBG("--no unrolling needed -> SKIP");
106 return std::nullopt;
107 }
108 LDBG("--found an integral shape ratio to unroll to -> SUCCESS");
109 return targetShape;
110 }
111
116 llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
117 if (options.traversalOrderCallback != nullptr) {
118 std::optional<SmallVector<int64_t>> order =
119 options.traversalOrderCallback(op);
120 if (order) {
121 loopOrder = std::move(*order);
122 }
123 }
124 return loopOrder;
125 }
126
127 namespace {
128
129 struct UnrollTransferReadPattern
131 UnrollTransferReadPattern(MLIRContext *context,
134 : OpRewritePatternvector::TransferReadOp(context, benefit),
136
137 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
139
140 if (readOp.getTransferRank() == 0)
141 return failure();
142 if (readOp.getMask())
143 return failure();
145 if (!targetShape)
146 return failure();
147 auto sourceVectorType = readOp.getVectorType();
149 Location loc = readOp.getLoc();
150 ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
151
152
153 Value result = rewriter.createarith::ConstantOp(
154 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
155 auto targetType =
156 VectorType::get(*targetShape, sourceVectorType.getElementType());
158 readOp.getIndices().end());
165 readOp.getPermutationMap(), loc, rewriter);
166 auto slicedRead = rewriter.createvector::TransferReadOp(
167 loc, targetType, readOp.getBase(), indices,
168 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
169 readOp.getInBoundsAttr());
170
171 result = rewriter.createOrFoldvector::InsertStridedSliceOp(
172 loc, slicedRead, result, elementOffsets, strides);
173 }
174 rewriter.replaceOp(readOp, result);
175 return success();
176 }
177
178 private:
180 };
181
182 struct UnrollTransferWritePattern
184 UnrollTransferWritePattern(MLIRContext *context,
187 : OpRewritePatternvector::TransferWriteOp(context, benefit),
189
190 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
192
193 if (writeOp.getTransferRank() == 0)
194 return failure();
195
196 if (writeOp.getMask())
197 return failure();
199 if (!targetShape)
200 return failure();
201 auto sourceVectorType = writeOp.getVectorType();
203 Location loc = writeOp.getLoc();
206 writeOp.getIndices().end());
209 Value resultTensor;
212 Value slicedVector = rewriter.createOrFoldvector::ExtractStridedSliceOp(
213 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
216 writeOp.getPermutationMap(), loc, rewriter);
217 Operation *slicedWrite = rewriter.createvector::TransferWriteOp(
218 loc, slicedVector, resultTensor ? resultTensor : writeOp.getBase(),
219 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
220
221 if (!slicedWrite->getResults().empty())
222 resultTensor = slicedWrite->getResult(0);
223 }
224 if (resultTensor)
225 rewriter.replaceOp(writeOp, resultTensor);
226 else
227 rewriter.eraseOp(writeOp);
228 return success();
229 }
230
231 private:
233 };
234
235 struct OffsetMapInfo {
237
239
241 return static_cast<unsigned>(llvm::hash_combine_range(v));
242 }
243
246 return lhs == rhs;
247 }
248 };
249
250 struct UnrollContractionPattern
252 UnrollContractionPattern(MLIRContext *context,
257
258 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
261 if (!targetShape)
262 return failure();
263 auto dstVecType = cast(contractOp.getResultType());
265
266 Location loc = contractOp.getLoc();
267 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
268 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
269 llvm::MapVector<
272 accCache;
273
275 contractOp.getIteratorTypes().size(), contractOp, options);
276
280
281
282 auto extractOperand = [&](unsigned index, Value operand,
288 slicesOperands[index] =
289 rewriter.createOrFoldvector::ExtractStridedSliceOp(
290 loc, operand, operandOffets, operandShape, operandStrides);
291 };
292
293
294 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
297 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
298
299
300 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
303 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
304
305 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
308
309
310 auto *accIt = accCache.find(accOffets);
311 if (accIt != accCache.end())
312 slicesOperands[2] = accIt->second;
313 else
314 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
315
318 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
320 rewriter, loc, contractOp, slicesOperands, targetType);
321
324
325
326 accCache[dstOffets] = newOp->getResult(0);
327 }
328
329 Value result = rewriter.createarith::ConstantOp(
330 loc, dstVecType, rewriter.getZeroAttr(dstVecType));
331 for (const auto &it : accCache) {
333 result = rewriter.createOrFoldvector::InsertStridedSliceOp(
334 loc, it.second, result, it.first, dstStrides);
335 }
336 rewriter.replaceOp(contractOp, result);
337 return success();
338 }
339
340 private:
342 };
343
344 struct UnrollMultiReductionPattern
346 UnrollMultiReductionPattern(MLIRContext *context,
349 : OpRewritePatternvector::MultiDimReductionOp(context, benefit),
351
352 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
354 auto resultType = reductionOp->getResult(0).getType();
355 if (resultType.isIntOrFloat()) {
357 "Unrolling scalars is not supported");
358 }
359 std::optional<SmallVector<int64_t>> targetShape =
361 if (!targetShape)
362 return failure();
364 llvm::MapVector<
367 accCache;
368 Location loc = reductionOp.getLoc();
369
370
371
376 Value slicedOperand =
377 rewriter.createOrFoldvector::ExtractStridedSliceOp(
378 loc, reductionOp.getSource(), offsets, *targetShape,
379 operandStrides);
380 operands.push_back(slicedOperand);
383 for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
384 if (!reductionOp.isReducedDim(i)) {
385 destOffset.push_back(offsets[i]);
386 dstShape.push_back((*targetShape)[i]);
387 }
388 }
391
392
393 auto *accIt = accCache.find(destOffset);
394 if (accIt != accCache.end())
395 acc = accIt->second;
396 else
397 acc = rewriter.createOrFoldvector::ExtractStridedSliceOp(
398 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
399 operands.push_back(acc);
401 dstShape, reductionOp.getSourceVectorType().getElementType());
403 operands, targetType);
405 accCache[destOffset] = result;
406 }
407
408 Value result = rewriter.createarith::ConstantOp(
409 loc, reductionOp.getDestType(),
410 rewriter.getZeroAttr(reductionOp.getDestType()));
411 for (const auto &it : accCache) {
413 result = rewriter.createOrFoldvector::InsertStridedSliceOp(
414 loc, it.second, result, it.first, dstStrides);
415 }
416 rewriter.replaceOp(reductionOp, result);
417 return success();
418 }
419
420 private:
422 };
423
424 struct UnrollElementwisePattern : public RewritePattern {
425 UnrollElementwisePattern(MLIRContext *context,
428 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
430
431 LogicalResult matchAndRewrite(Operation *op,
434 return failure();
436 if (!targetShape)
437 return failure();
438 auto dstVecType = cast(op->getResult(0).getType());
440 *cast(op).getShapeForUnroll();
441
442
443
444 if (originalSize.size() != targetShape->size())
446 op, "expected input vector rank to match target shape rank");
448
449 Value result = rewriter.createarith::ConstantOp(
450 loc, dstVecType, rewriter.getZeroAttr(dstVecType));
452 VectorType newVecType =
453 VectorType::get(*targetShape, dstVecType.getElementType());
454
455
460 auto vecType = dyn_cast(operand.get().getType());
461 if (!vecType) {
462 extractOperands.push_back(operand.get());
463 continue;
464 }
465 extractOperands.push_back(
466 rewriter.createOrFoldvector::ExtractStridedSliceOp(
467 loc, operand.get(), offsets, *targetShape, strides));
468 }
470 rewriter, loc, op, extractOperands, newVecType);
471 result = rewriter.createOrFoldvector::InsertStridedSliceOp(
472 loc, newOp->getResult(0), result, offsets, strides);
473 }
475 return success();
476 }
477
478 private:
480 };
481
482 struct UnrollReductionPattern : public OpRewritePatternvector::ReductionOp {
483 UnrollReductionPattern(MLIRContext *context,
488
489 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
491 std::optional<SmallVector<int64_t>> targetShape =
493 if (!targetShape)
494 return failure();
496
497
498 Location loc = reductionOp.getLoc();
499 Value accumulator = nullptr;
503 Value slicedOperand =
504 rewriter.createOrFoldvector::ExtractStridedSliceOp(
505 loc, reductionOp.getVector(), offsets, *targetShape, strides);
507 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
508 Value result = newOp->getResult(0);
509
510 if (!accumulator) {
511
512 accumulator = result;
513 } else {
514
515 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
516 accumulator, result);
517 }
518 }
519
520 rewriter.replaceOp(reductionOp, accumulator);
521 return success();
522 }
523
524 private:
526 };
527
528 struct UnrollTransposePattern : public OpRewritePatternvector::TransposeOp {
529 UnrollTransposePattern(MLIRContext *context,
534
535 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
537 if (transposeOp.getResultVectorType().getRank() == 0)
538 return failure();
540 if (!targetShape)
541 return failure();
542 auto originalVectorType = transposeOp.getResultVectorType();
544 Location loc = transposeOp.getLoc();
546
547
548 Value result = rewriter.createarith::ConstantOp(
549 loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
551
552
557
559 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
560 permutedShape[indices.value()] = (*targetShape)[indices.index()];
561 }
562 Value slicedOperand =
563 rewriter.createOrFoldvector::ExtractStridedSliceOp(
564 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
565 strides);
566 Value transposedSlice = rewriter.createOrFoldvector::TransposeOp(
567 loc, slicedOperand, permutation);
568 result = rewriter.createOrFoldvector::InsertStridedSliceOp(
569 loc, transposedSlice, result, elementOffsets, strides);
570 }
571 rewriter.replaceOp(transposeOp, result);
572 return success();
573 }
574
575 private:
577 };
578
579 struct UnrollGatherPattern : public OpRewritePatternvector::GatherOp {
580 UnrollGatherPattern(MLIRContext *context,
584 }
585
586 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
588 VectorType sourceVectorType = gatherOp.getVectorType();
589 if (sourceVectorType.getRank() == 0)
590 return failure();
592 if (!targetShape)
593 return failure();
595 Location loc = gatherOp.getLoc();
596 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
597
598
599 Value result = rewriter.createarith::ConstantOp(
600 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
601 auto targetType =
602 VectorType::get(*targetShape, sourceVectorType.getElementType());
603
608
609
610
611 Value indexSubVec = rewriter.createOrFoldvector::ExtractStridedSliceOp(
612 loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
613 Value maskSubVec = rewriter.createOrFoldvector::ExtractStridedSliceOp(
614 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
615 Value passThruSubVec =
616 rewriter.createOrFoldvector::ExtractStridedSliceOp(
617 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
618 strides);
619 auto slicedGather = rewriter.createvector::GatherOp(
620 loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
621 indexSubVec, maskSubVec, passThruSubVec);
622
623 result = rewriter.createOrFoldvector::InsertStridedSliceOp(
624 loc, slicedGather, result, elementOffsets, strides);
625 }
626 rewriter.replaceOp(gatherOp, result);
627 return success();
628 }
629
630 private:
632 };
633
634 struct UnrollBroadcastPattern : public OpRewritePatternvector::BroadcastOp {
635 UnrollBroadcastPattern(MLIRContext *context,
640
641 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
644 if (!targetShape)
645 return failure();
646
647 Location loc = broadcastOp.getLoc();
648 VectorType srcType = dyn_cast(broadcastOp.getSourceType());
649 VectorType resType = broadcastOp.getResultVectorType();
650 VectorType targetType =
651 resType.cloneWith(*targetShape, resType.getElementType());
652 Value result = rewriter.createarith::ConstantOp(
653 loc, resType, rewriter.getZeroAttr(resType));
654
657
661 if (!srcType) {
662
663 newSrc = broadcastOp.getSource();
664 } else {
665
666 int64_t rank = srcType.getRank();
669 targetShape->end());
671
672 for (int64_t i = 0; i < rank; ++i) {
673 if (srcType.getDimSize(i) == 1) {
674 srcOffsets[i] = 0;
675 srcShape[i] = 1;
676 }
677 }
678 newSrc = rewriter.createOrFoldvector::ExtractStridedSliceOp(
679 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
680 }
681
683 newSrc, targetType);
684
685 result = rewriter.createOrFoldvector::InsertStridedSliceOp(
686 loc, newOp->getResult(0), result, offsets, strides);
687 }
688
689 rewriter.replaceOp(broadcastOp, result);
690 return success();
691 }
692
693 private:
695 };
696
697 }
698
699 void mlir::vector::populateVectorUnrollPatterns(
703 .add<UnrollTransferReadPattern, UnrollTransferWritePattern,
704 UnrollContractionPattern, UnrollElementwisePattern,
705 UnrollReductionPattern, UnrollMultiReductionPattern,
706 UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
708 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.