MLIR: lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

31#include "llvm/ADT/STLExtras.h"

32#include "llvm/ADT/SmallVector.h"

33#include "llvm/ADT/StringExtras.h"

34#include "llvm/Support/Debug.h"

35#include "llvm/Support/MathExtras.h"

36

37#include

38

39#define DEBUG_TYPE "mlir-spirv-conversion"

40

41using namespace mlir;

42

43namespace {

44

45

46

47

48

49static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {

50 LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");

51 if (vecType.isScalable()) {

52 LLVM_DEBUG(llvm::dbgs()

53 << "--scalable vectors are not supported -> BAIL\n");

54 return std::nullopt;

55 }

59 if (!targetShape) {

60 LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");

61 return std::nullopt;

62 }

63 auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);

64 if (!maybeShapeRatio) {

65 LLVM_DEBUG(llvm::dbgs()

66 << "--could not compute integral shape ratio -> BAIL\n");

67 return std::nullopt;

68 }

69 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {

70 LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");

71 return std::nullopt;

72 }

73 LLVM_DEBUG(llvm::dbgs()

74 << "--found an integral shape ratio to unroll to -> SUCCESS\n");

75 return targetShape;

76}

77

78

79

80

81

82

83

84template

85static LogicalResult checkExtensionRequirements(

88 for (const auto &ors : candidates) {

89 if (targetEnv.allows(ors))

90 continue;

91

92 LLVM_DEBUG({

94 for (spirv::Extension ext : ors)

95 extStrings.push_back(spirv::stringifyExtension(ext));

96

97 llvm::dbgs() << label << " illegal: requires at least one extension in ["

98 << llvm::join(extStrings, ", ")

99 << "] but none allowed in target environment\n";

100 });

101 return failure();

102 }

104}

105

106

107

108

109

110

111

112template

113static LogicalResult checkCapabilityRequirements(

116 for (const auto &ors : candidates) {

117 if (targetEnv.allows(ors))

118 continue;

119

120 LLVM_DEBUG({

122 for (spirv::Capability cap : ors)

123 capStrings.push_back(spirv::stringifyCapability(cap));

124

125 llvm::dbgs() << label << " illegal: requires at least one capability in ["

126 << llvm::join(capStrings, ", ")

127 << "] but none allowed in target environment\n";

128 });

129 return failure();

130 }

132}

133

134

135

136static bool needsExplicitLayout(spirv::StorageClass storageClass) {

137 switch (storageClass) {

138 case spirv::StorageClass::PhysicalStorageBuffer:

139 case spirv::StorageClass::PushConstant:

140 case spirv::StorageClass::StorageBuffer:

141 case spirv::StorageClass::Uniform:

142 return true;

143 default:

144 return false;

145 }

146}

147

148

149

151wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {

152 auto structType = needsExplicitLayout(storageClass)

156}

157

158

159

160

161

164 return castspirv::ScalarType(

165 IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));

166}

167

168

169

170static std::optional<int64_t>

172 if (isaspirv::ScalarType(type)) {

174

175

176

177

178

179

180 if (bitWidth == 1)

181 return std::nullopt;

182 return bitWidth / 8;

183 }

184

185

186 if (options.emulateUnsupportedFloatTypes && isa(type)) {

188 if (bitWidth == 8)

189 return bitWidth / 8;

190 return std::nullopt;

191 }

192

193 if (auto complexType = dyn_cast(type)) {

194 auto elementSize = getTypeNumBytes(options, complexType.getElementType());

195 if (!elementSize)

196 return std::nullopt;

197 return 2 * *elementSize;

198 }

199

200 if (auto vecType = dyn_cast(type)) {

201 auto elementSize = getTypeNumBytes(options, vecType.getElementType());

202 if (!elementSize)

203 return std::nullopt;

204 return vecType.getNumElements() * *elementSize;

205 }

206

207 if (auto memRefType = dyn_cast(type)) {

208

209

212 if (!memRefType.hasStaticShape() ||

213 failed(memRefType.getStridesAndOffset(strides, offset)))

214 return std::nullopt;

215

216

217

218

219 auto elementSize = getTypeNumBytes(options, memRefType.getElementType());

220 if (!elementSize)

221 return std::nullopt;

222

223 if (memRefType.getRank() == 0)

224 return elementSize;

225

226 auto dims = memRefType.getShape();

227 if (llvm::is_contained(dims, ShapedType::kDynamic) ||

228 ShapedType::isDynamic(offset) ||

229 llvm::is_contained(strides, ShapedType::kDynamic))

230 return std::nullopt;

231

233 for (const auto &shape : enumerate(dims))

234 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);

235

236 return (offset + memrefSize) * *elementSize;

237 }

238

239 if (auto tensorType = dyn_cast(type)) {

240 if (!tensorType.hasStaticShape())

241 return std::nullopt;

242

243 auto elementSize = getTypeNumBytes(options, tensorType.getElementType());

244 if (!elementSize)

245 return std::nullopt;

246

247 int64_t size = *elementSize;

248 for (auto shape : tensorType.getShape())

250

251 return size;

252 }

253

254

255 return std::nullopt;

256}

257

258

262 std::optionalspirv::StorageClass storageClass = {}) {

263

266 type.getExtensions(extensions, storageClass);

267 type.getCapabilities(capabilities, storageClass);

268

269

270 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&

271 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))

272 return type;

273

274

275

276 if (options.emulateLT32BitScalarTypes)

277 return nullptr;

278

279

281 LLVM_DEBUG(llvm::dbgs()

282 << type

283 << " not converted to 32-bit for SPIR-V to avoid truncation\n");

284 return nullptr;

285 }

286

287 if (auto floatType = dyn_cast(type)) {

288 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");

290 }

291

292 auto intType = cast(type);

293 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");

294 return IntegerType::get(targetEnv.getContext(), 32,

295 intType.getSignedness());

