MLIR: lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

23 #include

24

25 namespace mlir {

26 namespace linalg {

28 return llvm::all_of(

29 attr, [](const APInt &element) { return element.getSExtValue() == 1; });

30 }

31

33 if (isa(x.getType()))

34 return builder.createarith::AddIOp(loc, x, y);

35 if (isa(x.getType()))

36 return builder.createcomplex::AddOp(loc, x, y);

37 return builder.createarith::AddFOp(loc, x, y);

38 }

39

42

47 if (isa(accType))

48 return builder.createcomplex::MulOp(loc, xConvert, yConvert);

49 if (isa(accType))

50 return builder.createarith::MulIOp(loc, xConvert, yConvert);

51 return builder.createarith::MulFOp(loc, xConvert, yConvert);

52 }

53

54

57 assert(!factors.empty() && "empty factor list");

59 for (int64_t f : factors)

61 FailureOr<SmallVector> multiIndex =

63 assert(!failed(multiIndex) && "Failed to linearize img2col index");

64 return *multiIndex;

65 }

66

67

68

69

71 Value fIndex, int64_t stride) {

76 }

77

78 FailureOr<std::pair<Operation *, Operation *>>

80 auto inputType = cast(convOp.getInputs()[0].getType());

81 auto filterType = cast(convOp.getInputs()[1].getType());

82 auto outputType = cast(convOp.getOutputs()[0].getType());

83

84 if (!filterType.hasStaticShape())

86 convOp, "expected a static shape for the filter");

87

88 if (!inputType.hasStaticShape())

90 "expected a static shape for the input");

91

92

95 "expected all ones for dilations");

96

98 Value input = convOp.getInputs()[0];

99 Value filter = convOp.getInputs()[1];

100 Value output = convOp.getOutputs()[0];

101

104

105 int64_t n = outputShape[0];

106 int64_t oh = outputShape[1];

107 int64_t ow = outputShape[2];

108 int64_t oc = outputShape[3];

109 int64_t fh = filterShape[0];

110 int64_t fw = filterShape[1];

111 int64_t ic = filterShape[2];

112

113 Location loc = convOp.getLoc();

114

115

117 auto reshapedFilterType =

119 Value reshapedFilter = rewriter.createtensor::CollapseShapeOp(

120 loc, reshapedFilterType, filter, filterReassocIndices);

121

123 RankedTensorType reshapedOutputType =

125 Value reshapedOutput = rewriter.createtensor::CollapseShapeOp(

126 loc, reshapedOutputType, output, outputReassocIndices);

127

129 Value colTensor = rewriter.createtensor::EmptyOp(

130 loc, colTensorShape, inputType.getElementType());

131

132

133 auto nloops = colTensorShape.size();

134

135 auto parallel = utils::IteratorType::parallel;

136 auto reduction = utils::IteratorType::reduction;

138

141

142 auto img2ColTensor = rewriter.createlinalg::GenericOp(

143 loc, colTensor.getType(),

144 ValueRange{}, colTensor, img2colIndexingMaps,

145 img2colIterators,

147

148 Value bIndex = nestedBuilder.createlinalg::IndexOp(loc, 0);

149 Value mIndex = nestedBuilder.createlinalg::IndexOp(loc, 1);

150 Value kIndex = nestedBuilder.createlinalg::IndexOp(loc, 2);

151

152

155 auto ohIndex = mIndices[0];

156 auto owIndex = mIndices[1];

157

159 nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});

160 auto fhIndex = kIndices[0];

161 auto fwIndex = kIndices[1];

162 auto icIndex = kIndices[2];

163

164

167 convOp.getStrides().getValues<int64_t>()[0]);

170 convOp.getStrides().getValues<int64_t>()[1]);

171

172

173 SmallVector extractionIndices{bIndex, hIndex, wIndex, icIndex};

174 Value inputVal = nestedBuilder.createtensor::ExtractOp(

175 loc, input, extractionIndices);

176 nestedBuilder.createlinalg::YieldOp(nestedLoc, inputVal);

177 });

178

179

180

181

182

184 bindDims(context, bDim, mDim, nDim, kDim);

185 auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);

186 auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);

187 auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);

189 parallel, reduction};

190

191 auto genericOp = rewriter.createlinalg::GenericOp(

192 loc, reshapedOutputType,

193 ValueRange{img2ColTensor.getResult(0), reshapedFilter},

194 ValueRange{reshapedOutput},

198 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);

199 Value add = createAdd(loc, mul, args[2], nestedBuilder);

200 nestedBuilder.createlinalg::YieldOp(nestedLoc, add);

201 });

202 Value result = genericOp.getResults().front();

203

