MLIR: lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

14

27 #include "llvm/ADT/TypeSwitch.h"

28

29 #include

30 #include

31

32 using namespace mlir;

34

35 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"

36

37 void AMDGPUDialect::initialize() {

38 addOperations<

39 #define GET_OP_LIST

40 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"

41 >();

42 addAttributes<

43 #define GET_ATTRDEF_LIST

44 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"

45 >();

46 }

47

48

49

50

52 if (getExisting() && getExisting().getType() != getResult().getType())

53 return emitOpError("existing values must have same type as result");

54 return success();

55 }

56

58 if (getExisting() && getExisting().getType() != getResult().getType())

59 return emitOpError("existing values must have same type as result");

60 return success();

61 }

62

63

64

65

67 if (getExisting() && getExisting().getType() != getResult().getType())

68 return emitOpError("existing values must have same type as result");

69 return success();

70 }

71

72

73

74

75

76

77

78

79

81 bool resetOffset) {

86 MemRefLayoutAttrInterface layout = source.getLayout();

87 if (resetOffset && !layout.isIdentity()) {

88 auto stridedLayout = dyn_cast(layout);

89 if (!stridedLayout)

90 return failure();

92 }

93 return (MemRefType)(mb);

94 }

95

96 LogicalResult FatRawBufferCastOp::inferReturnTypes(

100 Adaptor adaptor(operands, attributes, properties, regions);

101 auto sourceType =

102 dyn_cast_if_present(adaptor.getSource().getType());

103 if (!sourceType)

104 return failure();

105 FailureOr resultType =

107 if (failed(resultType))

108 return failure();

110 return success();

111 }

112

114 FailureOr expectedResultType =

116 if (failed(expectedResultType))

117 return emitOpError("source type ")

118 << getSource().getType() << " can't have its offset reset";

119 if (getResult().getType() != *expectedResultType)

120 return emitOpError("expected result type to be ")

121 << *expectedResultType << " but got " << getResult().getType();

122 return success();

123 }

124

126 if (!memorySpace)

127 return true;

128 if (auto intMemorySpace = dyn_cast(memorySpace))

129 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;

130 if (auto gpuMemorySpace = dyn_castgpu::AddressSpaceAttr(memorySpace))

131 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;

132 return false;

133 }

134

136 if (auto intMemorySpace = dyn_cast(memorySpace))

137 return intMemorySpace.getInt() == 3;

138 if (auto gpuMemorySpace = dyn_castgpu::AddressSpaceAttr(memorySpace))

139 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;

140 return false;

141 }

142

144 if (auto intMemorySpace = dyn_cast(memorySpace))

145 return intMemorySpace.getInt() == 7;

146 if (auto gpuMemorySpace = dyn_castamdgpu::AddressSpaceAttr(memorySpace))

147 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;

148 return false;

149 }

150

151

152

153

154 template

156 MemRefType bufferType = llvm::cast(op.getMemref().getType());

158

159 if (!isGlobal)

160 return op.emitOpError(

161 "Buffer ops must operate on a memref in global memory");

162 if (!bufferType.hasRank())

163 return op.emitOpError(

164 "Cannot meaningfully buffer_store to an unranked memref");

165 if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())

166 return op.emitOpError("Expected " + Twine(bufferType.getRank()) +

167 " indices to memref");

168 return success();

169 }

170

172

174

177 }

178

181 }

182

185 }

186

189 }

190

193 }

194

196 APInt cst;

198 return std::nullopt;

200 return cst.getZExtValue();

201 return std::nullopt;

202 }

203

204 template

206 if (!op.getBoundsCheck())

207 return false;

208 MemRefType bufferType = op.getMemref().getType();

209 if (!bufferType.hasStaticShape())

210 return false;

211 int64_t offset;

213 if (failed(bufferType.getStridesAndOffset(strides, offset)))

214 return false;

215 int64_t result = offset + op.getIndexOffset().value_or(0);

216 if (op.getSgprOffset()) {

217 std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());

218 if (!sgprOffset)

219 return false;

220 result += *sgprOffset;

221 }

222 if (strides.size() != op.getIndices().size())

223 return false;

224 int64_t indexVal = 0;

225 for (auto pair : llvm::zip(strides, op.getIndices())) {

226 int64_t stride = std::get<0>(pair);

227 Value idx = std::get<1>(pair);

229 if (!idxVal)

230 return false;

231 indexVal += stride * *idxVal;

232 }

233 result += indexVal;

235

236 return false;

237 return result >= bufferType.getNumElements();

238 }

239