296}

297

298

299

300

301

302

303

304

305

307 IntegerType type) {

308 if (type.getWidth() > 8) {

309 LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");

310 return nullptr;

311 }

313 LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");

314 return nullptr;

315 }

316

317 if (!llvm::isPowerOf2_32(type.getWidth())) {

318 LLVM_DEBUG(llvm::dbgs()

319 << "unsupported non-power-of-two bitwidth in sub-byte" << type

320 << "\n");

321 return nullptr;

322 }

323

324 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");

325 return IntegerType::get(type.getContext(), 32,

326 type.getSignedness());

327}

328

329

330

332 FloatType type) {

333 if (options.emulateUnsupportedFloatTypes)

334 return nullptr;

335

336 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,

337 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,

338 Float8E8M0FNUType>(type))

339 return IntegerType::get(type.getContext(), type.getWidth());

340 LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");

341 return nullptr;

342}

343

344

345

346

347static ShapedType

348convertShaped8BitFloatType(ShapedType type,

350 if (options.emulateUnsupportedFloatTypes)

351 return type;

352 Type srcElementType = type.getElementType();

353 Type convertedElementType = nullptr;

354

355 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,

356 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,

357 Float8E8M0FNUType>(srcElementType))

358 convertedElementType = IntegerType::get(

360

361 if (!convertedElementType)

362 return type;

363

364 return type.clone(convertedElementType);

365}

366

367

368

369

370static ShapedType

371convertIndexElementType(ShapedType type,

373 Type indexType = dyn_cast(type.getElementType());

374 if (!indexType)

375 return type;

376

378}

379

380

384 std::optionalspirv::StorageClass storageClass = {}) {

385 type = cast(convertIndexElementType(type, options));

386 type = cast(convertShaped8BitFloatType(type, options));

387 auto scalarType = dyn_cast_or_nullspirv::ScalarType(type.getElementType());

388 if (!scalarType) {

389

390

391 auto intType = dyn_cast(type.getElementType());

392 if (!intType) {

393 LLVM_DEBUG(llvm::dbgs()

394 << type

395 << " illegal: cannot convert non-scalar element type\n");

396 return nullptr;

397 }

398

399 Type elementType = convertSubByteIntegerType(options, intType);

400 if (!elementType)

401 return nullptr;

402

403 if (type.getRank() <= 1 && type.getNumElements() == 1)

404 return elementType;

405

406 if (type.getNumElements() > 4) {

407 LLVM_DEBUG(llvm::dbgs()

408 << type << " illegal: > 4-element unimplemented\n");

409 return nullptr;

410 }

411

412 return VectorType::get(type.getShape(), elementType);

413 }

414

415 if (type.getRank() <= 1 && type.getNumElements() == 1)

416 return convertScalarType(targetEnv, options, scalarType, storageClass);

417

419 LLVM_DEBUG(llvm::dbgs()

420 << type << " illegal: not a valid composite type\n");

421 return nullptr;

422 }

423

424

427 castspirv::CompositeType(type).getExtensions(extensions, storageClass);

428 castspirv::CompositeType(type).getCapabilities(capabilities, storageClass);

429

430

431 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&

432 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))

433 return type;

434

435 auto elementType =

436 convertScalarType(targetEnv, options, scalarType, storageClass);

437 if (elementType)

438 return VectorType::get(type.getShape(), elementType);

439 return nullptr;

440}

441

445 std::optionalspirv::StorageClass storageClass = {}) {

446 auto scalarType = dyn_cast_or_nullspirv::ScalarType(type.getElementType());

447 if (!scalarType) {

448 LLVM_DEBUG(llvm::dbgs()

449 << type << " illegal: cannot convert non-scalar element type\n");

450 return nullptr;

451 }

452

453 auto elementType =

454 convertScalarType(targetEnv, options, scalarType, storageClass);

455 if (!elementType)

456 return nullptr;

457 if (elementType != type.getElementType()) {

458 LLVM_DEBUG(llvm::dbgs()

459 << type << " illegal: complex type emulation unsupported\n");

460 return nullptr;

461 }

462

463 return VectorType::get(2, elementType);

464}

465

466

467

468

469

470

471

475

476 if (!type.hasStaticShape()) {

477 LLVM_DEBUG(llvm::dbgs()

478 << type << " illegal: dynamic shape unimplemented\n");

479 return nullptr;

480 }

481

482 type = cast(convertIndexElementType(type, options));

483 type = cast(convertShaped8BitFloatType(type, options));

484 auto scalarType = dyn_cast_or_nullspirv::ScalarType(type.getElementType());

485 if (!scalarType) {

486 LLVM_DEBUG(llvm::dbgs()

487 << type << " illegal: cannot convert non-scalar element type\n");

488 return nullptr;

489 }

490

491 std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);

492 std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);

493 if (!scalarSize || !tensorSize) {

494 LLVM_DEBUG(llvm::dbgs()

495 << type << " illegal: cannot deduce element count\n");

496 return nullptr;

497 }

498

499 int64_t arrayElemCount = *tensorSize / *scalarSize;

500 if (arrayElemCount == 0) {

501 LLVM_DEBUG(llvm::dbgs()

502 << type << " illegal: cannot handle zero-element tensors\n");

503 return nullptr;

504 }

505 if (arrayElemCount > std::numeric_limits::max()) {

506 LLVM_DEBUG(llvm::dbgs()

507 << type << " illegal: cannot fit tensor into target type\n");

508 return nullptr;

509 }

510

511 Type arrayElemType = convertScalarType(targetEnv, options, scalarType);

512 if (!arrayElemType)

513 return nullptr;

514 std::optional<int64_t> arrayElemSize =

515 getTypeNumBytes(options, arrayElemType);

516 if (!arrayElemSize) {

517 LLVM_DEBUG(llvm::dbgs()

518 << type << " illegal: cannot deduce converted element size\n");

519 return nullptr;

520 }

