MLIR: lib/Dialect/XeGPU/IR/XeGPUOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

18

19#include "llvm/Support/Debug.h"

20

21#define DEBUG_TYPE "xegpu"

22

23using namespace mlir;

25

27 Attribute attr = memrefTy.getMemorySpace();

28 if (auto intAttr = llvm::dyn_cast(attr))

29 return intAttr.getInt() == 3;

30 if (auto memrefSpace = llvm::dyn_cast(attr))

31 return memrefSpace.getValue() == MemorySpace::SLM;

32 if (auto xevmSpace = llvm::dyn_castxevm::AddrSpaceAttr(attr))

33 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;

34 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);

35}

36

37template

38static std::string makeString(T array, bool breakline = false) {

39 std::string buf;

40 buf.clear();

41 llvm::raw_string_ostream os(buf);

42 os << "[";

43 for (size_t i = 1; i < array.size(); i++) {

44 os << array[i - 1] << ", ";

45 if (breakline)

46 os << "\n\t\t";

47 }

48 os << array.back() << "]";

49 return buf;

50}

51

54 if (auto ty = llvm::dyn_cast(type))

56 else

57 shape.push_back(1);

59}

60

62 if (!attr)

63 return true;

64 auto kind = attr.getValue();

65 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||

66 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;

67}

68

70 if (!attr)

71 return true;

72 auto kind = attr.getValue();

73 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||

74 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;

75}

76

77static LogicalResult

79 TensorDescType tdescTy,

81

82 if (!tdescTy.isScattered())

83 return emitError() << "Expects a scattered TensorDesc.";

84

85 auto chunkSize = tdescTy.getChunkSizeAsInt();

86 if (!valueTy) {

87 if (chunkSize > 1)

88 return emitError() << "Expecting chunk size == 1 for scalar result";

89 if (dyn_cast(maskTy))

90 return emitError() << "Expecting a vector type result.";

92 }

93

95 auto valueShape = getShapeOf(valueTy);

96 auto tdescShape = getShapeOf(tdescTy);

97

98 if (valueTy.getElementType() != tdescTy.getElementType())

100 << "Value should have the same element type as TensorDesc.";

101

103 if (chunkSize > 1)

104 expectedMaskShape.pop_back();

105 if (expectedMaskShape != maskShape)

107 << "Mask should match TensorDesc except the chunk size dim.";

108

109

110 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {

111 if (tdescTy.getLayoutAttr())

112 return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";

114 }

115

116 if (tdescShape != valueShape)

118 << " is neither a valid distribution for SIMT nor "

119 "consistent with the tensor descriptor for SIMD "

120 << tdescTy;

122}

123

124static LogicalResult

126 VectorType valueTy, int64_t chunkSize,

128

129 auto maskVecTy = dyn_cast(maskTy);

130 auto offsetsVecTy = dyn_cast(offsetsTy);

131 if (!valueTy) {

132 if (chunkSize > 1)

133 return emitError() << "Expecting chunk size == 1 for scalar result";

134 if (maskVecTy || offsetsVecTy)

135 return emitError() << "Expecting scalar mask and offsets.";

136 else if (maskVecTy && offsetsVecTy)

137 return emitError() << "Expecting a vector type result.";

139 }

140

141 auto valueSize = valueTy.getNumElements();

142

143 if (!maskVecTy && !offsetsVecTy) {

144 if (valueSize != chunkSize)

145 return emitError() << "value elements must match chunk size "

146 << chunkSize;

148 }

150 auto valueShape = getShapeOf(valueTy);

151

152 if (!maskVecTy)

153 return emitError() << "Expecting a vector type mask.";

154 int64_t maskSize = maskVecTy.getNumElements();

155

156 if (chunkSize > 1) {

157 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))

158 return emitError() << "value elements must match chunk size "

159 << chunkSize;

160 } else {

161 if (valueSize != maskSize)

163 << "Mask should match value except the chunk size dim.";

164 }

