MLIR: lib/Dialect/NVGPU/Utils/MMAUtils.cpp Source File (original) (raw)

1

2

3

4

5

6

7

9

15

16 using namespace mlir;

18

19

22

25 }

26

27

28

31 auto shape = type.vectorType.getShape();

33 (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /

34 lineSize;

35 }

36

37

38

40 Type elementType,

41 int64_t lineSizeBits) {

42

45 lineSizeBits};

46 }

47

48

49

52 if (auto contractOp = dyn_castvector::ContractionOp(user))

53 return contractOp;

54 }

55 return failure();

56 }

57

60

61

62 if (vector::TransferWriteOp writeOp = dyn_castvector::TransferWriteOp(op)) {

63 info.vectorType = writeOp.getVectorType();

64 } else if (isa<vector::TransferReadOp, vector::ContractionOp,

65 vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {

67 } else {

69 << "unhandled operation type in nvgpu.mma.sync conversion path";

70 }

71

72

73

75 FailureOrvector::ContractionOp contractOp = getUserContract(op);

76 if (failed(contractOp))

77 return info;

78

79 if ((*contractOp).getLhs() == op->getResult(0))

81 else if ((*contractOp).getRhs() == op->getResult(0))

83

84 return info;

85 }

86

91 return 256;

92 }

94 return isAcc ? 512 : 256;

95 }

96 return 128;

97 }

98

99 FailureOr

103

105 if (elType.isF16()) {

108 }

109

110

112 if (elType.isF64()) {

113 return isAccum

118 }

119

120

124 }

125

126

130 }

131

132

136 }

137

138

139 if (elType.isF32()) {

141 return isAccum

146 }

147 return failure();

148 }

149

151 Type elementType,

153 bool isAccumulator,

154 int64_t elementsPerRegister,

156 const int64_t elementsPerLine =

158 const std::array<int64_t, 2> num8x128bTiles =

159 getTileShape(operandShape, elementType, lineSize);

160 AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);

162 2, 0,

163 {(registerIdx % num8x128bTiles[0]) * 8,

164 (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},

166 }

167

168 FailureOr

171 Type elementType = fragmentType.vectorType.getElementType();

173 FailureOrnvgpu::FragmentElementInfo regInfo =

175 if (failed(regInfo))

176 return failure();

177

179 const int64_t elementsPerRegister =

180 regInfo->registerWidthBits / elementBitWidth;

182

183 AffineExpr laneId, logicalValueIdDim;

185

186

187

189 lineSize, elementType, operandShape,

191 logicalValueIdDim);

192

195 };

196

197 auto tileRow = registerIndexToTileCoord.getResult(0);

198 auto tileCol = registerIndexToTileCoord.getResult(1);

200 tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +

201 (logicalValueIdDim % elementsPerRegister)});

202 }

203

204 FailureOrnvgpu::LdMatrixParams

212 } else {

214 }

217 : vector::IteratorType::reduction;

218

219 if (params.contiguousDimType == vector::IteratorType::reduction) {

222 } else {

225 }

226

228 return failure();

229

230 return params;

231 }

232

233 FailureOr

236

237 const int bitsPerElement = static_cast<int>(

238 params.fragmentType.getElementType().getIntOrFloatBitWidth());

239 const int kElementsPer128b = (128 / bitsPerElement);

242

245 };

246

247

248

249 int idx =

250 (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;

251

252

253

254 AffineExpr strided = d0 % (operandShape[idx]);

255 AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b);

256

257

258

259

261 return makeMap({strided, contiguous});

262

263

264

265

267 return makeMap({contiguous, strided});

268

269 return failure();

270 }

271

273 if (op.getMask() || op.hasOutOfBoundsDim())

274 return false;

275 VectorType type = op.getType();

276

277

278

279

280

281 if (!type.hasStaticShape() || type.getRank() != 2)

282 return false;

283

284

285

286

287

288 auto sourceType = dyn_cast(op.getBase().getType());

289 if (!sourceType)

290 return false;

291

292

293

294

295 auto [strides, offset] = sourceType.getStridesAndOffset();

296 return strides.back() == 1;

297 }

298

300 if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)

301 return false;

302 VectorType type = op.getVectorType();

303 if (!type.hasStaticShape() || type.getRank() != 2)

304 return false;

305

306

307

308 if (!op.getPermutationMap().isMinorIdentity())

309 return false;

310

311

312 auto sourceType = dyn_cast(op.getBase().getType());

313 if (!sourceType)

314 return false;

315

316

317

318

319 auto [strides, offset] = sourceType.getStridesAndOffset();

320 return strides.back() == 1;

321 }

static constexpr int64_t kNumRowsPerTile

static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)

static constexpr int64_t kThreadsPerRow

There are always 4 threads per [128|256|512] bit row.

static bool isAccumulatorOrResult(MatMulOperandRole operandType)

static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type)

Returns the number of registers which compose a matrix fragment held by a single thread.

static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)

Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.

Base type for affine expression.

AffineExpr floorDiv(uint64_t v) const

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

AffineExpr getResult(unsigned idx) const

MLIRContext * getContext() const

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

This class helps build Operations.

Operation is the basic unit of execution within MLIR.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

user_range getUsers()

Returns a range of all users.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

MLIRContext * getContext() const

Return the MLIRContext in which this type was uniqued.

bool isInteger() const

Return true if this is an integer type (with the specified width).

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

Type getType() const

Return the type of this value.

int64_t inferTileWidthInBits(const WarpMatrixInfo &type)

Returns the number of bits in a single tile row.

FailureOr< vector::ContractionOp > getUserContract(Operation *op)

Returns the first user of the op that is vector.contract.

FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)

Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...

FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)

If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.

FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)

Returns an AffineMap which maps a single dimension representing the laneId to two results representin...

MatMulOperandRole

Represents the role of an operand in an MMA instruction: result := matmul(A, B) + C

FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)

Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...

FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)

Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...

bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)

Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...

static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)

Include the generated interface declarations.

void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to DimExpr at positions: [0 .

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)

These free functions allow clients of the API to not use classes in detail.

Specifies information about the registers which compose a matrix fragment according to the PTX docume...

Encapsulates the parameters needed to lower a nvgpu.ldmatrix operation to nvvm.ldmatrix.

NVVM::MMALayout targetLayout

vector::IteratorType contiguousDimType

Collects information about a warp-level matrix operand represented by a VectorType.

MatMulOperandRole operandRole