240 namespace {

241 template

242 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern {

244

245 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {

247 return failure();

248 Type loadType = op.getResult().getType();

251 return success();

252 }

253 };

254

255 template

256 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern {

258

259 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {

261 return failure();

262

264 return success();

265 }

266 };

267 }

268

269 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,

271 results.add<RemoveStaticallyOobBufferLoads>(context);

272 }

273

274 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,

276 results.add<RemoveStaticallyOobBufferWrites>(context);

277 }

278

279 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(

281 results.add<RemoveStaticallyOobBufferWrites>(context);

282 }

283

284 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(

286 results.add<RemoveStaticallyOobBufferWrites>(context);

287 }

288

289 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(

291 results.add<RemoveStaticallyOobBufferWrites>(context);

292 }

293

294 void RawBufferAtomicUminOp::getCanonicalizationPatterns(

296 results.add<RemoveStaticallyOobBufferWrites>(context);

297 }

298

299 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(

301 results.add<RemoveStaticallyOobBufferLoads>(

302 context);

303 }

304

305

306

307

309 Type sourceAType = getSourceA().getType();

310 Type sourceBType = getSourceB().getType();

311 Type destType = getDestC().getType();

312

313 VectorType sourceVectorAType = dyn_cast(sourceAType);

314 VectorType sourceVectorBType = dyn_cast(sourceBType);

315 VectorType destVectorType = dyn_cast(destType);

316

317 Type sourceAElemType = sourceVectorAType.getElementType();

318 Type sourceBElemType = sourceVectorBType.getElementType();

319 Type destElemType = destVectorType.getElementType();

320

321 if (sourceVectorAType.getNumElements() !=

322 sourceVectorBType.getNumElements()) {

323 return emitOpError("source vectors have different lengths: ")

324 << sourceVectorAType << " vs. " << sourceVectorBType;

325 }

326

327 bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);

328 bool isSrcFloat =

329 isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(

330 sourceAElemType);

331

332 if (isDestFloat && !isSrcFloat) {

333 return emitOpError("Expected float sources with float destination");

334 }

335

336 if (!isDestFloat && isSrcFloat) {

337 return emitOpError("Expected int sources with int destination");

338 }

339

340 if (sourceAElemType != sourceBElemType &&

341 !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&

342 isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {

343 return emitOpError(

344 "source element types much match (except for fp8) but have ")

345 << sourceAType << " and " << sourceBType;

346 }

347 return success();

348 }

349

350

351

352

354 constexpr uint32_t waveSize = 64;

356

357 Type sourceType = getSourceA().getType();

358 Type destType = getDestC().getType();

359

360 Type sourceElem = sourceType, destElem = destType;

361 uint32_t sourceLen = 1, destLen = 1;

362 if (auto sourceVector = llvm::dyn_cast(sourceType)) {

363 sourceLen = sourceVector.getNumElements();

364 sourceElem = sourceVector.getElementType();

365 }

366 if (auto destVector = llvm::dyn_cast(destType)) {

367 destLen = destVector.getNumElements();

368 destElem = destVector.getElementType();

369 }

370

371 Type sourceBType = getSourceB().getType();

373 int64_t sourceBLen = 1;

374 Type sourceBElem = sourceBType;

375 if (auto sourceBVector = llvm::dyn_cast(sourceBType)) {

376 sourceBLen = sourceBVector.getNumElements();

377 sourceBElem = sourceBVector.getElementType();

378 }

379 if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&

381 return emitOpError("expected both source operands to have small-float "

382 "elements if one does");

383 if (sourceLen != sourceBLen)

384 return emitOpError(

385 "expected both small-float source vectors to have the same length");

386 } else {

387 if (sourceType != sourceBType)

388 return emitOpError("expected both non-small-float source operand types "

389 "to match exactly");

390 }

391

393 sourceLen *= 4;

394 sourceElem = b.getI8Type();

395 }

397 sourceLen *= 8;

398 sourceElem = b.getI8Type();

399 }

400

401 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;

402 if (sourceLen != numSourceElems)

403 return emitOpError("expected " + Twine(numSourceElems) +

404 " source values for this operation but got " +

405 Twine(sourceLen));

406

407 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;

408 if (destLen != numDestElems)

409 return emitOpError("expected " + Twine(numDestElems) +

410 " result values for this operation but got " +

411 Twine(destLen));

412

413 if (destElem.isF64() && getBlgp() != MFMAPermB::none)

414 return emitOpError(

415 "double-precision ops do not support permuting lanes of B");

416 if (destElem.isF64() && getCbsz() != 0)

417 return emitOpError(

418 "double-precision ops do not support permuting lanes of A");

419 if (getAbid() >= (1u << getCbsz()))