166 if (maskSize == 1)

168 if (chunkSize > 1)

169 expectedMaskShape.pop_back();

170 if (expectedMaskShape != maskShape)

171 return emitError() << "Mask should match value except the chunk size dim.";

172

174}

175

176LogicalResult

178 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,

180

181 if (!dataTy) {

182 if (subgroup_block_io)

183 return emitError() << "subgroup_block_io "

184 "are only allowed when result is a VectorType.";

185 else

187 }

188

189 if (mdescTy.getRank() != 2)

190 return emitError() << "mem_desc must be 2D.";

191

194

196 ArrayAttr strideAttr = mdescTy.getStrideAttr();

198 for (Attribute attr : strideAttr.getValue()) {

199 strides.push_back(cast(attr).getInt());

200 }

201 if (subgroup_block_io && layout) {

202 auto laneData = layout.getEffectiveLaneDataAsInt();

203 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();

204 if (!laneData.empty()) {

205 bool isLaneDataContiguous =

206 std::all_of(laneData.begin(), std::prev(laneData.end()),

207 [](int x) { return x == 1; });

208 if (!isLaneDataContiguous)

209 return emitError() << "With subgroup_block_io, accessed data must be "

210 "contiguous and coalesced.";

211 for (size_t i = 0; i < laneData.size(); ++i) {

212 if (laneLayout[i] != blockShape[i])

213 return emitError() << "With subgroup_block_io, the block shape must "

214 "match the lane layout.";

215 if (laneLayout[i] != 1 && strides[i] != 1)

216 return emitError() << "With subgroup_block_io, the distributed "

217 "dimensions must be contiguous.";

218 }

219 }

220 }

221 if (dataShape.size() == 2) {

222 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),

223 [](auto p) { return std::get<0>(p) > std::get<1>(p); }))

224 return emitError() << "data shape must not exceed mem_desc shape.";

225 } else {

226

227

228 if (subgroup_block_io && !blockShape.size())

229 return emitError() << "mem_desc must have block attribute when "

230 "subgroup_block_io is set.";

231

232

233 if (subgroup_block_io && mdescTy.isColMajor())

234 return emitError() << "mem_desc should be row major when "

235 "subgroup_block_io is set.";

236 }

237

239}

240

241

242

243

244

247 [[maybe_unused]] auto ty = source.getType();

248 assert(ty.hasStaticShape() && "expecting a memref with static shape");

249

250 build(builder, state, tdesc, source, ValueRange({}) ,

251 ValueRange({}) ,

252 ValueRange({}) ,

256}

257

263 assert((isa<IntegerType, MemRefType>(srcTy)) &&

264 "Source has to be either int or memref.");

265

268

271

274

277

278 if (auto memrefTy = dyn_cast(srcTy)) {

279 auto memrefShape = memrefTy.getShape();

280 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();

281

282

283

284

285 if (staticShape == memrefShape && staticStrides == memrefStrides &&

286 dynamicShape.empty() && dynamicStrides.empty()) {

289 }

290 }

291

292 build(builder, state, tdesc, source, ValueRange({}), dynamicShape,

294 staticStridesAttr);

295}

296

300 [[maybe_unused]] auto ty = source.getType();

301 assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());

302

306

307 build(builder, state, tdesc, source, dynamicOffsets ,

308 ValueRange({}) ,

309 ValueRange({}) ,

311 {} , {} );

312}

313

319 assert(shape.empty() && !offsets.empty() && !strides.empty() &&

320 shape.size() == strides.size() && shape.size() == offsets.size());

321

323 assert((isa<IntegerType, MemRefType>(srcTy)) &&

324 "Source has to be either int or memref.");

325

329

333

337

341

342 if (auto memrefTy = dyn_cast(srcTy)) {

343 auto memrefShape = memrefTy.getShape();

344 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();

345

346

347

348

349 if (staticShape == memrefShape && staticStrides == memrefStrides &&

350 dynamicShape.empty() && dynamicStrides.empty()) {

353 }

354 }

