MLIR: lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
16
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Debug.h"
24 #include
25
26 namespace mlir {
27 namespace xegpu {
28 #define GEN_PASS_DEF_XEGPUUNROLL
29 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30 }
31 }
32
33 #define DEBUG_TYPE "xegpu-unroll"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
36
37 using namespace mlir;
38
39 namespace {
40
41 template
46
47 protected:
48
49
52 LDBG("Get unroll shape for: " << *op);
53
54 if (options.filterConstraint && failed(options.filterConstraint(op))) {
55 LDBG("--no filter constraint -> BAIL");
56 return std::nullopt;
57 }
58
59 assert(options.nativeShape &&
60 "expects the native shape for native shape call back function.");
61 auto nativeShape = options.nativeShape(op);
62 return nativeShape;
63 }
64
67 return options.getUnrolledTypes(type, tileShape);
68 }
69
70
71
74 if (auto vecTy = dyn_cast(destTy)) {
75 assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
76 "Expecting blockSize size to match the rank of destTy.");
77 auto shape = vecTy.getShape();
79 }
80
81 if (isaxegpu::TensorDescType(destTy)) {
86 auto castOp = rewriter.create(
88 return castOp.getResult(0);
89 }
90
91 llvm_unreachable("Unexpected destTy.");
93 }
94
95
96
100 if (auto vecTy = dyn_cast(src.getType())) {
101 assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
102 "Expecting blockSize size to match the rank of src.");
104 blockSize);
105 }
106
107 if (isaxegpu::TensorDescType(src.getType())) {
112 auto castOp = rewriter.create(
114 return castOp.getResults();
115 }
116
117 llvm_unreachable("Unexpected src type.");
119 }
120
121 private:
122 const char *const packAttrName = "__xegpu_blocking_pack__";
123 const char *const unpackAttrName = "__xegpu_blocking_unpack__";
124 const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
125
127 };
128
129 struct UnrollCreateNdOp : public UnrollPatternxegpu::CreateNdDescOp {
130 using UnrollPatternxegpu::CreateNdDescOp::UnrollPattern;
131 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
134 xegpu::TensorDescType tdescTy = op.getType();
135 int64_t rank = tdescTy.getRank();
137
138 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
139 if (!targetShape)
140 return failure();
141
142 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
143
146 if (maybeInt) {
147 return rewriter.createarith::ConstantIndexOp(loc, *maybeInt + b);
148 } else {
149 auto aV = llvm::cast(a);
150 auto bV = rewriter.createarith::ConstantIndexOp(loc, b);
151 return rewriter.createOrFoldarith::AddIOp(loc, aV, bV);
152 }
153 };
154
156
157
158
160 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
161 auto validIdxes =
162 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
163
167
168 for (auto [idx, oldOff, offset] :
169 llvm::zip(validIdxes, oldOffsets, offsets))
170 mixedOffsets[idx] = addi(oldOff, offset);
171
172 auto newOp = rewriter.createxegpu::CreateNdDescOp(
173 loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(),
174 op.getMixedStrides());
175 newOps.push_back(newOp);
176 }
177 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
179
180 return success();
181 }
182 };
183
184 struct UnrollUpdateNdOffsetOp : public UnrollPatternxegpu::UpdateNdOffsetOp {
185 using UnrollPatternxegpu::UpdateNdOffsetOp::UnrollPattern;
186 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
189 xegpu::TensorDescType tdescTy = op.getTensorDescType();
190
191 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
192 if (!targetShape)
193 return failure();
194
196 getUnrolledTypes(tdescTy, *targetShape);
198 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
199
201 for (auto t : convertedTdesc) {
202 auto newOp = rewriter.createxegpu::UpdateNdOffsetOp(
203 loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
204 newOps.push_back(newOp);
205 }
206 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
208 return success();
209 }
210 };
211
212 struct UnrollPrefetchNdOp : public UnrollPatternxegpu::PrefetchNdOp {
213 using UnrollPatternxegpu::PrefetchNdOp::UnrollPattern;
214 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
217 xegpu::TensorDescType tdescTy = op.getTensorDescType();
218
219 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
220 if (!targetShape)
221 return failure();
222
224 getUnrolledTypes(tdescTy, *targetShape);
226 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
227
228 for (auto t : convertedTdesc)
229 rewriter.createxegpu::PrefetchNdOp(loc, TypeRange(), t, op->getAttrs());
230
232 return success();
233 }
234 };
235
236 struct UnrollLoadNdOp : public UnrollPatternxegpu::LoadNdOp {
237 using UnrollPatternxegpu::LoadNdOp::UnrollPattern;
238 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
240
242 VectorType valueTy = op.getType();
243 xegpu::TensorDescType tdescTy = op.getTensorDescType();
244
245 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
246 if (!targetShape)
247 return failure();
248
249 Type elemTy = tdescTy.getElementType();
250 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
251
253 getUnrolledTypes(tdescTy, *targetShape);
255 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
256
258 for (auto t : convertedTdescs) {
259 auto newOp =
260 rewriter.createxegpu::LoadNdOp(loc, newValueTy, t, op->getAttrs());
261 newOps.push_back(newOp);
262 }
263
264 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
265
267 return success();
268 }
269 };
270
271 struct UnrollStoreNdOp : public UnrollPatternxegpu::StoreNdOp {
272 using UnrollPatternxegpu::StoreNdOp::UnrollPattern;
273 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
276 VectorType valueTy = op.getValueType();
277 xegpu::TensorDescType tdescTy = op.getTensorDescType();
278
279 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
280 if (!targetShape)
281 return failure();
282
284 getUnrolledTypes(valueTy, *targetShape);
286 getUnrolledTypes(tdescTy, *targetShape);
287
289 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
291 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
292
293 for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
294 rewriter.createxegpu::StoreNdOp(loc, v, t, op.getL1HintAttr(),
295 op.getL2HintAttr(), op.getL3HintAttr());
296
298 return success();
299 }
300 };
301
302 struct UnrollDpasOp : public UnrollPatternxegpu::DpasOp {
303 using UnrollPatternxegpu::DpasOp::UnrollPattern;
304 LogicalResult matchAndRewrite(xegpu::DpasOp op,
307
308
309 if (llvm::any_of(op->getOperandTypes(), [&](Type type) {
310 auto vecTy = dyn_cast(type);
311 return !vecTy || vecTy.getRank() != 2;
312 }))
313 return failure();
314
315
316
317 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
318 if (!targetShape || targetShape->size() != 3)
319 return failure();
320 auto M = (*targetShape)[0];
321 auto K = (*targetShape)[1];
322 auto N = (*targetShape)[2];
323
324 int64_t aBlockSize[2] = {M, K};
325 int64_t bBlockSize[2] = {K, N};
326 int64_t cBlockSize[2] = {M, N};
327
330 VectorType type = val.getType();
331 std::optional<SmallVector<int64_t>> grids =
333 assert(grids && "Expecting grids to be computed.");
335 if (numNewOps == 1)
337 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
340 pack(val, convertedTypes, blockSize, loc, rewriter);
341 return values;
342 };
343
344 auto a = op.getLhs();
345 auto b = op.getRhs();
346 auto c = op.getAcc();
347
348 auto aShape = a.getType().getShape();
349 auto bShape = b.getType().getShape();
350
352 aVals = packWrapper(a, aBlockSize);
353 bVals = packWrapper(b, bBlockSize);
354
355 if (c)
356 cVals = packWrapper(c, cBlockSize);
357
358
359
362 if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
363 llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
364 return failure();
365
366 VectorType resultTy = op.getResult().getType();
367 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
368
369 int64_t mIters = aShape[0] / M;
370 int64_t kIters = aShape[1] / K;
371 int64_t nIters = bShape[1] / N;
372
374 for (int64_t i = 0; i < mIters; ++i) {
375 for (int64_t j = 0; j < nIters; ++j) {
377 if (c)
378 tmpC = cVals[i * nIters + j];
379
380 for (int64_t k = 0; k < kIters; ++k) {
381 Value aVec = aVals[i * kIters + k];
382 Value bVec = bVals[k * nIters + j];
384 if (tmpC)
385 operands.push_back(tmpC);
386
387 tmpC = rewriter.createxegpu::DpasOp(loc, vecTy, operands,
388 op->getAttrs());
389 }
390 newOps.push_back(tmpC);
391 }
392 }
393 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
395 return success();
396 }
397 };
398
399 }
400
403 patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
404 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
406 }
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
StringAttr getStringAttr(const Twine &bytes)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.