204 auto reshapedResult = rewriter.createtensor::ExpandShapeOp(

205 loc, outputType, result, outputReassocIndices);

206

208

209 return std::make_pair(img2ColTensor.getOperation(),

210 reshapedResult.getOperation());

211 }

212

213 FailureOr<std::pair<Operation *, Operation *>>

215 linalg::DepthwiseConv2DNhwcHwcOp convOp) {

216 auto inputType = cast(convOp.getInputs()[0].getType());

217 auto filterType = cast(convOp.getInputs()[1].getType());

218 auto outputType = cast(convOp.getOutputs()[0].getType());

219

220 if (!filterType.hasStaticShape())

222 convOp, "expected a static shape for the filter");

223

224 if (!inputType.hasStaticShape())

226 "expected a static shape for the input");

227

228

231 "expected all ones for dilations");

232

233 Location loc = convOp.getLoc();

234

236 auto operandTensorType = cast(operand.getType());

237 auto nloops = indices.size();

239

241 llvm::map_range(indices, [&](int64_t index) -> AffineExpr {

243 }));

244

246 indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));

247

248 Value outputTensor = rewriter.createtensor::EmptyOp(

249 loc, targetShape, operandTensorType.getElementType());

250

252 nloops, utils::IteratorType::parallel);

253

258

259 auto transposedOp = rewriter.createlinalg::GenericOp(

260 loc, outputTensor.getType(),

261 operand, outputTensor, indexingMaps,

262 loopAttributeTypes,

264 nestedBuilder.createlinalg::YieldOp(nestedLoc, args[0]);

265 });

266

267 return transposedOp.getResult(0);

268 };

269

270 Value input = convOp.getInputs()[0];

271 Value filter = convOp.getInputs()[1];

272 Value output = convOp.getOutputs()[0];

273

274

275 Value inputT = transposeOperand(input, {0, 3, 1, 2});

276 Value filterT = transposeOperand(filter, {2, 0, 1});

278 cast(filterT.getType()).getShape();

280

281 int n = outputShape[0];

282 int oh = outputShape[1];

283 int ow = outputShape[2];

284 int c = outputShape[3];

285 int fh = filterTShape[1];

286 int fw = filterTShape[2];

287

289 Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});

290

291 AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;

292 bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);

293

295 convOp.getStrides().getValues<int64_t>()[0]);

297 convOp.getStrides().getValues<int64_t>()[1]);

298

300 owDim * swSym + kwDim};

301

302 auto nloops = colTensorShape.size();

303

305 nloops, utils::IteratorType::parallel);

306

310

311 Value colTensor = rewriter.createtensor::EmptyOp(

312 loc, colTensorShape, inputType.getElementType());

313

314 auto img2ColTensor = rewriter.createlinalg::GenericOp(

315 loc, colTensor.getType(),

316 inputT, colTensor, indexingMaps,

317 loopAttributeTypes,

319 nestedBuilder.createlinalg::YieldOp(nestedLoc, args[0]);

320 });

321

323 {0, 1}, {2, 3}, {4, 5}};

326 {2, 3}};

327

329 {n * c, oh * ow, fh * fw}, inputType.getElementType());

330 auto reshapedFilterTensorType =

332 auto reshapedOutputTensorType =

334

335 Value reshapedImg2ColTensor = rewriter.createtensor::CollapseShapeOp(

336 loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),

337 img2ColTensorReassocIndices);

338 Value reshapedFilterTensor = rewriter.createtensor::CollapseShapeOp(

339 loc, reshapedFilterTensorType, filterT, filterReassociationIndice);

340 Value reshapedoutputTensor = rewriter.createtensor::CollapseShapeOp(

341 loc, reshapedOutputTensorType, transposedOutputTensor,

342 outputReassociationIndice);

343

344 auto batchMatVecResult = rewriter.createlinalg::BatchMatvecOp(

345 loc, TypeRange{reshapedoutputTensor.getType()},

346 ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},

348

350 {2, 3}};

351

352 auto batchMatVecResultReshaped = rewriter.createtensor::ExpandShapeOp(

353 loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),

354 batchMatVecReassociationIndice);

355

356 Value transposedResult =

357 transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});

358

360 return std::make_pair(img2ColTensor.getOperation(),

361 transposedResult.getDefiningOp());

362 }

363

364 FailureOr<std::pair<Operation *, Operation *>>

366 auto inputType = cast(convOp.getInputs()[0].getType());

367 auto filterType = cast(convOp.getInputs()[1].getType());

368 auto outputType = cast(convOp.getOutputs()[0].getType());

369

370 if (!filterType.hasStaticShape())

372 convOp, "expected a static shape for the filter");

373

374 if (!inputType.hasStaticShape())

