MLIR: lib/Dialect/GPU/IR/GPUDialect.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

14

34 #include "llvm/ADT/STLExtras.h"

35 #include "llvm/ADT/TypeSwitch.h"

36 #include "llvm/Support/CommandLine.h"

37 #include "llvm/Support/ErrorHandling.h"

38 #include "llvm/Support/FormatVariadic.h"

39 #include "llvm/Support/InterleavedRange.h"

40 #include "llvm/Support/StringSaver.h"

41 #include

42 #include

43

44 using namespace mlir;

46

47 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"

48

49

50

51

52

53 int64_t GPUBlockMappingAttr::getMappingId() const {

54 return static_cast<int64_t>(getBlock());

55 }

56

57 bool GPUBlockMappingAttr::isLinearMapping() const {

58 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);

59 }

60

61 int64_t GPUBlockMappingAttr::getRelativeIndex() const {

62 return isLinearMapping()

63 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)

64 : getMappingId();

65 }

66

67 int64_t GPUWarpgroupMappingAttr::getMappingId() const {

68 return static_cast<int64_t>(getWarpgroup());

69 }

70

71 bool GPUWarpgroupMappingAttr::isLinearMapping() const {

72 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);

73 }

74

75 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {

76 return isLinearMapping()

77 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)

78 : getMappingId();

79 }

80

81 int64_t GPUWarpMappingAttr::getMappingId() const {

82 return static_cast<int64_t>(getWarp());

83 }

84

85 bool GPUWarpMappingAttr::isLinearMapping() const {

86 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);

87 }

88

89 int64_t GPUWarpMappingAttr::getRelativeIndex() const {

90 return isLinearMapping()

91 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)

92 : getMappingId();

93 }

94

95 int64_t GPUThreadMappingAttr::getMappingId() const {

96 return static_cast<int64_t>(getThread());

97 }

98

99 bool GPUThreadMappingAttr::isLinearMapping() const {

100 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);

101 }

102

103 int64_t GPUThreadMappingAttr::getRelativeIndex() const {

104 return isLinearMapping()

105 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)

106 : getMappingId();

107 }

108

109 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {

110 return static_cast<int64_t>(getAddressSpace());

111 }

112

113 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {

114 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");

115 }

116

117 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {

118 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");

119 }

120

121

122

123

124

126 StringRef operand) {

128 }

129

133 StringRef operand) {

135 elementType, operand);

136 }

137

139

141 return getImpl()->getShape();

142 }

143

145

147

149 return elementType.isF16() || elementType.isF32() ||

152 }

153

154 LogicalResult

157 StringRef operand) {

158 if (operand != "AOp" && operand != "BOp" && operand != "COp")

159 return emitError() << "operand expected to be one of AOp, BOp or COp";

160

161 if (shape.size() != 2)

162 return emitError() << "MMAMatrixType must have exactly two dimensions";

163

166 << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";

167

168 return success();

169 }

170

171

172

173

174

175 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {

176 if (!memorySpace)

177 return false;

178 if (auto gpuAttr = llvm::dyn_castgpu::AddressSpaceAttr(memorySpace))

179 return gpuAttr.getValue() == getWorkgroupAddressSpace();

180 return false;

181 }

182

183 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {

184 Attribute memorySpace = type.getMemorySpace();

185 return isWorkgroupMemoryAddressSpace(memorySpace);

186 }

187

188 bool GPUDialect::isKernel(Operation *op) {

189 UnitAttr isKernelAttr = op->getAttrOfType(getKernelFuncAttrName());

190 return static_cast<bool>(isKernelAttr);

191 }

192

193 namespace {

194

195

198

199

201 return true;

202 }

203 };

204 }

205

206 void GPUDialect::initialize() {

207 addTypes();

208 addTypes();

209 addTypes();

210 addTypes();

211 addTypes();

212 addOperations<

213 #define GET_OP_LIST

214 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"

215 >();

216 addAttributes<

217 #define GET_ATTRDEF_LIST

218 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"

219 >();

220 addInterfaces();

221 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,

222 TerminatorOp>();

223 declarePromisedInterfaces<

224 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,

225 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,

226 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();

227 }

228

230 switch (kind) {

232 return "sparse.dntensor_handle";

234 return "sparse.spmat_handle";

236 return "sparse.spgemmop_handle";

237 }

238 llvm_unreachable("unknown sparse handle kind");

239 return "";

240 }

241

243

244 StringRef keyword;

246 return Type();

248

249

250 if (keyword == "async.token")

252

253 if (keyword == "mma_matrix") {

254 SMLoc beginLoc = parser.getNameLoc();

255

256

258 return nullptr;

259

260

262 Type elementType;

265 return nullptr;

266

267

269 return nullptr;

270

271

272 std::string operand;

274 return nullptr;

275

276

278 return nullptr;

279

282 shape, elementType, operand);

283 }

284

291

293 return Type();

294 }

295

296

300 .Case([&](Type) {

302 })

303 .Case(

305 .Case([&](Type) {

307 })

309 os << "mma_matrix<";

310 auto shape = fragTy.getShape();

311 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)

312 os << *dim << 'x';

313 os << shape.back() << 'x' << fragTy.getElementType();

314 os << ", \"" << fragTy.getOperand() << "\"" << '>';

315 })

316 .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });

317 }

318

321 auto array = dyn_cast(attr.getValue());

322 if (!array)

324 " must be a dense i32 array");

325 if (array.size() != 3)

327 " must contain exactly 3 elements");

328 return success();

329 }

330

331 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,

333 if (attr.getName() == getKnownBlockSizeAttrHelper().getName())

335 if (attr.getName() == getKnownGridSizeAttrHelper().getName())

337 if (!llvm::isa(attr.getValue()) ||

338 attr.getName() != getContainerModuleAttrName())

339 return success();

340

341 auto module = dyn_cast(op);

342 if (!module)

343 return op->emitError("expected '")

344 << getContainerModuleAttrName() << "' attribute to be attached to '"

345 << ModuleOp::getOperationName() << '\'';

346

347 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {

348

349

350 if (!launchOp->getParentOp() ||

351 launchOp->getParentOp()->getParentOp() != module)

352 return success();

353

354

355

356 if (!launchOp->getAttrOfType(

357 LaunchFuncOp::getKernelAttrName(launchOp->getName())))

358 return success();

359

360

361 StringAttr kernelContainerName = launchOp.getKernelModuleName();

362 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);

363 if (!kernelContainer)

365 << "kernel container '" << kernelContainerName.getValue()

366 << "' is undefined";

367

368

369 if (isa(kernelContainer))

370 return success();

371

372 auto kernelModule = dyn_cast(kernelContainer);

373 if (!kernelModule)

374 return launchOp.emitOpError()

375 << "kernel module '" << kernelContainerName.getValue()

376 << "' is undefined";

377

378

379 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());

380 if (!kernelFunc)

381 return launchOp.emitOpError("kernel function '")

382 << launchOp.getKernel() << "' is undefined";

383 auto kernelConvertedFunction = dyn_cast(kernelFunc);

384 if (!kernelConvertedFunction) {

386 << "referenced kernel '" << launchOp.getKernel()

387 << "' is not a function";

388 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";

390 }

391

393 GPUDialect::getKernelFuncAttrName()))

394 return launchOp.emitOpError("kernel function is missing the '")

395 << GPUDialect::getKernelFuncAttrName() << "' attribute";

396

397

398

399

400 auto kernelGPUFunction = dyn_castgpu::GPUFuncOp(kernelFunc);

401 if (!kernelGPUFunction)

402 return success();

403

404 unsigned actualNumArguments = launchOp.getNumKernelOperands();

405 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();

406 if (expectedNumArguments != actualNumArguments)

407 return launchOp.emitOpError("got ")

408 << actualNumArguments << " kernel operands but expected "

409 << expectedNumArguments;

410

411 auto functionType = kernelGPUFunction.getFunctionType();

412 for (unsigned i = 0; i < expectedNumArguments; ++i) {

413 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {

414 return launchOp.emitOpError("type of function argument ")

415 << i << " does not match";

416 }

417 }

418

419 return success();

420 });

421

422 return walkResult.wasInterrupted() ? failure() : success();

423 }

424

425

426

427

428

435 return parser.emitError(loc, "needs to be named when marked 'async'");

437 }

440 }

441

442

443

444

446 Type asyncTokenType,

448 if (asyncTokenType)

449 printer << "async";

450 if (asyncDependencies.empty())

451 return;

452 if (asyncTokenType)

453 printer << ' ';

454 printer << llvm::interleaved_array(asyncDependencies);

455 }

456

457

458

459

460

461

462

463

464

465 static ParseResult

468

470 return success();

471

473 true);

474 }

475

476

479 if (values.empty())

480 return;

481

483 return llvm::formatv("{} : {}", v, v.getType());

484 };

485 p << ' ' << keyword << '('

486 << llvm::interleaved(llvm::map_range(values, printBlockArg)) << ')';

487 }

488

489

492 gpu::AddressSpace memorySpace) {

493 for (Value v : attributions) {

494 auto type = llvm::dyn_cast(v.getType());

495 if (!type)

496 return op->emitOpError() << "expected memref type in attribution";

497

498

499

500 auto addressSpace =

501 llvm::dyn_cast_or_nullgpu::AddressSpaceAttr(type.getMemorySpace());

502 if (!addressSpace)

503 continue;

504 if (addressSpace.getValue() != memorySpace)

506 << "expected memory space " << stringifyAddressSpace(memorySpace)

507 << " in attribution";

508 }

509 return success();

510 }