355

356 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,

357 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);

358}

359

360LogicalResult CreateNdDescOp::verify() {

362 bool invalidRank = rank != getMixedStrides().size();

363 bool invalidElemTy = false;

364

365

366

367

368

369 auto srcMemorySpace = getSourceMemorySpace();

370 auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());

371 if (srcMemorySpace != tdescMemorySpace)

372 return emitOpError("Memory space mismatch.")

373 << " Source: " << srcMemorySpace

374 << ", TensorDesc: " << tdescMemorySpace;

375

376 if (size_t offsetRank = getMixedOffsets().size())

377 invalidRank |= (offsetRank != rank);

378

379

380

381 if (auto memrefTy = dyn_cast(getSourceType()))

382 invalidElemTy |= memrefTy.getElementType() != getElementType();

383

384 if (llvm::isa(getSourceType())) {

385

386 if (getMixedStrides().empty() || getMixedSizes().empty())

387 return emitOpError("expecting strides and shape to be present for "

388 "integer source.");

389 }

390

391 if (invalidRank)

393 "Expecting the rank of shape, strides, offsets, and source (if source "

394 "is a memref) should match with each other.");

395

396

399 "Expecting the TensorDesc rank is not greater than the "

400 "ranks of shape, strides, offsets or the memref source.");

401

402 if (invalidElemTy)

403 return emitOpError("TensorDesc should have the same element "

404 "type with the source if it is a memref.\n");

405

406 if (getType().isScattered())

407 return emitOpError("Expects a non-scattered TensorDesc.\n");

408

410}

411

417

419 auto parseIntegerOrValue = [&]() {

422

423 if (res.has_value() && succeeded(res.value())) {

424 values.push_back(operand);

425 integerVals.push_back(ShapedType::kDynamic);

426 if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))

427 return failure();

428 } else {

431 return failure();

432 integerVals.push_back(integer);

433 }

435 };

436

437

442 << "expected a list of SSA values or integers";

445 }

446

448}

449

453 if (!integers || integers.empty())

454 return;

457}

458

459

460

461

463 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,

464 xegpu::CachePolicyAttr l2_hint,

465 xegpu::CachePolicyAttr l3_hint) {

466

468 l1_hint, l2_hint, l3_hint, nullptr);

469}

470

473 xegpu::CachePolicyAttr l1_hint,

474 xegpu::CachePolicyAttr l2_hint,

475 xegpu::CachePolicyAttr l3_hint,

476 xegpu::DistributeLayoutAttr layout) {

480

482

483 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,

484 l2_hint, l3_hint, layout);

485}

486

487LogicalResult PrefetchNdOp::verify() {

488 auto tdescTy = getTensorDescType();

489 if (tdescTy.isScattered())

490 return emitOpError("Expects a non-scattered TensorDesc.\n");

491

493 return emitOpError("invalid l1_hint: ") << getL1HintAttr();

494

496 return emitOpError("invalid l2_hint: ") << getL2HintAttr();

497

499 return emitOpError("invalid l3_hint: ") << getL3HintAttr();

500

501 int64_t tDescRank = tdescTy.getRank();

502 int64_t offsetSize = getMixedOffsets().size();

503 if (offsetSize != 0 && offsetSize != tDescRank)

505 "Mismatched ranks between offsets and tensor descriptor");

506

508}

509

510

511

512

513

515 Value tensorDesc, UnitAttr packed,

517 xegpu::CachePolicyAttr l1_hint,

518 xegpu::CachePolicyAttr l2_hint,

519 xegpu::CachePolicyAttr l3_hint) {

520

521 return build(builder, state, retType, tensorDesc, ValueRange(),

523 l3_hint, nullptr);

524}

525

529 xegpu::CachePolicyAttr l1_hint,

530 xegpu::CachePolicyAttr l2_hint,

531 xegpu::CachePolicyAttr l3_hint,

