MLIR: lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

24 #include "llvm/Support/Debug.h"

25 #include

26 #include

27

28 #define DEBUG_TYPE "memref-to-spirv-pattern"

29

30 using namespace mlir;

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

48 int targetBits, OpBuilder &builder) {

49 assert(targetBits % sourceBits == 0);

51 IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);

52 auto idx = builder.createOrFoldspirv::ConstantOp(loc, type, idxAttr);

53 IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);

54 auto srcBitsValue =

55 builder.createOrFoldspirv::ConstantOp(loc, type, srcBitsAttr);

56 auto m = builder.createOrFoldspirv::UModOp(loc, srcIdx, idx);

57 return builder.createOrFoldspirv::IMulOp(loc, type, m, srcBitsValue);

58 }

59

60

61

62

63

64

65

66

67

70 spirv::AccessChainOp op, int sourceBits,

71 int targetBits, OpBuilder &builder) {

72 assert(targetBits % sourceBits == 0);

73 const auto loc = op.getLoc();

74 Value lastDim = op->getOperand(op.getNumOperands() - 1);

76 IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);

77 auto idx = builder.createOrFoldspirv::ConstantOp(loc, type, attr);

78 auto indices = llvm::to_vector<4>(op.getIndices());

79

80 assert(indices.size() == 2);

81 indices.back() = builder.createOrFoldspirv::SDivOp(loc, lastDim, idx);

82 Type t = typeConverter.convertType(op.getComponentPtr().getType());

83 return builder.createspirv::AccessChainOp(loc, t, op.getBasePtr(), indices);

84 }

85

86

91 return srcBool;

93 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);

94 return builder.createOrFoldspirv::SelectOp(loc, dstType, srcBool, one,

95 zero);

96 }

97

98

99

102 IntegerType dstType = cast(mask.getType());

103 int targetBits = static_cast<int>(dstType.getWidth());

105 assert(valueBits <= targetBits);

106

107 if (valueBits == 1) {

108 value = castBoolToIntN(loc, value, dstType, builder);

109 } else {

110 if (valueBits < targetBits) {

111 value = builder.createspirv::UConvertOp(

113 }

114

115 value = builder.createOrFoldspirv::BitwiseAndOp(loc, value, mask);

116 }

117 return builder.createOrFoldspirv::ShiftLeftLogicalOp(loc, value.getType(),

118 value, offset);

119 }

120

121

122

124 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {

125 auto sc = dyn_cast_or_nullspirv::StorageClassAttr(type.getMemorySpace());

126 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)

127 return false;

128 } else if (isamemref::AllocaOp(allocOp)) {

129 auto sc = dyn_cast_or_nullspirv::StorageClassAttr(type.getMemorySpace());

130 if (!sc || sc.getValue() != spirv::StorageClass::Function)

131 return false;

132 } else {

133 return false;

134 }

135

136

137

138 if (!type.hasStaticShape())

139 return false;

140

141 Type elementType = type.getElementType();

142 if (auto vecType = dyn_cast(elementType))

143 elementType = vecType.getElementType();

145 }

146

147

148

149

151 auto sc = dyn_cast_or_nullspirv::StorageClassAttr(type.getMemorySpace());

152 switch (sc.getValue()) {

153 case spirv::StorageClass::StorageBuffer:

154 return spirv::Scope::Device;

155 case spirv::StorageClass::Workgroup:

156 return spirv::Scope::Workgroup;

157 default:

158 break;

159 }

160 return {};

161 }

162

163

166 return srcInt;

167

169 return builder.createOrFoldspirv::INotEqualOp(loc, srcInt, one);

170 }

171

172

173

174

175

176

177

178

179

180 namespace {

181

182

183 class AllocaOpPattern final : public OpConversionPatternmemref::AllocaOp {

184 public:

186

187 LogicalResult

188 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,

190 };

191

192

193

194

195

196 class AllocOpPattern final : public OpConversionPatternmemref::AllocOp {

197 public:

199

200 LogicalResult

201 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,

203 };

204

205

206 class AtomicRMWOpPattern final

208 public:

210

211 LogicalResult

212 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,

214 };

