MLIR: lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

16

22 #include "llvm/ADT/STLExtras.h"

23 #include "llvm/Support/Debug.h"

24 #include

25

26 namespace mlir {

27 namespace xegpu {

28 #define GEN_PASS_DEF_XEGPUUNROLL

29 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"

30 }

31 }

32

33 #define DEBUG_TYPE "xegpu-unroll"

34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

35 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

36

37 using namespace mlir;

38

39 namespace {

40

41 template

46

47 protected:

48

49

52 LDBG("Get unroll shape for: " << *op);

53

54 if (options.filterConstraint && failed(options.filterConstraint(op))) {

55 LDBG("--no filter constraint -> BAIL");

56 return std::nullopt;

57 }

58

59 assert(options.nativeShape &&

60 "expects the native shape for native shape call back function.");

61 auto nativeShape = options.nativeShape(op);

62 return nativeShape;

63 }

64

67 return options.getUnrolledTypes(type, tileShape);

68 }

69

70

71

74 if (auto vecTy = dyn_cast(destTy)) {

75 assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&

76 "Expecting blockSize size to match the rank of destTy.");

77 auto shape = vecTy.getShape();

79 }

80

81 if (isaxegpu::TensorDescType(destTy)) {

86 auto castOp = rewriter.create(

88 return castOp.getResult(0);

89 }

90

91 llvm_unreachable("Unexpected destTy.");

93 }

94

95

96

100 if (auto vecTy = dyn_cast(src.getType())) {

101 assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&

102 "Expecting blockSize size to match the rank of src.");

104 blockSize);

105 }

106

107 if (isaxegpu::TensorDescType(src.getType())) {

112 auto castOp = rewriter.create(

114 return castOp.getResults();

115 }

116

117 llvm_unreachable("Unexpected src type.");

119 }

120

121 private:

122 const char *const packAttrName = "__xegpu_blocking_pack__";

123 const char *const unpackAttrName = "__xegpu_blocking_unpack__";

124 const char *const blockAttrName = "__xegpu_blocking_tile_shape__";

125

127 };

128

129 struct UnrollCreateNdOp : public UnrollPatternxegpu::CreateNdDescOp {

130 using UnrollPatternxegpu::CreateNdDescOp::UnrollPattern;

131 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,

134 xegpu::TensorDescType tdescTy = op.getType();

135 int64_t rank = tdescTy.getRank();

137

138 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);

139 if (!targetShape)

140 return failure();

141

142 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];

143

146 if (maybeInt) {

147 return rewriter.createarith::ConstantIndexOp(loc, *maybeInt + b);

148 } else {

149 auto aV = llvm::cast(a);

150 auto bV = rewriter.createarith::ConstantIndexOp(loc, b);

151 return rewriter.createOrFoldarith::AddIOp(loc, aV, bV);

152 }

153 };

154

156

157

158

160 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));

161 auto validIdxes =

162 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());

163

167

168 for (auto [idx, oldOff, offset] :

169 llvm::zip(validIdxes, oldOffsets, offsets))

170 mixedOffsets[idx] = addi(oldOff, offset);

171

172 auto newOp = rewriter.createxegpu::CreateNdDescOp(

173 loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(),

174 op.getMixedStrides());

175 newOps.push_back(newOp);

176 }

177 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);

179

180 return success();

181 }

182 };

183

184 struct UnrollUpdateNdOffsetOp : public UnrollPatternxegpu::UpdateNdOffsetOp {

185 using UnrollPatternxegpu::UpdateNdOffsetOp::UnrollPattern;

186 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,

189 xegpu::TensorDescType tdescTy = op.getTensorDescType();

190

191 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);

192 if (!targetShape)

193 return failure();

194

196 getUnrolledTypes(tdescTy, *targetShape);

198 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

199

201 for (auto t : convertedTdesc) {

202 auto newOp = rewriter.createxegpu::UpdateNdOffsetOp(

203 loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());

204 newOps.push_back(newOp);

205 }

206 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);

208 return success();

209 }

210 };

211

212 struct UnrollPrefetchNdOp : public UnrollPatternxegpu::PrefetchNdOp {

213 using UnrollPatternxegpu::PrefetchNdOp::UnrollPattern;

214 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,

217 xegpu::TensorDescType tdescTy = op.getTensorDescType();

218

219 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);

220 if (!targetShape)

221 return failure();

222

224 getUnrolledTypes(tdescTy, *targetShape);

226 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

227

228 for (auto t : convertedTdesc)

229 rewriter.createxegpu::PrefetchNdOp(loc, TypeRange(), t, op->getAttrs());

230

232 return success();

233 }

234 };

235

236 struct UnrollLoadNdOp : public UnrollPatternxegpu::LoadNdOp {

237 using UnrollPatternxegpu::LoadNdOp::UnrollPattern;

238 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,

240

242 VectorType valueTy = op.getType();

243 xegpu::TensorDescType tdescTy = op.getTensorDescType();

244

245 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);