532 xegpu::DistributeLayoutAttr layout) {

536

538

539 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,

540 packed, transpose, l1_hint, l2_hint, l3_hint,

541 layout);

542}

543

544LogicalResult LoadNdOp::verify() {

545 auto tdescTy = getTensorDescType();

546 auto valueTy = getType();

547

548 if (tdescTy.isScattered())

549 return emitOpError("Expects a non-scattered TensorDesc.\n");

550

551 if (tdescTy.getRank() > 2)

552 return emitOpError("Expects a 1D or 2D TensorDesc.\n");

553

554 if (!valueTy)

555 return emitOpError("Invalid result, it should be a VectorType.\n");

556

558 return emitOpError("invalid l1_hint: ") << getL1HintAttr();

559

561 return emitOpError("invalid l2_hint: ") << getL2HintAttr();

562

564 return emitOpError("invalid l3_hint: ") << getL3HintAttr();

565

566 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();

567 int valueElems = valueTy.getNumElements();

568

569

570

571

572 if (valueElems < tdescElems && valueTy.getRank() == 1) {

573

574 if (tdescTy.getLayoutAttr())

576 << "TensorDesc doesn't need LayoutAttr for SIMT code";

577

578

579

580

581 if (tdescElems % valueElems)

584 << " is not a valid distribution for tensor descriptor "

585 << tdescTy;

586

588 }

589

590

591 auto tdescShape = getShapeOf(tdescTy);

592 auto valueShape = getShapeOf(valueTy);

593

594 if (getTranspose()) {

595 auto trans = getTranspose().value();

596

597 if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))

599 else

600 mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";

601 }

602

603 if (getPacked()) {

604 if (tdescTy.getRank() == 2) {

605 const int axis = 0;

606 auto vnni_factor = valueShape.back();

607 tdescShape[axis] /= vnni_factor;

608 tdescShape.push_back(vnni_factor);

609 } else {

611 << "Invalid Packed Attr. It is ignored (available for 2D "

612 "TensorDesc only).";

613 }

614 }

615

616 auto array_len = tdescTy.getArrayLength();

617 if (array_len > 1)

618 tdescShape.insert(tdescShape.begin(), array_len);

619

620 if (tdescShape != valueShape)

622 << " is not consistent with tensor descriptor "

623 << tdescTy;

624

625 int64_t tDescRank = tdescTy.getRank();

626 int64_t offsetSize = getMixedOffsets().size();

627 if (offsetSize != 0 && offsetSize != tDescRank)

629 "Mismatched ranks between offsets and tensor descriptor");

630

632}

633

634

635

636

637

639 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,

640 xegpu::CachePolicyAttr l2_hint,

641 xegpu::CachePolicyAttr l3_hint) {

642

643 return build(builder, state, value, tensorDesc, ValueRange(),

645 nullptr);

646}

647

650 xegpu::CachePolicyAttr l1_hint,

651 xegpu::CachePolicyAttr l2_hint,

652 xegpu::CachePolicyAttr l3_hint,

653 xegpu::DistributeLayoutAttr layout) {

657

659

660 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,

661 l1_hint, l2_hint, l3_hint, layout);

662}

663