215

216

217

218 class DeallocOpPattern final : public OpConversionPatternmemref::DeallocOp {

219 public:

221

222 LogicalResult

223 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,

225 };

226

227

228 class IntLoadOpPattern final : public OpConversionPatternmemref::LoadOp {

229 public:

231

232 LogicalResult

233 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,

235 };

236

237

239 public:

241

242 LogicalResult

243 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,

245 };

246

247

248 class IntStoreOpPattern final : public OpConversionPatternmemref::StoreOp {

249 public:

251

252 LogicalResult

253 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,

255 };

256

257

258 class MemorySpaceCastOpPattern final

260 public:

262

263 LogicalResult

264 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,

266 };

267

268

269 class StoreOpPattern final : public OpConversionPatternmemref::StoreOp {

270 public:

272

273 LogicalResult

274 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,

276 };

277

278 class ReinterpretCastPattern final

280 public:

282

283 LogicalResult

284 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,

286 };

287

289 public:

291

292 LogicalResult

293 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,

295 Value src = adaptor.getSource();

297

298 const TypeConverter *converter = getTypeConverter();

300 if (srcType != dstType)

302 diag << "types doesn't match: " << srcType << " and " << dstType;

303 });

304

306 return success();

307 }

308 };

309

310

311 class ExtractAlignedPointerAsIndexOpPattern final

313 public:

315

316 LogicalResult

317 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,

318 OpAdaptor adaptor,

320 };

321 }

322

323

324

325

326

327 LogicalResult

328 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,

330 MemRefType allocType = allocaOp.getType();

332 return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");

333

334

335 Type spirvType = getTypeConverter()->convertType(allocType);

336 if (!spirvType)

337 return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");

338

340 spirv::StorageClass::Function,

341 nullptr);

342 return success();

343 }

344

345

346

347

348

349 LogicalResult

350 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,

352 MemRefType allocType = operation.getType();

354 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");

355

356

357 Type spirvType = getTypeConverter()->convertType(allocType);

358 if (!spirvType)

359 return rewriter.notifyMatchFailure(operation, "type conversion failed");

360

361

364 if (!parent)

365 return failure();

366 Location loc = operation.getLoc();

367 spirv::GlobalVariableOp varOp;

368 {

372 auto varOps = entryBlock.getOpsspirv::GlobalVariableOp();

373 std::string varName =

374 std::string("__workgroup_mem__") +

375 std::to_string(std::distance(varOps.begin(), varOps.end()));

376 varOp = rewriter.createspirv::GlobalVariableOp(loc, spirvType, varName,

377 nullptr);

378 }

379

380

382 return success();

383 }

384

385

386

387

388

389 LogicalResult

390 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,

391 OpAdaptor adaptor,

393 if (isa(atomicOp.getType()))

395 "unimplemented floating-point case");

396

397 auto memrefType = cast(atomicOp.getMemref().getType());

398 std::optionalspirv::Scope scope = getAtomicOpScope(memrefType);

399 if (!scope)

401 "unsupported memref memory space");

402

403 auto &typeConverter = *getTypeConverter();

404 Type resultType = typeConverter.convertType(atomicOp.getType());

405 if (!resultType)

407 "failed to convert result type");

408

409 auto loc = atomicOp.getLoc();

412 adaptor.getIndices(), loc, rewriter);

413

414 if (!ptr)

415 return failure();

416

417 #define ATOMIC_CASE(kind, spirvOp) \

418 case arith::AtomicRMWKind::kind: \

419 rewriter.replaceOpWithNewOpspirv::spirvOp( \

420 atomicOp, resultType, ptr, *scope, \

421 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \

422 break

423

424 switch (atomicOp.getKind()) {

432 default:

433 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");

434 }

435

436 #undef ATOMIC_CASE

437

438 return success();

439 }