420 return emitOpError(

421 "block ID for permuting A (abid) must be below 2 ** cbsz");

422

423 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())

424 return emitOpError(

425 "negation flags only available for double-precision operations");

426

427 return success();

428 }

429

430

431

432

434 Type srcType = getSrc().getType();

436 return emitOpError("integer and floating point types larger than 64 bits "

437 "are not supported");

438 }

439

440 DPPPerm kind = getKind();

442

443 switch (kind) {

444

445 case DPPPerm::quad_perm: {

446 auto quadPermAttr = dyn_cast_or_null(permArgument);

447 if (!quadPermAttr || quadPermAttr.size() != 4) {

448 return emitOpError("quad_perm attribute must have exactly 4 elements");

449 }

450 for (auto elem : quadPermAttr.getAsRange()) {

451 int32_t num = elem.getInt();

452 if (num < 0 || num > 3) {

453 return emitOpError(

454 "Each element of quad_perm must be in the range [0, 3]");

455 }

456 }

457 } break;

458

459 case DPPPerm::row_shl:

460 case DPPPerm::row_shr:

461 case DPPPerm::row_ror: {

462 if (!permArgument) {

463 return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +

464 "' value not specified");

465 }

466 if (auto intAttr = dyn_cast(permArgument)) {

467 uint32_t attrValue = intAttr.getInt();

468 if (attrValue < 1 || attrValue > 15) {

469 return emitOpError("Attribute value must be between 1 and 15");

470 }

471 }

472 } break;

473

474 case DPPPerm::wave_shl:

475 case DPPPerm::wave_shr:

476 case DPPPerm::wave_rol:

477 case DPPPerm::wave_ror:

478 case DPPPerm::row_mirror:

479 case DPPPerm::row_half_mirror:

480 case DPPPerm::row_bcast_15:

481 case DPPPerm::row_bcast_31: {

482 if (permArgument && !isa(permArgument)) {

483 return emitOpError("Expected unit attribute for permArgument, but found "

484 "non-trivial argument");

485 }

486 break;

487 }

488 }

489 return success();

490 }

491

493 MemRefType srcType = cast(getSrc().getType());

494 MemRefType dstType = cast(getDst().getType());

495

496 if (!dstType.areTrailingDimsContiguous(dstType.getRank()))

497 return emitOpError("destination types must be contiguous");

498

499 auto elemType = srcType.getElementType();

500

501 if (elemType != dstType.getElementType())

502 return emitOpError("source and destination element types must match");

503

504

505 auto transferType = getTransferType();

506 size_t transferSize;

507 if (auto vectorTransfer = dyn_cast(transferType)) {

508 transferSize = vectorTransfer.getNumElements() *

509 vectorTransfer.getElementTypeBitWidth();

510 } else {

511 transferSize = transferType.getIntOrFloatBitWidth();

512 }

513 if (transferSize != 8 && transferSize != 16 && transferSize != 32)

514 return emitOpError("Transfering type size must be 8, 16, or 32 bits");

515

518 return emitOpError(

519 "source memory address space must be global or fat raw buffer");

520

522 return emitOpError("destination memory address space must be Workgroup");

523

524 return success();

525 }

526

527 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"

528

529 #define GET_ATTRDEF_CLASSES

530 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"

531

532 #define GET_OP_CLASSES

533 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"

static LogicalResult verifyRawBufferOp(T &op)

static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)

Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...

static bool hasGlobalMemorySpace(Attribute memorySpace)

static bool hasWorkgroupMemorySpace(Attribute memorySpace)

static std::optional< uint32_t > getConstantUint32(Value v)

static bool hasFatRawBufferMemorySpace(Attribute memorySpace)

static bool staticallyOutOfBounds(OpType op)

static MLIRContext * getContext(OpFoldResult val)

union mlir::linalg::@1203::ArityGroupAndKind::Kind kind

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

Attributes are known-constant values of operations.

This class is a general helper class for creating context-global objects like types,...

TypedAttr getZeroAttr(Type type)

MLIRContext is the top-level object for a collection of MLIR operations.

This is a builder type that keeps local references to arguments.

Builder & setLayout(MemRefLayoutAttrInterface newLayout)

Builder & setMemorySpace(Attribute newMemorySpace)

Simple wrapper around a void* in order to express generically how to pass in op properties through AP...

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class provides an abstraction over the different types of ranges over Regions.

RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)

Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isFloat() const

Return true if this is an float type (with the specified width).

bool isInteger() const

Return true if this is an integer type (with the specified width).

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

uint64_t getN(LevelType lt)

uint64_t getM(LevelType lt)

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...