664LogicalResult StoreNdOp::verify() {

665 auto dstTy = getTensorDescType();

667

668 if (dstTy.isScattered())

669 return emitOpError("Expects a non-scattered TensorDesc.\n");

670

671 if (dstTy.getRank() > 2)

672 return emitOpError("Expects a 1D or 2D TensorDesc.\n");

673

674 if (!valTy)

675 return emitOpError("Expecting a VectorType result.\n");

676

678 return emitOpError("invalid l1_hint: ") << getL1HintAttr();

679

681 return emitOpError("invalid l2_hint: ") << getL2HintAttr();

682

684 return emitOpError("invalid l3_hint: ") << getL3HintAttr();

685

686 auto array_len = dstTy.getArrayLength();

687 if (array_len > 1)

688 return emitOpError("array length is not supported by store_nd.\n");

689

690 auto tdescElems = dstTy.getNumElements();

691 auto valueElems = valTy.getNumElements();

692

693

694

695

696 if (valTy.getRank() == 1 && valueElems < tdescElems) {

697

698 if (dstTy.getLayoutAttr())

700 << "TensorDesc doesn't need LayoutAttr for SIMT code";

701

702 if (tdescElems % valueElems)

705 << " is not a valid distribution for tensor descriptor " << dstTy;

706

708 }

709

710

713 if (tdescShape != valueShape)

715 << " is not consistent with tensor descriptor "

716 << dstTy;

717

718 int64_t tDescRank = dstTy.getRank();

719 int64_t offsetSize = getMixedOffsets().size();

720 if (offsetSize != 0 && offsetSize != tDescRank)

722 "Mismatched ranks between offsets and tensor descriptor");

723

725}

726

727

728

729

730LogicalResult UpdateNdOffsetOp::verify() {

731 auto ty = getTensorDescType();

732 if (ty.isScattered())

733 return emitOpError("Expects a non-scattered TensorDesc.\n");

734

735

736 if (ty.getRank() != (int64_t)getNumOffsets()) {

737 return emitOpError("Invalid number of offsets.");

738 }

740}

741

742

743

744

745

747 TensorDescType TensorDesc, Value source,

749 auto loc = source.getLoc();

751 auto type = VectorType::get(size, builder.getIndexType());

753 auto offset = vector::FromElementsOp::create(builder, loc, type, values);

754 build(builder, state, TensorDesc, source, offset);

755}

756

758 TensorDescType TensorDesc, Value source,

761 build(builder, state, TensorDesc, source, ofrs);

762}

763

764LogicalResult CreateDescOp::verify() {

765 auto tdescTy = getTensorDescType();

766

767 if (!tdescTy.isScattered())

768 return emitOpError("Expects a scattered TensorDesc.\n");

769

770

771

772

773

774 auto srcMemorySpace = getSourceMemorySpace();

775 auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());

776 if (srcMemorySpace != tdescMemorySpace)

777 return emitOpError("Memory space mismatch.")

778 << " Source: " << srcMemorySpace

779 << ", TensorDesc: " << tdescMemorySpace;

780

781

782 auto chunkSize = tdescTy.getChunkSizeAsInt();

784 if (chunkSize != 1)

785 shape.push_back(chunkSize);

786

787 auto tdescShape = getShapeOf(tdescTy);

788 if (shape != tdescShape)

789 return emitOpError("Incorrect TensorDesc shape. ")

791

793}

794

795

796

797

798LogicalResult PrefetchOp::verify() {

799 auto tdescTy = getTensorDescType();

800

801 if (!tdescTy && !getOffsets())

803

804 if (tdescTy && getOffsets())

805 return emitOpError("offsets not allowed.");

806

807 if (tdescTy && !tdescTy.isScattered())

808 return emitOpError("Expects a scattered TensorDesc.");

809

811 return emitOpError("invalid l1_hint: ") << getL1HintAttr();

812

814 return emitOpError("invalid l2_hint: ") << getL2HintAttr();

815

817 return emitOpError("invalid l3_hint: ") << getL3HintAttr();

818

819 auto srcTy = getSourceType();

820 if (srcTy.isInteger() && !getOffsetAlignByteAttr())

821 return emitOpError("offset_align_byte is required with integer source.");

822

823 if (getOffsetAlignByteAttr() && !srcTy.isInteger())

824 return emitOpError("offset_align_byte only allowed with integer source.");

825

827}

828

830 xegpu::CachePolicyAttr l1_hint,

831 xegpu::CachePolicyAttr l2_hint,

832 xegpu::CachePolicyAttr l3_hint) {

833 build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,

834 IntegerAttr{}, nullptr);

835}

836

837

838

839