440

441

442

443

444

445 LogicalResult

446 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,

447 OpAdaptor adaptor,

449 MemRefType deallocType = cast(operation.getMemref().getType());

451 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");

452 rewriter.eraseOp(operation);

453 return success();

454 }

455

456

457

458

459

463 };

464

465

466

467 static FailureOr

470

472 if (isNontemporal) {

473 memoryAccess = spirv::MemoryAccess::Nontemporal;

474 }

475

476 auto ptrType = castspirv::PointerType(accessedPtr.getType());

477 if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {

480 }

482 IntegerAttr{}};

483 }

484

485

486 auto pointeeType = dyn_castspirv::ScalarType(ptrType.getPointeeType());

487 if (!pointeeType)

488 return failure();

489

490

491 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();

492 if (!sizeInBytes.has_value())

493 return failure();

494

495 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;

499 }

500

501

502

503

504 template

505 static FailureOr

507 static_assert(

508 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,

509 "Must be called on either memref::LoadOp or memref::StoreOp");

510

511 Operation *memrefAccessOp = loadOrStoreOp.getOperation();

512 auto memrefMemAccess = memrefAccessOp->getAttrOfTypespirv::MemoryAccessAttr(

513 spirv::attributeNamespirv::MemoryAccess());

514 auto memrefAlignment =

515 memrefAccessOp->getAttrOfType("alignment");

516 if (memrefMemAccess && memrefAlignment)

518

520 loadOrStoreOp.getNontemporal());

521 }

522

523 LogicalResult

524 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,

526 auto loc = loadOp.getLoc();

527 auto memrefType = cast(loadOp.getMemref().getType());

528 if (!memrefType.getElementType().isSignlessInteger())

529 return failure();

530

531 const auto &typeConverter = *getTypeConverter();

532 Value accessChain =

534 adaptor.getIndices(), loc, rewriter);

535

536 if (!accessChain)

537 return failure();

538

539 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();

540 bool isBool = srcBits == 1;

541 if (isBool)

542 srcBits = typeConverter.getOptions().boolNumBits;

543

544 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);

545 if (!pointerType)

546 return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");

547

548 Type pointeeType = pointerType.getPointeeType();

549 Type dstType;

550 if (typeConverter.allows(spirv::Capability::Kernel)) {

551 if (auto arrayType = dyn_castspirv::ArrayType(pointeeType))

552 dstType = arrayType.getElementType();

553 else

554 dstType = pointeeType;

555 } else {

556

557 Type structElemType =

558 castspirv::StructType(pointeeType).getElementType(0);

559 if (auto arrayType = dyn_castspirv::ArrayType(structElemType))

560 dstType = arrayType.getElementType();

561 else

562 dstType = castspirv::RuntimeArrayType(structElemType).getElementType();

563 }

565 assert(dstBits % srcBits == 0);

566

567

568

569 if (srcBits == dstBits) {

571 if (failed(memoryRequirements))

573 loadOp, "failed to determine memory requirements");

574

575 auto [memoryAccess, alignment] = *memoryRequirements;

576 Value loadVal = rewriter.createspirv::LoadOp(loc, accessChain,

577 memoryAccess, alignment);

578 if (isBool)

580 rewriter.replaceOp(loadOp, loadVal);

581 return success();

582 }

583

584

585

586 if (typeConverter.allows(spirv::Capability::Kernel))

587 return failure();

588

589 auto accessChainOp = accessChain.getDefiningOpspirv::AccessChainOp();

590 if (!accessChainOp)

591 return failure();

592

593

594

595

596 assert(accessChainOp.getIndices().size() == 2);

598 srcBits, dstBits, rewriter);

600 if (failed(memoryRequirements))

602 loadOp, "failed to determine memory requirements");

603

604 auto [memoryAccess, alignment] = *memoryRequirements;

