MLIR: lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14 #include
15 #include <type_traits>
16
23
28
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35
36 #define DEBUG_TYPE "vector-transfer-split"
37
38 using namespace mlir;
40
41
42
44 VectorTransferOpInterface xferOp) {
45 assert(xferOp.getPermutationMap().isMinorIdentity() &&
46 "Expected minor identity map");
47 Value inBoundsCond;
48 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
49
50
51
52 if (xferOp.isDimInBounds(resultIdx))
53 return;
54
55 Location loc = xferOp.getLoc();
56 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
59 {xferOp.getIndices()[indicesIdx]});
64 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
65 return;
67 b.createarith::CmpIOp(loc, arith::CmpIPredicate::sle,
70
71 if (inBoundsCond)
72 inBoundsCond = b.createarith::AndIOp(loc, inBoundsCond, cond);
73 else
74 inBoundsCond = cond;
75 });
76 return inBoundsCond;
77 }
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112 static LogicalResult
114
115 if (xferOp.getTransferRank() == 0)
116 return failure();
117
118
119 if (!xferOp.getPermutationMap().isMinorIdentity())
120 return failure();
121
122 if (!xferOp.hasOutOfBoundsDim())
123 return failure();
124
125
126
127 if (isascf::IfOp(xferOp->getParentOp()))
128 return failure();
129 return success();
130 }
131
132
133
134
135
136
137
138
139
140
142 if (memref::CastOp::areCastCompatible(aT, bT))
143 return aT;
144 if (aT.getRank() != bT.getRank())
145 return MemRefType();
146 int64_t aOffset, bOffset;
148 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
149 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
150 aStrides.size() != bStrides.size())
151 return MemRefType();
152
153 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
154 int64_t resOffset;
156 resStrides(bT.getRank(), 0);
157 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
158 resShape[idx] =
159 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
160 resStrides[idx] =
161 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
162 }
163 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
165 resShape, aT.getElementType(),
167 }
168
169
170
171
173 MemRefType compatibleMemRefType) {
174 MemRefType sourceType = cast(memref.getType());
175 Value res = memref;
176 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
178 sourceType.getShape(), sourceType.getElementType(),
179 sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
180 res = b.creatememref::MemorySpaceCastOp(memref.getLoc(), sourceType, res);
181 }
182 if (sourceType == compatibleMemRefType)
183 return res;
184 return b.creatememref::CastOp(memref.getLoc(), compatibleMemRefType, res);
185 }
186
187
188
189
190 static std::pair<Value, Value>
193 Location loc = xferOp.getLoc();
194 int64_t memrefRank = xferOp.getShapedType().getRank();
195
196 assert(memrefRank == cast(alloc.getType()).getRank() &&
197 "Expected memref rank to match the alloc rank");
199 xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
201 sizes.append(leadingIndices.begin(), leadingIndices.end());
202 auto isaWrite = isavector::TransferWriteOp(xferOp);
203 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
205 Value dimMemRef =
206 b.creatememref::DimOp(xferOp.getLoc(), xferOp.getBase(), indicesIdx);
207 Value dimAlloc = b.creatememref::DimOp(loc, alloc, resultIdx);
208 Value index = xferOp.getIndices()[indicesIdx];
210 bindDims(xferOp.getContext(), i, j, k);
213
214 Value affineMin = b.createaffine::AffineMinOp(
215 loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
216 sizes.push_back(affineMin);
217 });
218
220 xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; }));
223 auto copySrc = b.creatememref::SubViewOp(
224 loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);
225 auto copyDest = b.creatememref::SubViewOp(
226 loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);
227 return std::make_pair(copySrc, copyDest);
228 }
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249 static scf::IfOp
252 MemRefType compatibleMemRefType, Value alloc) {
253 Location loc = xferOp.getLoc();
254 Value zero = b.createarith::ConstantIndexOp(loc, 0);
255 Value memref = xferOp.getBase();
257 loc, inBoundsCond,
261 llvm::append_range(viewAndIndices, xferOp.getIndices());
262 b.createscf::YieldOp(loc, viewAndIndices);
263 },
265 b.createlinalg::FillOp(loc, ValueRange{xferOp.getPadding()},
267
268
271 rewriter, cast(xferOp.getOperation()),
272 alloc);
273 b.creatememref::CopyOp(loc, copyArgs.first, copyArgs.second);
277 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
278 zero);
279 b.createscf::YieldOp(loc, viewAndIndices);
280 });
281 }
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
304 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
305 Location loc = xferOp.getLoc();
306 scf::IfOp fullPartialIfOp;
307 Value zero = b.createarith::ConstantIndexOp(loc, 0);
308 Value memref = xferOp.getBase();
310 loc, inBoundsCond,
314 llvm::append_range(viewAndIndices, xferOp.getIndices());
315 b.createscf::YieldOp(loc, viewAndIndices);
316 },
318 Operation *newXfer = b.clone(*xferOp.getOperation());
319 Value vector = cast(newXfer).getVector();
320 b.creatememref::StoreOp(
321 loc, vector,
322 b.createvector::TypeCastOp(
324
328 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
329 zero);
330 b.createscf::YieldOp(loc, viewAndIndices);
331 });
332 }
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
352 MemRefType compatibleMemRefType, Value alloc) {
353 Location loc = xferOp.getLoc();
354 Value zero = b.createarith::ConstantIndexOp(loc, 0);
355 Value memref = xferOp.getBase();
356 return b
358 loc, inBoundsCond,
363 llvm::append_range(viewAndIndices, xferOp.getIndices());
364 b.createscf::YieldOp(loc, viewAndIndices);
365 },
370 viewAndIndices.insert(viewAndIndices.end(),
371 xferOp.getTransferRank(), zero);
372 b.createscf::YieldOp(loc, viewAndIndices);
373 })
374 ->getResults();
375 }
376
377
378
379
380
381
382
383
384
385
386
387
388
389
391 vector::TransferWriteOp xferOp,
393 Location loc = xferOp.getLoc();
394 auto notInBounds = b.createarith::XOrIOp(
395 loc, inBoundsCond, b.createarith::ConstantIntOp(loc, true, 1));
399 rewriter, cast(xferOp.getOperation()),
400 alloc);
401 b.creatememref::CopyOp(loc, copyArgs.first, copyArgs.second);
403 });
404 }
405
406
407
408
409
410
411
412
413
414
415
416
417
419 vector::TransferWriteOp xferOp,
420 Value inBoundsCond,
422 Location loc = xferOp.getLoc();
423 auto notInBounds = b.createarith::XOrIOp(
424 loc, inBoundsCond, b.createarith::ConstantIntOp(loc, true, 1));
428 loc,
429 b.createvector::TypeCastOp(
432 mapping.map(xferOp.getVector(), load);
433 b.clone(*xferOp.getOperation(), mapping);
435 });
436 }
437
438
440
441
442
447 scope = parent;
448 if (!isa<scf::ForOp, affine::AffineForOp>(parent))
449 break;
450 }
451 assert(scope && "Expected op to be inside automatic allocation scope");
452 return scope;
453 }
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
516 RewriterBase &b, VectorTransferOpInterface xferOp,
519 return failure();
520
523 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
525 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
526 });
527 return success();
528 }
529
530
531
532 {
534 "Expected splitFullAndPartialTransferPrecondition to hold");
535
536 auto xferReadOp = dyn_castvector::TransferReadOp(xferOp.getOperation());
537 auto xferWriteOp = dyn_castvector::TransferWriteOp(xferOp.getOperation());
538
539 if (!(xferReadOp || xferWriteOp))
540 return failure();
541 if (xferWriteOp && xferWriteOp.getMask())
542 return failure();
543 if (xferReadOp && xferReadOp.getMask())
544 return failure();
545 }
546
550 b, cast(xferOp.getOperation()));
551 if (!inBoundsCond)
552 return failure();
553
554
556 {
560 "AutomaticAllocationScope with >1 regions");
562 auto shape = xferOp.getVectorType().getShape();
563 Type elementType = xferOp.getVectorType().getElementType();
564 alloc = b.creatememref::AllocaOp(scope->getLoc(),
567 }
568
569 MemRefType compatibleMemRefType =
571 cast(alloc.getType()));
572 if (!compatibleMemRefType)
573 return failure();
574
577 returnTypes[0] = compatibleMemRefType;
578
579 if (auto xferReadOp =
580 dyn_castvector::TransferReadOp(xferOp.getOperation())) {
581
582 scf::IfOp fullPartialIfOp =
583 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
585 inBoundsCond,
586 compatibleMemRefType, alloc)
588 inBoundsCond, compatibleMemRefType,
589 alloc);
590 if (ifOp)
591 *ifOp = fullPartialIfOp;
592
593
594 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
595 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
596
598 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
599 });
600
601 return success();
602 }
603
604 auto xferWriteOp = castvector::TransferWriteOp(xferOp.getOperation());
605
606
608 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
609
610
611
612
614 mapping.map(xferWriteOp.getBase(), memrefAndIndices.front());
615 mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
616 auto *clone = b.clone(*xferWriteOp, mapping);
617 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
618
619
620
621 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
623 else
625
627
628 return success();
629 }
630
631 namespace {
632
633
634 struct VectorTransferFullPartialRewriter : public RewritePattern {
635 using FilterConstraintType =
636 std::function<LogicalResult(VectorTransferOpInterface op)>;
637
638 explicit VectorTransferFullPartialRewriter(
641 FilterConstraintType filter =
642 [](VectorTransferOpInterface op) { return success(); },
645 filter(std::move(filter)) {}
646
647
648 LogicalResult matchAndRewrite(Operation *op,
650
651 private:
653 FilterConstraintType filter;
654 };
655
656 }
657
658 LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
660 auto xferOp = dyn_cast(op);
662 failed(filter(xferOp)))
663 return failure();
665 }
666
669 patterns.add(patterns.getContext(),
671 }
static llvm::ManagedStatic< PassManagerOptions > options
static scf::IfOp createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static void createFullPartialVectorTransferWrite(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc)
Given an xferOp for which:
static ValueRange getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static scf::IfOp createFullPartialVectorTransferRead(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fast path and a ...
static std::pair< Value, Value > createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Value alloc)
Operates under a scoped context to build the intersection between the view xferOp....
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT)
Given two MemRefTypes aT and bT, return a MemRefType to which both can be cast.
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
static Operation * getAutomaticAllocationScope(Operation *op)
static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, MemRefType compatibleMemRefType)
Casts the given memref to a compatible memref type.
Base type for affine expression.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
LogicalResult splitFullAndPartialTransfer(RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options=VectorTransformsOptions(), scf::IfOp *ifOp=nullptr)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Structure to control the behavior of vector transform patterns.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.