MLIR: lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

28 #include "llvm/ADT/STLExtras.h"

29 #include "llvm/ADT/StringSwitch.h"

30

31 #include

32

33 namespace mlir {

34

35

36

37

38

39

40

41

42

44 gpu::SubgroupMmaElementwiseOp op, Type coopType,

46 assert((isaspirv::CooperativeMatrixType(coopType)));

47

48 switch (op.getOpType()) {

49 case gpu::MMAElementwiseOp::ADDF:

51 return true;

54 return true;

55 case gpu::MMAElementwiseOp::SUBF:

57 return true;

60 return true;

61 case gpu::MMAElementwiseOp::DIVF:

63 return true;

64 case gpu::MMAElementwiseOp::DIVS:

66 return true;

67 case gpu::MMAElementwiseOp::DIVU:

69 return true;

70 case gpu::MMAElementwiseOp::NEGATEF:

72 return true;

73 case gpu::MMAElementwiseOp::NEGATES:

75 return true;

76 case gpu::MMAElementwiseOp::EXTF:

77 builder.replaceOpWithNewOpspirv::FConvertOp(op, coopType, operands);

78 return true;

79 default:

80 break;

81 }

82 return false;

83 }

84

86 assert(!operands.empty());

87 if (!llvm::all_equal(

88 llvm::map_range(operands, [](Value v) { return v.getType(); })))

89 return false;

90

91 return isaspirv::CooperativeMatrixType(operands.front().getType());

92 }

93

94 namespace {

95

96

97 struct WmmaConstantOpToSPIRVLowering final

98 : OpConversionPatterngpu::SubgroupMmaConstantMatrixOp {

100

101 LogicalResult

102 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,

103 ConversionPatternRewriter &rewriter) const override {

104 Value cst = llvm::getSingleElement(adaptor.getOperands());

105 auto coopType = getTypeConverter()->convertType(op.getType());

106 if (!coopType)

107 return rewriter.notifyMatchFailure(op, "type conversion failed");

108

109 rewriter.replaceOpWithNewOpspirv::CompositeConstructOp(op, coopType, cst);

110 return success();

111 }

112 };

113

114

115

116 struct WmmaExtractOpToSPIRVLowering final

117 : OpConversionPatterngpu::SubgroupMmaExtractThreadLocalOp {

119

120 LogicalResult

121 matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,

122 ConversionPatternRewriter &rewriter) const override {

123 Value matrix = adaptor.getMatrix();

124 auto coopType =

125 getTypeConverter()->convertTypespirv::CooperativeMatrixType(

126 matrix.getType());

127 if (!coopType)

128 return rewriter.notifyMatchFailure(op, "type conversion failed");

129

130 SmallVector<int32_t> intValues;

131 for (Value val : op.getIndices()) {

132 if (auto constOp = val.getDefiningOparith::ConstantIndexOp()) {

133 intValues.push_back(static_cast<int32_t>(constOp.value()));

134 } else {

135 return rewriter.notifyMatchFailure(op, "indices must be constants");

136 }

137 }

138

139 Type elementType = coopType.getElementType();

140 rewriter.replaceOpWithNewOpspirv::CompositeExtractOp(

141 op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));

142 return success();

143 }

144 };

145

146

147

148 struct WmmaInsertOpToSPIRVLowering final

149 : OpConversionPatterngpu::SubgroupMmaInsertThreadLocalOp {

151

152 LogicalResult

153 matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,

154 ConversionPatternRewriter &rewriter) const override {

155 Value value = adaptor.getValue();

156 Value matrix = adaptor.getMatrix();

157 auto coopType = getTypeConverter()->convertType(matrix.getType());

158 if (!coopType)

159 return rewriter.notifyMatchFailure(op, "type conversion failed");

160

161 SmallVector<int32_t> intValues;

162 for (Value val : op.getIndices()) {

163 if (auto constOp = val.getDefiningOparith::ConstantIndexOp()) {

164 intValues.push_back(static_cast<int32_t>(constOp.value()));

165 } else {

166 return rewriter.notifyMatchFailure(op, "indices must be constants");

167 }

168 }

169

170 rewriter.replaceOpWithNewOpspirv::CompositeInsertOp(

171 op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));

172 return success();

173 }

174 };

175

176

177

178 struct WmmaElementwiseOpToSPIRVDefaultLowering final

179 : OpConversionPatterngpu::SubgroupMmaElementwiseOp {

181

182 LogicalResult

183 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,

184 ConversionPatternRewriter &rewriter) const override {

185

187 return rewriter.notifyMatchFailure(op,

188 "not all operands are coop matrices");

189 }

190

191 auto coopType = getTypeConverter()->convertType(op.getType());

192 if (!coopType)

193 return rewriter.notifyMatchFailure(op, "type conversion failed");

194

195 return success(

197 }

198 };

199

200

201