605 Value spvLoadOp = rewriter.createspirv::LoadOp(loc, dstType, adjustedPtr,

606 memoryAccess, alignment);

607

608

609

610 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);

612 Value result = rewriter.createOrFoldspirv::ShiftRightArithmeticOp(

613 loc, spvLoadOp.getType(), spvLoadOp, offset);

614

615

617 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));

618 result =

619 rewriter.createOrFoldspirv::BitwiseAndOp(loc, dstType, result, mask);

620

621

622

623

624 IntegerAttr shiftValueAttr =

627 rewriter.createOrFoldspirv::ConstantOp(loc, dstType, shiftValueAttr);

628 result = rewriter.createOrFoldspirv::ShiftLeftLogicalOp(loc, dstType,

630 result = rewriter.createOrFoldspirv::ShiftRightArithmeticOp(

632

633 rewriter.replaceOp(loadOp, result);

634

635 assert(accessChainOp.use_empty());

636 rewriter.eraseOp(accessChainOp);

637

638 return success();

639 }

640

641 LogicalResult

642 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,

644 auto memrefType = cast(loadOp.getMemref().getType());

645 if (memrefType.getElementType().isSignlessInteger())

646 return failure();

648 *getTypeConverter(), memrefType, adaptor.getMemref(),

649 adaptor.getIndices(), loadOp.getLoc(), rewriter);

650

651 if (!loadPtr)

652 return failure();

653

655 if (failed(memoryRequirements))

657 loadOp, "failed to determine memory requirements");

658

659 auto [memoryAccess, alignment] = *memoryRequirements;

660 rewriter.replaceOpWithNewOpspirv::LoadOp(loadOp, loadPtr, memoryAccess,

661 alignment);

662 return success();

663 }

664

665 LogicalResult

666 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,

668 auto memrefType = cast(storeOp.getMemref().getType());

669 if (!memrefType.getElementType().isSignlessInteger())

671 "element type is not a signless int");

672

673 auto loc = storeOp.getLoc();

674 auto &typeConverter = *getTypeConverter();

675 Value accessChain =

677 adaptor.getIndices(), loc, rewriter);

678

679 if (!accessChain)

681 storeOp, "failed to convert element pointer type");

682

683 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();

684

685 bool isBool = srcBits == 1;

686 if (isBool)

687 srcBits = typeConverter.getOptions().boolNumBits;

688

689 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);

690 if (!pointerType)

692 "failed to convert memref type");

693

694 Type pointeeType = pointerType.getPointeeType();

695 IntegerType dstType;

696 if (typeConverter.allows(spirv::Capability::Kernel)) {

697 if (auto arrayType = dyn_castspirv::ArrayType(pointeeType))

698 dstType = dyn_cast(arrayType.getElementType());

699 else

700 dstType = dyn_cast(pointeeType);

701 } else {

702

703 Type structElemType =

704 castspirv::StructType(pointeeType).getElementType(0);

705 if (auto arrayType = dyn_castspirv::ArrayType(structElemType))

706 dstType = dyn_cast(arrayType.getElementType());

707 else

708 dstType = dyn_cast(

709 castspirv::RuntimeArrayType(structElemType).getElementType());

710 }

711

712 if (!dstType)

714 storeOp, "failed to determine destination element type");

715

716 int dstBits = static_cast<int>(dstType.getWidth());

717 assert(dstBits % srcBits == 0);

718

719 if (srcBits == dstBits) {

721 if (failed(memoryRequirements))

723 storeOp, "failed to determine memory requirements");

724

725 auto [memoryAccess, alignment] = *memoryRequirements;

726 Value storeVal = adaptor.getValue();

727 if (isBool)

728 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);

729 rewriter.replaceOpWithNewOpspirv::StoreOp(storeOp, accessChain, storeVal,

730 memoryAccess, alignment);

731 return success();

732 }

733

734

735

