MLIR: lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
17
18using namespace mlir;
19
20
21
22template
24 return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
25 op.hasPureBufferSemantics() &&
27 .isLastDimUnitStride();
28}
29
30
31
33 if (auto transferWrite = dyn_castvector::TransferWriteOp(write))
35
36 return isavector::StoreOp(write);
37}
38
39
40
42 if (auto transferRead = dyn_castvector::TransferReadOp(read))
44
45 return isavector::LoadOp(read);
46}
47
48namespace {
49
50struct TransferMask {
51 vector::CreateMaskOp createMaskOp;
52 SmallVector<int64_t> extractPosition;
53};
54}
55
56
57
59 auto transferRead = dyn_castvector::TransferReadOp(loadOp);
60 if (!transferRead || !transferRead.getMask())
61 return TransferMask{{}, {}};
62 assert(transferRead.getMask().getType().getRank() == 1 &&
63 "expected 1-D mask");
64
65
66 if (auto maskOp =
67 transferRead.getMask().getDefiningOpvector::CreateMaskOp())
68 return TransferMask{maskOp, {}};
69
70
71 if (auto extractOp =
72 transferRead.getMask().getDefiningOpvector::ExtractOp())
73 if (auto maskOp =
74 extractOp.getSource().getDefiningOpvector::CreateMaskOp())
75 return TransferMask{maskOp,
77
78
79 return failure();
80}
81
82
85 FailureOr transferMask = getMaskOp(readOp);
86 assert(succeeded(transferMask) && "invalid transfer mask");
87
88
89 if (!transferMask->createMaskOp)
91
92
93 if (transferMask->extractPosition.empty()) {
94 assert(transferMask->createMaskOp.getNumOperands() == 1 &&
95 "expected single operand");
96 return transferMask->createMaskOp.getOperand(0);
97 }
98
99
100
101
102
103 assert(transferMask->createMaskOp.getVectorType().getRank() -
104 transferMask->extractPosition.size() ==
105 1 &&
106 "expected N-D -> (N-1)-D extract");
108
109 for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,
110 transferMask->createMaskOp->getOperands())) {
112 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt,
114 if (!cond) {
115 cond = cmp;
116 continue;
117 }
118 cond = arith::AndIOp::create(b, loc, cmp, cond);
119 }
120 return arith::SelectOp::create(
121 b, loc, cond, transferMask->createMaskOp->getOperands().back(),
123}
124
125
127 VectorType vecType) {
128 assert(vecType.getRank() == 1 && "expected 1-D vector");
129 constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};
130
131
132 bool supportedCopySize = false;
133 int64_t numElements = vecType.getNumElements();
134 Type elementType = vecType.getElementType();
135 for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) {
136 if (alignmentInBytes * 8 ==
138 supportedCopySize = true;
139 break;
140 }
141 }
142 if (!supportedCopySize)
143 return false;
144
145
146
147
148
149 return true;
150}
151
153 bool bypassL1) {
154 llvm::SmallSetVector<Operation *, 16> copyToSharedMem;
155
156
158
160 return;
162 if (cast(vectorVal.getType()).getRank() != 1)
163 return;
165 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
166 cast(storeBase.getType())))
167 return;
168
169
172 return;
174
175 if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
176 cast(loadBase.getType())))
177 return;
178
179
180 if (auto transferRead = dyn_castvector::TransferReadOp(readOp)) {
181 if (Value mask = transferRead.getMask()) {
184 return;
186 return;
187 }
188 }
189
190
191
192 VectorType vecType = cast(vectorVal.getType());
193
195 vecType) ||
197 vecType))
198 return;
199
200 copyToSharedMem.insert(writeOp);
201 return;
202 });
203
204 while (!copyToSharedMem.empty()) {
205
207 Operation *writeOp = *copyToSharedMem.begin();
208 copyToSharedMem.remove(writeOp);
209 group.push_back(writeOp);
211
212
213 while ((nextNode = nextNode->getNextNode())) {
214
215 auto memInterface = dyn_cast(nextNode);
216 if (memInterface && memInterface.hasNoEffect() &&
218 continue;
219
220 if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) {
223 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
224 cast(memrefOperand.getType()))) {
225 continue;
226 }
227 }
228 if (copyToSharedMem.count(nextNode)) {
229
230 copyToSharedMem.remove(nextNode);
231 group.push_back(nextNode);
232 continue;
233 }
234
235 break;
236 }
237
238
240 for (Operation *writeOp : group) {
243 auto vectorType = cast(vectorVal.getType());
244 int64_t numElements = vectorType.getNumElements();
248 Value numReadElements =
250 auto dstMemref = cast(storeBase.getType());
252 (dstMemref.getElementTypeBitWidth() * numElements) / 8;
253
254 Value token = nvgpu::DeviceAsyncCopyOp::create(
255 rewriter, writeOp->getLoc(),
256 nvgpu::DeviceAsyncTokenType::get(op->getContext()),
258 loadBase,
260 rewriter.getIndexAttr(numElements),
261 numReadElements,
262 bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr()
263 : UnitAttr());
264 tokens.push_back(token);
265 }
266
267
268 Value groupToken = nvgpu::DeviceAsyncCreateGroupOp::create(
269 rewriter, op->getLoc(),
270 nvgpu::DeviceAsyncTokenType::get(op->getContext()), tokens);
271 nvgpu::DeviceAsyncWaitOp::create(rewriter, op->getLoc(), groupToken,
272 nullptr);
273
275 rewriter.eraseOp(writeOp);
276 }
277}
static bool isContiguousStore(Operation *write)
Return "true" if the given op is a contiguous and suitable vector.transfer_write or vector....
Definition CreateAsyncGroups.cpp:32
static bool isContiguousXferOp(OpTy op)
Return "true" if the given vector transfer op is contiguous and suitable for replacement with an asyn...
Definition CreateAsyncGroups.cpp:23
static FailureOr< TransferMask > getMaskOp(Operation *loadOp)
If the given vector load op has a mask that is defined by vector.create_mask, return that op.
Definition CreateAsyncGroups.cpp:58
static bool resultsInSupportedAsyncCopy(MemRefType memrefType, VectorType vecType)
Return "true" if the conversion to async copy is supported by "async copy".
Definition CreateAsyncGroups.cpp:126
static bool isContiguousRead(Operation *read)
Return "true" if the given op is a contiguous and suitable vector.transfer_read or vector....
Definition CreateAsyncGroups.cpp:41
static Value buildNumReadElements(OpBuilder &b, Location loc, Operation *readOp)
Build an SSA value that represents the number of read elements.
Definition CreateAsyncGroups.cpp:83
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
IntegerAttr getIndexAttr(int64_t value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Location getLoc()
The source location the operation was defined or derived from.
void remove()
Remove the operation from its parent block, but don't delete it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Value getMemrefOperand(Operation *op)
Get the memref that is loaded from/stored into by the given load/store operation.
Value getValueStored(Operation *op)
Get the value that is stored by the given store operation.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
Definition CreateAsyncGroups.cpp:152
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.