MLIR: lib/Dialect/NVGPU/Utils/MMAUtils.cpp Source File (original) (raw)
1
2
3
4
5
6
7
9
15
16 using namespace mlir;
18
19
22
25 }
26
27
28
31 auto shape = type.vectorType.getShape();
33 (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
34 lineSize;
35 }
36
37
38
40 Type elementType,
41 int64_t lineSizeBits) {
42
45 lineSizeBits};
46 }
47
48
49
52 if (auto contractOp = dyn_castvector::ContractionOp(user))
53 return contractOp;
54 }
55 return failure();
56 }
57
60
61
62 if (vector::TransferWriteOp writeOp = dyn_castvector::TransferWriteOp(op)) {
63 info.vectorType = writeOp.getVectorType();
64 } else if (isa<vector::TransferReadOp, vector::ContractionOp,
65 vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
67 } else {
69 << "unhandled operation type in nvgpu.mma.sync conversion path";
70 }
71
72
73
75 FailureOrvector::ContractionOp contractOp = getUserContract(op);
76 if (failed(contractOp))
77 return info;
78
79 if ((*contractOp).getLhs() == op->getResult(0))
81 else if ((*contractOp).getRhs() == op->getResult(0))
83
84 return info;
85 }
86
91 return 256;
92 }
94 return isAcc ? 512 : 256;
95 }
96 return 128;
97 }
98
99 FailureOr
103
105 if (elType.isF16()) {
108 }
109
110
112 if (elType.isF64()) {
113 return isAccum
118 }
119
120
124 }
125
126
130 }
131
132
136 }
137
138
139 if (elType.isF32()) {
141 return isAccum
146 }
147 return failure();
148 }
149
151 Type elementType,
153 bool isAccumulator,
154 int64_t elementsPerRegister,
156 const int64_t elementsPerLine =
158 const std::array<int64_t, 2> num8x128bTiles =
159 getTileShape(operandShape, elementType, lineSize);
160 AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
162 2, 0,
163 {(registerIdx % num8x128bTiles[0]) * 8,
164 (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
166 }
167
168 FailureOr
171 Type elementType = fragmentType.vectorType.getElementType();
173 FailureOrnvgpu::FragmentElementInfo regInfo =
175 if (failed(regInfo))
176 return failure();
177
179 const int64_t elementsPerRegister =
180 regInfo->registerWidthBits / elementBitWidth;
182
183 AffineExpr laneId, logicalValueIdDim;
185
186
187
189 lineSize, elementType, operandShape,
191 logicalValueIdDim);
192
195 };
196
197 auto tileRow = registerIndexToTileCoord.getResult(0);
198 auto tileCol = registerIndexToTileCoord.getResult(1);
200 tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
201 (logicalValueIdDim % elementsPerRegister)});
202 }
203
204 FailureOrnvgpu::LdMatrixParams
212 } else {
214 }
217 : vector::IteratorType::reduction;
218
219 if (params.contiguousDimType == vector::IteratorType::reduction) {
222 } else {
225 }
226
228 return failure();
229
230 return params;
231 }
232
233 FailureOr
236
237 const int bitsPerElement = static_cast<int>(
238 params.fragmentType.getElementType().getIntOrFloatBitWidth());
239 const int kElementsPer128b = (128 / bitsPerElement);
242
245 };
246
247
248
249 int idx =
250 (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;
251
252
253
254 AffineExpr strided = d0 % (operandShape[idx]);
255 AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b);
256
257
258
259
261 return makeMap({strided, contiguous});
262
263
264
265
267 return makeMap({contiguous, strided});
268
269 return failure();
270 }
271
273 if (op.getMask() || op.hasOutOfBoundsDim())
274 return false;
275 VectorType type = op.getType();
276
277
278
279
280
281 if (!type.hasStaticShape() || type.getRank() != 2)
282 return false;
283
284
285
286
287
288 auto sourceType = dyn_cast(op.getBase().getType());
289 if (!sourceType)
290 return false;
291
292
293
294
295 auto [strides, offset] = sourceType.getStridesAndOffset();
296 return strides.back() == 1;
297 }
298
300 if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
301 return false;
302 VectorType type = op.getVectorType();
303 if (!type.hasStaticShape() || type.getRank() != 2)
304 return false;
305
306
307
308 if (!op.getPermutationMap().isMinorIdentity())
309 return false;
310
311
312 auto sourceType = dyn_cast(op.getBase().getType());
313 if (!sourceType)
314 return false;
315
316
317
318
319 auto [strides, offset] = sourceType.getStridesAndOffset();
320 return strides.back() == 1;
321 }
static constexpr int64_t kNumRowsPerTile
static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)
static constexpr int64_t kThreadsPerRow
There are always 4 threads per [128|256|512] bit row.
static bool isAccumulatorOrResult(MatMulOperandRole operandType)
static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type)
Returns the number of registers which compose a matrix fragment held by a single thread.
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
user_range getUsers()
Returns a range of all users.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type getType() const
Return the type of this value.
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
MatMulOperandRole
Represents the role of an operand in an MMA instruction: result := matmul(A, B) + C
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Encapsulates the parameters needed to lower a nvgpu.ldmatrix operation to nvvm.ldmatrix.
NVVM::MMALayout targetLayout
vector::IteratorType contiguousDimType
Collects information about a warp-level matrix operand represented by a VectorType.
MatMulOperandRole operandRole