521

523}

524

527 MemRefType type,

528 spirv::StorageClass storageClass) {

529 unsigned numBoolBits = options.boolNumBits;

530 if (numBoolBits != 8) {

531 LLVM_DEBUG(llvm::dbgs()

532 << "using non-8-bit storage for bool types unimplemented");

533 return nullptr;

534 }

535 auto elementType = dyn_castspirv::ScalarType(

536 IntegerType::get(type.getContext(), numBoolBits));

537 if (!elementType)

538 return nullptr;

539 Type arrayElemType =

540 convertScalarType(targetEnv, options, elementType, storageClass);

541 if (!arrayElemType)

542 return nullptr;

543 std::optional<int64_t> arrayElemSize =

544 getTypeNumBytes(options, arrayElemType);

545 if (!arrayElemSize) {

546 LLVM_DEBUG(llvm::dbgs()

547 << type << " illegal: cannot deduce converted element size\n");

548 return nullptr;

549 }

550

551 if (!type.hasStaticShape()) {

552

553

554 if (targetEnv.allows(spirv::Capability::Kernel))

556 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;

558

559

560 return wrapInStructAndGetPointer(arrayType, storageClass);

561 }

562

563 if (type.getNumElements() == 0) {

564 LLVM_DEBUG(llvm::dbgs()

565 << type << " illegal: zero-element memrefs are not supported\n");

566 return nullptr;

567 }

568

569 int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);

570 int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);

571 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;

573 if (targetEnv.allows(spirv::Capability::Kernel))

575 return wrapInStructAndGetPointer(arrayType, storageClass);

576}

577

580 MemRefType type,

581 spirv::StorageClass storageClass) {

582 IntegerType elementType = cast(type.getElementType());

583 Type arrayElemType = convertSubByteIntegerType(options, elementType);

584 if (!arrayElemType)

585 return nullptr;

586 int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);

587

588 if (!type.hasStaticShape()) {

589

590

591 if (targetEnv.allows(spirv::Capability::Kernel))

593 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;

595

596

597 return wrapInStructAndGetPointer(arrayType, storageClass);

598 }

599

600 if (type.getNumElements() == 0) {

601 LLVM_DEBUG(llvm::dbgs()

602 << type << " illegal: zero-element memrefs are not supported\n");

603 return nullptr;

604 }

605

607 llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);

608 int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);

609 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;

611 if (targetEnv.allows(spirv::Capability::Kernel))

613 return wrapInStructAndGetPointer(arrayType, storageClass);

614}

615

616static spirv::Dim convertRank(int64_t rank) {

617 switch (rank) {

618 case 1:

619 return spirv::Dim::Dim1D;

620 case 2:

621 return spirv::Dim::Dim2D;

622 case 3:

623 return spirv::Dim::Dim3D;

624 default:

625 llvm_unreachable("Invalid memref rank!");

626 }

627}

628

629static spirv::ImageFormat getImageFormat(Type elementType) {

631 .Case([](Float16Type) { return spirv::ImageFormat::R16f; })

632 .Case([](Float32Type) { return spirv::ImageFormat::R32f; })

633 .Case([](IntegerType intType) {

634 auto const isSigned = intType.isSigned() || intType.isSignless();

635#define BIT_WIDTH_CASE(BIT_WIDTH) \

636 case BIT_WIDTH: \

637 return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \

638 : spirv::ImageFormat::R##BIT_WIDTH##ui

639

640 switch (intType.getWidth()) {

643 default:

644 llvm_unreachable("Unhandled integer type!");

645 }

646 })

647 .DefaultUnreachable("Unhandled element type!");

648#undef BIT_WIDTH_CASE

649}

650

653 MemRefType type) {

654 auto attr = dyn_cast_or_nullspirv::StorageClassAttr(type.getMemorySpace());

655 if (!attr) {

656 LLVM_DEBUG(

657 llvm::dbgs()

658 << type

659 << " illegal: expected memory space to be a SPIR-V storage class "

660 "attribute; please use MemorySpaceToStorageClassConverter to map "

661 "numeric memory spaces beforehand\n");

662 return nullptr;

663 }

664 spirv::StorageClass storageClass = attr.getValue();

665

666

667

668

669 if (storageClass == spirv::StorageClass::Image) {

670 const int64_t rank = type.getRank();

671 if (rank < 1 || rank > 3) {

672 LLVM_DEBUG(llvm::dbgs()

673 << type << " illegal: cannot lower memref of rank " << rank

674 << " to a SPIR-V Image\n");

675 return nullptr;

676 }

677

678

679

680 auto elementType = type.getElementType();

681 if (!isaspirv::ScalarType(elementType)) {

682 LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "

683 << elementType << " to a SPIR-V Image\n");

684 return nullptr;

685 }

686

687

688

689

691 elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,

692 spirv::ImageArrayedInfo::NonArrayed,

693 spirv::ImageSamplingInfo::SingleSampled,

694 spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));

697 spvSampledImageType, spirv::StorageClass::UniformConstant);

698 return imagePtrType;

699 }

700

701 if (isa(type.getElementType())) {

702 if (type.getElementTypeBitWidth() == 1)

703 return convertBoolMemrefType(targetEnv, options, type, storageClass);

704 if (type.getElementTypeBitWidth() < 8)

705 return convertSubByteMemrefType(targetEnv, options, type, storageClass);

706 }

707

708 Type arrayElemType;

709 Type elementType = type.getElementType();

