MLIR: lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

21

22 namespace mlir {

23 namespace vector {

24 #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION

25 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"

26 }

27 }

28

29 #define DEBUG_TYPE "vector-multi-reduction"

30

31 using namespace mlir;

32

33 namespace {

34

35

36

37

38

39 class InnerOuterDimReductionConversion

41 public:

43

44 explicit InnerOuterDimReductionConversion(

48 useInnerDimsForReduction(

49 options == vector::VectorMultiReductionLowering::InnerReduction) {}

50

51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,

53

55 auto maskableOp =

56 castvector::MaskableOpInterface(multiReductionOp.getOperation());

58 if (maskableOp.isMasked()) {

60 rootOp = maskableOp.getMaskingOp();

61 } else {

62 rootOp = multiReductionOp;

63 }

64

65 auto src = multiReductionOp.getSource();

66 auto loc = multiReductionOp.getLoc();

67 auto srcRank = multiReductionOp.getSourceVectorType().getRank();

68

69

70 ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();

71 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),

72 reductionDims.end());

73 int64_t reductionSize = reductionDims.size();

75 for (int64_t i = 0; i < srcRank; ++i)

76 if (!reductionDimsSet.contains(i))

77 parallelDims.push_back(i);

78

79

80

81 if (parallelDims.empty())

82 return failure();

83 if (useInnerDimsForReduction &&

84 (parallelDims ==

85 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))

86 return failure();

87

88 if (!useInnerDimsForReduction &&

89 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(

90 reductionDims.size(),

91 parallelDims.size() + reductionDims.size()))))

92 return failure();

93

95 if (useInnerDimsForReduction) {

96 indices.append(parallelDims.begin(), parallelDims.end());

97 indices.append(reductionDims.begin(), reductionDims.end());

98 } else {

99 indices.append(reductionDims.begin(), reductionDims.end());

100 indices.append(parallelDims.begin(), parallelDims.end());

101 }

102

103

104 Value transposedMask;

105 if (maskableOp.isMasked()) {

106 transposedMask = rewriter.createvector::TransposeOp(

107 loc, maskableOp.getMaskingOp().getMask(), indices);

108 }

109

110

111 auto transposeOp = rewriter.createvector::TransposeOp(loc, src, indices);

113 for (int i = 0; i < reductionSize; ++i) {

114 if (useInnerDimsForReduction)

115 reductionMask[srcRank - i - 1] = true;

116 else

117 reductionMask[i] = true;

118 }

119

120 Operation *newMultiRedOp = rewriter.createvector::MultiDimReductionOp(

121 multiReductionOp.getLoc(), transposeOp.getResult(),

122 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());

123 newMultiRedOp =

125

127 return success();

128 }

129

130 private:

131 const bool useInnerDimsForReduction;

132 };

133

134

135

136 class ReduceMultiDimReductionRank

138 public:

140

141 explicit ReduceMultiDimReductionRank(

145 useInnerDimsForReduction(

146 options == vector::VectorMultiReductionLowering::InnerReduction) {}

147

148 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,

150

152 auto maskableOp =

153 castvector::MaskableOpInterface(multiReductionOp.getOperation());

155 if (maskableOp.isMasked()) {

157 rootOp = maskableOp.getMaskingOp();

158 } else {

159 rootOp = multiReductionOp;

160 }

161

162 auto srcRank = multiReductionOp.getSourceVectorType().getRank();

163 auto srcShape = multiReductionOp.getSourceVectorType().getShape();

164 auto srcScalableDims =

165 multiReductionOp.getSourceVectorType().getScalableDims();

166 auto loc = multiReductionOp.getLoc();

167

168

169 if (srcRank < 2)

170 return failure();

171

172

173

174 if (llvm::count(srcScalableDims, true) > 1)

175 return failure();

176

177

178 SmallVector reductionMask = multiReductionOp.getReductionMask();

179 if (srcRank == 2 && reductionMask.front() != reductionMask.back())

180 return failure();

181

182

186 bool isReductionDimScalable = false;

188 int64_t i = it.index();

189 bool isReduction = it.value();

190 if (isReduction) {

191 reductionDims.push_back(i);

192 reductionShapes.push_back(srcShape[i]);

193 isReductionDimScalable |= srcScalableDims[i];

194 } else {

195 parallelDims.push_back(i);

196 parallelShapes.push_back(srcShape[i]);

197 parallelScalableDims.push_back(srcScalableDims[i]);

198 }

199 }

200

201

202 int flattenedParallelDim = 0;