511

512

513

514

515

517 Type resType) {

518 using Kind = gpu::AllReduceOperation;

519 if (llvm::is_contained(

520 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},

521 opName)) {

522 if (!isa(resType))

523 return failure();

524 }

525

526 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,

527 Kind::AND, Kind::OR, Kind::XOR},

528 opName)) {

529 if (!isa(resType))

530 return failure();

531 }

532

533 return success();

534 }

535

536 LogicalResult gpu::AllReduceOp::verifyRegions() {

537 if (getBody().empty() != getOp().has_value())

538 return emitError("expected either an op attribute or a non-empty body");

539 if (!getBody().empty()) {

540 if (getBody().getNumArguments() != 2)

541 return emitError("expected two region arguments");

542 for (auto argument : getBody().getArguments()) {

543 if (argument.getType() != getType())

544 return emitError("incorrect region argument type");

545 }

546 unsigned yieldCount = 0;

547 for (Block &block : getBody()) {

548 if (auto yield = dyn_castgpu::YieldOp(block.getTerminator())) {

549 if (yield.getNumOperands() != 1)

550 return emitError("expected one gpu.yield operand");

551 if (yield.getOperand(0).getType() != getType())

552 return emitError("incorrect gpu.yield type");

553 ++yieldCount;

554 }

555 }

556 if (yieldCount == 0)

557 return emitError("expected gpu.yield op in region");

558 } else {

559 gpu::AllReduceOperation opName = *getOp();

561 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)

562 << "` reduction operation is not compatible with type "

564 }

565 }

566

567 return success();

568 }

569

571 auto launchOp = dyn_castgpu::LaunchOp(op->getParentOp());

572 if (!launchOp)

573 return false;

574

575 Region &body = launchOp.getBody();

576 assert(!body.empty() && "Invalid region");

577

578

580 }

581

582 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor ) {

584 setUniform(true);

585 return getResult();

586 }

587

588 return nullptr;

589 }

590

591

593 AllReduceOperationAttr &attr) {

594 StringRef enumStr;

596 std::optional op =

597 gpu::symbolizeAllReduceOperation(enumStr);

598 if (!op)

601 }

602 return success();

603 }

604

606 AllReduceOperationAttr attr) {

607 if (attr)

608 attr.print(printer);

609 }

610

611

612

613

614

617 if (auto vecTy = dyn_cast(elemType)) {

618 if (vecTy.isScalable())

619 return emitOpError() << "is not compatible with scalable vector types";

620

621 elemType = vecTy.getElementType();

622 }

623

624 gpu::AllReduceOperation opName = getOp();

626 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)

627 << "` reduction operation is not compatible with type "

629 }

630

631 auto clusterSize = getClusterSize();

632 if (clusterSize) {

633 uint32_t size = *clusterSize;

634 if (!llvm::isPowerOf2_32(size)) {

635 return emitOpError() << "cluster size " << size

636 << " is not a power of two";

637 }

638 }

639

640 uint32_t stride = getClusterStride();

641 if (stride != 1 && !clusterSize) {

642 return emitOpError() << "cluster stride can only be specified if cluster "

643 "size is specified";

644 }

645 if (!llvm::isPowerOf2_32(stride)) {

646 return emitOpError() << "cluster stride " << stride

647 << " is not a power of two";

648 }

649

650 return success();

651 }

652

653 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {

654 if (getClusterSize() == 1)

655 return getValue();

656

658 setUniform(true);

659 return getResult();

660 }

661

662 return nullptr;

663 }

664

665

666

667

668

671 if (!op->template hasTraitOpTrait::AttrSizedOperandSegments())

672 return;

673 auto attrName =

675 auto sizeAttr = op->template getAttrOfType(attrName);

676

677

678 if (!sizeAttr)

679 return;

680

682 ++sizes.front();

684 }

685

686

687

688

689

692 Value getBlockSizeX, Value getBlockSizeY,

693 Value getBlockSizeZ, Value dynamicSharedMemorySize,

697 Value clusterSizeY, Value clusterSizeZ) {

699

700

701

702 result.addAttribute(getNumWorkgroupAttributionsAttrName(),

704

705

707 if (asyncTokenType)

709

710

711 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,

712 getBlockSizeY, getBlockSizeZ});

713 if (clusterSizeX)

715 if (clusterSizeY)

717 if (clusterSizeZ)

719 if (dynamicSharedMemorySize)

720 result.addOperands(dynamicSharedMemorySize);

721

722

723

724

727

728 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)

730

731 for (Type argTy : workgroupAttributions)

733 for (Type argTy : privateAttributions)

735

737 segmentSizes.front() = asyncDependencies.size();

738 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;

739 segmentSizes[7] = clusterSizeX ? 1 : 0;

740 segmentSizes[8] = clusterSizeY ? 1 : 0;

741 segmentSizes[9] = clusterSizeZ ? 1 : 0;

742 result.addAttribute(getOperandSegmentSizeAttr(),

744 }

745

747 assert(!getBody().empty() && "LaunchOp body must not be empty.");

748 auto args = getBody().getArguments();

749 return KernelDim3{args[0], args[1], args[2]};

750 }

751

752 KernelDim3 LaunchOp::getThreadIds() {

753 assert(!getBody().empty() && "LaunchOp body must not be empty.");

754 auto args = getBody().getArguments();

755 return KernelDim3{args[3], args[4], args[5]};

756 }

757

759 assert(!getBody().empty() && "LaunchOp body must not be empty.");

760 auto args = getBody().getArguments();

761 return KernelDim3{args[6], args[7], args[8]};

762 }

763

765 assert(!getBody().empty() && "LaunchOp body must not be empty.");

766 auto args = getBody().getArguments();

767 return KernelDim3{args[9], args[10], args[11]};

768 }

769

770 std::optional LaunchOp::getClusterIds() {

771 assert(!getBody().empty() && "LaunchOp body must not be empty.");

772 if (!hasClusterSize())

773 return std::nullopt;

774 auto args = getBody().getArguments();

775 return KernelDim3{args[12], args[13], args[14]};

776 }

777

778 std::optional LaunchOp::getClusterSize() {

779 assert(!getBody().empty() && "LaunchOp body must not be empty.");

780 if (!hasClusterSize())

781 return std::nullopt;

782 auto args = getBody().getArguments();

783 return KernelDim3{args[15], args[16], args[17]};

784 }

785

786 KernelDim3 LaunchOp::getGridSizeOperandValues() {

787 auto operands = getOperands().drop_front(getAsyncDependencies().size());

788 return KernelDim3{operands[0], operands[1], operands[2]};

789 }

790

791 KernelDim3 LaunchOp::getBlockSizeOperandValues() {

792 auto operands = getOperands().drop_front(getAsyncDependencies().size());

793 return KernelDim3{operands[3], operands[4], operands[5]};

794 }

795

796 std::optional LaunchOp::getClusterSizeOperandValues() {

797 auto operands = getOperands().drop_front(getAsyncDependencies().size());

798 if (!hasClusterSize())

799 return std::nullopt;

800 return KernelDim3{operands[6], operands[7], operands[8]};

801 }

802

804 if (!(hasClusterSize()) &&

805 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))

806 return emitOpError() << "cluster size must be all present";

807 return success();

808 }

809

810 LogicalResult LaunchOp::verifyRegions() {

811

812

813

814 if (!getBody().empty()) {

815 if (getBody().getNumArguments() <

816 kNumConfigRegionAttributes + getNumWorkgroupAttributions())

817 return emitOpError("unexpected number of region arguments");

818 }

819

820

821 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),

822 GPUDialect::getWorkgroupAddressSpace())) ||

824 GPUDialect::getPrivateAddressSpace())))

825 return failure();

826

827

828

829 for (Block &block : getBody()) {

830 if (block.empty())

831 continue;

832 if (block.back().getNumSuccessors() != 0)

833 continue;

834 if (!isagpu::TerminatorOp(&block.back())) {

835 return block.back()

836 .emitError()

837 .append("expected '", gpu::TerminatorOp::getOperationName(),

838 "' or a terminator with successors")

839 .attachNote(getLoc())

840 .append("in '", LaunchOp::getOperationName(), "' body region");

841 }

842 }

843

844 if (getNumResults() == 0 && getAsyncToken())

845 return emitOpError("needs to be named when async keyword is specified");

846

847 return success();

848 }

849

850

851

852

853

856 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";

857 p << size.x << " = " << operands.x << ", ";

858 p << size.y << " = " << operands.y << ", ";

859 p << size.z << " = " << operands.z << ')';

860 }

861

863 if (getAsyncToken()) {

864 p << " async";

865 if (!getAsyncDependencies().empty())

866 p << " [" << getAsyncDependencies() << ']';

867 }

868

869 if (hasClusterSize()) {

870 p << ' ' << getClustersKeyword();

872 getClusterSizeOperandValues().value(),

873 getClusterIds().value());

874 }

875 p << ' ' << getBlocksKeyword();

877 getBlockIds());

878 p << ' ' << getThreadsKeyword();

880 getThreadIds());

881 if (getDynamicSharedMemorySize())

882 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '

883 << getDynamicSharedMemorySize();

884

885 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());

886 printAttributions(p, getPrivateKeyword(), getPrivateAttributions());

887

