MLIR: lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/TypeSwitch.h"
18
19 #include
20
21 namespace mlir {
22 #define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
23 #include "mlir/Dialect/Linalg/Passes.h.inc"
24 }
25
26 using namespace mlir;
28
29
32 if (!stride || *stride != 1)
33 return std::nullopt;
35 if (!offset)
36 return std::nullopt;
38 if (!size)
39 return std::nullopt;
40 return (*size - *offset);
41 }
42
43
47 if (dims.size() != tiles.size() || tiles.empty())
48 return false;
49
50 FailureOr contractDims =
52 if (failed(contractDims))
53 return false;
54 unsigned batchDimsOffset = contractDims->batch.size();
55
56
57
59 for (size_t i = 0; i < offsetDims.size(); i++)
60 offsetDims[i] += batchDimsOffset;
61
62 auto tileOp = cast(linalgOp.getOperation());
65 SmallVector iterationDomain = tileOp.getIterationDomain(builder);
66
68 if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
69 return false;
70
71 std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]);
72 std::optional<int64_t> rangeOnDim =
74
75
76
77 if (!tileSize || !rangeOnDim)
78 return false;
79
80
81 if (*rangeOnDim % *tileSize != 0)
82 return false;
83 }
84
85 return true;
86 }
87
88
89 static FailureOr
91 linalg::PackOp packOp, AffineMap operandMap,
93 bool transposeOuterBlocks, bool transposeInnerBlocks) {
94 assert(operandMap.getNumDims() >= 4 &&
95 "expected at least 4D prepacked matmul");
96 assert(blocksStartDimPos.size() >= 2 &&
97 "expected starting outer and inner block positions");
98
99
100 unsigned outerBlockPos = operandMap.getNumResults() - 4;
101 unsigned innerBlockPos = operandMap.getNumResults() - 2;
102
103
104
105
106
107 bool isOuterTransposed =
108 operandMap.getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
109 bool isInnerTransposed =
110 operandMap.getDimPosition(innerBlockPos) != blocksStartDimPos.back();
111
112
113
115 if (isInnerTransposed != transposeInnerBlocks)
116 innerPerm = {1, 0};
118 if (isOuterTransposed != transposeOuterBlocks)
119 outerPerm = {1, 0};
120
121
122
124 for (auto i : llvm::seq(0u, outerBlockPos))
125 offsetPerms.push_back(i);
126 for (auto perm : outerPerm)
127 offsetPerms.push_back(perm + outerBlockPos);
128 outerPerm = offsetPerms;
129
130 FailureOr packTransposedMatmul =
132 nullptr, outerPerm, innerPerm);
133
134 return packTransposedMatmul;
135 }
136
137
138 FailureOr
141
142
143 if (auto *batchMatmulOp = dyn_castlinalg::BatchMatmulOp(&linalgOp)) {
144 if (batchMatmulOp->hasUserDefinedMaps()) {
146 *batchMatmulOp,
147 "only batch_matmul ops with non-extended semantics are supported");
148 }
149 }
150
151 if (linalgOp.hasPureBufferSemantics())
152 return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
153
154 std::optional options = controlPackMatmul(linalgOp);
156 return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
157
158 if (options->blockFactors.size() != 3)
159 return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
160
163
164
165 if (->allowPadding &&
168 "expect packing full tiles only");
169 }
170
172
174
175
176
177
178
180 rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
182 if (failed(packedMatmul))
183 return failure();
184
185 assert(packedMatmul->packOps.size() == 3 &&
186 "invalid number of pack ops after matmul packing");
187 assert(packedMatmul->unPackOps.size() == 1 &&
188 "invalid number of unpack ops after matmul packing");
189
190 FailureOr contractDims =
192 if (failed(contractDims))
193 return failure();
194
195 auto genericOp =
196 dyn_castlinalg::GenericOp(packedMatmul->packedLinalgOp.getOperation());
198
199
201 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
202 contractDims->m, options->lhsTransposeOuterBlocks,
203 options->lhsTransposeInnerBlocks);
204 if (failed(packedLhs))
205 return failure();
206
207
208 packedMatmul->packOps[0] = packedLhs->transposedPackOp;
209 packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
210
211
213 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
214 contractDims->k, options->rhsTransposeOuterBlocks,
215 options->rhsTransposeInnerBlocks);
216 if (failed(packedRhs))
217 return failure();
218
219
220 packedMatmul->packOps[1] = packedRhs->transposedPackOp;
221 packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
222
223 return packedMatmul;
224 }
225
226 namespace {
227 template
231 : OpRewritePattern(context, benefit), controlFn(std::move(fun)) {}
232
233 LogicalResult matchAndRewrite(OpTy linalgOp,
235 FailureOr packedMatmul =
237 if (failed(packedMatmul))
238 return failure();
239 return success();
240 }
241
242 private:
244 };
245
246 template <>
247 struct BlockPackMatmullinalg::GenericOp
252 controlFn(std::move(fun)) {}
253
254 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
256
259 }
260
262 auto infer = [&](MapList m) {
264 };
265
267 bindDims(linalgOp->getContext(), i, j, k);
269
270
271 if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
272 maps == infer({{k, i}, {k, j}, {i, j}}) ||
273 maps == infer({{i, k}, {j, k}, {i, j}}))) {
274 return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
275 }
276
277 FailureOr packedMatmul =
279 if (failed(packedMatmul))
280 return failure();
281 return success();
282 }
283
284 private:
286 };
287
288
289 struct LinalgBlockPackMatmul
290 : public impl::LinalgBlockPackMatmulBase {
291 using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
292
293 void runOnOperation() override {
296
301 options.allowPadding = allowPadding;
302 options.mnkPaddedSizesNextMultipleOf =
304 if (!mnkOrder.empty())
306 options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
307 options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
308 options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
309 options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
311 };
312
315 return signalPassFailure();
316 }
317 };
318 }
319
322 patterns.add<BlockPackMatmullinalg::GenericOp,
323 BlockPackMatmullinalg::MatmulOp,
324 BlockPackMatmullinalg::BatchMatmulOp,
325 BlockPackMatmullinalg::MatmulTransposeAOp,
326 BlockPackMatmullinalg::BatchMatmulTransposeAOp,
327 BlockPackMatmullinalg::MatmulTransposeBOp,
328 BlockPackMatmullinalg::BatchMatmulTransposeBOp>(
329 patterns.getContext(), controlFn);
330 }
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > tiles, ArrayRef< int64_t > dims)
Return true if all dimensions are fully divisible by the respective tiles.
static std::optional< int64_t > getConstantRange(const Range &range)
Return constant range span or nullopt, otherwise.
static FailureOr< PackTransposeResult > transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, linalg::PackOp packOp, AffineMap operandMap, ArrayRef< unsigned > blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks)
Return failure or packed matmul with one of its operands transposed.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn)
Patterns to block pack Linalg matmul ops.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< PackResult > blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const ControlBlockPackMatmulFn &controlPackMatmul)
Pack a matmul operation into blocked 4D layout.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::function< std::optional< BlockPackMatmulOptions >(linalg::LinalgOp)> ControlBlockPackMatmulFn
Function type which is used to control matmul packing.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.