736 if (typeConverter.allows(spirv::Capability::Kernel))

737 return failure();

738

739 auto accessChainOp = accessChain.getDefiningOpspirv::AccessChainOp();

740 if (!accessChainOp)

741 return failure();

742

743

744

745

746

747

748

749

750

751

752

753

754 assert(accessChainOp.getIndices().size() == 2);

755 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);

757

758

759

761 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));

762 Value clearBitsMask = rewriter.createOrFoldspirv::ShiftLeftLogicalOp(

763 loc, dstType, mask, offset);

764 clearBitsMask =

765 rewriter.createOrFoldspirv::NotOp(loc, dstType, clearBitsMask);

766

767 Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);

769 srcBits, dstBits, rewriter);

770 std::optionalspirv::Scope scope = getAtomicOpScope(memrefType);

771 if (!scope)

772 return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");

773

774 Value result = rewriter.createspirv::AtomicAndOp(

775 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,

776 clearBitsMask);

777 result = rewriter.createspirv::AtomicOrOp(

778 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,

779 storeVal);

780

781

782

783

784

785 rewriter.eraseOp(storeOp);

786

787 assert(accessChainOp.use_empty());

788 rewriter.eraseOp(accessChainOp);

789

790 return success();

791 }

792

793

794

795

796

797 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(

798 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,

800 Location loc = addrCastOp.getLoc();

801 auto &typeConverter = *getTypeConverter();

802 if (!typeConverter.allows(spirv::Capability::Kernel))

804 loc, "address space casts require kernel capability");

805

806 auto sourceType = dyn_cast(addrCastOp.getSource().getType());

807 if (!sourceType)

809 loc, "SPIR-V lowering requires ranked memref types");

810 auto resultType = cast(addrCastOp.getResult().getType());

811

812 auto sourceStorageClassAttr =

813 dyn_cast_or_nullspirv::StorageClassAttr(sourceType.getMemorySpace());

814 if (!sourceStorageClassAttr)

816 diag << "source address space " << sourceType.getMemorySpace()

817 << " must be a SPIR-V storage class";

818 });

819 auto resultStorageClassAttr =

820 dyn_cast_or_nullspirv::StorageClassAttr(resultType.getMemorySpace());

821 if (!resultStorageClassAttr)

823 diag << "result address space " << resultType.getMemorySpace()

824 << " must be a SPIR-V storage class";

825 });

826

827 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();

828 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();

829

830 Value result = adaptor.getSource();

831 Type resultPtrType = typeConverter.convertType(resultType);

832 if (!resultPtrType)

834 "failed to convert memref type");

835

836 Type genericPtrType = resultPtrType;

837

838

839

840

841

842

843

844 if (sourceSc != spirv::StorageClass::Generic &&

845 resultSc != spirv::StorageClass::Generic) {

846 Type intermediateType =

847 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),

848 sourceType.getLayout(),

849 rewriter.getAttrspirv::StorageClassAttr(

850 spirv::StorageClass::Generic));

851 genericPtrType = typeConverter.convertType(intermediateType);

852 }

853 if (sourceSc != spirv::StorageClass::Generic) {

854 result =

855 rewriter.createspirv::PtrCastToGenericOp(loc, genericPtrType, result);

856 }

857 if (resultSc != spirv::StorageClass::Generic) {

858 result =

859 rewriter.createspirv::GenericCastToPtrOp(loc, resultPtrType, result);

860 }

861 rewriter.replaceOp(addrCastOp, result);

862 return success();

863 }

864

865 LogicalResult

866 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,

868 auto memrefType = cast(storeOp.getMemref().getType());

869 if (memrefType.getElementType().isSignlessInteger())

872 *getTypeConverter(), memrefType, adaptor.getMemref(),

873 adaptor.getIndices(), storeOp.getLoc(), rewriter);

874

875 if (!storePtr)

876 return rewriter.notifyMatchFailure(storeOp, "type conversion failed");