888 p << ' ';

889

890 p.printRegion(getBody(), false);

892 LaunchOp::getOperandSegmentSizeAttr(),

893 getNumWorkgroupAttributionsAttrName()});

894 }

895

896

897

898

899

900

901

902 static ParseResult

907 assert(indices.size() == 3 && "space for three indices expected");

910 false) ||

912 return failure();

913 std::move(args.begin(), args.end(), indices.begin());

914

915 for (int i = 0; i < 3; ++i) {

917 return failure();

918 if (parser.parseOperand(regionSizes[i], false) ||

920 return failure();

921 }

922

924 }

925

926

927

928

929

930

931

932

933

935

937 sizes(LaunchOp::kNumConfigOperands);

938

939

941 LaunchOp::kNumConfigRegionAttributes);

942

943

945 Type asyncTokenType;

946 if (failed(

948 parser.resolveOperands(asyncDependencies, asyncTokenType,

950 return failure();

952 result.types.push_back(asyncTokenType);

953

954 bool hasCluster = false;

955 if (succeeded(

957 hasCluster = true;

958 sizes.resize(9);

959 regionArgs.resize(18);

960 }

963

964

965

966 if (hasCluster) {

968 regionArgsRef.slice(15, 3),

969 regionArgsRef.slice(12, 3)))

970 return failure();

971 }

972

973

974

975

976

977 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||

979 regionArgsRef.slice(6, 3),

980 regionArgsRef.slice(0, 3)) ||

981 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||

983 regionArgsRef.slice(9, 3),

984 regionArgsRef.slice(3, 3)) ||

987 return failure();

988

990 bool hasDynamicSharedMemorySize = false;

992 LaunchOp::getDynamicSharedMemorySizeKeyword())) {

993 hasDynamicSharedMemorySize = true;

994 if (parser.parseOperand(dynamicSharedMemorySize) ||

998 return failure();

999 }

1000

1001

1002

1003

1004

1005

1006

1009 LaunchOp::kNumConfigRegionAttributes + 6, index);

1010

1012 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {

1014 arg.ssaName = std::get<0>(ssaValueAndType);

1015 arg.type = std::get<1>(ssaValueAndType);

1016 regionArguments.push_back(arg);

1017 }

1018

1020

1021 if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),

1022 regionArguments)))

1023 return failure();

1024

1025

1026

1027 unsigned numWorkgroupAttrs = regionArguments.size() -

1028 LaunchOp::kNumConfigRegionAttributes -

1029 (hasCluster ? 6 : 0);

1030 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),

1032

1033

1034 if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),

1035 regionArguments)))

1036 return failure();

1037

1038

1039

1040

1042 if (parser.parseRegion(*body, regionArguments) ||

1044 return failure();

1045

1047 segmentSizes.front() = asyncDependencies.size();

1048

1049 if (!hasCluster) {

1050 segmentSizes[7] = 0;

1051 segmentSizes[8] = 0;

1052 segmentSizes[9] = 0;

1053 }

1054 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;

1055 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),

1057 return success();

1058 }

1059

1060

1061

1066

1067

1069 bool simplified = false;

1070 auto constPropIdUses = [&](Value id, Value size) {

1071

1073 return;

1074 if (id.getUses().empty())

1075 return;

1076 if (!simplified) {

1077

1080 zero =

1081 rewriter.createarith::ConstantIndexOp(op.getLoc(), 0);

1082 }

1084 simplified = true;

1085 };

1086 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());

1087 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());

1088 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());

1089 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());

1090 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());

1091 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());

1092

1093 return success(simplified);

1094 }

1095 };

1096

1097 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,

1100 }

1101

1102

1103

1105 auto attrName = getNumWorkgroupAttributionsAttrName();

1106 auto attr = (*this)->getAttrOfType(attrName);

1107 (*this)->setAttr(attrName,

1109 return getBody().insertArgument(

1110 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);

1111 }

1112

1113

1114

1116

1117

1118 return getBody().addArgument(type, loc);

1119 }

1120

1121

1122

1123

1124

1126 SymbolRefAttr kernelSymbol, KernelDim3 gridSize,

1130 std::optional clusterSize) {

1131 assert(kernelSymbol.getNestedReferences().size() == 1 &&

1132 "expected a symbol reference with a single nested reference");

1134 if (asyncTokenType)

1136

1137

1140 if (clusterSize.has_value())

1141 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});

1142 if (dynamicSharedMemorySize)

1143 result.addOperands(dynamicSharedMemorySize);

1145

1147 prop.kernel = kernelSymbol;

1148 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);

1149

1150 for (auto &sz : prop.operandSegmentSizes)

1151 sz = 1;

1152 prop.operandSegmentSizes[0] = asyncDependencies.size();

1153 if (!clusterSize.has_value()) {

1154 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;

1155 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;

1156 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;

1157 }

1158 prop.operandSegmentSizes[segmentSizesLen - 3] =

1159 dynamicSharedMemorySize ? 1 : 0;

1160 prop.operandSegmentSizes[segmentSizesLen - 2] =

1161 static_cast<int32_t>(kernelOperands.size());

1162 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;

1163 }

1164

1166 GPUFuncOp kernelFunc, KernelDim3 gridSize,

1170 std::optional clusterSize) {

1171 auto kernelModule = kernelFunc->getParentOfType();

1172 auto kernelSymbol =

1174 {SymbolRefAttr::get(kernelFunc.getNameAttr())});

1175 build(builder, result, kernelSymbol, gridSize, getBlockSize,

1176 dynamicSharedMemorySize, kernelOperands, asyncTokenType,

1177 asyncDependencies, clusterSize);

1178 }

1179

1181 SymbolRefAttr kernel, KernelDim3 gridSize,

1184 std::optional clusterSize) {

1185

1188 if (clusterSize.has_value())

1189 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});

1190 if (dynamicSharedMemorySize)

1191 result.addOperands(dynamicSharedMemorySize);

1193 if (asyncObject)

1196 prop.kernel = kernel;

1197 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);

1198

1199 for (auto &sz : prop.operandSegmentSizes)

1200 sz = 1;

1201 prop.operandSegmentSizes[0] = 0;

1202 if (!clusterSize.has_value()) {

1203 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;

1204 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;

1205 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;

1206 }

1207 prop.operandSegmentSizes[segmentSizesLen - 3] =

1208 dynamicSharedMemorySize ? 1 : 0;

1209 prop.operandSegmentSizes[segmentSizesLen - 2] =

1210 static_cast<int32_t>(kernelOperands.size());

1211 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;

1212 }

1213

1214 StringAttr LaunchFuncOp::getKernelModuleName() {

1215 return getKernel().getRootReference();

1216 }

1217

1218 StringAttr LaunchFuncOp::getKernelName() {

1219 return getKernel().getLeafReference();

1220 }

1221

1222 unsigned LaunchFuncOp::getNumKernelOperands() {

1223 return getKernelOperands().size();

1224 }

1225

1226 Value LaunchFuncOp::getKernelOperand(unsigned i) {

1227 return getKernelOperands()[i];

1228 }

1229

1230 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {

1231 auto operands = getOperands().drop_front(getAsyncDependencies().size());

1232 return KernelDim3{operands[0], operands[1], operands[2]};

1233 }

1234

1235 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {

1236 auto operands = getOperands().drop_front(getAsyncDependencies().size());

1237 return KernelDim3{operands[3], operands[4], operands[5]};

1238 }

1239

1240 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {

1241 assert(hasClusterSize() &&

1242 "cluster size is not set, check hasClusterSize() first");

1243 auto operands = getOperands().drop_front(getAsyncDependencies().size());

1244 return KernelDim3{operands[6], operands[7], operands[8]};

1245 }

1246

1248 auto module = (*this)->getParentOfType();

1249 if (!module)

1250 return emitOpError("expected to belong to a module");

1251

1252 if (!module->getAttrOfType(

1253 GPUDialect::getContainerModuleAttrName()))

1254 return emitOpError("expected the closest surrounding module to have the '" +

1255 GPUDialect::getContainerModuleAttrName() +

1256 "' attribute");

1257

1258 if (hasClusterSize()) {

1259 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||

1260 getClusterSizeZ().getType() != getClusterSizeX().getType())

1261 return emitOpError()

1262 << "expects types of the cluster dimensions must be the same";

1263 }

1264

1265 return success();

1266 }

1267

1268 static ParseResult

1270 std::optionalOpAsmParser::UnresolvedOperand clusterValue,

1271 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {

1274 return failure();

1275 } else {

1277 }

1278 if (clusterValue.has_value()) {

1279 clusterXTy = clusterYTy = clusterZTy = dimTy;

1280 }

1281 return success();

1282 }

1283

1285 Value clusterValue, Type clusterXTy,

1286 Type clusterYTy, Type clusterZTy) {

1288 printer << ": " << dimTy;

1289 }

1290

1296 return success();

1297

1298 auto parseElement = [&]() -> ParseResult {

1299 return failure(parser.parseOperand(argNames.emplace_back()) ||

1301 };

1302

1304 parseElement, " in argument list");

1305 }

1306

1309 if (operands.empty())

1310 return;

1311 printer << "args(";

1312 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,

1313 [&](const auto &pair) {

1314 auto [operand, type] = pair;

1315 printer << operand << " : " << type;

1316 });

1317 printer << ")";

1318 }

1319

1320

1321

1322

1323