710 if (auto vecType = dyn_cast(elementType)) {

711 arrayElemType =

712 convertVectorType(targetEnv, options, vecType, storageClass);

713 } else if (auto complexType = dyn_cast(elementType)) {

714 arrayElemType =

715 convertComplexType(targetEnv, options, complexType, storageClass);

716 } else if (auto scalarType = dyn_castspirv::ScalarType(elementType)) {

717 arrayElemType =

718 convertScalarType(targetEnv, options, scalarType, storageClass);

719 } else if (auto indexType = dyn_cast(elementType)) {

720 type = cast(convertIndexElementType(type, options));

721 arrayElemType = type.getElementType();

722 } else if (auto floatType = dyn_cast(elementType)) {

723

724 type = cast(convertShaped8BitFloatType(type, options));

725 arrayElemType = type.getElementType();

726 } else {

727 LLVM_DEBUG(

728 llvm::dbgs()

729 << type

730 << " unhandled: can only convert scalar or vector element type\n");

731 return nullptr;

732 }

733 if (!arrayElemType)

734 return nullptr;

735

736 std::optional<int64_t> arrayElemSize =

737 getTypeNumBytes(options, arrayElemType);

738 if (!arrayElemSize) {

739 LLVM_DEBUG(llvm::dbgs()

740 << type << " illegal: cannot deduce converted element size\n");

741 return nullptr;

742 }

743

744 if (!type.hasStaticShape()) {

745

746

747 if (targetEnv.allows(spirv::Capability::Kernel))

749 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;

751

752

753 return wrapInStructAndGetPointer(arrayType, storageClass);

754 }

755

756 std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);

757 if (!memrefSize) {

758 LLVM_DEBUG(llvm::dbgs()

759 << type << " illegal: cannot deduce element count\n");

760 return nullptr;

761 }

762

763 if (*memrefSize == 0) {

764 LLVM_DEBUG(llvm::dbgs()

765 << type << " illegal: zero-element memrefs are not supported\n");

766 return nullptr;

767 }

768

769 int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);

770 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;

772 if (targetEnv.allows(spirv::Capability::Kernel))

774 return wrapInStructAndGetPointer(arrayType, storageClass);

775}

776

777

778

779

780

781

782

783

784

785

786

787

788

789

790

791

792

793

797

798 if (inputs.size() != 1) {

799 auto castOp =

800 UnrealizedConversionCastOp::create(builder, loc, type, inputs);

801 return castOp.getResult(0);

802 }

803 Value input = inputs.front();

804

805

806 if (!isa(type)) {

807 auto castOp =

808 UnrealizedConversionCastOp::create(builder, loc, type, inputs);

809 return castOp.getResult(0);

810 }

811 auto inputType = cast(input.getType());

812

813 auto scalarType = dyn_castspirv::ScalarType(type);

814 if (!scalarType) {

815 auto castOp =

816 UnrealizedConversionCastOp::create(builder, loc, type, inputs);

817 return castOp.getResult(0);

818 }

819

820

821

822

823 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {

824 auto castOp =

825 UnrealizedConversionCastOp::create(builder, loc, type, inputs);

826 return castOp.getResult(0);

827 }

828

829

831 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);

832 return spirv::IEqualOp::create(builder, loc, input, one);

833 }

834

835

838 scalarType.getExtensions(exts);

839 scalarType.getCapabilities(caps);

840 if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||

841 failed(checkExtensionRequirements(type, targetEnv, exts))) {

842 auto castOp =

843 UnrealizedConversionCastOp::create(builder, loc, type, inputs);

844 return castOp.getResult(0);

845 }

846

847

848

849

851 return spirv::SConvertOp::create(builder, loc, type, input);

852 }

853 return spirv::UConvertOp::create(builder, loc, type, input);

854}

855

856

857

858

859

860static spirv::GlobalVariableOp getBuiltinVariable(Block &body,

861 spirv::BuiltIn builtin) {

862

863

864 for (auto varOp : body.getOpsspirv::GlobalVariableOp()) {

865 if (auto builtinAttr = varOp->getAttrOfType(

866 spirv::SPIRVDialect::getAttributeName(

867 spirv::Decoration::BuiltIn))) {

868 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());

869 if (varBuiltIn == builtin) {

870 return varOp;

871 }

872 }

873 }

874 return nullptr;

875}

876

877

878std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,

879 StringRef suffix) {

880 return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();

881}

882

883

884static spirv::GlobalVariableOp

887 StringRef prefix, StringRef suffix) {

888 if (auto varOp = getBuiltinVariable(body, builtin))

889 return varOp;

890

893

894 spirv::GlobalVariableOp newVarOp;

896 case spirv::BuiltIn::NumWorkgroups:

897 case spirv::BuiltIn::WorkgroupSize:

898 case spirv::BuiltIn::WorkgroupId:

899 case spirv::BuiltIn::LocalInvocationId:

900 case spirv::BuiltIn::GlobalInvocationId: {

902 spirv::StorageClass::Input);

903 std::string name = getBuiltinVarName(builtin, prefix, suffix);

904 newVarOp =

905 spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);

906 break;

907 }

908 case spirv::BuiltIn::SubgroupId:

909 case spirv::BuiltIn::NumSubgroups:

910 case spirv::BuiltIn::SubgroupSize:

911 case spirv::BuiltIn::SubgroupLocalInvocationId: {

912 auto ptrType =

914 std::string name = getBuiltinVarName(builtin, prefix, suffix);

915 newVarOp =

916 spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);

917 break;

918 }

919 default:

920 emitError(loc, "unimplemented builtin variable generation for ")

921 << stringifyBuiltIn(builtin);

922 }

923 return newVarOp;

924}

925

926

927

928

929

930

931

932static spirv::PointerType getPushConstantStorageType(unsigned elementCount,

934 Type indexType) {

936 4);

939}

940

941

942

943static spirv::GlobalVariableOp getPushConstantVariable(Block &body,

944 unsigned elementCount) {

945 for (auto varOp : body.getOpsspirv::GlobalVariableOp()) {

946 auto ptrType = dyn_castspirv::PointerType(varOp.getType());

947 if (!ptrType)

948 continue;

949

950

951

952

953 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {

954 auto numElements = castspirv::ArrayType(

955 castspirv::StructType(ptrType.getPointeeType())

956 .getElementType(0))

957 .getNumElements();

958 if (numElements == elementCount)

959 return varOp;

960 }

961 }

962 return nullptr;

963}