877

879 if (failed(memoryRequirements))

881 storeOp, "failed to determine memory requirements");

882

883 auto [memoryAccess, alignment] = *memoryRequirements;

885 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);

886 return success();

887 }

888

889 LogicalResult ReinterpretCastPattern::matchAndRewrite(

890 memref::ReinterpretCastOp op, OpAdaptor adaptor,

892 Value src = adaptor.getSource();

893 auto srcType = dyn_castspirv::PointerType(src.getType());

894

895 if (!srcType)

897 diag << "invalid src type " << src.getType();

898 });

899

900 const TypeConverter *converter = getTypeConverter();

901

903 if (dstType != srcType)

905 diag << "invalid dst type " << op.getType();

906 });

907

909 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)

910 .front();

913 return success();

914 }

915

917 if (!intType)

918 return rewriter.notifyMatchFailure(op, "failed to convert index type");

919

921 auto offsetValue = [&]() -> Value {

922 if (auto val = dyn_cast(offset))

923 return val;

924

925 int64_t attrVal = cast(cast(offset)).getInt();

927 return rewriter.createOrFoldspirv::ConstantOp(loc, intType, attr);

928 }();

929

931 op, src, offsetValue, std::nullopt);

932 return success();

933 }

934

935

936

937

938

939 LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(

940 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,

942 auto &typeConverter = *getTypeConverter();

943 Type indexType = typeConverter.getIndexType();

944 rewriter.replaceOpWithNewOpspirv::ConvertPtrToUOp(extractOp, indexType,

945 adaptor.getSource());

946 return success();

947 }

948

949

950

951

952

953 namespace mlir {

957 .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,

958 DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,

959 MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,

960 CastPattern, ExtractAlignedPointerAsIndexOpPattern>(

961 typeConverter, patterns.getContext());

962 }

963 }

static Value getZero(OpBuilder &b, Location loc, Type elementType)

Get zero value for an element type.

static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)

Casts the given srcInt into a boolean value.

static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)

Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...

static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)

Returns an adjusted spirv::AccessChainOp.

static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)

Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...

static bool isAllocationSupported(Operation *allocOp, MemRefType type)

Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.

static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)

Returns the offset of the value in targetBits representation.

#define ATOMIC_CASE(kind, spirvOp)

static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal)

Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.

static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)

Casts the given srcBool into an integer of dstType.

static std::string diag(const llvm::Value &value)

static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)

Walks the given type hierarchy with the given indices, potentially down to component granularity,...

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

iterator_range< op_iterator< OpT > > getOps()

Return an iterator range over the operations within this block that are of 'OpT'.

IntegerAttr getIntegerAttr(Type type, int64_t value)

IntegerType getIntegerType(unsigned width)

Attr getAttr(Args &&...args)

Get or construct an instance of the attribute Attr with provided arguments.

This class implements a pattern rewriter for use with ConversionPatterns.

void replaceOp(Operation *op, ValueRange newValues) override

Replace the given operation with the new values.

void eraseOp(Operation *op) override

PatternRewriter hook for erasing a dead operation.

This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...

OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)

This class represents a single result from folding an operation.

Operation is the basic unit of execution within MLIR.

AttrClass getAttrOfType(StringAttr name)

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

Type conversion from builtin types to SPIR-V types for shader interface.

static Operation * getNearestSymbolTable(Operation *from)

Returns the nearest symbol table from a given operation from.

LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const

Convert the given type.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isInteger() const

Return true if this is an integer type (with the specified width).

bool isIntOrFloat() const

Return true if this is an integer (of any signedness) or a float type.

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

MLIRContext * getContext() const

Utility to get the associated MLIRContext that this value is defined in.

Type getType() const

Return the type of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)

Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...

Include the generated interface declarations.

const FrozenRewritePatternSet & patterns

bool isZeroInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 0.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)

Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...

void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)

Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.

spirv::MemoryAccessAttr memoryAccess