1325 int32_t offset, int32_t width, ShuffleMode mode) {

1326 build(builder, result, value,

1331 mode);

1332 }

1333

1334

1335

1336

1337

1338 namespace {

1339

1340

1341 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,

1343 if (isa_and_nonnull(op->getNextNode())) {

1345 return success();

1346 }

1347 return failure();

1348 }

1349

1350 }

1351

1352 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,

1354 results.add(eraseRedundantGpuBarrierOps);

1355 }

1356

1357

1358

1359

1360

1361

1362

1364 auto attrName = getNumWorkgroupAttributionsAttrName();

1365 auto attr = (*this)->getAttrOfType(attrName);

1366 (*this)->setAttr(attrName,

1368 return getBody().insertArgument(

1369 getFunctionType().getNumInputs() + attr.getInt(), type, loc);

1370 }

1371

1372

1373

1375

1376

1377 return getBody().addArgument(type, loc);

1378 }

1379

1381 StringRef name, FunctionType type,

1382 TypeRange workgroupAttributions,

1386

1391 result.addAttribute(getNumWorkgroupAttributionsAttrName(),

1396

1397

1398 for (Type argTy : type.getInputs())

1400 for (Type argTy : workgroupAttributions)

1402 for (Type argTy : privateAttributions)

1404 }

1405

1406

1407

1408

1409

1410

1411

1412

1413 static ParseResult

1417

1419 return success();

1420

1421 size_t existingArgs = args.size();

1422 ParseResult result =

1424 true, true);

1425 if (failed(result))

1426 return result;

1427

1428 bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),

1430 return arg.attrs && !arg.attrs.empty();

1431 });

1432 if (!hadAttrs) {

1433 attributionAttrs = nullptr;

1434 return result;

1435 }

1436

1439 for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {

1440 if (!argument.attrs)

1442 else

1443 attributionAttrsVec.push_back(argument.attrs);

1444 }

1445 attributionAttrs = builder.getArrayAttr(attributionAttrsVec);

1446 return result;

1447 }

1448

1449

1450

1451

1452

1453

1458 bool isVariadic;

1459

1460

1461 StringAttr nameAttr;

1464 return failure();

1465

1468 parser, false, entryArgs, isVariadic, resultTypes,

1469 resultAttrs)))

1470 return failure();

1471

1472 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())

1473 return parser.emitError(signatureLocation)

1474 << "gpu.func requires named arguments";

1475

1476

1477

1479

1481 for (auto &arg : entryArgs)

1482 argTypes.push_back(arg.type);

1483 auto type = builder.getFunctionType(argTypes, resultTypes);

1486

1488 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),

1489 getResAttrsAttrName(result.name));

1490

1491 Attribute workgroupAttributionAttrs;

1492

1493 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),

1494 entryArgs, workgroupAttributionAttrs)))

1495 return failure();

1496

1497

1498

1499 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();

1500 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),

1502 if (workgroupAttributionAttrs)

1503 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),

1504 workgroupAttributionAttrs);

1505

1506 Attribute privateAttributionAttrs;

1507

1508 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),

1509 entryArgs, privateAttributionAttrs)))

1510 return failure();

1511 if (privateAttributionAttrs)

1512 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),

1513 privateAttributionAttrs);

1514

1515

1517 result.addAttribute(GPUDialect::getKernelFuncAttrName(),

1519

1520

1522 return failure();

1523

1524

1525

1526 auto *body = result.addRegion();

1527 return parser.parseRegion(*body, entryArgs);

1528 }

1529

1532 ArrayAttr attributes) {

1533 if (values.empty())

1534 return;

1535

1536 p << ' ' << keyword << '(';

1537 llvm::interleaveComma(

1538 llvm::enumerate(values), p, [&p, attributes](auto pair) {

1540 p << v << " : " << v.getType();

1541

1542 size_t attributionIndex = pair.index();

1543 DictionaryAttr attrs;

1544 if (attributes && attributionIndex < attributes.size())

1545 attrs = llvm::cast(attributes[attributionIndex]);

1546 if (attrs)

1548 });

1549 p << ')';

1550 }

1551

1553 p << ' ';

1555

1556 FunctionType type = getFunctionType();

1558 false,

1559 type.getResults());

1560

1561 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),

1562 getWorkgroupAttribAttrs().value_or(nullptr));

1563 printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),

1564 getPrivateAttribAttrs().value_or(nullptr));

1565 if (isKernel())

1566 p << ' ' << getKernelKeyword();

1567

1569 p, *this,

1570 {getNumWorkgroupAttributionsAttrName(),

1571 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),

1572 getArgAttrsAttrName(), getResAttrsAttrName(),

1573 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});

1574 p << ' ';

1575 p.printRegion(getBody(), false);

1576 }

1577

1579 StringAttr attrName) {

1580 auto allAttrs = llvm::dyn_cast_or_null(op->getAttr(attrName));

1581 if (!allAttrs || index >= allAttrs.size())

1582 return DictionaryAttr();

1583 return llvm::cast(allAttrs[index]);

1584 }

1585

1586 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {

1587 return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());

1588 }

1589

1590 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {

1591 return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());

1592 }

1593

1595 DictionaryAttr value, StringAttr attrName) {

1597 auto allAttrs = llvm::dyn_cast_or_null(op->getAttr(attrName));

1599 if (allAttrs)

1600 elements.append(allAttrs.begin(), allAttrs.end());

1601 while (elements.size() <= index)

1603 if (!value)

1605 else

1606 elements[index] = value;

1608 op->setAttr(attrName, newValue);

1609 }

1610

1611 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,

1612 DictionaryAttr value) {

1613 setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());

1614 }

1615

1616 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,

1617 DictionaryAttr value) {

1618 setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());

1619 }

1620

1622 StringAttr name, StringAttr attrsName) {

1624 if (!dict)

1626 return dict.get(name);

1627 }

1628

1629 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,

1630 StringAttr name) {

1631 assert(index < getNumWorkgroupAttributions() &&

1632 "index must map to a workgroup attribution");

1634 getWorkgroupAttribAttrsAttrName());

1635 }

1636

1637 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,

1638 StringAttr name) {

1639 assert(index < getNumPrivateAttributions() &&

1640 "index must map to a private attribution");

1642 getPrivateAttribAttrsAttrName());

1643 }

1644

1646 Attribute value, StringAttr attrsName) {

1650 if (oldDict)

1651 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());

1652

1653 bool found = false;

1654 bool mustSort = true;

1655 for (unsigned i = 0, e = elems.size(); i < e; ++i) {

1656 if (elems[i].getName() == name) {

1657 found = true;

1658 if (!value) {

1659 std::swap(elems[i], elems[elems.size() - 1]);

1660 elems.pop_back();

1661 } else {

1662 mustSort = false;

1663 elems[i] = NamedAttribute(elems[i].getName(), value);

1664 }

1665 break;

1666 }

1667 }

1668 if (!found) {

1669 if (!value)

1670 return;

1671 elems.emplace_back(name, value);

1672 }

1673 if (mustSort) {

1674 DictionaryAttr::sortInPlace(elems);

1675 }

1676 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);

1678 }

1679

1680 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,

1682 assert(index < getNumWorkgroupAttributions() &&

1683 "index must map to a workgroup attribution");

1685 getWorkgroupAttribAttrsAttrName());

1686 }

1687

1688 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,

1690 assert(index < getNumPrivateAttributions() &&

1691 "index must map to a private attribution");

1693 getPrivateAttribAttrsAttrName());

1694 }

1695

1696 LogicalResult GPUFuncOp::verifyType() {

1697 if (isKernel() && getFunctionType().getNumResults() != 0)

1698 return emitOpError() << "expected void return type for kernel function";

1699

1700 return success();

1701 }

1702

1703

1704 LogicalResult GPUFuncOp::verifyBody() {

1705 if (empty())

1706 return emitOpError() << "expected body with at least one block";

1707 unsigned numFuncArguments = getNumArguments();

1708 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();

1709 unsigned numBlockArguments = front().getNumArguments();

1710 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)

1711 return emitOpError() << "expected at least "

1712 << numFuncArguments + numWorkgroupAttributions

1713 << " arguments to body region";

1714

1715 ArrayRef funcArgTypes = getFunctionType().getInputs();

1716 for (unsigned i = 0; i < numFuncArguments; ++i) {

1717 Type blockArgType = front().getArgument(i).getType();

1718 if (funcArgTypes[i] != blockArgType)

1719 return emitOpError() << "expected body region argument #" << i

1720 << " to be of type " << funcArgTypes[i] << ", got "

1721 << blockArgType;

1722 }

1723

1724 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),

1725 GPUDialect::getWorkgroupAddressSpace())) ||

1727 GPUDialect::getPrivateAddressSpace())))

1728 return failure();

1729

1730 return success();

1731 }

1732

1733

1734

1735

1736

1738 GPUFuncOp function = (*this)->getParentOfType();

1739

1740 FunctionType funType = function.getFunctionType();

1741

1742 if (funType.getNumResults() != getOperands().size())

1743 return emitOpError()

1744 .append("expected ", funType.getNumResults(), " result operands")

1745 .attachNote(function.getLoc())

1746 .append("return type declared here");

1747

1749 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {

1750 auto [type, operand] = pair.value();

1751 if (type != operand.getType())

1752 return emitOpError() << "unexpected type `" << operand.getType()

1753 << "' for operand #" << pair.index();

1754 }

1755 return success();

1756 }

