MLIR: lib/Dialect/XeGPU/IR/XeGPUOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
18
19#include "llvm/Support/Debug.h"
20
21#define DEBUG_TYPE "xegpu"
22
23using namespace mlir;
25
27 Attribute attr = memrefTy.getMemorySpace();
28 if (auto intAttr = llvm::dyn_cast(attr))
29 return intAttr.getInt() == 3;
30 if (auto memrefSpace = llvm::dyn_cast(attr))
31 return memrefSpace.getValue() == MemorySpace::SLM;
32 if (auto xevmSpace = llvm::dyn_castxevm::AddrSpaceAttr(attr))
33 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
34 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
35}
36
37template
38static std::string makeString(T array, bool breakline = false) {
39 std::string buf;
40 buf.clear();
41 llvm::raw_string_ostream os(buf);
42 os << "[";
43 for (size_t i = 1; i < array.size(); i++) {
44 os << array[i - 1] << ", ";
45 if (breakline)
46 os << "\n\t\t";
47 }
48 os << array.back() << "]";
49 return buf;
50}
51
54 if (auto ty = llvm::dyn_cast(type))
56 else
57 shape.push_back(1);
59}
60
62 if (!attr)
63 return true;
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
67}
68
70 if (!attr)
71 return true;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
75}
76
77static LogicalResult
79 TensorDescType tdescTy,
81
82 if (!tdescTy.isScattered())
83 return emitError() << "Expects a scattered TensorDesc.";
84
85 auto chunkSize = tdescTy.getChunkSizeAsInt();
86 if (!valueTy) {
87 if (chunkSize > 1)
88 return emitError() << "Expecting chunk size == 1 for scalar result";
89 if (dyn_cast(maskTy))
90 return emitError() << "Expecting a vector type result.";
92 }
93
95 auto valueShape = getShapeOf(valueTy);
96 auto tdescShape = getShapeOf(tdescTy);
97
98 if (valueTy.getElementType() != tdescTy.getElementType())
100 << "Value should have the same element type as TensorDesc.";
101
103 if (chunkSize > 1)
104 expectedMaskShape.pop_back();
105 if (expectedMaskShape != maskShape)
107 << "Mask should match TensorDesc except the chunk size dim.";
108
109
110 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
111 if (tdescTy.getLayoutAttr())
112 return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
114 }
115
116 if (tdescShape != valueShape)
118 << " is neither a valid distribution for SIMT nor "
119 "consistent with the tensor descriptor for SIMD "
120 << tdescTy;
122}
123
124static LogicalResult
126 VectorType valueTy, int64_t chunkSize,
128
129 auto maskVecTy = dyn_cast(maskTy);
130 auto offsetsVecTy = dyn_cast(offsetsTy);
131 if (!valueTy) {
132 if (chunkSize > 1)
133 return emitError() << "Expecting chunk size == 1 for scalar result";
134 if (maskVecTy || offsetsVecTy)
135 return emitError() << "Expecting scalar mask and offsets.";
136 else if (maskVecTy && offsetsVecTy)
137 return emitError() << "Expecting a vector type result.";
139 }
140
141 auto valueSize = valueTy.getNumElements();
142
143 if (!maskVecTy && !offsetsVecTy) {
144 if (valueSize != chunkSize)
145 return emitError() << "value elements must match chunk size "
146 << chunkSize;
148 }
150 auto valueShape = getShapeOf(valueTy);
151
152 if (!maskVecTy)
153 return emitError() << "Expecting a vector type mask.";
154 int64_t maskSize = maskVecTy.getNumElements();
155
156 if (chunkSize > 1) {
157 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
158 return emitError() << "value elements must match chunk size "
159 << chunkSize;
160 } else {
161 if (valueSize != maskSize)
163 << "Mask should match value except the chunk size dim.";
164 }
166 if (maskSize == 1)
168 if (chunkSize > 1)
169 expectedMaskShape.pop_back();
170 if (expectedMaskShape != maskShape)
171 return emitError() << "Mask should match value except the chunk size dim.";
172
174}
175
176LogicalResult
178 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
180
181 if (!dataTy) {
182 if (subgroup_block_io)
183 return emitError() << "subgroup_block_io "
184 "are only allowed when result is a VectorType.";
185 else
187 }
188
189 if (mdescTy.getRank() != 2)
190 return emitError() << "mem_desc must be 2D.";
191
194
196 ArrayAttr strideAttr = mdescTy.getStrideAttr();
198 for (Attribute attr : strideAttr.getValue()) {
199 strides.push_back(cast(attr).getInt());
200 }
201 if (subgroup_block_io && layout) {
202 auto laneData = layout.getEffectiveLaneDataAsInt();
203 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
204 if (!laneData.empty()) {
205 bool isLaneDataContiguous =
206 std::all_of(laneData.begin(), std::prev(laneData.end()),
207 [](int x) { return x == 1; });
208 if (!isLaneDataContiguous)
209 return emitError() << "With subgroup_block_io, accessed data must be "
210 "contiguous and coalesced.";
211 for (size_t i = 0; i < laneData.size(); ++i) {
212 if (laneLayout[i] != blockShape[i])
213 return emitError() << "With subgroup_block_io, the block shape must "
214 "match the lane layout.";
215 if (laneLayout[i] != 1 && strides[i] != 1)
216 return emitError() << "With subgroup_block_io, the distributed "
217 "dimensions must be contiguous.";
218 }
219 }
220 }
221 if (dataShape.size() == 2) {
222 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
223 [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
224 return emitError() << "data shape must not exceed mem_desc shape.";
225 } else {
226
227
228 if (subgroup_block_io && !blockShape.size())
229 return emitError() << "mem_desc must have block attribute when "
230 "subgroup_block_io is set.";
231
232
233 if (subgroup_block_io && mdescTy.isColMajor())
234 return emitError() << "mem_desc should be row major when "
235 "subgroup_block_io is set.";
236 }
237
239}
240
241
242
243
244
247 [[maybe_unused]] auto ty = source.getType();
248 assert(ty.hasStaticShape() && "expecting a memref with static shape");
249
250 build(builder, state, tdesc, source, ValueRange({}) ,
251 ValueRange({}) ,
252 ValueRange({}) ,
256}
257
263 assert((isa<IntegerType, MemRefType>(srcTy)) &&
264 "Source has to be either int or memref.");
265
268
271
274
277
278 if (auto memrefTy = dyn_cast(srcTy)) {
279 auto memrefShape = memrefTy.getShape();
280 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
281
282
283
284
285 if (staticShape == memrefShape && staticStrides == memrefStrides &&
286 dynamicShape.empty() && dynamicStrides.empty()) {
289 }
290 }
291
292 build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
294 staticStridesAttr);
295}
296
300 [[maybe_unused]] auto ty = source.getType();
301 assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
302
306
307 build(builder, state, tdesc, source, dynamicOffsets ,
308 ValueRange({}) ,
309 ValueRange({}) ,
311 {} , {} );
312}
313
319 assert(.empty() && !offsets.empty() && !strides.empty() &&
320 shape.size() == strides.size() && shape.size() == offsets.size());
321
323 assert((isa<IntegerType, MemRefType>(srcTy)) &&
324 "Source has to be either int or memref.");
325
329
333
337
341
342 if (auto memrefTy = dyn_cast(srcTy)) {
343 auto memrefShape = memrefTy.getShape();
344 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
345
346
347
348
349 if (staticShape == memrefShape && staticStrides == memrefStrides &&
350 dynamicShape.empty() && dynamicStrides.empty()) {
353 }
354 }
355
356 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
357 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
358}
359
360LogicalResult CreateNdDescOp::verify() {
362 bool invalidRank = rank != getMixedStrides().size();
363 bool invalidElemTy = false;
364
365
366
367
368
369 auto srcMemorySpace = getSourceMemorySpace();
370 auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
371 if (srcMemorySpace != tdescMemorySpace)
372 return emitOpError("Memory space mismatch.")
373 << " Source: " << srcMemorySpace
374 << ", TensorDesc: " << tdescMemorySpace;
375
376 if (size_t offsetRank = getMixedOffsets().size())
377 invalidRank |= (offsetRank != rank);
378
379
380
381 if (auto memrefTy = dyn_cast(getSourceType()))
382 invalidElemTy |= memrefTy.getElementType() != getElementType();
383
384 if (llvm::isa(getSourceType())) {
385
386 if (getMixedStrides().empty() || getMixedSizes().empty())
387 return emitOpError("expecting strides and shape to be present for "
388 "integer source.");
389 }
390
391 if (invalidRank)
393 "Expecting the rank of shape, strides, offsets, and source (if source "
394 "is a memref) should match with each other.");
395
396
399 "Expecting the TensorDesc rank is not greater than the "
400 "ranks of shape, strides, offsets or the memref source.");
401
402 if (invalidElemTy)
403 return emitOpError("TensorDesc should have the same element "
404 "type with the source if it is a memref.\n");
405
406 if (getType().isScattered())
407 return emitOpError("Expects a non-scattered TensorDesc.\n");
408
410}
411
417
419 auto parseIntegerOrValue = [&]() {
422
423 if (res.has_value() && succeeded(res.value())) {
424 values.push_back(operand);
425 integerVals.push_back(ShapedType::kDynamic);
426 if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
427 return failure();
428 } else {
431 return failure();
432 integerVals.push_back(integer);
433 }
435 };
436
437
442 << "expected a list of SSA values or integers";
445 }
446
448}
449
453 if (!integers || integers.empty())
454 return;
457}
458
459
460
461
463 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
464 xegpu::CachePolicyAttr l2_hint,
465 xegpu::CachePolicyAttr l3_hint) {
466
468 l1_hint, l2_hint, l3_hint, nullptr);
469}
470
473 xegpu::CachePolicyAttr l1_hint,
474 xegpu::CachePolicyAttr l2_hint,
475 xegpu::CachePolicyAttr l3_hint,
476 xegpu::DistributeLayoutAttr layout) {
480
482
483 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
484 l2_hint, l3_hint, layout);
485}
486
487LogicalResult PrefetchNdOp::verify() {
488 auto tdescTy = getTensorDescType();
489 if (tdescTy.isScattered())
490 return emitOpError("Expects a non-scattered TensorDesc.\n");
491
493 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
494
496 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
497
499 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
500
501 int64_t tDescRank = tdescTy.getRank();
502 int64_t offsetSize = getMixedOffsets().size();
503 if (offsetSize != 0 && offsetSize != tDescRank)
505 "Mismatched ranks between offsets and tensor descriptor");
506
508}
509
510
511
512
513
515 Value tensorDesc, UnitAttr packed,
517 xegpu::CachePolicyAttr l1_hint,
518 xegpu::CachePolicyAttr l2_hint,
519 xegpu::CachePolicyAttr l3_hint) {
520
521 return build(builder, state, retType, tensorDesc, ValueRange(),
523 l3_hint, nullptr);
524}
525
529 xegpu::CachePolicyAttr l1_hint,
530 xegpu::CachePolicyAttr l2_hint,
531 xegpu::CachePolicyAttr l3_hint,
532 xegpu::DistributeLayoutAttr layout) {
536
538
539 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
540 packed, transpose, l1_hint, l2_hint, l3_hint,
541 layout);
542}
543
544LogicalResult LoadNdOp::verify() {
545 auto tdescTy = getTensorDescType();
546 auto valueTy = getType();
547
548 if (tdescTy.isScattered())
549 return emitOpError("Expects a non-scattered TensorDesc.\n");
550
551 if (tdescTy.getRank() > 2)
552 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
553
554 if (!valueTy)
555 return emitOpError("Invalid result, it should be a VectorType.\n");
556
558 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
559
561 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
562
564 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
565
566 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
567 int valueElems = valueTy.getNumElements();
568
569
570
571
572 if (valueElems < tdescElems && valueTy.getRank() == 1) {
573
574 if (tdescTy.getLayoutAttr())
576 << "TensorDesc doesn't need LayoutAttr for SIMT code";
577
578
579
580
581 if (tdescElems % valueElems)
584 << " is not a valid distribution for tensor descriptor "
585 << tdescTy;
586
588 }
589
590
591 auto tdescShape = getShapeOf(tdescTy);
592 auto valueShape = getShapeOf(valueTy);
593
594 if (getTranspose()) {
595 auto trans = getTranspose().value();
596
597 if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
599 else
600 mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
601 }
602
603 if (getPacked()) {
604 if (tdescTy.getRank() == 2) {
605 const int axis = 0;
606 auto vnni_factor = valueShape.back();
607 tdescShape[axis] /= vnni_factor;
608 tdescShape.push_back(vnni_factor);
609 } else {
611 << "Invalid Packed Attr. It is ignored (available for 2D "
612 "TensorDesc only).";
613 }
614 }
615
616 auto array_len = tdescTy.getArrayLength();
617 if (array_len > 1)
618 tdescShape.insert(tdescShape.begin(), array_len);
619
620 if (tdescShape != valueShape)
622 << " is not consistent with tensor descriptor "
623 << tdescTy;
624
625 int64_t tDescRank = tdescTy.getRank();
626 int64_t offsetSize = getMixedOffsets().size();
627 if (offsetSize != 0 && offsetSize != tDescRank)
629 "Mismatched ranks between offsets and tensor descriptor");
630
632}
633
634
635
636
637
639 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
640 xegpu::CachePolicyAttr l2_hint,
641 xegpu::CachePolicyAttr l3_hint) {
642
643 return build(builder, state, value, tensorDesc, ValueRange(),
645 nullptr);
646}
647
650 xegpu::CachePolicyAttr l1_hint,
651 xegpu::CachePolicyAttr l2_hint,
652 xegpu::CachePolicyAttr l3_hint,
653 xegpu::DistributeLayoutAttr layout) {
657
659
660 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
661 l1_hint, l2_hint, l3_hint, layout);
662}
663
664LogicalResult StoreNdOp::verify() {
665 auto dstTy = getTensorDescType();
667
668 if (dstTy.isScattered())
669 return emitOpError("Expects a non-scattered TensorDesc.\n");
670
671 if (dstTy.getRank() > 2)
672 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
673
674 if (!valTy)
675 return emitOpError("Expecting a VectorType result.\n");
676
678 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
679
681 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
682
684 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
685
686 auto array_len = dstTy.getArrayLength();
687 if (array_len > 1)
688 return emitOpError("array length is not supported by store_nd.\n");
689
690 auto tdescElems = dstTy.getNumElements();
691 auto valueElems = valTy.getNumElements();
692
693
694
695
696 if (valTy.getRank() == 1 && valueElems < tdescElems) {
697
698 if (dstTy.getLayoutAttr())
700 << "TensorDesc doesn't need LayoutAttr for SIMT code";
701
702 if (tdescElems % valueElems)
705 << " is not a valid distribution for tensor descriptor " << dstTy;
706
708 }
709
710
713 if (tdescShape != valueShape)
715 << " is not consistent with tensor descriptor "
716 << dstTy;
717
718 int64_t tDescRank = dstTy.getRank();
719 int64_t offsetSize = getMixedOffsets().size();
720 if (offsetSize != 0 && offsetSize != tDescRank)
722 "Mismatched ranks between offsets and tensor descriptor");
723
725}
726
727
728
729
730LogicalResult UpdateNdOffsetOp::verify() {
731 auto ty = getTensorDescType();
732 if (ty.isScattered())
733 return emitOpError("Expects a non-scattered TensorDesc.\n");
734
735
736 if (ty.getRank() != (int64_t)getNumOffsets()) {
737 return emitOpError("Invalid number of offsets.");
738 }
740}
741
742
743
744
745
747 TensorDescType TensorDesc, Value source,
749 auto loc = source.getLoc();
751 auto type = VectorType::get(size, builder.getIndexType());
753 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
754 build(builder, state, TensorDesc, source, offset);
755}
756
758 TensorDescType TensorDesc, Value source,
761 build(builder, state, TensorDesc, source, ofrs);
762}
763
764LogicalResult CreateDescOp::verify() {
765 auto tdescTy = getTensorDescType();
766
767 if (!tdescTy.isScattered())
768 return emitOpError("Expects a scattered TensorDesc.\n");
769
770
771
772
773
774 auto srcMemorySpace = getSourceMemorySpace();
775 auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
776 if (srcMemorySpace != tdescMemorySpace)
777 return emitOpError("Memory space mismatch.")
778 << " Source: " << srcMemorySpace
779 << ", TensorDesc: " << tdescMemorySpace;
780
781
782 auto chunkSize = tdescTy.getChunkSizeAsInt();
784 if (chunkSize != 1)
785 shape.push_back(chunkSize);
786
787 auto tdescShape = getShapeOf(tdescTy);
788 if (shape != tdescShape)
789 return emitOpError("Incorrect TensorDesc shape. ")
791
793}
794
795
796
797
798LogicalResult PrefetchOp::verify() {
799 auto tdescTy = getTensorDescType();
800
801 if (!tdescTy && !getOffsets())
803
804 if (tdescTy && getOffsets())
805 return emitOpError("offsets not allowed.");
806
807 if (tdescTy && !tdescTy.isScattered())
808 return emitOpError("Expects a scattered TensorDesc.");
809
811 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
812
814 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
815
817 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
818
819 auto srcTy = getSourceType();
820 if (srcTy.isInteger() && !getOffsetAlignByteAttr())
821 return emitOpError("offset_align_byte is required with integer source.");
822
823 if (getOffsetAlignByteAttr() && !srcTy.isInteger())
824 return emitOpError("offset_align_byte only allowed with integer source.");
825
827}
828
830 xegpu::CachePolicyAttr l1_hint,
831 xegpu::CachePolicyAttr l2_hint,
832 xegpu::CachePolicyAttr l3_hint) {
833 build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
834 IntegerAttr{}, nullptr);
835}
836
837
838
839
840LogicalResult LoadGatherOp::verify() {
841 auto tdescTy = getTensorDescType();
842 auto maskTy = getMaskType();
844
845 if (!tdescTy && !getOffsets())
847
848 if (tdescTy && getOffsets())
849 return emitOpError("offsets not allowed.");
850
851 if (tdescTy && !tdescTy.isScattered())
852 return emitOpError("Expects a scattered TensorDesc.");
853
855 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
856
858 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
859
861 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
862
863 if (tdescTy)
866 auto srcTy = getSourceType();
867 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
868 auto memTy = dyn_cast(srcTy);
869
870 if (memTy && (getElementType() != memTy.getElementType()))
871 return emitError() << "Value should have the same element type as MemRef.";
872
873 auto offsetsTy = getOffsets().getType();
876}
877
880 xegpu::CachePolicyAttr l1_hint,
881 xegpu::CachePolicyAttr l2_hint,
882 xegpu::CachePolicyAttr l3_hint) {
883 build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
884 l1_hint, l2_hint, l3_hint, nullptr);
885}
886
890 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
891 xegpu::CachePolicyAttr l2_hint,
892 xegpu::CachePolicyAttr l3_hint) {
893 auto loc = source.getLoc();
895 auto type = VectorType::get(size, builder.getIndexType());
897 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
898
899 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
900 l2_hint, l3_hint, nullptr);
901}
902
906 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
907 xegpu::CachePolicyAttr l2_hint,
908 xegpu::CachePolicyAttr l3_hint,
909 DistributeLayoutAttr layout) {
910 auto loc = source.getLoc();
912 auto type = VectorType::get(size, builder.getIndexType());
914 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
915
916 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
917 l2_hint, l3_hint, layout);
918}
919
920
921
922
923LogicalResult StoreScatterOp::verify() {
924 auto tdescTy = getTensorDescType();
925 auto maskTy = getMaskType();
927
928 if (!tdescTy && !getOffsets())
930
931 if (tdescTy && getOffsets())
932 return emitOpError("offsets not allowed.");
933
934 if (tdescTy && !tdescTy.isScattered())
935 return emitOpError("Expects a scattered TensorDesc.");
936
938 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
939
941 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
942
944 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
945
946 if (tdescTy)
949
950 auto destTy = getDestType();
951 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
952 auto memTy = dyn_cast(destTy);
953
954 if (memTy && (getElementType() != memTy.getElementType()))
955 return emitError() << "Value should have the same element type as MemRef.";
956
957 auto offsetsTy = getOffsets().getType();
960}
961
964 xegpu::CachePolicyAttr l1_hint,
965 xegpu::CachePolicyAttr l2_hint,
966 xegpu::CachePolicyAttr l3_hint) {
967 build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
968 l2_hint, l3_hint, nullptr);
969}
970
974 IntegerAttr chunk_size,
975 xegpu::CachePolicyAttr l1_hint,
976 xegpu::CachePolicyAttr l2_hint,
977 xegpu::CachePolicyAttr l3_hint) {
978 auto loc = dest.getLoc();
980 auto type = VectorType::get(size, builder.getIndexType());
982 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
983
984
985 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
986 l3_hint, nullptr);
987}
988
989void StoreScatterOp::build(
992 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
993 xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
994 auto loc = dest.getLoc();
996 auto type = VectorType::get(size, builder.getIndexType());
998 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
999
1000
1001 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
1002 l3_hint, layout);
1003}
1004
1005
1006
1007
1011 auto tdescTy = mlir::dyn_cast(tensorDesc.getType());
1012 assert(tdescTy && "Expecting the source is a TensorDescType value.");
1013 auto loc = tensorDesc.getLoc();
1014 int64_t size = static_cast<int64_t>(offsets.size());
1015 auto type = VectorType::get({size}, builder.getIndexType());
1017 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
1018 build(builder, state, tdescTy, tensorDesc, offset);
1019}
1020
1024 build(builder, state, tensorDesc, ofrs);
1025}
1026
1027LogicalResult UpdateOffsetOp::verify() {
1028 auto tdescTy = getTensorDescType();
1029 if (!tdescTy.isScattered())
1030 return emitOpError("Expects a scattered TensorDesc.\n");
1031
1034 if (tdescTy.getChunkSizeAsInt() > 1)
1035 expectedOffsetShape.pop_back();
1036
1037 if (expectedOffsetShape != offsetShape)
1039 "Offsets should match TensorDesc except the chunk size dim.");
1040
1042}
1043
1044
1045
1046
1047LogicalResult DpasOp::verify() {
1048 int64_t lhsRank = getLhsType().getRank();
1049 int64_t rhsRank = getRhsType().getRank();
1050 int64_t resRank = getResultType().getRank();
1051 auto lhsShape = getLhsType().getShape();
1052 auto rhsShape = getRhsType().getShape();
1053 auto resShape = getResultType().getShape();
1054
1055 if (getAcc() && getAcc().getType() != getResultType())
1056 return emitOpError("Expecting the acc type to be the same as result.");
1057
1058
1059
1060
1061 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
1062 auto numElems = getRhsType().getNumElements();
1063 auto elemTy = getRhsType().getElementType();
1064 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
1065 if (numElems % factor != 0)
1066 return emitOpError("Expecting B operand to be a multiple of 32 bits.");
1068 }
1069
1070
1071 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
1073 "expecting lhs and result to be a 2D vector, and rhs to be either "
1074 "2D or 3D (packed) vector.");
1075 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
1076 if (bK != lhsShape[1])
1077 return emitOpError("K-dimension mismatch.");
1078 if (lhsShape[0] != resShape[0])
1079 return emitOpError("M-dimension mismatch.");
1080 if (rhsShape[1] != resShape[1])
1081 return emitOpError("N-dimension mismatch.");
1082
1084}
1085
1086
1087
1088
1089LogicalResult ConvertLayoutOp::verify() {
1090 auto srcLayout = getInputLayout();
1091 auto resLayout = getTargetLayout();
1092 if (!srcLayout)
1093 return emitOpError("expected input layout.");
1094 if (!resLayout)
1095 return emitOpError("expected target layout.");
1096
1097
1098
1099 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
1100 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
1101 return emitOpError("expected input layout and target layout be WgLayout or "
1102 "SgLayout at the same time.");
1103
1104 auto shape = getSource().getType().getShape();
1105 if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
1107 "invalid input layout, data cannot be evenly distributed.");
1108
1109 if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
1111 "invalid target layout, data cannot be evenly distributed.");
1112
1113 return mlir::success();
1114}
1115
1116OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {
1117 if (getInputLayout() == getTargetLayout())
1118 return getSource();
1119 return {};
1120}
1121
1126 if (op.getInputLayout() == op.getTargetLayout()) {
1127 rewriter.replaceOp(op, op.getSource());
1129 }
1130 return failure();
1131 }
1132};
1133
1137}
1138
1139
1140
1141
1145 DistributeLayoutAttr layout) {
1150
1151
1152 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1153 nullptr, layout);
1154}
1155
1156LogicalResult LoadMatrixOp::verify() {
1157
1158 auto resTy = dyn_cast(getRes().getType());
1159 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1160 MemDescType mdescTy = getMemDesc().getType();
1161
1163 getLayoutAttr(), [&]() { return emitError(); });
1164}
1165
1166
1167
1168
1172 DistributeLayoutAttr layout) {
1177 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1178 nullptr, layout);
1179}
1180
1181LogicalResult StoreMatrixOp::verify() {
1182
1183 auto dataTy = dyn_cast(getData().getType());
1184 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1185 MemDescType mdescTy = getMemDesc().getType();
1187 getLayoutAttr(), [&]() { return emitError(); });
1188}
1189
1190namespace mlir {
1191#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
1192}
1193#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
1194#define GET_OP_CLASSES
1195#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static Type getValueType(Attribute attr)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static SmallVector< int64_t > getShapeOf(Type type)
Definition XeGPUOps.cpp:52
LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:177
static std::string makeString(T array, bool breakline=false)
Definition XeGPUOps.cpp:38
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:69
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:61
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:125
static void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers)
Definition XeGPUOps.cpp:450
static bool isSharedMemory(const MemRefType &memrefTy)
Definition XeGPUOps.cpp:26
static ParseResult parseOptionalDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Definition XeGPUOps.cpp:412
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:78
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Attributes are known-constant values of operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::function_ref< Fn > function_ref
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
Definition XeGPUOps.cpp:1122
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override
Definition XeGPUOps.cpp:1124
This is the representation of an operand reference.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.