964

965

966

967static spirv::GlobalVariableOp

968getOrInsertPushConstantVariable(Location loc, Block &block,

970 Type indexType) {

971 if (auto varOp = getPushConstantVariable(block, elementCount))

972 return varOp;

973

975 auto type = getPushConstantStorageType(elementCount, builder, indexType);

976 const char *name = "__push_constant_var__";

977 return spirv::GlobalVariableOp::create(builder, loc, type, name,

978 nullptr);

979}

980

981

982

983

984

985

986

987struct FuncOpConversion final : OpConversionPatternfunc::FuncOp {

988 using Base::Base;

989

990 LogicalResult

991 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,

992 ConversionPatternRewriter &rewriter) const override {

993 FunctionType fnType = funcOp.getFunctionType();

994 if (fnType.getNumResults() > 1)

995 return failure();

996

997 TypeConverter::SignatureConversion signatureConverter(

998 fnType.getNumInputs());

999 for (const auto &argType : enumerate(fnType.getInputs())) {

1000 auto convertedType = getTypeConverter()->convertType(argType.value());

1001 if (!convertedType)

1002 return failure();

1003 signatureConverter.addInputs(argType.index(), convertedType);

1004 }

1005

1006 Type resultType;

1007 if (fnType.getNumResults() == 1) {

1008 resultType = getTypeConverter()->convertType(fnType.getResult(0));

1009 if (!resultType)

1010 return failure();

1011 }

1012

1013

1014 auto newFuncOp = spirv::FuncOp::create(

1015 rewriter, funcOp.getLoc(), funcOp.getName(),

1016 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),

1017 resultType ? TypeRange(resultType)

1019

1020

1021 for (const auto &namedAttr : funcOp->getAttrs()) {

1022 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&

1024 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());

1025 }

1026

1027 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),

1028 newFuncOp.end());

1029 if (failed(rewriter.convertRegionTypes(

1030 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))

1031 return failure();

1032 rewriter.eraseOp(funcOp);

1034 }

1035};

1036

1037

1038

1039struct FuncOpVectorUnroll final : OpRewritePatternfunc::FuncOp {

1040 using Base::Base;

1041

1042 LogicalResult matchAndRewrite(func::FuncOp funcOp,

1044 FunctionType fnType = funcOp.getFunctionType();

1045

1046

1047 if (funcOp.isDeclaration()) {

1048 LLVM_DEBUG(llvm::dbgs()

1049 << fnType << " illegal: declarations are unsupported\n");

1050 return failure();

1051 }

1052

1053

1054 auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(),

1055 funcOp.getName(), fnType);

1057 newFuncOp.end());

1058

1059 Location loc = newFuncOp.getBody().getLoc();

1060

1061 Block &entryBlock = newFuncOp.getBlocks().front();

1064

1065 TypeConverter::SignatureConversion oneToNTypeMapping(

1066 fnType.getInputs().size());

1067

1068

1069

1070

1072 size_t newInputNo = 0;

1073

1074

1075

1076

1077

1078 llvm::SmallDenseMap<Operation *, size_t> tmpOps;

1079

1080

1081 size_t newOpCount = 0;

1082

1083

1084 for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {

1085

1086 auto origVecType = dyn_cast(origType);

1087 if (!origVecType) {

1088

1089 Value result = arith::ConstantOp::create(

1090 rewriter, loc, origType, rewriter.getZeroAttr(origType));

1092 tmpOps.insert({result.getDefiningOp(), newInputNo});

1093 oneToNTypeMapping.addInputs(origInputNo, origType);

1094 ++newInputNo;

1095 ++newOpCount;

1096 continue;

1097 }

1098

1100 if (!targetShape) {

1101

1102 Value result = arith::ConstantOp::create(

1103 rewriter, loc, origType, rewriter.getZeroAttr(origType));

1105 tmpOps.insert({result.getDefiningOp(), newInputNo});

1106 oneToNTypeMapping.addInputs(origInputNo, origType);

1107 ++newInputNo;

1108 ++newOpCount;

1109 continue;

1110 }

1111 VectorType unrolledType =

1112 VectorType::get(*targetShape, origVecType.getElementType());

1113 auto originalShape =

1114 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());

1115

1116

1117 Value result = arith::ConstantOp::create(

1118 rewriter, loc, origVecType, rewriter.getZeroAttr(origVecType));

1119 ++newOpCount;

1120

1121 Value dummy = arith::ConstantOp::create(

1122 rewriter, loc, unrolledType, rewriter.getZeroAttr(unrolledType));

1123 ++newOpCount;

1124

1125

1130 result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy,

1131 result, offsets, strides);

1132 newTypes.push_back(unrolledType);

1133 unrolledInputNums.push_back(newInputNo);

1134 ++newInputNo;

1135 ++newOpCount;

1136 }

1138 oneToNTypeMapping.addInputs(origInputNo, newTypes);

1139 }

1140

1141

1142 auto convertedTypes = oneToNTypeMapping.getConvertedTypes();

1143 auto newFnType = fnType.clone(convertedTypes, fnType.getResults());

1145 [&] { newFuncOp.setFunctionType(newFnType); });

1146

1147

1148 entryBlock.eraseArguments(0, fnType.getNumInputs());

1150 entryBlock.addArguments(convertedTypes, locs);

1151

1152

1153

1154 for (auto &[placeholderOp, argIdx] : tmpOps) {

1155 if (!placeholderOp)

1156 continue;

1159 }

1160

1161

1162

1163

1164

1165 size_t unrolledInputIdx = 0;