840LogicalResult LoadGatherOp::verify() {

841 auto tdescTy = getTensorDescType();

842 auto maskTy = getMaskType();

844

845 if (!tdescTy && !getOffsets())

847

848 if (tdescTy && getOffsets())

849 return emitOpError("offsets not allowed.");

850

851 if (tdescTy && !tdescTy.isScattered())

852 return emitOpError("Expects a scattered TensorDesc.");

853

855 return emitOpError("invalid l1_hint: ") << getL1HintAttr();

856

858 return emitOpError("invalid l2_hint: ") << getL2HintAttr();

859

861 return emitOpError("invalid l3_hint: ") << getL3HintAttr();

862

863 if (tdescTy)

866 auto srcTy = getSourceType();

867 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));

868 auto memTy = dyn_cast(srcTy);

869

870 if (memTy && (getElementType() != memTy.getElementType()))

871 return emitError() << "Value should have the same element type as MemRef.";

872

873 auto offsetsTy = getOffsets().getType();

876}

877

880 xegpu::CachePolicyAttr l1_hint,

881 xegpu::CachePolicyAttr l2_hint,

882 xegpu::CachePolicyAttr l3_hint) {

883 build(builder, state, valueType, source, Value(), mask, IntegerAttr(),

884 l1_hint, l2_hint, l3_hint, nullptr);

885}

886

890 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,

891 xegpu::CachePolicyAttr l2_hint,

892 xegpu::CachePolicyAttr l3_hint) {

893 auto loc = source.getLoc();

895 auto type = VectorType::get(size, builder.getIndexType());

897 auto offset = vector::FromElementsOp::create(builder, loc, type, values);

898

899 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,

900 l2_hint, l3_hint, nullptr);

901}

902

906 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,

907 xegpu::CachePolicyAttr l2_hint,

908 xegpu::CachePolicyAttr l3_hint,

909 DistributeLayoutAttr layout) {

910 auto loc = source.getLoc();

912 auto type = VectorType::get(size, builder.getIndexType());

914 auto offset = vector::FromElementsOp::create(builder, loc, type, values);

915

916 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,

917 l2_hint, l3_hint, layout);

918}

919

920

921

922

923LogicalResult StoreScatterOp::verify() {

924 auto tdescTy = getTensorDescType();

925 auto maskTy = getMaskType();

927

928 if (!tdescTy && !getOffsets())

930

931 if (tdescTy && getOffsets())

932 return emitOpError("offsets not allowed.");

933

934 if (tdescTy && !tdescTy.isScattered())

935 return emitOpError("Expects a scattered TensorDesc.");

936

938 return emitOpError("invalid l1_hint: ") << getL1HintAttr();

939

941 return emitOpError("invalid l2_hint: ") << getL2HintAttr();

942

944 return emitOpError("invalid l3_hint: ") << getL3HintAttr();

945

946 if (tdescTy)

949

950 auto destTy = getDestType();

951 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));

952 auto memTy = dyn_cast(destTy);

953

954 if (memTy && (getElementType() != memTy.getElementType()))

955 return emitError() << "Value should have the same element type as MemRef.";

956

957 auto offsetsTy = getOffsets().getType();

960}

961

964 xegpu::CachePolicyAttr l1_hint,

965 xegpu::CachePolicyAttr l2_hint,

966 xegpu::CachePolicyAttr l3_hint) {

967 build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,

968 l2_hint, l3_hint, nullptr);

969}

970

974 IntegerAttr chunk_size,

975 xegpu::CachePolicyAttr l1_hint,

976 xegpu::CachePolicyAttr l2_hint,

977 xegpu::CachePolicyAttr l3_hint) {

978 auto loc = dest.getLoc();

980 auto type = VectorType::get(size, builder.getIndexType());

982 auto offset = vector::FromElementsOp::create(builder, loc, type, values);

983

984

985 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,

986 l3_hint, nullptr);

987}

988