202 struct WmmaElementwiseOpToSPIRVScalarMulLowering final

203 : OpConversionPatterngpu::SubgroupMmaElementwiseOp {

205

206 LogicalResult

207 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,

208 ConversionPatternRewriter &rewriter) const override {

209 if (adaptor.getOperands().size() != 2)

210 return failure();

211

212

214 return rewriter.notifyMatchFailure(op,

215 "not all operands are coop matrices");

216 }

217

218 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)

219 return failure();

220

221

222

223 Value lhs = op.getOperands().front();

224 Value rhs = op.getOperands().back();

225 Value splat = nullptr;

226 Value matrix = nullptr;

227 if (lhs.getDefiningOpgpu::SubgroupMmaConstantMatrixOp()) {

228 splat = adaptor.getOperands().front();

229 matrix = adaptor.getOperands().back();

230 } else if (rhs.getDefiningOpgpu::SubgroupMmaConstantMatrixOp()) {

231 matrix = adaptor.getOperands().front();

232 splat = adaptor.getOperands().back();

233 }

234 if (!splat || !matrix)

235 return rewriter.notifyMatchFailure(op, "no splat operand");

236

237

239 auto cc = splat.getDefiningOpspirv::CompositeConstructOp();

240 if (!cc) {

241 return rewriter.notifyMatchFailure(op,

242 "splat is not a composite construct");

243 }

244

245 scalar = llvm::getSingleElement(cc.getConstituents());

246

247 auto coopType = getTypeConverter()->convertType(op.getType());

248 if (!coopType)

249 return rewriter.notifyMatchFailure(op, "type conversion failed");

250 rewriter.replaceOpWithNewOpspirv::MatrixTimesScalarOp(

251 op, coopType, ValueRange{matrix, scalar});

252 return success();

253 }

254 };

255 }

256

257

258

259

260

261 namespace khr {

262 namespace {

263

264

265

266 struct WmmaLoadOpToSPIRVLowering final

269

270 LogicalResult

271 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,

273 const auto &typeConverter = *getTypeConverter();

275

276 auto retType = castgpu::MMAMatrixType(op.getRes().getType());

277 MemRefType memrefType = op.getSrcMemref().getType();

278 Value bufferPtr =

280 adaptor.getIndices(), loc, rewriter);

281

282 auto coopType =

284 if (!coopType)

286

287 int64_t stride = op.getLeadDimension().getSExtValue();

288 IntegerType i32Type = rewriter.getI32Type();

289 auto strideValue = rewriter.createspirv::ConstantOp(

291

292 bool isColMajor = op.getTranspose().value_or(false);

293 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor

294 : spirv::CooperativeMatrixLayoutKHR::RowMajor;

295

297 op, coopType, bufferPtr, strideValue, layout);

298 return success();

299 }

300 };

301

302

303

304 struct WmmaStoreOpToSPIRVLowering final

307

308 LogicalResult

309 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,

311 const auto &typeConverter = *getTypeConverter();

313

314 auto memrefType = cast(op.getDstMemref().getType());

315 Value bufferPtr =

317 adaptor.getIndices(), loc, rewriter);

318

319 int64_t stride = op.getLeadDimension().getSExtValue();

320 IntegerType i32Type = rewriter.getI32Type();

321 auto strideValue = rewriter.createspirv::ConstantOp(

323

324 bool isColMajor = op.getTranspose().value_or(false);

325 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor

326 : spirv::CooperativeMatrixLayoutKHR::RowMajor;

327

329 op, bufferPtr, adaptor.getSrc(), strideValue, layout);

330 return success();

331 }

332 };

333

334

335

336 struct WmmaMmaOpToSPIRVLowering final

339

340 LogicalResult

341 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,

342 OpAdaptor adaptor,

345 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),

346 adaptor.getOpC());

347 return success();

348 }

349 };

350

351 }

352 }

353 }

354

357 using namespace mlir;

359 patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,

360 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,

361 WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,

362 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);

363

364 patterns.add(converter, context,

365 2);

366 }

367

373 auto use =

375 .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)

376 .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)

377 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);

378

380 retTypeShape[1],

381 spirv::Scope::Subgroup, use);

382 });

383 }

This class implements a pattern rewriter for use with ConversionPatterns.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...

OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

Type conversion from builtin types to SPIR-V types for shader interface.

void addConversion(FnT &&callback)

Register a conversion function.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

type_range getType() const

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.

ArrayRef< int64_t > getShape() const

Get shape of the matrix.

Type getElementType() const

Get elementType of a single element.

StringRef getOperand() const

The general form of operation this type supports is given by the equation C += A*B.

@ Type

An inlay hint that for a type annotation.

Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)

Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...

Include the generated interface declarations.

static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)

Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...

void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)

Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...

const FrozenRewritePatternSet & patterns

bool allOperandsHaveSameCoopMatrixType(ValueRange operands)

void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)

Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...