376 "expected a static shape for the input");

377

378

381 "expected all ones for dilations");

382

383 Value input = convOp.getInputs()[0];

384 Value filter = convOp.getInputs()[1];

385 Value output = convOp.getOutputs()[0];

386

387 auto filterShape = filterType.getShape();

388 auto outputShape = outputType.getShape();

389

390 int64_t n = outputShape[0];

391 int64_t oc = outputShape[1];

392 int64_t oh = outputShape[2];

393 int64_t ow = outputShape[3];

394 int64_t ic = filterShape[1];

395 int64_t fh = filterShape[2];

396 int64_t fw = filterShape[3];

397

398 auto loc = convOp.getLoc();

400

402 auto reshapedFilterType =

404 Value reshapedFilter = rewriter.createtensor::CollapseShapeOp(

405 loc, reshapedFilterType, filter, filterReassocIndices);

406

408 auto reshapedOutputType =

410 Value reshapedOutput = rewriter.createtensor::CollapseShapeOp(

411 loc, reshapedOutputType, output, outputReassocIndices);

412

413

415 Value colTensor = rewriter.createtensor::EmptyOp(

416 loc, colTensorShape, inputType.getElementType());

417

418 auto nloops = colTensorShape.size();

419

420 auto parallel = utils::IteratorType::parallel;

421 auto reduction = utils::IteratorType::reduction;

423

426

427 auto img2ColTensor = rewriter.createlinalg::GenericOp(

428 loc, colTensor.getType(),

429 ValueRange{}, colTensor, img2colIndexingMaps,

430 img2colIterators,

432

433 Value bIndex = nestedBuilder.createlinalg::IndexOp(loc, 0);

434 Value kIndex = nestedBuilder.createlinalg::IndexOp(loc, 1);

435 Value nIndex = nestedBuilder.createlinalg::IndexOp(loc, 2);

436

437

439 nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});

440 auto icIndex = kIndices[0];

441 auto fhIndex = kIndices[1];

442 auto fwIndex = kIndices[2];

443

446 auto ohIndex = nIndices[0];

447 auto owIndex = nIndices[1];

448

449

452 convOp.getStrides().getValues<int64_t>()[0]);

455 convOp.getStrides().getValues<int64_t>()[1]);

456

457

458 SmallVector extractionIndices{bIndex, icIndex, hIndex, wIndex};

459 Value inputVal = nestedBuilder.createtensor::ExtractOp(

460 loc, input, extractionIndices);

461 nestedBuilder.createlinalg::YieldOp(nestedLoc, inputVal);

462 });

463

464

465

466

467

469 bindDims(context, bDim, mDim, nDim, kDim);

470 auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);

471 auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);

472 auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);

474 parallel, reduction};

475 auto genericOp = rewriter.createlinalg::GenericOp(

476 loc, reshapedOutputType,

477 ValueRange{reshapedFilter, img2ColTensor.getResult(0)},

478 ValueRange{reshapedOutput},

482 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);

483 Value add = createAdd(loc, mul, args[2], nestedBuilder);

484 nestedBuilder.createlinalg::YieldOp(nestedLoc, add);

485 });

486 Value result = genericOp.getResults().front();

487

488 auto reshapedResult = rewriter.createtensor::ExpandShapeOp(

489 loc, outputType, result, outputReassocIndices);

490

492

493 return std::make_pair(img2ColTensor.getOperation(),

494 reshapedResult.getOperation());

495 }

496

497 FailureOr<std::pair<Operation *, Operation *>>

499 auto inputType = cast(convOp.getInputs()[0].getType());

500 auto filterType = cast(convOp.getInputs()[1].getType());

501 auto outputType = cast(convOp.getOutputs()[0].getType());

502

503 if (!filterType.hasStaticShape())

505 convOp, "expected a static shape for the filter");

506

507 if (!inputType.hasStaticShape())

509 "expected a static shape for the input");

510

511

514 "expected all ones for dilations");

515

517 Value input = convOp.getInputs()[0];

518 Value filter = convOp.getInputs()[1];

519 Value output = convOp.getOutputs()[0];

520

523

524 int64_t n = outputShape[0];

525 int64_t oh = outputShape[1];

526 int64_t ow = outputShape[2];

527 int64_t oc = outputShape[3];

528 int64_t fh = filterShape[1];

529 int64_t fw = filterShape[2];

530 int64_t ic = filterShape[3];

531

532 Location loc = convOp.getLoc();

533

534

535

537 auto reshapedFilterType =

539 Value reshapedFilter = rewriter.createtensor::CollapseShapeOp(

540 loc, reshapedFilterType, filter, filterReassocIndices);

541

543 RankedTensorType reshapedOutputType =