1757

1758

1759

1760

1761

1763 StringRef name, ArrayAttr targets,

1767 if (targets)

1768 props.targets = targets;

1770 props.offloadingHandler = offloadingHandler;

1771 }

1772

1776 build(builder, result, name,

1777 targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),

1778 offloadingHandler);

1779 }

1780

1781 bool GPUModuleOp::hasTarget(Attribute target) {

1782 if (ArrayAttr targets = getTargetsAttr())

1783 return llvm::count(targets.getValue(), target);

1784 return false;

1785 }

1786

1788 ArrayAttr &targetsAttr = getProperties().targets;

1791 }

1792

1794 auto targets = getOperation()->getAttrOfType("targets");

1795

1796 if (!targets)

1797 return success();

1798

1799 for (auto target : targets) {

1800 if (auto verifyTargetAttr =

1801 llvm::dyn_cast(target)) {

1802 if (verifyTargetAttr.verifyTarget(getOperation()).failed())

1803 return failure();

1804 }

1805 }

1806 return success();

1807 }

1808

1809

1810

1811

1813 Attribute offloadingHandler, ArrayAttr objects) {

1817 properties.objects = objects;

1818 if (offloadingHandler)

1819 properties.offloadingHandler = offloadingHandler;

1820 else

1821 properties.offloadingHandler = builder.getAttr(nullptr);

1822 }

1823

1826 build(builder, result, name, offloadingHandler,

1827 objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));

1828 }

1829

1834 return failure();

1836 return failure();

1837 }

1838 if (!offloadingHandler)

1839 offloadingHandler = parser.getBuilder().getAttr(nullptr);

1840 return success();

1841 }

1842

1846 printer << '<' << offloadingHandler << '>';

1847 }

1848

1849 //===----------------------------------------------------------------------===//

1850 // GPUMemcpyOp

1851 //===----------------------------------------------------------------------===//

1852

1853 LogicalResult MemcpyOp::verify() {

1854 auto srcType = getSrc().getType();

1855 auto dstType = getDst().getType();

1856

1857 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))

1858 return emitOpError("arguments have incompatible element type");

1859

1860 if (failed(verifyCompatibleShape(srcType, dstType)))

1861 return emitOpError("arguments have incompatible shape");

1862

1863 return success();

1864 }

1865

1866 namespace {

1867

1868

1869

1870 struct EraseTrivialCopyOp : public OpRewritePattern {

1871 using OpRewritePattern::OpRewritePattern;

1872

1873 LogicalResult matchAndRewrite(MemcpyOp op,

1874 PatternRewriter &rewriter) const override {

1875 Value dest = op.getDst();

1876 Operation *destDefOp = dest.getDefiningOp();

1877 // `dest` must be defined by an op having Allocate memory effect in order to

1878 // perform the folding.

1879 if (!destDefOp ||

1880 !hasSingleEffectMemoryEffects::Allocate(destDefOp, dest))

1881 return failure();

1882 // We can erase `op` iff `dest` has no other use apart from its

1883 // use by `op` and dealloc ops.

1884 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {

1885 return user != op &&

1886 !hasSingleEffectMemoryEffects::Free(user, dest);

1887 }))

1888 return failure();

1889 // We can perform the folding if and only if op has a single async

1890 // dependency and produces an async token as result, or if it does not have

1891 // any async dependency and does not produce any async token result.

1892 if (op.getAsyncDependencies().size() > 1 ||

1893 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||

1894 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))

1895 return failure();

1896 rewriter.replaceOp(op, op.getAsyncDependencies());

1897 return success();

1898 }

1899 };

1900

1901 } // end anonymous namespace

1902

1903 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,

1904 MLIRContext *context) {

1905 results.add(context);

1906 }

1907

1908 //===----------------------------------------------------------------------===//

1909 // GPU_SubgroupMmaLoadMatrixOp

1910 //===----------------------------------------------------------------------===//

1911

1912 LogicalResult SubgroupMmaLoadMatrixOp::verify() {

1913 auto srcType = getSrcMemref().getType();

1914 auto resType = getRes().getType();

1915 auto resMatrixType = llvm::castgpu::MMAMatrixType(resType);

1916 auto operand = resMatrixType.getOperand();

1917 auto srcMemrefType = llvm::cast(srcType);

1918

1919 if (!srcMemrefType.isLastDimUnitStride())

1920 return emitError(

1921 "expected source memref most minor dim must have unit stride");

1922

1923 if (operand != "AOp" && operand != "BOp" && operand != "COp")

1924 return emitError("only AOp, BOp and COp can be loaded");

1925

1926 return success();

1927 }

1928

1929 //===----------------------------------------------------------------------===//

1930 // GPU_SubgroupMmaStoreMatrixOp

1931 //===----------------------------------------------------------------------===//

1932

1933 LogicalResult SubgroupMmaStoreMatrixOp::verify() {

1934 auto srcType = getSrc().getType();

1935 auto dstType = getDstMemref().getType();

1936 auto srcMatrixType = llvm::castgpu::MMAMatrixType(srcType);

1937 auto dstMemrefType = llvm::cast(dstType);

1938

1939 if (!dstMemrefType.isLastDimUnitStride())

1940 return emitError(

1941 "expected destination memref most minor dim must have unit stride");

1942

1943 if (srcMatrixType.getOperand() != "COp")

1944 return emitError(

1945 "expected the operand matrix being stored to have 'COp' operand type");

1946

1947 return success();

1948 }

1949

1950 //===----------------------------------------------------------------------===//

1951 // GPU_SubgroupMmaComputeOp

1952 //===----------------------------------------------------------------------===//

1953

1954 LogicalResult SubgroupMmaComputeOp::verify() {

1955 enum OperandMap { A, B, C };

1956 SmallVector<MMAMatrixType, 3> opTypes;

1957 opTypes.push_back(llvm::cast(getOpA().getType()));

1958 opTypes.push_back(llvm::cast(getOpB().getType()));

1959 opTypes.push_back(llvm::cast(getOpC().getType()));

1960

1961 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||

1962 opTypes[C].getOperand() != "COp")

1963 return emitError("operands must be in the order AOp, BOp, COp");

1964

1965 ArrayRef<int64_t> aShape, bShape, cShape;

1966 aShape = opTypes[A].getShape();

1967 bShape = opTypes[B].getShape();

1968 cShape = opTypes[C].getShape();

1969

1970 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||

1971 bShape[1] != cShape[1])

1972 return emitError("operand shapes do not satisfy matmul constraints");

1973

1974 return success();

1975 }

1976

1977 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,

1978 SmallVectorImpl<::mlir::OpFoldResult> &results) {

1979 return memref::foldMemRefCast(*this);

1980 }

1981

1982 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,

1983 SmallVectorImpl<::mlir::OpFoldResult> &results) {

1984 return memref::foldMemRefCast(*this);

1985 }

1986

1987 //===----------------------------------------------------------------------===//

1988 // GPU_WaitOp

1989 //===----------------------------------------------------------------------===//

1990

1991 namespace {

1992

1993

1994

1995

1996 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern {

1997 public:

1998 using OpRewritePattern::OpRewritePattern;

1999

2000 LogicalResult matchAndRewrite(WaitOp op,

2001 PatternRewriter &rewriter) const final {

2002 auto predicate = [](Value value) {

2003 auto waitOp = value.getDefiningOp();

2004 return waitOp && waitOp->getNumOperands() == 0;

2005 };

2006 if (llvm::none_of(op.getAsyncDependencies(), predicate))

2007 return failure();

2008 SmallVector validOperands;

2009 for (Value operand : op->getOperands()) {

2010 if (predicate(operand))

2011 continue;

2012 validOperands.push_back(operand);

2013 }

2014 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });

2015 return success();

2016 }

2017 };

2018

2019

2020

2021

2022

2023

2024

2025

2026 struct SimplifyGpuWaitOp : public OpRewritePattern {

2027 public:

2028 using OpRewritePattern::OpRewritePattern;

2029

2030 LogicalResult matchAndRewrite(WaitOp op,

2031 PatternRewriter &rewriter) const final {

2032 // Erase gpu.wait ops that neither have any async dependencies nor return

2033 // any async token.

2034 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {

2035 rewriter.eraseOp(op);

2036 return success();

2037 }

2038 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.

2039 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&

2040 op.getAsyncToken()) {

2041 rewriter.replaceOp(op, op.getAsyncDependencies());

2042 return success();

2043 }

2044 // Erase %t = gpu.wait async ... ops, where %t has no uses.

2045 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {

2046 rewriter.eraseOp(op);

2047 return success();

2048 }

2049 return failure();

2050 }

2051 };

2052

2053 } // end anonymous namespace

2054

2055 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,

2056 MLIRContext *context) {

2057 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);

2058 }

2059

2060 //===----------------------------------------------------------------------===//

2061 // GPU_AllocOp

2062 //===----------------------------------------------------------------------===//

2063

2064 LogicalResult AllocOp::verify() {

2065 auto memRefType = llvm::cast(getMemref().getType());

2066

2067 if (getDynamicSizes().size() != memRefType.getNumDynamicDims())

2068 return emitOpError("dimension operand count does not equal memref "

2069 "dynamic dimension count");

2070

2071 unsigned numSymbols = 0;

2072 if (!memRefType.getLayout().isIdentity())

2073 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();

2074 if (getSymbolOperands().size() != numSymbols) {

2075 return emitOpError(

2076 "symbol operand count does not equal memref symbol count");

2077 }

2078

2079 return success();

2080 }