203 int flattenedReductionDim = 0;

204 if (!parallelShapes.empty()) {

205 flattenedParallelDim = 1;

206 for (auto d : parallelShapes)

207 flattenedParallelDim *= d;

208 }

209 if (!reductionShapes.empty()) {

210 flattenedReductionDim = 1;

211 for (auto d : reductionShapes)

212 flattenedReductionDim *= d;

213 }

214

215 assert((flattenedParallelDim || flattenedReductionDim) &&

216 "expected at least one parallel or reduction dim");

217

218

219

220 int64_t counter = 0;

221 if (useInnerDimsForReduction &&

222 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))

223 return failure();

224

225 counter = reductionDims.size();

226 if (!useInnerDimsForReduction &&

227 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))

228 return failure();

229

230

231

235 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);

236 if (flattenedParallelDim) {

237 mask.push_back(false);

238 vectorShape.push_back(flattenedParallelDim);

239 scalableDims.push_back(isParallelDimScalable);

240 }

241 if (flattenedReductionDim) {

242 mask.push_back(true);

243 vectorShape.push_back(flattenedReductionDim);

244 scalableDims.push_back(isReductionDimScalable);

245 }

246 if (!useInnerDimsForReduction && vectorShape.size() == 2) {

247 std::swap(mask.front(), mask.back());

249 std::swap(scalableDims.front(), scalableDims.back());

250 }

251

252 Value newVectorMask;

253 if (maskableOp.isMasked()) {

254 Value vectorMask = maskableOp.getMaskingOp().getMask();

257 llvm::cast(vectorMask.getType()).getElementType());

258 newVectorMask =

259 rewriter.createvector::ShapeCastOp(loc, maskCastedType, vectorMask);

260 }

261

263 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),

264 scalableDims);

265 Value cast = rewriter.createvector::ShapeCastOp(

266 loc, castedType, multiReductionOp.getSource());

267

268 Value acc = multiReductionOp.getAcc();

269 if (flattenedParallelDim) {

271 {flattenedParallelDim},

272 multiReductionOp.getSourceVectorType().getElementType(),

273 {isParallelDimScalable});

274 acc = rewriter.createvector::ShapeCastOp(loc, accType, acc);

275 }

276

277

278 Operation *newMultiDimRedOp = rewriter.createvector::MultiDimReductionOp(

279 loc, cast, acc, mask, multiReductionOp.getKind());

280 newMultiDimRedOp =

282

283

284

285 if (parallelShapes.empty()) {

287 return success();

288 }

289

290

292 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),

293 parallelScalableDims);

295 rootOp, outputCastedType, newMultiDimRedOp->getResult(0));

296 return success();

297 }

298

299 private:

300 const bool useInnerDimsForReduction;

301 };

302

303

304

305 struct TwoDimMultiReductionToElementWise

308

309 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,

311 auto srcRank = multiReductionOp.getSourceVectorType().getRank();

312

313 if (srcRank != 2)

314 return failure();

315

316 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))

317 return failure();

318

319 auto loc = multiReductionOp.getLoc();

321 multiReductionOp.getSourceVectorType().getShape();

322

325 return failure();

326

328 auto maskableOp =

329 castvector::MaskableOpInterface(multiReductionOp.getOperation());

331 Value mask = nullptr;

332 if (maskableOp.isMasked()) {

334 rootOp = maskableOp.getMaskingOp();

335 mask = maskableOp.getMaskingOp().getMask();

336 } else {

337 rootOp = multiReductionOp;

338 }

339

340 Value result = multiReductionOp.getAcc();

341 for (int64_t i = 0; i < srcShape[0]; i++) {

342 auto operand = rewriter.createvector::ExtractOp(

343 loc, multiReductionOp.getSource(), i);

344 Value extractMask = nullptr;

345 if (mask) {

346 extractMask = rewriter.createvector::ExtractOp(loc, mask, i);

347 }

348 result =

349 makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,

350 result, nullptr, extractMask);

351 }

352

353 rewriter.replaceOp(rootOp, result);

354 return success();

355 }

356 };

357

358

359

360 struct TwoDimMultiReductionToReduction

363

364 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,

366 auto srcRank = multiReductionOp.getSourceVectorType().getRank();

367 if (srcRank != 2)

368 return failure();

369

370 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))

371 return failure();

372

373

375 auto maskableOp =

376 castvector::MaskableOpInterface(multiReductionOp.getOperation());