246 if (!targetShape)

247 return failure();

248

249 Type elemTy = tdescTy.getElementType();

250 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);

251

253 getUnrolledTypes(tdescTy, *targetShape);

255 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

256

258 for (auto t : convertedTdescs) {

259 auto newOp =

260 rewriter.createxegpu::LoadNdOp(loc, newValueTy, t, op->getAttrs());

261 newOps.push_back(newOp);

262 }

263

264 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);

265

267 return success();

268 }

269 };

270

271 struct UnrollStoreNdOp : public UnrollPatternxegpu::StoreNdOp {

272 using UnrollPatternxegpu::StoreNdOp::UnrollPattern;

273 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,

276 VectorType valueTy = op.getValueType();

277 xegpu::TensorDescType tdescTy = op.getTensorDescType();

278

279 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);

280 if (!targetShape)

281 return failure();

282

284 getUnrolledTypes(valueTy, *targetShape);

286 getUnrolledTypes(tdescTy, *targetShape);

287

289 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);

291 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

292

293 for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))

294 rewriter.createxegpu::StoreNdOp(loc, v, t, op.getL1HintAttr(),

295 op.getL2HintAttr(), op.getL3HintAttr());

296

298 return success();

299 }

300 };

301

302 struct UnrollDpasOp : public UnrollPatternxegpu::DpasOp {

303 using UnrollPatternxegpu::DpasOp::UnrollPattern;

304 LogicalResult matchAndRewrite(xegpu::DpasOp op,

307

308

309 if (llvm::any_of(op->getOperandTypes(), [&](Type type) {

310 auto vecTy = dyn_cast(type);

311 return !vecTy || vecTy.getRank() != 2;

312 }))

313 return failure();

314

315

316

317 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);

318 if (!targetShape || targetShape->size() != 3)

319 return failure();

320 auto M = (*targetShape)[0];

321 auto K = (*targetShape)[1];

322 auto N = (*targetShape)[2];

323

324 int64_t aBlockSize[2] = {M, K};

325 int64_t bBlockSize[2] = {K, N};

326 int64_t cBlockSize[2] = {M, N};

327

330 VectorType type = val.getType();

331 std::optional<SmallVector<int64_t>> grids =

333 assert(grids && "Expecting grids to be computed.");

335 if (numNewOps == 1)

337 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());

340 pack(val, convertedTypes, blockSize, loc, rewriter);

341 return values;

342 };

343

344 auto a = op.getLhs();

345 auto b = op.getRhs();

346 auto c = op.getAcc();

347

348 auto aShape = a.getType().getShape();

349 auto bShape = b.getType().getShape();

350

352 aVals = packWrapper(a, aBlockSize);

353 bVals = packWrapper(b, bBlockSize);

354

355 if (c)

356 cVals = packWrapper(c, cBlockSize);

357

358

359

362 if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||

363 llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))

364 return failure();

365

366 VectorType resultTy = op.getResult().getType();

367 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());

368

369 int64_t mIters = aShape[0] / M;

370 int64_t kIters = aShape[1] / K;

371 int64_t nIters = bShape[1] / N;

372

374 for (int64_t i = 0; i < mIters; ++i) {

375 for (int64_t j = 0; j < nIters; ++j) {

377 if (c)

378 tmpC = cVals[i * nIters + j];

379

380 for (int64_t k = 0; k < kIters; ++k) {

381 Value aVec = aVals[i * kIters + k];

382 Value bVec = bVals[k * nIters + j];

384 if (tmpC)

385 operands.push_back(tmpC);

386

387 tmpC = rewriter.createxegpu::DpasOp(loc, vecTy, operands,

388 op->getAttrs());

389 }

390 newOps.push_back(tmpC);

391 }

392 }

393 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);

395 return success();

396 }

397 };

398

399 }

400

403 patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,

404 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(

406 }

static llvm::ManagedStatic< PassManagerOptions > options

static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)

Return the target shape for unrolling for the given op.

DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)

StringAttr getStringAttr(const Twine &bytes)

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

NamedAttribute represents a combination of a name and an Attribute value.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents a single result from folding an operation.

Operation is the basic unit of execution within MLIR.

This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)

Implement packing of a single LinalgOp by packedSizes.

Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)

Create a vector of shape from a set of values using vector.insert_stride_slice.

void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)

Collect a set of patterns to unroll xegpu operations to a smaller shapes.

SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)

Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.

Include the generated interface declarations.

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue

If Ty is mlir::Type this will select Value instead of having a wrapper around it.

int64_t computeProduct(ArrayRef< int64_t > basis)

Self-explicit.

const FrozenRewritePatternSet & patterns

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)

Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

Options to control the XeGPU unrolling.

Eliminates variable at the specified position using Fourier-Motzkin variable elimination.