2081

2082 namespace {

2083

2084

2085

2086 struct SimplifyDimOfAllocOp : public OpRewritePatternmemref::DimOp {

2087 using OpRewritePatternmemref::DimOp::OpRewritePattern;

2088

2089 LogicalResult matchAndRewrite(memref::DimOp dimOp,

2090 PatternRewriter &rewriter) const override {

2091 std::optional<int64_t> index = dimOp.getConstantIndex();

2092 if (!index)

2093 return failure();

2094

2095 auto memrefType = llvm::dyn_cast(dimOp.getSource().getType());

2096 if (!memrefType || index.value() >= memrefType.getRank() ||

2097 !memrefType.isDynamicDim(index.value()))

2098 return failure();

2099

2100 auto alloc = dimOp.getSource().getDefiningOp();

2101 if (!alloc)

2102 return failure();

2103

2104 Value substituteOp = *(alloc.getDynamicSizes().begin() +

2105 memrefType.getDynamicDimIndex(index.value()));

2106 rewriter.replaceOp(dimOp, substituteOp);

2107 return success();

2108 }

2109 };

2110

2111 } // namespace

2112

2113 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,

2114 MLIRContext *context) {

2115 results.add(context);

2116 }

2117

2118 //===----------------------------------------------------------------------===//

2119 // GPU object attribute

2120 //===----------------------------------------------------------------------===//

2121

2122 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,

2123 Attribute target, CompilationTarget format,

2124 StringAttr object, DictionaryAttr properties,

2125 KernelTableAttr kernels) {

2126 if (!target)

2127 return emitError() << "the target attribute cannot be null";

2128 if (target.hasPromiseOrImplementsInterface())

2129 return success();

2130 return emitError() << "the target attribute must implement or promise the "

2131 "`gpu::TargetAttrInterface`";

2132 }

2133

2134 namespace {

2135 ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,

2136 StringAttr &object) {

2137 std::optional formatResult;

2138 StringRef enumKeyword;

2139 auto loc = odsParser.getCurrentLocation();

2140 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))

2141 formatResult = CompilationTarget::Fatbin;

2142 if (!formatResult &&

2143 (formatResult =

2144 gpu::symbolizeEnumgpu::CompilationTarget(enumKeyword)) &&

2145 odsParser.parseEqual())

2146 return odsParser.emitError(loc, "expected an equal sign");

2147 if (!formatResult)

2148 return odsParser.emitError(loc, "expected keyword for GPU object format");

2149 FailureOr objectResult =

2150 FieldParser::parse(odsParser);

2151 if (failed(objectResult))

2152 return odsParser.emitError(odsParser.getCurrentLocation(),

2153 "failed to parse GPU_ObjectAttr parameter "

2154 "'object' which is to be a `StringAttr`");

2155 format = *formatResult;

2156 object = *objectResult;

2157 return success();

2158 }

2159

2160 void printObject(AsmPrinter &odsParser, CompilationTarget format,

2161 StringAttr object) {

2162 if (format != CompilationTarget::Fatbin)

2163 odsParser << stringifyEnum(format) << " = ";

2164 odsParser << object;

2165 }

2166 } // namespace

2167

2168 //===----------------------------------------------------------------------===//

2169 // GPU select object attribute

2170 //===----------------------------------------------------------------------===//

2171

2172 LogicalResult

2173 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,

2174 Attribute target) {

2175 // Check `target`, it can be null, an integer attr or a GPU Target attribute.

2176 if (target) {

2177 if (auto intAttr = mlir::dyn_cast(target)) {

2178 if (intAttr.getInt() < 0) {

2179 return emitError() << "the object index must be positive";

2180 }

2181 } else if (!target.hasPromiseOrImplementsInterface()) {

2182 return emitError()

2183 << "the target attribute must be a GPU Target attribute";

2184 }

2185 }

2186 return success();

2187 }

2188

2189 //===----------------------------------------------------------------------===//

2190 // DynamicSharedMemoryOp

2191 //===----------------------------------------------------------------------===//

2192

2193 LogicalResult gpu::DynamicSharedMemoryOp::verify() {

2194 if (!getOperation()->getParentWithTraitOpTrait::SymbolTable())

2195 return emitOpError() << "must be inside an op with symbol table";

2196

2197 MemRefType memrefType = getResultMemref().getType();

2198 // Check address space

2199 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {

2200 return emitOpError() << "address space must be "

2201 << gpu::AddressSpaceAttr::getMnemonic() << "<"

2202 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";

2203 }

2204 if (memrefType.hasStaticShape()) {

2205 return emitOpError() << "result memref type must be memref<?xi8, "

2206 "#gpu.address_space>";

2207 }

2208 return success();

2209 }

2210

2211 //===----------------------------------------------------------------------===//

2212 // GPU WarpExecuteOnLane0Op

2213 //===----------------------------------------------------------------------===//

2214

2215 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {

2216 p << "(" << getLaneid() << ")";

2217

2218 SmallVector coreAttr = {getWarpSizeAttrName()};

2219 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());

2220 p << "[" << llvm::cast(warpSizeAttr).getInt() << "]";

2221

2222 if (!getArgs().empty())

2223 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";

2224 if (!getResults().empty())

2225 p << " -> (" << getResults().getTypes() << ')';

2226 p << " ";

2227 p.printRegion(getRegion(),

2228 /*printEntryBlockArgs=*/true,

2229 /*printBlockTerminators=*/!getResults().empty());

2230 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);

2231 }

2232

2233 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,

2234 OperationState &result) {

2235 // Create the region.

2236 result.regions.reserve(1);

2237 Region *warpRegion = result.addRegion();

2238

2239 auto &builder = parser.getBuilder();

2240 OpAsmParser::UnresolvedOperand laneId;

2241

2242 // Parse predicate operand.

2243 if (parser.parseLParen() ||

2244 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||

2245 parser.parseRParen())

2246 return failure();

2247

2248 int64_t warpSize;

2249 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||

2250 parser.parseRSquare())

2251 return failure();

2252 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),

2253 builder.getContext())),

2254 builder.getI64IntegerAttr(warpSize));

2255

2256 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))

2257 return failure();

2258

2259 llvm::SMLoc inputsOperandsLoc;

2260 SmallVectorOpAsmParser::UnresolvedOperand inputsOperands;

2261 SmallVector inputTypes;

2262 if (succeeded(parser.parseOptionalKeyword("args"))) {

2263 if (parser.parseLParen())

2264 return failure();

2265

2266 inputsOperandsLoc = parser.getCurrentLocation();

2267 if (parser.parseOperandList(inputsOperands) ||

2268 parser.parseColonTypeList(inputTypes) || parser.parseRParen())

2269 return failure();

2270 }

2271 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,

2272 result.operands))

2273 return failure();

2274

2275 // Parse optional results type list.

2276 if (parser.parseOptionalArrowTypeList(result.types))

2277 return failure();

2278 // Parse the region.

2279 if (parser.parseRegion(*warpRegion, /*arguments=*/{},

2280 /*argTypes=*/{}))

2281 return failure();

2282 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);

2283

2284 // Parse the optional attribute list.

2285 if (parser.parseOptionalAttrDict(result.attributes))

2286 return failure();

2287 return success();

2288 }

2289

2290 void WarpExecuteOnLane0Op::getSuccessorRegions(

2291 RegionBranchPoint point, SmallVectorImpl &regions) {

2292 if (!point.isParent()) {

2293 regions.push_back(RegionSuccessor(getResults()));

2294 return;

2295 }

2296

2297 // The warp region is always executed

2298 regions.push_back(RegionSuccessor(&getWarpRegion()));

2299 }

2300

2301 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,

2302 TypeRange resultTypes, Value laneId,

2303 int64_t warpSize) {

2304 build(builder, result, resultTypes, laneId, warpSize,

2305 /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);

2306 }

2307

2308 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,

2309 TypeRange resultTypes, Value laneId,

2310 int64_t warpSize, ValueRange args,

2311 TypeRange blockArgTypes) {

2312 result.addOperands(laneId);

2313 result.addAttribute(getAttributeNames()[0],

2314 builder.getI64IntegerAttr(warpSize));

2315 result.addTypes(resultTypes);

2316 result.addOperands(args);

2317 assert(args.size() == blockArgTypes.size());

2318 OpBuilder::InsertionGuard guard(builder);

2319 Region *warpRegion = result.addRegion();

2320 Block *block = builder.createBlock(warpRegion);

2321 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))

2322 block->addArgument(type, arg.getLoc());

2323 }

2324

2325

2326

2327 static LogicalResult verifyDistributedType(Type expanded, Type distributed,

2328 int64_t warpSize, Operation *op) {

2329 // If the types matches there is no distribution.

2330 if (expanded == distributed)

2331 return success();

2332 auto expandedVecType = llvm::dyn_cast(expanded);

2333 auto distributedVecType = llvm::dyn_cast(distributed);

2334 if (!expandedVecType || !distributedVecType)

2335 return op->emitOpError("expected vector type for distributed operands.");

2336 if (expandedVecType.getRank() != distributedVecType.getRank() ||

2337 expandedVecType.getElementType() != distributedVecType.getElementType())

2338 return op->emitOpError(

2339 "expected distributed vectors to have same rank and element type.");

2340

2341 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);