1166 for (auto [count, op] : enumerate(entryBlock.getOperations())) {

1168

1169

1170

1171 if (count >= newOpCount)

1172 continue;

1173 if (auto vecOp = dyn_castvector::InsertStridedSliceOp(op)) {

1174 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];

1176 curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));

1177 });

1178 ++unrolledInputIdx;

1179 }

1180 }

1181

1182

1183

1184 rewriter.eraseOp(funcOp);

1186 }

1187};

1188

1189

1190

1191

1192

1193

1194

1195struct ReturnOpVectorUnroll final : OpRewritePatternfunc::ReturnOp {

1196 using Base::Base;

1197

1198 LogicalResult matchAndRewrite(func::ReturnOp returnOp,

1200

1201 auto funcOp = dyn_castfunc::FuncOp(returnOp->getParentOp());

1202 if (!funcOp)

1203 return failure();

1204

1205 FunctionType fnType = funcOp.getFunctionType();

1206 TypeConverter::SignatureConversion oneToNTypeMapping(

1207 fnType.getResults().size());

1208 Location loc = returnOp.getLoc();

1209

1210

1212

1213

1214 for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {

1215

1216 auto origVecType = dyn_cast(origType);

1217 if (!origVecType) {

1218 oneToNTypeMapping.addInputs(origResultNo, origType);

1219 newOperands.push_back(returnOp.getOperand(origResultNo));

1220 continue;

1221 }

1222

1224 if (!targetShape) {

1225

1226 oneToNTypeMapping.addInputs(origResultNo, origType);

1227 newOperands.push_back(returnOp.getOperand(origResultNo));

1228 continue;

1229 }

1230 VectorType unrolledType =

1231 VectorType::get(*targetShape, origVecType.getElementType());

1232

1233

1234

1235 auto originalShape =

1236 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());

1239 extractShape.back() = targetShape->back();

1241 Value returnValue = returnOp.getOperand(origResultNo);

1244 Value result = vector::ExtractStridedSliceOp::create(

1245 rewriter, loc, returnValue, offsets, extractShape, strides);

1246 if (originalShape.size() > 1) {

1249 vector::ExtractOp::create(rewriter, loc, result, extractIndices);

1250 }

1251 newOperands.push_back(result);

1252 newTypes.push_back(unrolledType);

1253 }

1254 oneToNTypeMapping.addInputs(origResultNo, newTypes);

1255 }

1256

1257

1258 auto newFnType =

1260 TypeRange(oneToNTypeMapping.getConvertedTypes()));

1262 [&] { funcOp.setFunctionType(newFnType); });

1263

1264

1265

1267 func::ReturnOp::create(rewriter, loc, newOperands));

1268

1270 }

1271};

1272

1273}

1274

1275

1276

1277

1278

1282 StringRef prefix, StringRef suffix) {

1284 if (!parent) {

1285 op->emitError("expected operation to be within a module-like op");

1286 return nullptr;

1287 }

1288

1289 spirv::GlobalVariableOp varOp =

1291 builtin, integerType, builder, prefix, suffix);

1292 Value ptr = spirv::AddressOfOp::create(builder, op->getLoc(), varOp);

1293 return spirv::LoadOp::create(builder, op->getLoc(), ptr);

1294}

1295

1296

1297

1298

1299

1301 unsigned offset, Type integerType,

1305 if (!parent) {

1306 op->emitError("expected operation to be within a module-like op");

1307 return nullptr;

1308 }

1309

1310 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(

1311 loc, parent->getRegion(0).front(), elementCount, builder, integerType);

1312

1313 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);

1314 Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType,

1316 auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp);

1317 auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp,

1319 return spirv::LoadOp::create(builder, loc, acOp);

1320}

1321

1322

1323

1324

1325

1329 assert(indices.size() == strides.size() &&

1330 "must provide indices for all dimensions");

1331

1332

1333

1334

1335

1336

1337 Value linearizedIndex = builder.createOrFoldspirv::ConstantOp(

1338 loc, integerType, IntegerAttr::get(integerType, offset));

1339 for (const auto &index : llvm::enumerate(indices)) {

1341 loc, integerType,

1342 IntegerAttr::get(integerType, strides[index.index()]));

1344 builder.createOrFoldspirv::IMulOp(loc, index.value(), strideVal);

1345 linearizedIndex =

1346 builder.createOrFoldspirv::IAddOp(loc, update, linearizedIndex);

1347 }

1348 return linearizedIndex;

1349}

1350

1352 MemRefType baseType, Value basePtr,

1355

1356

1359 if (failed(baseType.getStridesAndOffset(strides, offset)) ||

1360 llvm::is_contained(strides, ShapedType::kDynamic) ||

1361 ShapedType::isDynamic(offset)) {

1362 return nullptr;

1363 }

1364

1365 auto indexType = typeConverter.getIndexType();

1366

1368 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);

1369

1370

1371 linearizedIndices.push_back(zero);

1372

1373 if (baseType.getRank() == 0) {

1374 linearizedIndices.push_back(zero);

1375 } else {

1376 linearizedIndices.push_back(

1378 }

1379 return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices);

1380}

1381

1383 MemRefType baseType, Value basePtr,

1386

1387

1390 if (failed(baseType.getStridesAndOffset(strides, offset)) ||

1391 llvm::is_contained(strides, ShapedType::kDynamic) ||

1392 ShapedType::isDynamic(offset)) {

1393 return nullptr;

1394 }

1395

1396 auto indexType = typeConverter.getIndexType();

1397

1399 Value linearIndex;

1400 if (baseType.getRank() == 0) {

1401 linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);

1402 } else {

1403 linearIndex =

1405 }

1406 Type pointeeType =

1407 castspirv::PointerType(basePtr.getType()).getPointeeType();

1408 if (isaspirv::ArrayType(pointeeType)) {

1409 linearizedIndices.push_back(linearIndex);

1410 return spirv::AccessChainOp::create(builder, loc, basePtr,

1411 linearizedIndices);

1412 }