545 Value reshapedOutput = rewriter.createtensor::CollapseShapeOp(

546 loc, reshapedOutputType, output, outputReassocIndices);

547

549 Value colTensor = rewriter.createtensor::EmptyOp(

550 loc, colTensorShape, inputType.getElementType());

551

552

553 auto nloops = colTensorShape.size();

554

555 auto parallel = utils::IteratorType::parallel;

556 auto reduction = utils::IteratorType::reduction;

558

561

562 auto img2ColTensor = rewriter.createlinalg::GenericOp(

563 loc, colTensor.getType(),

564 ValueRange{}, colTensor, img2colIndexingMaps,

565 img2colIterators,

567

568 Value bIndex = nestedBuilder.createlinalg::IndexOp(loc, 0);

569 Value mIndex = nestedBuilder.createlinalg::IndexOp(loc, 1);

570 Value kIndex = nestedBuilder.createlinalg::IndexOp(loc, 2);

571

572

575 auto ohIndex = mIndices[0];

576 auto owIndex = mIndices[1];

577

579 nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});

580 auto fhIndex = kIndices[0];

581 auto fwIndex = kIndices[1];

582 auto icIndex = kIndices[2];

583

584

587 convOp.getStrides().getValues<int64_t>()[0]);

590 convOp.getStrides().getValues<int64_t>()[1]);

591

592

593 SmallVector extractionIndices{bIndex, hIndex, wIndex, icIndex};

594 Value inputVal = nestedBuilder.createtensor::ExtractOp(

595 loc, input, extractionIndices);

596 nestedBuilder.createlinalg::YieldOp(nestedLoc, inputVal);

597 });

598

599

600

601

603 bindDims(context, bDim, mDim, nDim, kDim);

604 auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);

605 auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);

606 auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);

608 parallel, reduction};

609

610 auto genericOp = rewriter.createlinalg::GenericOp(

611 loc, reshapedOutputType,

612 ValueRange{img2ColTensor.getResult(0), reshapedFilter},

613 ValueRange{reshapedOutput},

617 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);

618 Value add = createAdd(loc, mul, args[2], nestedBuilder);

619 nestedBuilder.createlinalg::YieldOp(nestedLoc, add);

620 });

621 Value result = genericOp.getResults().front();

622

623 auto reshapedResult = rewriter.createtensor::ExpandShapeOp(

624 loc, outputType, result, outputReassocIndices);

625

627

628 return std::make_pair(img2ColTensor.getOperation(),

629 reshapedResult.getOperation());

630 }

631

632 namespace {

633

634 class ConvertConv2DNhwcHwcf final

636 public:

638

639 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,

642 return failure();

643 return success();

644 }

645 };

646

647 class ConvertDepthwiseConv2DNhwcHwc final

648 : public OpRewritePatternlinalg::DepthwiseConv2DNhwcHwcOp {

649 public:

651

652 LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,

653 PatternRewriter &rewriter) const override {

655 return failure();

656 return success();

657 }

658 };

659

660 class ConvertConv2DNchwFchw final

661 : public OpRewritePatternlinalg::Conv2DNchwFchwOp {

662 public:

664

665 LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,

666 PatternRewriter &rewriter) const override {

668 return failure();

669 return success();

670 }

671 };

672

673 class ConvertConv2DNhwcFhwc final

674 : public OpRewritePatternlinalg::Conv2DNhwcFhwcOp {

675 public:

677

678 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,

679 PatternRewriter &rewriter) const override {

681 return failure();

682 return success();

683 }

684 };

685 }

686

689 patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,

690 ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);

691 }

692 }

693 }

Base type for affine expression.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)

Returns an AffineMap with 'numDims' identity result dim exprs.

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

IntegerAttr getIndexAttr(int64_t value)

AffineExpr getAffineConstantExpr(int64_t constant)

AffineExpr getAffineDimExpr(unsigned position)

MLIRContext * getContext() const

An attribute that represents a reference to a dense integer vector or tensor object.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

This class helps build Operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Location getLoc() const

Return the location of this value.

AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...

FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)

Generate the IR to delinearize linearIndex given the basis and return the multi-index.

static SmallVector< Value > unrollIndex(OpBuilder &b, Location loc, Value index, ArrayRef< int64_t > factors)

FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)

Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....

void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)

Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...

static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)

static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)

static bool hasAllOneValues(DenseIntElementsAttr attr)

static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, Value fIndex, int64_t stride)

Include the generated interface declarations.

Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)

Converts a scalar value operand to type toType.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to DimExpr at positions: [0 .

AffineMap inversePermutation(AffineMap map)

Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...

const FrozenRewritePatternSet & patterns

void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to SymbolExpr at positions: [0 .

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...