2342 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {

2343 int64_t eDim = expandedVecType.getDimSize(i);

2344 int64_t dDim = distributedVecType.getDimSize(i);

2345 if (eDim == dDim)

2346 continue;

2347 if (eDim % dDim != 0)

2348 return op->emitOpError()

2349 << "expected expanded vector dimension #" << i << " (" << eDim

2350 << ") to be a multipler of the distributed vector dimension ("

2351 << dDim << ")";

2352 scales[i] = eDim / dDim;

2353 }

2354 if (std::accumulate(scales.begin(), scales.end(), 1,

2355 std::multiplies<int64_t>()) != warpSize)

2356 return op->emitOpError()

2357 << "incompatible distribution dimensions from " << expandedVecType

2358 << " to " << distributedVecType << " with warp size = " << warpSize;

2359

2360 return success();

2361 }

2362

2363 LogicalResult WarpExecuteOnLane0Op::verify() {

2364 if (getArgs().size() != getWarpRegion().getNumArguments())

2365 return emitOpError(

2366 "expected same number op arguments and block arguments.");

2367 auto yield =

2368 cast(getWarpRegion().getBlocks().begin()->getTerminator());

2369 if (yield.getNumOperands() != getNumResults())

2370 return emitOpError(

2371 "expected same number of yield operands and return values.");

2372 int64_t warpSize = getWarpSize();

2373 for (auto [regionArg, arg] :

2374 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {

2375 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),

2376 warpSize, getOperation())))

2377 return failure();

2378 }

2379 for (auto [yieldOperand, result] :

2380 llvm::zip_equal(yield.getOperands(), getResults())) {

2381 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),

2382 warpSize, getOperation())))

2383 return failure();

2384 }

2385 return success();

2386 }

2387 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {

2388 return succeeded(

2389 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));

2390 }

2391

2392 //===----------------------------------------------------------------------===//

2393 // GPU KernelMetadataAttr

2394 //===----------------------------------------------------------------------===//

2395

2396 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,

2397 DictionaryAttr metadata) {

2398 assert(kernel && "invalid kernel");

2399 return get(kernel.getNameAttr(), kernel.getFunctionType(),

2400 kernel.getAllArgAttrs(), metadata);

2401 }

2402

2403 KernelMetadataAttr

2404 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,

2405 FunctionOpInterface kernel,

2406 DictionaryAttr metadata) {

2407 assert(kernel && "invalid kernel");

2408 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),

2409 kernel.getAllArgAttrs(), metadata);

2410 }

2411

2412 KernelMetadataAttr

2413 KernelMetadataAttr::appendMetadata(ArrayRef attrs) const {

2414 if (attrs.empty())

2415 return *this;

2416 NamedAttrList attrList;

2417 if (DictionaryAttr dict = getMetadata())

2418 attrList.append(dict);

2419 attrList.append(attrs);

2420 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),

2421 attrList.getDictionary(getContext()));

2422 }

2423

2424 LogicalResult

2425 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,

2426 StringAttr name, Type functionType,

2427 ArrayAttr argAttrs, DictionaryAttr metadata) {

2428 if (name.empty())

2429 return emitError() << "the kernel name can't be empty";

2430 if (argAttrs) {

2431 if (llvm::any_of(argAttrs, [](Attribute attr) {

2432 return !llvm::isa(attr);

2433 }))

2434 return emitError()

2435 << "all attributes in the array must be a dictionary attribute";

2436 }

2437 return success();

2438 }

2439

2440 //===----------------------------------------------------------------------===//

2441 // GPU KernelTableAttr

2442 //===----------------------------------------------------------------------===//

2443

2444 KernelTableAttr KernelTableAttr::get(MLIRContext *context,

2445 ArrayRef kernels,

2446 bool isSorted) {

2447 // Note that `is_sorted` is always only invoked once even with assertions ON.

2448 assert((!isSorted || llvm::is_sorted(kernels)) &&

2449 "expected a sorted kernel array");

2450 // Immediately return the attribute if the array is sorted.

2451 if (isSorted || llvm::is_sorted(kernels))

2452 return Base::get(context, kernels);

2453 // Sort the array.

2454 SmallVector kernelsTmp(kernels);

2455 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());

2456 return Base::get(context, kernelsTmp);

2457 }

2458

2459 KernelTableAttr KernelTableAttr::getChecked(

2460 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,

2461 ArrayRef kernels, bool isSorted) {

2462 // Note that `is_sorted` is always only invoked once even with assertions ON.

2463 assert((!isSorted || llvm::is_sorted(kernels)) &&

2464 "expected a sorted kernel array");

2465 // Immediately return the attribute if the array is sorted.

2466 if (isSorted || llvm::is_sorted(kernels))

2467 return Base::getChecked(emitError, context, kernels);

2468 // Sort the array.

2469 SmallVector kernelsTmp(kernels);

2470 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());

2471 return Base::getChecked(emitError, context, kernelsTmp);

2472 }

2473

2474 LogicalResult

2475 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,

2476 ArrayRef kernels) {

2477 if (kernels.size() < 2)

2478 return success();

2479 // Check that the kernels are uniquely named.

2480 if (std::adjacent_find(kernels.begin(), kernels.end(),

2481 [](KernelMetadataAttr l, KernelMetadataAttr r) {

2482 return l.getName() == r.getName();

2483 }) != kernels.end()) {

2484 return emitError() << "expected all kernels to be uniquely named";

2485 }

2486 return success();

2487 }

2488

2489 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {

2490 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);

2491 return found ? *iterator : KernelMetadataAttr();

2492 }

2493

2494 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {

2495 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);

2496 return found ? *iterator : KernelMetadataAttr();

2497 }

2498

2499 //===----------------------------------------------------------------------===//

2500 // GPU target options

2501 //===----------------------------------------------------------------------===//

2502

2503 TargetOptions::TargetOptions(

2504 StringRef toolkitPath, ArrayRef librariesToLink,

2505 StringRef cmdOptions, StringRef elfSection,

2506 CompilationTarget compilationTarget,

2507 function_ref<SymbolTable *()> getSymbolTableCallback,

2508 function_ref<void(llvm::Module &)> initialLlvmIRCallback,

2509 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,

2510 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,

2511 function_ref<void(StringRef)> isaCallback)

2512 : TargetOptions(TypeID::get(), toolkitPath, librariesToLink,

2513 cmdOptions, elfSection, compilationTarget,

2514 getSymbolTableCallback, initialLlvmIRCallback,

2515 linkedLlvmIRCallback, optimizedLlvmIRCallback,

2516 isaCallback) {}

2517

2518 TargetOptions::TargetOptions(

2519 TypeID typeID, StringRef toolkitPath, ArrayRef librariesToLink,

2520 StringRef cmdOptions, StringRef elfSection,

2521 CompilationTarget compilationTarget,

2522 function_ref<SymbolTable *()> getSymbolTableCallback,

2523 function_ref<void(llvm::Module &)> initialLlvmIRCallback,

2524 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,

2525 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,

2526 function_ref<void(StringRef)> isaCallback)

2527 : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),

2528 cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),

2529 compilationTarget(compilationTarget),

2530 getSymbolTableCallback(getSymbolTableCallback),

2531 initialLlvmIRCallback(initialLlvmIRCallback),

2532 linkedLlvmIRCallback(linkedLlvmIRCallback),

2533 optimizedLlvmIRCallback(optimizedLlvmIRCallback),

2534 isaCallback(isaCallback), typeID(typeID) {}

2535

2536 TypeID TargetOptions::getTypeID() const { return typeID; }

2537

2538 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }

2539

2540 ArrayRef TargetOptions::getLibrariesToLink() const {

2541 return librariesToLink;

2542 }

2543

2544 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }

2545

2546 StringRef TargetOptions::getELFSection() const { return elfSection; }

2547

2548 SymbolTable *TargetOptions::getSymbolTable() const {

2549 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;

2550 }

2551

2552 function_ref<void(llvm::Module &)>

2553 TargetOptions::getInitialLlvmIRCallback() const {

2554 return initialLlvmIRCallback;

2555 }

2556

2557 function_ref<void(llvm::Module &)>

2558 TargetOptions::getLinkedLlvmIRCallback() const {

2559 return linkedLlvmIRCallback;

2560 }

2561

2562 function_ref<void(llvm::Module &)>

2563 TargetOptions::getOptimizedLlvmIRCallback() const {

2564 return optimizedLlvmIRCallback;

2565 }

2566

2567 function_ref<void(StringRef)> TargetOptions::getISACallback() const {

2568 return isaCallback;

2569 }

2570

2571 CompilationTarget TargetOptions::getCompilationTarget() const {

2572 return compilationTarget;

2573 }

2574

2575 CompilationTarget TargetOptions::getDefaultCompilationTarget() {

2576 return CompilationTarget::Fatbin;

2577 }

2578

2579 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>

2580 TargetOptions::tokenizeCmdOptions(const std::string &cmdOptions) {

2581 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;

2582 llvm::StringSaver stringSaver(options.first);

2583 StringRef opts = cmdOptions;

2584 // For a correct tokenization of the command line options `opts` must be

2585 // unquoted, otherwise the tokenization function returns a single string: the

2586 // unquoted `cmdOptions` -which is not the desired behavior.

2587 // Remove any quotes if they are at the beginning and end of the string:

2588 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')

2589 opts.consume_front("\""), opts.consume_back("\"");

2590 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')

2591 opts.consume_front("'"), opts.consume_back("'");

2592 #ifdef _WIN32

2593 llvm:๐Ÿ†‘:TokenizeWindowsCommandLine(opts, stringSaver, options.second,

2594 false);

2595 #else

2596 llvm:๐Ÿ†‘:TokenizeGNUCommandLine(opts, stringSaver, options.second,

2597 false);

2598 #endif

2600 }

