MLIR: lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
21
22 namespace mlir {
23 namespace vector {
24 #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
26 }
27 }
28
29 #define DEBUG_TYPE "vector-multi-reduction"
30
31 using namespace mlir;
32
33 namespace {
34
35
36
37
38
39 class InnerOuterDimReductionConversion
41 public:
43
44 explicit InnerOuterDimReductionConversion(
48 useInnerDimsForReduction(
49 options == vector::VectorMultiReductionLowering::InnerReduction) {}
50
51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
53
55 auto maskableOp =
56 castvector::MaskableOpInterface(multiReductionOp.getOperation());
58 if (maskableOp.isMasked()) {
60 rootOp = maskableOp.getMaskingOp();
61 } else {
62 rootOp = multiReductionOp;
63 }
64
65 auto src = multiReductionOp.getSource();
66 auto loc = multiReductionOp.getLoc();
67 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
68
69
70 ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
71 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
72 reductionDims.end());
73 int64_t reductionSize = reductionDims.size();
75 for (int64_t i = 0; i < srcRank; ++i)
76 if (!reductionDimsSet.contains(i))
77 parallelDims.push_back(i);
78
79
80
81 if (parallelDims.empty())
82 return failure();
83 if (useInnerDimsForReduction &&
84 (parallelDims ==
85 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
86 return failure();
87
88 if (!useInnerDimsForReduction &&
89 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
90 reductionDims.size(),
91 parallelDims.size() + reductionDims.size()))))
92 return failure();
93
95 if (useInnerDimsForReduction) {
96 indices.append(parallelDims.begin(), parallelDims.end());
97 indices.append(reductionDims.begin(), reductionDims.end());
98 } else {
99 indices.append(reductionDims.begin(), reductionDims.end());
100 indices.append(parallelDims.begin(), parallelDims.end());
101 }
102
103
104 Value transposedMask;
105 if (maskableOp.isMasked()) {
106 transposedMask = rewriter.createvector::TransposeOp(
107 loc, maskableOp.getMaskingOp().getMask(), indices);
108 }
109
110
111 auto transposeOp = rewriter.createvector::TransposeOp(loc, src, indices);
113 for (int i = 0; i < reductionSize; ++i) {
114 if (useInnerDimsForReduction)
115 reductionMask[srcRank - i - 1] = true;
116 else
117 reductionMask[i] = true;
118 }
119
120 Operation *newMultiRedOp = rewriter.createvector::MultiDimReductionOp(
121 multiReductionOp.getLoc(), transposeOp.getResult(),
122 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
123 newMultiRedOp =
125
127 return success();
128 }
129
130 private:
131 const bool useInnerDimsForReduction;
132 };
133
134
135
136 class ReduceMultiDimReductionRank
138 public:
140
141 explicit ReduceMultiDimReductionRank(
145 useInnerDimsForReduction(
146 options == vector::VectorMultiReductionLowering::InnerReduction) {}
147
148 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
150
152 auto maskableOp =
153 castvector::MaskableOpInterface(multiReductionOp.getOperation());
155 if (maskableOp.isMasked()) {
157 rootOp = maskableOp.getMaskingOp();
158 } else {
159 rootOp = multiReductionOp;
160 }
161
162 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
163 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
164 auto srcScalableDims =
165 multiReductionOp.getSourceVectorType().getScalableDims();
166 auto loc = multiReductionOp.getLoc();
167
168
169 if (srcRank < 2)
170 return failure();
171
172
173
174 if (llvm::count(srcScalableDims, true) > 1)
175 return failure();
176
177
178 SmallVector reductionMask = multiReductionOp.getReductionMask();
179 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
180 return failure();
181
182
186 bool isReductionDimScalable = false;
188 int64_t i = it.index();
189 bool isReduction = it.value();
190 if (isReduction) {
191 reductionDims.push_back(i);
192 reductionShapes.push_back(srcShape[i]);
193 isReductionDimScalable |= srcScalableDims[i];
194 } else {
195 parallelDims.push_back(i);
196 parallelShapes.push_back(srcShape[i]);
197 parallelScalableDims.push_back(srcScalableDims[i]);
198 }
199 }
200
201
202 int flattenedParallelDim = 0;
203 int flattenedReductionDim = 0;
204 if (!parallelShapes.empty()) {
205 flattenedParallelDim = 1;
206 for (auto d : parallelShapes)
207 flattenedParallelDim *= d;
208 }
209 if (!reductionShapes.empty()) {
210 flattenedReductionDim = 1;
211 for (auto d : reductionShapes)
212 flattenedReductionDim *= d;
213 }
214
215 assert((flattenedParallelDim || flattenedReductionDim) &&
216 "expected at least one parallel or reduction dim");
217
218
219
220 int64_t counter = 0;
221 if (useInnerDimsForReduction &&
222 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
223 return failure();
224
225 counter = reductionDims.size();
226 if (!useInnerDimsForReduction &&
227 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
228 return failure();
229
230
231
235 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
236 if (flattenedParallelDim) {
237 mask.push_back(false);
238 vectorShape.push_back(flattenedParallelDim);
239 scalableDims.push_back(isParallelDimScalable);
240 }
241 if (flattenedReductionDim) {
242 mask.push_back(true);
243 vectorShape.push_back(flattenedReductionDim);
244 scalableDims.push_back(isReductionDimScalable);
245 }
246 if (!useInnerDimsForReduction && vectorShape.size() == 2) {
247 std::swap(mask.front(), mask.back());
249 std::swap(scalableDims.front(), scalableDims.back());
250 }
251
252 Value newVectorMask;
253 if (maskableOp.isMasked()) {
254 Value vectorMask = maskableOp.getMaskingOp().getMask();
257 llvm::cast(vectorMask.getType()).getElementType());
258 newVectorMask =
259 rewriter.createvector::ShapeCastOp(loc, maskCastedType, vectorMask);
260 }
261
263 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
264 scalableDims);
265 Value cast = rewriter.createvector::ShapeCastOp(
266 loc, castedType, multiReductionOp.getSource());
267
268 Value acc = multiReductionOp.getAcc();
269 if (flattenedParallelDim) {
271 {flattenedParallelDim},
272 multiReductionOp.getSourceVectorType().getElementType(),
273 {isParallelDimScalable});
274 acc = rewriter.createvector::ShapeCastOp(loc, accType, acc);
275 }
276
277
278 Operation *newMultiDimRedOp = rewriter.createvector::MultiDimReductionOp(
279 loc, cast, acc, mask, multiReductionOp.getKind());
280 newMultiDimRedOp =
282
283
284
285 if (parallelShapes.empty()) {
287 return success();
288 }
289
290
292 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
293 parallelScalableDims);
295 rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
296 return success();
297 }
298
299 private:
300 const bool useInnerDimsForReduction;
301 };
302
303
304
305 struct TwoDimMultiReductionToElementWise
308
309 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
311 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
312
313 if (srcRank != 2)
314 return failure();
315
316 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
317 return failure();
318
319 auto loc = multiReductionOp.getLoc();
321 multiReductionOp.getSourceVectorType().getShape();
322
325 return failure();
326
328 auto maskableOp =
329 castvector::MaskableOpInterface(multiReductionOp.getOperation());
331 Value mask = nullptr;
332 if (maskableOp.isMasked()) {
334 rootOp = maskableOp.getMaskingOp();
335 mask = maskableOp.getMaskingOp().getMask();
336 } else {
337 rootOp = multiReductionOp;
338 }
339
340 Value result = multiReductionOp.getAcc();
341 for (int64_t i = 0; i < srcShape[0]; i++) {
342 auto operand = rewriter.createvector::ExtractOp(
343 loc, multiReductionOp.getSource(), i);
344 Value extractMask = nullptr;
345 if (mask) {
346 extractMask = rewriter.createvector::ExtractOp(loc, mask, i);
347 }
348 result =
349 makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
350 result, nullptr, extractMask);
351 }
352
353 rewriter.replaceOp(rootOp, result);
354 return success();
355 }
356 };
357
358
359
360 struct TwoDimMultiReductionToReduction
363
364 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
366 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
367 if (srcRank != 2)
368 return failure();
369
370 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
371 return failure();
372
373
375 auto maskableOp =
376 castvector::MaskableOpInterface(multiReductionOp.getOperation());
378 if (maskableOp.isMasked()) {
380 rootOp = maskableOp.getMaskingOp();
381 } else {
382 rootOp = multiReductionOp;
383 }
384
385 auto loc = multiReductionOp.getLoc();
386 Value result = rewriter.createarith::ConstantOp(
387 loc, multiReductionOp.getDestType(),
388 rewriter.getZeroAttr(multiReductionOp.getDestType()));
389 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
390
391 for (int i = 0; i < outerDim; ++i) {
392 auto v = rewriter.createvector::ExtractOp(
394 auto acc = rewriter.createvector::ExtractOp(
396 Operation *reductionOp = rewriter.createvector::ReductionOp(
397 loc, multiReductionOp.getKind(), v, acc);
398
399
400 if (maskableOp.isMasked()) {
401 Value mask = rewriter.createvector::ExtractOp(
404 }
405
406 result = rewriter.createvector::InsertOp(loc, reductionOp->getResult(0),
407 result, i);
408 }
409
410 rewriter.replaceOp(rootOp, result);
411 return success();
412 }
413 };
414
415
416
417
418
419
420 struct OneDimMultiReductionToTwoDim
423
424 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
426 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
427
428 if (srcRank != 1)
429 return failure();
430
431
433 auto maskableOp =
434 castvector::MaskableOpInterface(multiReductionOp.getOperation());
437 if (maskableOp.isMasked()) {
439 rootOp = maskableOp.getMaskingOp();
440 mask = maskableOp.getMaskingOp().getMask();
441 } else {
442 rootOp = multiReductionOp;
443 }
444
445 auto loc = multiReductionOp.getLoc();
446 auto srcVectorType = multiReductionOp.getSourceVectorType();
447 auto srcShape = srcVectorType.getShape();
449 ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
450 ArrayRef{false, srcVectorType.getScalableDims().back()});
451
452 auto accType =
454 assert(!llvm::isa(multiReductionOp.getDestType()) &&
455 "multi_reduction with a single dimension expects a scalar result");
456
457
458
460
461
462 Value cast = rewriter.createvector::ShapeCastOp(
463 loc, castedType, multiReductionOp.getSource());
464 Value castAcc = rewriter.createvector::BroadcastOp(
465 loc, accType, multiReductionOp.getAcc());
467 if (maskableOp.isMasked()) {
468 auto maskType = llvm::cast(mask.getType());
471 maskType.getElementType(),
472 ArrayRef{false, maskType.getScalableDims().back()});
473 castMask = rewriter.createvector::BroadcastOp(loc, castMaskType, mask);
474 }
475
476 Operation *newOp = rewriter.createvector::MultiDimReductionOp(
477 loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
479
482 return success();
483 }
484 };
485
486 struct LowerVectorMultiReductionPass
487 : public vector::impl::LowerVectorMultiReductionBase<
488 LowerVectorMultiReductionPass> {
489 LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
490 this->loweringStrategy = option;
491 }
492
493 void runOnOperation() override {
496
499 this->loweringStrategy);
500
502 signalPassFailure();
503 }
504
505 void getDependentDialects(DialectRegistry ®istry) const override {
506 registry.insertvector::VectorDialect();
507 }
508 };
509
510 }
511
515 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
517 patterns.add(patterns.getContext(), benefit);
518 if (options == VectorMultiReductionLowering ::InnerReduction)
519 patterns.add(patterns.getContext(),
520 benefit);
521 else
522 patterns.add(patterns.getContext(),
523 benefit);
524 }
525
527 vector::VectorMultiReductionLowering option) {
528 return std::make_unique(option);
529 }
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
TypedAttr getZeroAttr(Type type)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...