1413 return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex,

1414 linearizedIndices);

1415}

1416

1418 MemRefType baseType, Value basePtr,

1421

1422 if (typeConverter.allows(spirv::Capability::Kernel)) {

1424 builder);

1425 }

1426

1428 builder);

1429}

1430

1431

1432

1433

1434

1436 for (int i : {4, 3, 2}) {

1437 if (size % i == 0)

1438 return i;

1439 }

1440 return 1;

1441}

1442

1445 VectorType srcVectorType = op.getSourceVectorType();

1446 assert(srcVectorType.getRank() == 1);

1449 return {vectorSize};

1450}

1451

1454 VectorType vectorType = op.getResultVectorType();

1456 nativeSize.back() =

1458 return nativeSize;

1459}

1460

1461std::optional<SmallVector<int64_t>>

1464 if (auto vecType = dyn_cast(op->getResultTypes()[0])) {

1466 nativeSize.back() =

1468 return nativeSize;

1469 }

1470 }

1471

1473 .Case<vector::ReductionOp, vector::TransposeOp>(

1475 .Default(std::nullopt);

1476}

1477

1490

1493

1494

1495 {

1501 return failure();

1502 }

1503

1504

1505

1506 {

1509 patterns, vector::VectorTransposeLowering::EltWise);

1512 return failure();

1513 }

1514

1515

1516 {

1518

1519

1520 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);

1521 vector::ReductionOp::getCanonicalizationPatterns(patterns, context);

1522 vector::TransposeOp::getCanonicalizationPatterns(patterns, context);

1523

1524

1525

1526 vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(

1528 vector::InsertOp::getCanonicalizationPatterns(patterns, context);

1529 vector::ExtractOp::getCanonicalizationPatterns(patterns, context);

1530

1531

1532

1533 vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);

1534 vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);

1535

1537 return failure();

1538 }

1540}

1541

1542

1543

1544

1545

1548 : targetEnv(targetAttr), options(options) {

1549

1550

1551

1552

1553

1554

1555

1556

1557

1558

1560

1561 addConversion([this](IndexType ) { return getIndexType(); });

1562

1563 addConversion([this](IntegerType intType) -> std::optional {

1564 if (auto scalarType = dyn_castspirv::ScalarType(intType))

1565 return convertScalarType(this->targetEnv, this->options, scalarType);

1566 if (intType.getWidth() < 8)

1567 return convertSubByteIntegerType(this->options, intType);

1568 return Type();

1569 });

1570

1571 addConversion([this](FloatType floatType) -> std::optional {

1572 if (auto scalarType = dyn_castspirv::ScalarType(floatType))

1573 return convertScalarType(this->targetEnv, this->options, scalarType);

1574 if (floatType.getWidth() == 8)

1575 return convert8BitFloatType(this->options, floatType);

1576 return Type();

1577 });

1578

1579 addConversion([this](ComplexType complexType) {

1580 return convertComplexType(this->targetEnv, this->options, complexType);

1581 });

1582

1583 addConversion([this](VectorType vectorType) {

1584 return convertVectorType(this->targetEnv, this->options, vectorType);

1585 });

1586

1587 addConversion([this](TensorType tensorType) {

1588 return convertTensorType(this->targetEnv, this->options, tensorType);

1589 });

1590

1591 addConversion([this](MemRefType memRefType) {

1592 return convertMemrefType(this->targetEnv, this->options, memRefType);

1593 });

1594

1595

1596 addSourceMaterialization(

1598 return castToSourceType(this->targetEnv, builder, type, inputs, loc);

1599 });

1602 auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs);

1603 return cast.getResult(0);

1604 });

1605}

1606

1608 return ::getIndexType(getContext(), options);

1609}

1610

1611MLIRContext *SPIRVTypeConverter::getContext() const {

1612 return targetEnv.getAttr().getContext();

1613}

1614

1616 return targetEnv.allows(capability);

1617}

1618

1619

1620

1621

1622

1623std::unique_ptr

1625 std::unique_ptr target(

1626

1627 new SPIRVConversionTarget(targetAttr));

1628 SPIRVConversionTarget *targetPtr = target.get();

1629 target->addDynamicallyLegalDialectspirv::SPIRVDialect(

1630

1631

1632 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });

1634}

1635

1636SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)

1638

1639bool SPIRVConversionTarget::isLegalOp(Operation *op) {

1640

1641

1642

1643 if (auto minVersionIfx = dyn_castspirv::QueryMinVersionInterface(op)) {

1644 std::optionalspirv::Version minVersion = minVersionIfx.getMinVersion();

1645 if (minVersion && *minVersion > this->targetEnv.getVersion()) {

1646 LLVM_DEBUG(llvm::dbgs()

1647 << op->getName() << " illegal: requiring min version "

1648 << spirv::stringifyVersion(*minVersion) << "\n");

1649 return false;

1650 }

1651 }

1652 if (auto maxVersionIfx = dyn_castspirv::QueryMaxVersionInterface(op)) {

1653 std::optionalspirv::Version maxVersion = maxVersionIfx.getMaxVersion();

1654 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {

1655 LLVM_DEBUG(llvm::dbgs()

1656 << op->getName() << " illegal: requiring max version "

1657 << spirv::stringifyVersion(*maxVersion) << "\n");

1658 return false;

1659 }

1660 }

1661

1662

1663

1664

1665 if (auto extensions = dyn_castspirv::QueryExtensionInterface(op))

1666 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,

1667 extensions.getExtensions())))

1668 return false;

1669

1670

1671

1672

1673 if (auto capabilities = dyn_castspirv::QueryCapabilityInterface(op))

1674 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,

1675 capabilities.getCapabilities())))

1676 return false;

1677

1678 SmallVector<Type, 4> valueTypes;