2601

2602 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>

2605 }

2606

2607 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>

2609 size_t startPos = cmdOptions.find(startsWith);

2610 if (startPos == std:๐Ÿงต:npos)

2612

2613 auto tokenized =

2616 return tokenized;

2617 }

2618

2620

2621 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"

2622 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"

2623

2624 #define GET_ATTRDEF_CLASSES

2625 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"

2626

2627 #define GET_OP_CLASSES

2628 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"

2629

2630 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"

static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)

static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)

Parses an optional list of async operands with an optional leading keyword.

static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)

static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)

static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)

Prints optional async dependencies with its leading keyword.

static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)

static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)

static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)

Prints a GPU function memory attribution.

static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)

static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)

static bool canMakeGroupOpUniform(Operation *op)

static std::string getSparseHandleKeyword(SparseHandleKind kind)

static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)

static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)

static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)

Parses a GPU function memory attribution.

static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)

static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)

static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)

static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)

static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)

static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)

static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)

static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)

Verifies a GPU function memory attribution.

static MLIRContext * getContext(OpFoldResult val)

static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)

Utility to check that all of the operations within 'src' can be inlined.

union mlir::linalg::@1203::ArityGroupAndKind::Kind kind

static std::string diag(const llvm::Value &value)

static llvm::ManagedStatic< PassManagerOptions > options

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)

#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)

This base class exposes generic asm parser hooks, usable across the various derived parsers.

ParseResult parseSymbolName(StringAttr &result)

Parse an -identifier and store it (without the '@' symbol) in a string attribute.

@ Paren

Parens surrounding zero or more operands.

@ OptionalSquare

Square brackets supporting zero or more ops, or nothing.

virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0

Parse a list of comma-separated items with an optional delimiter.

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0

Parse a named dictionary into 'result' if it is present.

virtual ParseResult parseOptionalKeyword(StringRef keyword)=0

Parse the given keyword if present.

MLIRContext * getContext() const

virtual Location getEncodedSourceLoc(SMLoc loc)=0

Re-encode the given source location as an MLIR location and return it.

virtual ParseResult parseRParen()=0

Parse a ) token.

virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0

Emit a diagnostic at the specified location and return failure.

virtual ParseResult parseOptionalColon()=0

Parse a : token if present.

virtual ParseResult parseLess()=0

Parse a '<' token.

virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0

Parse a dimension list of a tensor or memref type.

virtual ParseResult parseEqual()=0

Parse a = token.

virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0

Parse a named dictionary into 'result' if the attributes keyword is present.

virtual ParseResult parseColonType(Type &result)=0

Parse a colon followed by a type.

virtual SMLoc getCurrentLocation()=0

Get the location of the next token and store it into the argument.

virtual SMLoc getNameLoc() const =0

Return the location of the original name token.

virtual ParseResult parseOptionalString(std::string *string)=0

Parse a quoted string token if present.

virtual ParseResult parseOptionalLess()=0

Parse a '<' token if present.

virtual ParseResult parseGreater()=0

Parse a '>' token.

virtual ParseResult parseLParen()=0

Parse a ( token.

virtual ParseResult parseType(Type &result)=0

Parse a type.

virtual ParseResult parseComma()=0

Parse a , token.

ParseResult parseKeyword(StringRef keyword)

Parse a given keyword.

virtual ParseResult parseAttribute(Attribute &result, Type type={})=0

Parse an arbitrary attribute of a given type and return it in result.

This base class exposes generic asm printer hooks, usable across the various derived printers.

virtual void printSymbolName(StringRef symbolRef)

Print the given string as a symbol reference, i.e.

Attributes are known-constant values of operations.

This class represents an argument of a Block.

Block represents an ordered list of Operations.

BlockArgument addArgument(Type type, Location loc)

Add one value to the argument list.

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getI32IntegerAttr(int32_t value)

DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)

FunctionType getFunctionType(TypeRange inputs, TypeRange results)

IntegerAttr getI64IntegerAttr(int64_t value)

Ty getType(Args &&...args)

Get or construct an instance of the type Ty with provided arguments.

StringAttr getStringAttr(const Twine &bytes)

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)

NamedAttribute getNamedAttr(StringRef name, Attribute val)

Attr getAttr(Args &&...args)

Get or construct an instance of the attribute Attr with provided arguments.

The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

This is the interface that must be implemented by the dialects of operations to be inlined.

DialectInlinerInterface(Dialect *dialect)

This is a utility class for mapping one set of IR entities to another.

This class represents a diagnostic that is inflight and set to be reported.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

void push_back(NamedAttribute newAttribute)

Add an attribute with the specified name.

NamedAttribute represents a combination of a name and an Attribute value.

StringAttr getName() const

Return the name of the attribute.

Attribute getValue() const

Return the value of the attribute.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

virtual size_t getNumResults() const =0

Return the number of declared SSA results.

virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0

Parses a region.

virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0

Parse zero or more arguments with a specified surrounding delimiter.

virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0

Resolve an operand to an SSA value, emitting an error on failure.

ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)

Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...

virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single SSA value operand name along with a result number if allowResultNumber is true.

virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0

Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary with their values.

virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0

Prints a region.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents a single result from folding an operation.

static StringRef getOperandSegmentSizeAttr()

This class implements the operand iterators for the Operation class.

Operation is the basic unit of execution within MLIR.

void insertOperands(unsigned index, ValueRange operands)

Insert the given operands into the operand list at the given 'index'.

AttrClass getAttrOfType(StringAttr name)

MLIRContext * getContext()

Return the context this operation is associated with.

Location getLoc()

The source location the operation was defined or derived from.

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

Block * getBlock()

Returns the operation block that contains this operation.

void setAttr(StringAttr name, Attribute value)

If the an attribute exists with the specified name, change it to the new value.

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class contains a list of basic blocks and a link to the parent operation it is attached to.

RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)

Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.

void replaceAllUsesWith(Value from, Value to)

Find uses of from and replace them with to.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

static StringRef getSymbolAttrName()

Return the name of the attribute used for symbol names.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

MLIRContext * getContext() const

Return the MLIRContext in which this type was uniqued.

bool isSignedInteger() const

Return true if this is a signed integer type (with the specified width).

bool isUnsignedInteger() const

Return true if this is an unsigned integer type (with the specified width).

bool isInteger() const

Return true if this is an integer type (with the specified width).

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

A utility result that is used to signal how to proceed with an ongoing walk:

static ConcreteT get(MLIRContext *ctx, Args &&...args)

Get or create a new ConcreteT instance within the ctx.

ImplType * getImpl() const

Utility for easy access to the storage instance.

MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.

ArrayRef< int64_t > getShape() const

Get shape of the matrix.

static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)

Get MMAMatrixType and verify construction Invariants.

Type getElementType() const

Get elementType of a single element.

static bool isValidElementType(Type elementType)

Check if a type is valid a MMAMatrixType elementType.

static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)

Verify that shape and elementType are actually allowed for the MMAMatrixType.

StringRef getOperand() const

The general form of operation this type supports is given by the equation C += A*B.

static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)

Get MMAMatrixType at a particular location and verify construction Invariants.

unsigned getNumDims() const

Get number of dims.

This class serves as an opaque interface for passing options to the TargetAttrInterface methods.

std::string cmdOptions

An optional set of command line options to be used by the compilation process.

std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const

Returns a tokenization of the command line options.

std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)

Returns a tokenization of the substr of the command line options that starts with startsWith and ends...

void printType(Type type, AsmPrinter &printer)

Prints an LLVM Dialect type.

void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)

Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)

Utility method to generate a callback that can be used to generate a diagnostic when checking the con...

ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)

Parses a function signature using parser.

void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})

Prints the list of function prefixed with the "attributes" keyword.

void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)

Prints the signature of the function-like operation op.

void addAsyncDependency(Operation *op, Value token)

llvm::StringMap< llvm::SmallString< 8 > > dictionary

A dictionary stores a mapping of template variable names to their assigned string values.

Kind

An enumeration of the kinds of predicates.

QueryRef parse(llvm::StringRef line, const QuerySession &qs)

SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)

Given the dimToLvl map, returns the block sizes in a vector.

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

detail::constant_int_predicate_matcher m_One()

Matches a constant scalar / vector splat / tensor splat integer one.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)

This parses a single MLIR type to an MLIR context if it was valid.

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.

LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override

UnresolvedOperand ssaName

This is the representation of an operand reference.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

This represents an operation in an abstracted form, suitable for use with the builder APIs.

T & getOrAddProperties()

Get (or create) a properties of the provided type to be set on the operation on creation.

SmallVector< Value, 4 > operands

void addOperands(ValueRange newOperands)

void addAttributes(ArrayRef< NamedAttribute > newAttributes)

Add an array of named attributes.

void addAttribute(StringRef name, Attribute attr)

Add an attribute with the specified name.

SmallVector< Type, 4 > types

Types of the results of this operation.

Region * addRegion()

Create a region that should be attached to the operation.

Utility class for the GPU dialect to represent triples of Values accessible through ....