989void StoreScatterOp::build(

992 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,

993 xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {

994 auto loc = dest.getLoc();

996 auto type = VectorType::get(size, builder.getIndexType());

998 auto offset = vector::FromElementsOp::create(builder, loc, type, values);

999

1000

1001 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,

1002 l3_hint, layout);

1003}

1004

1005

1006

1007

1011 auto tdescTy = mlir::dyn_cast(tensorDesc.getType());

1012 assert(tdescTy && "Expecting the source is a TensorDescType value.");

1013 auto loc = tensorDesc.getLoc();

1014 int64_t size = static_cast<int64_t>(offsets.size());

1015 auto type = VectorType::get({size}, builder.getIndexType());

1017 auto offset = vector::FromElementsOp::create(builder, loc, type, values);

1018 build(builder, state, tdescTy, tensorDesc, offset);

1019}

1020

1024 build(builder, state, tensorDesc, ofrs);

1025}

1026

1027LogicalResult UpdateOffsetOp::verify() {

1028 auto tdescTy = getTensorDescType();

1029 if (!tdescTy.isScattered())

1030 return emitOpError("Expects a scattered TensorDesc.\n");

1031

1034 if (tdescTy.getChunkSizeAsInt() > 1)

1035 expectedOffsetShape.pop_back();

1036

1037 if (expectedOffsetShape != offsetShape)

1039 "Offsets should match TensorDesc except the chunk size dim.");

1040

1042}

1043

1044

1045

1046

1047LogicalResult DpasOp::verify() {

1048 int64_t lhsRank = getLhsType().getRank();

1049 int64_t rhsRank = getRhsType().getRank();

1050 int64_t resRank = getResultType().getRank();

1051 auto lhsShape = getLhsType().getShape();

1052 auto rhsShape = getRhsType().getShape();

1053 auto resShape = getResultType().getShape();

1054

1055 if (getAcc() && getAcc().getType() != getResultType())

1056 return emitOpError("Expecting the acc type to be the same as result.");

1057

1058

1059

1060

1061 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {

1062 auto numElems = getRhsType().getNumElements();

1063 auto elemTy = getRhsType().getElementType();

1064 auto factor = 32 / elemTy.getIntOrFloatBitWidth();

1065 if (numElems % factor != 0)

1066 return emitOpError("Expecting B operand to be a multiple of 32 bits.");

1068 }

1069

1070

1071 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)

1073 "expecting lhs and result to be a 2D vector, and rhs to be either "

1074 "2D or 3D (packed) vector.");

1075 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];

1076 if (bK != lhsShape[1])

1077 return emitOpError("K-dimension mismatch.");

1078 if (lhsShape[0] != resShape[0])

1079 return emitOpError("M-dimension mismatch.");

1080 if (rhsShape[1] != resShape[1])

1081 return emitOpError("N-dimension mismatch.");

1082

1084}

1085

1086

1087

1088

1089LogicalResult ConvertLayoutOp::verify() {

1090 auto srcLayout = getInputLayout();

1091 auto resLayout = getTargetLayout();

1092 if (!srcLayout)

1093 return emitOpError("expected input layout.");

1094 if (!resLayout)

1095 return emitOpError("expected target layout.");

1096

1097

1098

1099 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&

1100 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))

1101 return emitOpError("expected input layout and target layout be WgLayout or "

1102 "SgLayout at the same time.");

1103

1104 auto shape = getSource().getType().getShape();

1105 if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))

1107 "invalid input layout, data cannot be evenly distributed.");

1108

1109 if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))

1111 "invalid target layout, data cannot be evenly distributed.");

1112

1113 return mlir::success();

1114}

1115

1116OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {

1117 if (getInputLayout() == getTargetLayout())

1118 return getSource();

1119 return {};

1120}

1121

1126 if (op.getInputLayout() == op.getTargetLayout()) {

1127 rewriter.replaceOp(op, op.getSource());

1129 }

1130 return failure();

1131 }

1132};

1133

1137}

1138

1139

1140

1141