378 if (maskableOp.isMasked()) {

380 rootOp = maskableOp.getMaskingOp();

381 } else {

382 rootOp = multiReductionOp;

383 }

384

385 auto loc = multiReductionOp.getLoc();

386 Value result = rewriter.createarith::ConstantOp(

387 loc, multiReductionOp.getDestType(),

388 rewriter.getZeroAttr(multiReductionOp.getDestType()));

389 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];

390

391 for (int i = 0; i < outerDim; ++i) {

392 auto v = rewriter.createvector::ExtractOp(

394 auto acc = rewriter.createvector::ExtractOp(

396 Operation *reductionOp = rewriter.createvector::ReductionOp(

397 loc, multiReductionOp.getKind(), v, acc);

398

399

400 if (maskableOp.isMasked()) {

401 Value mask = rewriter.createvector::ExtractOp(

404 }

405

406 result = rewriter.createvector::InsertOp(loc, reductionOp->getResult(0),

407 result, i);

408 }

409

410 rewriter.replaceOp(rootOp, result);

411 return success();

412 }

413 };

414

415

416

417

418

419

420 struct OneDimMultiReductionToTwoDim

423

424 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,

426 auto srcRank = multiReductionOp.getSourceVectorType().getRank();

427

428 if (srcRank != 1)

429 return failure();

430

431

433 auto maskableOp =

434 castvector::MaskableOpInterface(multiReductionOp.getOperation());

437 if (maskableOp.isMasked()) {

439 rootOp = maskableOp.getMaskingOp();

440 mask = maskableOp.getMaskingOp().getMask();

441 } else {

442 rootOp = multiReductionOp;

443 }

444

445 auto loc = multiReductionOp.getLoc();

446 auto srcVectorType = multiReductionOp.getSourceVectorType();

447 auto srcShape = srcVectorType.getShape();

449 ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),

450 ArrayRef{false, srcVectorType.getScalableDims().back()});

451

452 auto accType =

454 assert(!llvm::isa(multiReductionOp.getDestType()) &&

455 "multi_reduction with a single dimension expects a scalar result");

456

457

458

460

461

462 Value cast = rewriter.createvector::ShapeCastOp(

463 loc, castedType, multiReductionOp.getSource());

464 Value castAcc = rewriter.createvector::BroadcastOp(

465 loc, accType, multiReductionOp.getAcc());

467 if (maskableOp.isMasked()) {

468 auto maskType = llvm::cast(mask.getType());

471 maskType.getElementType(),

472 ArrayRef{false, maskType.getScalableDims().back()});

473 castMask = rewriter.createvector::BroadcastOp(loc, castMaskType, mask);

474 }

475

476 Operation *newOp = rewriter.createvector::MultiDimReductionOp(

477 loc, cast, castAcc, reductionMask, multiReductionOp.getKind());

479

482 return success();

483 }

484 };

485

486 struct LowerVectorMultiReductionPass

487 : public vector::impl::LowerVectorMultiReductionBase<

488 LowerVectorMultiReductionPass> {

489 LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {

490 this->loweringStrategy = option;

491 }

492

493 void runOnOperation() override {

496

499 this->loweringStrategy);

500

502 signalPassFailure();

503 }

504

505 void getDependentDialects(DialectRegistry &registry) const override {

506 registry.insertvector::VectorDialect();

507 }

508 };

509

510 }

511

515 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(

517 patterns.add(patterns.getContext(), benefit);

518 if (options == VectorMultiReductionLowering ::InnerReduction)

519 patterns.add(patterns.getContext(),

520 benefit);

521 else

522 patterns.add(patterns.getContext(),

523 benefit);

524 }

525

527 vector::VectorMultiReductionLowering option) {

528 return std::make_unique(option);

529 }

static llvm::ManagedStatic< PassManagerOptions > options

static std::optional< VectorShape > vectorShape(Type type)

static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)

Walks the given type hierarchy with the given indices, potentially down to component granularity,...

TypedAttr getZeroAttr(Type type)

The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

Operation is the basic unit of execution within MLIR.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

MLIRContext * getContext()

Return the context this operation is associated with.

Location getLoc()

The source location the operation was defined or derived from.

This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isIntOrIndexOrFloat() const

Return true if this is an integer (of any signedness), index, or float type.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)

Returns the result value of reducing two scalar/vector values with the corresponding arith operation.

std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)

Creates an instance of the vector.multi_reduction lowering pass.

Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())

Creates a vector.mask operation around a maskable operation.

void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)

Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....

Include the generated interface declarations.

LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)

Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...