1681

1682

1683 if (llvm::any_of(valueTypes,

1684 [](Type t) { return !isaspirv::SPIRVType(t); }))

1685 return false;

1686

1687

1688

1689 if (auto globalVar = dyn_castspirv::GlobalVariableOp(op))

1690 valueTypes.push_back(globalVar.getType());

1691

1692

1693

1694 SmallVector<ArrayRefspirv::Extension, 4> typeExtensions;

1695 SmallVector<ArrayRefspirv::Capability, 8> typeCapabilities;

1696 for (Type valueType : valueTypes) {

1697 typeExtensions.clear();

1698 castspirv::SPIRVType(valueType).getExtensions(typeExtensions);

1699 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,

1700 typeExtensions)))

1701 return false;

1702

1703 typeCapabilities.clear();

1704 castspirv::SPIRVType(valueType).getCapabilities(typeCapabilities);

1705 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,

1706 typeCapabilities)))

1707 return false;

1708 }

1709

1710 return true;

1711}

1712

1713

1714

1715

1716

1719 patterns.add(typeConverter, patterns.getContext());

1720}

1721

1725

b

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...

*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`

static llvm::ManagedStatic< PassManagerOptions > options

#define BIT_WIDTH_CASE(BIT_WIDTH)

static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)

Return the target shape for unrolling for the given op.

Block represents an ordered list of Operations.

iterator_range< op_iterator< OpT > > getOps()

Return an iterator range over the operations within this block that are of 'OpT'.

iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)

Add one argument to the argument list for each type specified in the list.

OpListType & getOperations()

void eraseArguments(unsigned start, unsigned num)

Erases 'num' arguments from the index 'start'.

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getI32IntegerAttr(int32_t value)

TypedAttr getZeroAttr(Type type)

MLIRContext * getContext() const

This class allows control over how the GreedyPatternRewriteDriver works.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)

Create a builder and set the insertion point to before the first operation in the block but still ins...

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation is the basic unit of execution within MLIR.

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

void setOperand(unsigned idx, Value value)

operand_type_iterator operand_type_end()

Location getLoc()

The source location the operation was defined or derived from.

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

result_type_iterator result_type_end()

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

result_type_iterator result_type_begin()

OperationName getName()

The name of an operation is the key identifier for it.

result_type_range getResultTypes()

MLIRContext * getContext()

Return the context this operation is associated with.

unsigned getNumResults()

Return the number of results held by this operation.

operand_type_iterator operand_type_begin()

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)

Move the blocks that belong to "region" before the given position in another region "parent".

virtual void replaceAllUsesWith(Value from, Value to)

Find uses of from and replace them with to.

static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)

Creates a SPIR-V conversion target for the given target environment.

Definition SPIRVConversion.cpp:1624

Type conversion from builtin types to SPIR-V types for shader interface.

Type getIndexType() const

Gets the SPIR-V correspondence for the standard index type.

Definition SPIRVConversion.cpp:1607

SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})

Definition SPIRVConversion.cpp:1546

bool allows(spirv::Capability capability) const

Checks if the SPIR-V capability inquired is supported.

Definition SPIRVConversion.cpp:1615

A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...

static StringRef getSymbolAttrName()

Return the name of the attribute used for symbol names.

static Operation * getNearestSymbolTable(Operation *from)

Returns the nearest symbol table from a given operation from.

Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...

Type getElementType() const

Returns the element type of this tensor type.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

MLIRContext * getContext() const

Return the MLIRContext in which this type was uniqued.

bool isSignedInteger() const

Return true if this is a signed integer type (with the specified width).

bool isInteger() const

Return true if this is an integer type (with the specified width).

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

static ArrayType get(Type elementType, unsigned elementCount)

static bool isValid(VectorType)

Returns true if the given vector type is valid for the SPIR-V dialect.

static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)

static PointerType get(Type pointeeType, StorageClass storageClass)

static RuntimeArrayType get(Type elementType)

SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector

The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...

SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector

The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...

static SampledImageType get(Type imageType)

static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})

Construct a literal StructType with at least one member.

An attribute that specifies the target version, allowed extensions and capabilities,...

A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...

Version getVersion() const

bool allows(Capability) const

Returns true if the given capability is allowed.

TargetEnvAttr getAttr() const

MLIRContext * getContext() const

Returns the MLIRContext.

bool hasElementwiseMappableTraits(Operation *op)

Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...

Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")

Returns the value for the given builtin variable.

Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)

Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...

Definition SPIRVConversion.cpp:1417

Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)

Definition SPIRVConversion.cpp:1382

Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)

Gets the value at the given offset of the push constant storage with a total of elementCount integerT...

Definition SPIRVConversion.cpp:1300

std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)

Definition SPIRVConversion.cpp:1462

Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)

Generates IR to perform index linearization with the given indices and their corresponding strides,...

Definition SPIRVConversion.cpp:1326

LogicalResult unrollVectorsInFuncBodies(Operation *op)

Definition SPIRVConversion.cpp:1491

Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)

Definition SPIRVConversion.cpp:1351

SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)

Definition SPIRVConversion.cpp:1444

int getComputeVectorSize(int64_t size)

Definition SPIRVConversion.cpp:1435

LogicalResult unrollVectorsInSignatures(Operation *op)

Definition SPIRVConversion.cpp:1478

void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Populate the pattern set with the following patterns:

void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)

Populate the pattern set with the following patterns:

Include the generated interface declarations.

void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)

Definition SPIRVConversion.cpp:1722

void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)

Definition SPIRVConversion.cpp:1726

@ Packed

Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.

LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)

Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)

Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...

Definition SPIRVConversion.cpp:1717

const FrozenRewritePatternSet & patterns

llvm::TypeSwitch< T, ResultT > TypeSwitch

@ ExistingOps

Only pre-existing ops are processed.

std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)

Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

Options that control the vector unrolling.

UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)