1145 DistributeLayoutAttr layout) {

1150

1151

1152 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,

1153 nullptr, layout);

1154}

1155

1156LogicalResult LoadMatrixOp::verify() {

1157

1158 auto resTy = dyn_cast(getRes().getType());

1159 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();

1160 MemDescType mdescTy = getMemDesc().getType();

1161

1163 getLayoutAttr(), [&]() { return emitError(); });

1164}

1165

1166

1167

1168

1172 DistributeLayoutAttr layout) {

1177 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,

1178 nullptr, layout);

1179}

1180

1181LogicalResult StoreMatrixOp::verify() {

1182

1183 auto dataTy = dyn_cast(getData().getType());

1184 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();

1185 MemDescType mdescTy = getMemDesc().getType();

1187 getLayoutAttr(), [&]() { return emitError(); });

1188}

1189

1190namespace mlir {

1191#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>

1192}

1193#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>

1194#define GET_OP_CLASSES

1195#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>

p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")

Given a list of lists of parsed operands, populates uniqueOperands with unique operands.

static Type getElementType(Type type)

Determine the element type of type.

static Type getValueType(Attribute attr)

static ArrayRef< int64_t > getShape(Type type)

Returns the shape of the given type.

static SmallVector< int64_t > getShapeOf(Type type)

Definition XeGPUOps.cpp:52

LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)

Definition XeGPUOps.cpp:177

static std::string makeString(T array, bool breakline=false)

Definition XeGPUOps.cpp:38

static bool isWriteHintOrNone(const CachePolicyAttr &attr)

Definition XeGPUOps.cpp:69

static bool isReadHintOrNone(const CachePolicyAttr &attr)

Definition XeGPUOps.cpp:61

static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)

Definition XeGPUOps.cpp:125

static void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers)

Definition XeGPUOps.cpp:450

static bool isSharedMemory(const MemRefType &memrefTy)

Definition XeGPUOps.cpp:26

static ParseResult parseOptionalDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)

Definition XeGPUOps.cpp:412

static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, function_ref< InFlightDiagnostic()> emitError)

Definition XeGPUOps.cpp:78

Delimiter

These are the supported delimiters around operand lists and region argument lists,...

@ Square

Square brackets surrounding zero or more operands.

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0

Parse a list of comma-separated items with an optional delimiter.

virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0

Emit a diagnostic at the specified location and return failure.

virtual ParseResult parseRSquare()=0

Parse a ] token.

ParseResult parseInteger(IntT &result)

Parse an integer value from the stream.

virtual ParseResult parseColonType(Type &result)=0

Parse a colon followed by a type.

virtual SMLoc getNameLoc() const =0

Return the location of the original name token.

virtual ParseResult parseOptionalLSquare()=0

Parse a [ token if present.

Attributes are known-constant values of operations.

DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)

MLIRContext * getContext() const

This class represents a diagnostic that is inflight and set to be reported.

MLIRContext is the top-level object for a collection of MLIR operations.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single operand if present.

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

This class helps build Operations.

This class represents a single result from folding an operation.

This class implements the operand iterators for the Operation class.

Operation is the basic unit of execution within MLIR.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isInteger() const

Return true if this is an integer type (with the specified width).

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Location getLoc() const

Return the location of this value.

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given memref value.

Include the generated interface declarations.

InFlightDiagnostic emitWarning(Location loc)

Utility method to emit a warning message using this location.

OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)

Convert int64_t to integer attributes of index type and return them as OpFoldResult.

detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue

If Ty is mlir::Type this will select Value instead of having a wrapper around it.

const FrozenRewritePatternSet & patterns

void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)

Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

llvm::function_ref< Fn > function_ref

void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)

Printer hooks for custom directive in assemblyFormat.

Definition XeGPUOps.cpp:1122

LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override

Definition XeGPUOps.cpp:1124

This is the representation of an operand reference.

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

This represents an operation in an abstracted form, suitable for use with the builder APIs.