MLIR: lib/Dialect/Linalg/IR/LinalgInterfaces.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

22 #include "llvm/ADT/STLExtras.h"

23 #include "llvm/ADT/SetOperations.h"

24 #include "llvm/ADT/SmallBitVector.h"

25 #include "llvm/ADT/SmallVector.h"

26 #include "llvm/Support/Casting.h"

27 #include "llvm/Support/raw_ostream.h"

28 #include

29 #include

30 #include

31

32 using namespace mlir;

34

35

36 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"

37

38

39

40

41

45 for (auto &opOperand : linalgOp->getOpOperands()) {

46 if (llvm::is_contained(droppedOperands, &opOperand))

47 continue;

48 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));

49 }

50 if (indexingMaps.empty()) {

51

52

53 return linalgOp.getNumLoops() == 0;

54 }

56 indexingMaps, linalgOp.getContext())) != AffineMap();

57 }

58

59

60

61

62

64

65 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())

66 return false;

67

68 auto mapRange = op.getIndexingMapsArray();

69 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||

70 !mapRange.back().isIdentity()) {

71 return false;

72 }

73

74 return llvm::hasSingleElement(op.getBlock()->getOperations());

75 }

76

77

78

79

81

82 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||

83 !op.isSingleYieldOp())

84 return std::nullopt;

85

86

87 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||

88 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))

89 return std::nullopt;

90

91 OpOperand *value = op.getDpsInputOperand(0);

92 if (!op.isScalar(value))

93 return std::nullopt;

94 return value->get();

95 }

96

97

98

99

100 std::optional<SmallVector<int64_t>>

102

103 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||

104 !op.isSingleYieldOp())

105 return std::nullopt;

106

107 auto srcTy = op.getDpsInputOperand(0)->get().getType();

108 auto dstTy = op.getDpsInitOperand(0)->get().getType();

109 if (!isa<MemRefType, RankedTensorType>(srcTy) ||

110 !isa<MemRefType, RankedTensorType>(dstTy))

111 return std::nullopt;

112

113

114

115

116 auto dstMap = op.getIndexingMapsArray()[1];

117 if (!dstMap.isIdentity())

118 return std::nullopt;

119

121 auto srcMap = op.getIndexingMapsArray()[0];

122

123 if (srcMap.getResults().size() >= dstMap.getResults().size())

124 return std::nullopt;

125

126

127 for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {

128 auto expr = llvm::dyn_cast(srcMap.getResults()[i]);

129 if (!expr)

130 return std::nullopt;

131 int64_t pos = expr.getPosition();

132 if (i > 0 && pos <= position[i - 1])

133 return std::nullopt;

134 position.push_back(expr.getPosition());

135 }

136

138 auto numDims = srcMap.getNumDims();

139

140 for (auto dim : llvm::seq<int64_t>(0, numDims)) {

141 if (!llvm::is_contained(position, dim))

142 broadcastedDims.push_back(dim);

143 }

144 return broadcastedDims;

145 }

146

147

148

149

150 std::optional<SmallVector<int64_t>>

152

153

154

155 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||

156 !op.isSingleYieldOp())

157 return std::nullopt;

158

159 auto mapRange = op.getIndexingMapsArray();

160 if (mapRange.size() != 2)

161 return std::nullopt;

162

163 auto mapOfInput = mapRange.front();

164 auto mapOfResult = mapRange.back();

165

166

167

168 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())

169 return std::nullopt;

170

172 for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {

173 auto expr = llvm::cast(mapOfInput.getResults()[i]);

174 permutation[expr.getPosition()] = i;

175 }

176 return permutation;

177 }

178

179

180

181

183 unsigned arity) {

184

185 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)

186 return false;

187

188

189 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||

190 !llvm::all_of(op.getIndexingMapsArray(),

191 [](AffineMap map) { return map.isIdentity(); }))

192 return false;

193

194

195 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))

196 return false;

197

198

199

200

201

202 Block *body = op.getBody();

204 return false;

205

208 return false;

209

210 auto yieldOp = dyn_castlinalg::YieldOp(body->back());

211 if (!yieldOp || yieldOp.getNumOperands() != 1 ||

212 yieldOp->getOperand(0).getDefiningOp() != oper)

213 return false;

214 return true;

215 }

216

218

220 return false;

221

222

223 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))

224 return false;

225 return true;

226 }

227

230 return false;

231

232

233 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);

234 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);

235 if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||

236 !op.payloadUsesValueFromOperand(inputOpOperand1))

237 return false;

238 return true;

239 }

240

241

242

243

244

245

246

247

248

252 auto iface = dyn_cast(op);

253 if (!iface || !iface.hasNoEffect())

254 break;

257 }

258 return value;

259 }

260

263 llvm::raw_ostream &errs) {

265 errs << "no terminator in the block";

266 return false;

267 }

268

270 errs << "expected block with 3 arguments";

271 return false;

272 }

273

276 errs << "expected terminator with 1 operand";

277 return false;

278 }

279

283 errs << "expected reduction op to be binary";

284 return false;

285 }

286

289

290 if (reductionLHS != block.getArgument(2) &&

292 errs << "expected reduction to take block argument #2 as one of the "

293 "operands (modulo unary casts)";

294 return false;

295 }

296

298 isa(reductionLHS) ? reductionRHS : reductionLHS);

300 if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||

302 errs << "expected elementwise op to be binary";

303 return false;

304 }

305

306 if (!isaPair(elementwiseOp, reductionOp)) {

307 errs << "expected reduction/elementwise op kind not satisfied";

308 return false;

309 }

310

313 if ((elementwiseLHS == block.getArgument(0) &&

314 elementwiseRHS == block.getArgument(1)) ||

315 (elementwiseLHS == block.getArgument(1) &&

316 elementwiseRHS == block.getArgument(0))) {

317 return true;

318 }

319

320 errs << "expected elementwise op to apply to block arguments (modulo unary "

321 "casts)";

322 return false;

323 }

324

325

326

327 template <typename AddOpTy, typename MulOpTy, typename... Args>

329 static_assert(sizeof...(Args) % 2 == 0,

330 "expected an even number of template arguments");

331 if (isa(add) && isa(mul))

332 return true;

333

334 if constexpr (sizeof...(Args) > 0)

336 else

337 return false;

338 }

339

340

341

342 template <typename... Args>

345 }

346

347

348

349

350

351

352

353

354 static llvm::SmallDenseSet<int64_t>

357 utils::IteratorType iter) {

358 assert(iterators.size() == indexingMap.getNumDims());

359 llvm::SmallDenseSet<int64_t> res;

361 if (auto d = dyn_cast(e)) {

362 if (iterators[d.getPosition()] == iter &&

364 return e.isFunctionOfDim(d.getPosition());

365 }) == 1)

366 res.insert(d.getPosition());

367 }

368 }

369 return res;

370 }

371

372 namespace {

373 auto par = utils::IteratorType::parallel;

374 auto red = utils::IteratorType::reduction;

375 }

376

377

378

379

380

381 static FailureOr<SmallVectorutils::IteratorType>

384 return failure();

387 if (auto dim = dyn_cast(expr))

388 iterators[dim.getPosition()] = par;

389 return iterators;

390 }

391

392

393

394

395

396

397

398

399

400

401

402

403 static FailureOr

406 llvm::SmallDenseSet<int64_t> a =

408 llvm::SmallDenseSet<int64_t> b =

410 llvm::SmallDenseSet<int64_t> c =

412

413

414 llvm::SmallDenseSet<int64_t> ac = a;

415 llvm::set_intersect(ac, c);

416 llvm::set_subtract(ac, b);

417

418 llvm::SmallDenseSet<int64_t> bc = b;

419 llvm::set_intersect(bc, c);

420 llvm::set_subtract(bc, a);

421

422 llvm::SmallDenseSet<int64_t> batches = a;

423 llvm::set_intersect(batches, b);

424 llvm::set_intersect(batches, c);

425

426

427 llvm::SmallDenseSet<int64_t> ra =

429 llvm::SmallDenseSet<int64_t> rb =

431 llvm::set_intersect(ra, rb);

432

433

439 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());

440 llvm::sort(dimensions.m.begin(), dimensions.m.end());

441 llvm::sort(dimensions.n.begin(), dimensions.n.end());

442 llvm::sort(dimensions.k.begin(), dimensions.k.end());

443 return dimensions;

444 }

445

446 FailureOr

448 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)

449 return failure();

451 linalgOp.getIteratorTypesArray());

452 }

453

454 FailureOr

456 if (indexingMaps.size() != 3)

457 return failure();

459 if (failed(iterators))

460 return failure();

462 }

463

472 };

473 }

474

478 auto linalgOp = dyn_castlinalg::LinalgOp(op);

479 if (!linalgOp)

480 return MatchContractionResult::NotLinalgOp;

481 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)

482 return MatchContractionResult::WrongNumOperands;

483 auto mapRange = linalgOp.getIndexingMapsArray();

484 if (linalgOp.getNumReductionLoops() == 0)

485 return MatchContractionResult::NoReduction;

486 if (llvm::any_of(mapRange,

487 [](AffineMap m) { return !m.isProjectedPermutation(); }))

488 return MatchContractionResult::NotProjectedPermutations;

489

490

492 arith::MulFOp, arith::AddFOp,

493 arith::MulIOp, arith::AddIOp,

494 complex::MulOp, complex::AddOp,

495 arith::AndIOp, arith::OrIOp>(

496 *linalgOp.getBlock())) {

497 return MatchContractionResult::NotAddMul;

498 }

499

500

501 if (dimensions) {

503 assert(succeeded(res) && "unexpected failure to infer contraction dims");

504 *dimensions = *res;

505 }

506 return MatchContractionResult::Success;

507 }

508

509 StringRef

511 switch (res) {

512 case MatchContractionResult::NotLinalgOp:

513 return "expected a LinalgOp";

514 case MatchContractionResult::WrongNumOperands:

515 return "expected op with 2 inputs and 1 output";

516 case MatchContractionResult::NoReduction:

517 return "expected at least 1 reduction";

518 case MatchContractionResult::NotProjectedPermutations:

519 return "expected indexing maps to be projected permutations";

520 case MatchContractionResult::NotAddMul:

521 return "expected add/mul op in the body";

522 case MatchContractionResult::Success:

523 return "";

524 }

525 llvm_unreachable("unhandled MatchContractionResult case");

526 }

527

529 if (!linalgOp)

530 return false;

531 Operation *op = linalgOp.getOperation();

532 return isa(op) ||

535 }

536

537

538

539

540

541

542

543

544

545

546

547

548

549

552 if (res != MatchContractionResult::Success)

554 return success();

555 }

556

557

558

559

560

561

562

563 template

565 return isa(lhs) ? cast(lhs) : (isa(rhs) ? cast(rhs) : nullptr);

566 }

567

568 namespace {

569

570

571

572

573

574

575

576

577 struct ConvAccessExprWalker

578 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {

579

580 llvm::SmallDenseSet<int64_t> convolvedDims;

581

582 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;

583

584 llvm::SmallDenseSet<int64_t> unConvolvedDims;

585

586 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;

587

588

589

590 void clearMultiUseDims(AffineMap map) {

591 for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {

593 return e.isFunctionOfDim(dimPos);

594 }) > 1) {

595 convolvedDims.erase(dimPos);

596 unConvolvedDims.erase(dimPos);

597

598

599 auto it = convolvedDimMapping.find(dimPos);

600 if (it != convolvedDimMapping.end()) {

601 int64_t pairedDim = it->second;

602 convolvedDims.erase(pairedDim);

603 unConvolvedDims.erase(pairedDim);

604 strideAndDilationMapping.erase(pairedDim);

605 convolvedDimMapping.erase(dimPos);

606 convolvedDimMapping.erase(pairedDim);

607 }

608 }

609 }

610 }

611

612 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {

613 unsigned position = dimExpr.getPosition();

614 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {

615 return failure();

616 }

617 unConvolvedDims.insert(position);

618 return success();

619 }

620

621 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }

622

623 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }

624

625 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {

626

628 return failure();

629 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());

630 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());

631 if (failed(lhsDimPos) || failed(rhsDimPos))

632 return failure();

633 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;

634 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;

635 return success();

636 }

637

638 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {

639 if (auto dimExpr = dyn_cast(expr)) {

641 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))

642 return failure();

643

644 strideAndDilationMapping[dim] =

646 convolvedDims.insert(dim);

647 return dim;

648 }

649 if (auto symbolMulExpr = dyn_cast(expr)) {

651 return failure();

652 auto lhsExpr = symbolMulExpr.getLHS();

653 auto rhsExpr = symbolMulExpr.getRHS();

654

656 getAffineExprOfType(lhsExpr, rhsExpr);

657

658 if (!mulExpr) {

659 mulExpr = getAffineExprOfType(lhsExpr, rhsExpr);

660 }

661 auto dimExpr = getAffineExprOfType(lhsExpr, rhsExpr);

662 if (!mulExpr || !dimExpr)

663 return failure();

665 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))

666 return failure();

667 strideAndDilationMapping[dim] = mulExpr;

668 convolvedDims.insert(dim);

669 return dim;

670 }

671 return failure();

672 }

673 };

674 }

675

678 "expected map to have projected permutations");

679 llvm::SmallDenseSet<int64_t> preservedDims;

681 preservedDims.insert(cast(expr).getPosition());

682 return preservedDims;

683 }

684

688 for (auto e : exprs) {

689 auto constantExpr = dyn_cast(e);

690 assert(constantExpr && "Found non-constant stride/dilation");

691 vals.push_back(constantExpr.getValue());

692 }

693 return vals;

694 }

695

696

697

698

699

700

701

702

703 static FailureOr

705 ConvAccessExprWalker &inputExprWalker,

706 bool allowEmptyConvolvedDims) {

707 auto filterMap =

708 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));

709 auto outputMap =

710 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));

712 filterMap, linalgOp.getIteratorTypesArray(), par);

714 outputMap, linalgOp.getIteratorTypesArray(), par);

715

716

717 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;

718 llvm::set_intersect(batch, outputDims);

719 llvm::set_subtract(batch, filterDims);

720

721

722 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;

723 llvm::set_intersect(oi, outputDims);

724

725

726 llvm::SmallDenseSet<int64_t> oc = filterDims;

727 llvm::set_intersect(oc, outputDims);

728 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);

729

730

731 llvm::SmallDenseSet<int64_t> depth = filterDims;

732 llvm::set_intersect(depth, outputDims);

733 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);

734

735 llvm::SmallDenseSet<int64_t> filterReducedDims =

737 linalgOp.getIteratorTypesArray(), red);

738

739

740 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;

741 llvm::set_intersect(fl, filterReducedDims);

742

743

744 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;

745 llvm::set_intersect(ic, filterReducedDims);

746

747 if (oi.empty() && !allowEmptyConvolvedDims)

748 return failure();

749

750

760 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());

761 llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());

762 llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());

763 llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());

764 llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());

765 llvm::sort(dimensions.depth.begin(), dimensions.depth.end());

766

767

768 auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");

769 if (!nativeStrides) {

771 for (unsigned oiDim : dimensions.outputImage)

772 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);

774 } else {

775 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());

776 }

777 auto nativeDilations =

779 if (!nativeDilations) {

781 for (unsigned flDim : dimensions.filterLoop)

782 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);

784 } else {

785 dimensions.dilations =

786 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());

787 }

788 return dimensions;

789 }

790

791

792

793

794

795

796

797

798

799

800

801

802

803

804

805

806

807

808

809

810

811

812

813

814

815 FailureOr

817 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)

818 return failure();

819

820 auto indexingMaps = linalgOp.getIndexingMapsArray();

821

822

823 ConvAccessExprWalker inputExprWalker;

824 for (AffineExpr expr : indexingMaps[0].getResults())

825 (void)inputExprWalker.visit(expr);

826 inputExprWalker.clearMultiUseDims(indexingMaps[0]);

827

829 false);

830 }

831

843 };

844 }

845

849 bool allowEmptyConvolvedDims) {

850 auto linalgOp = dyn_castlinalg::LinalgOp(op);

851 if (!linalgOp)

852 return MatchConvolutionResult::NotLinalgOp;

853 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)

854 return MatchConvolutionResult::WrongNumOperands;

855

856 auto indexingMaps = linalgOp.getIndexingMapsArray();

857

858

859 ConvAccessExprWalker inputExprWalker;

860 if (llvm::any_of(indexingMaps[0].getResults(),

861 [&inputExprWalker](AffineExpr expr) {

862 return failed(inputExprWalker.visit(expr));

863 })) {

864 return MatchConvolutionResult::WrongInputIndexingMap;

865 }

866

867

868 if (!indexingMaps[1].isProjectedPermutation() ||

869 !indexingMaps.back().isProjectedPermutation())

870 return MatchConvolutionResult::NotProjectedPermutations;

871

872 auto iteratorTypes = linalgOp.getIteratorTypesArray();

873

874 llvm::SmallDenseSet<int64_t> outputDims =

876 llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);

877

878

879

880

881

882

883

884

885

886

887

888

889

890 llvm::SmallDenseSet<int64_t> allLoopDims;

891 for (auto outputExpr : indexingMaps.back().getResults()) {

892 int64_t outputDim = cast(outputExpr).getPosition();

893 if (inputExprWalker.unConvolvedDims.count(outputDim) &&

894 !filterDims.count(outputDim)) {

895

896 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)

897 return MatchConvolutionResult::OutputDimsNotParallel;

898 allLoopDims.insert(outputDim);

899 continue;

900 }

901 if (inputExprWalker.convolvedDims.count(outputDim) &&

902 !filterDims.count(outputDim)) {

903

904 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)

905 return MatchConvolutionResult::OutputDimsNotParallel;

906 allLoopDims.insert(outputDim);

907 continue;

908 }

909 if (!inputExprWalker.convolvedDims.count(outputDim) &&

910 !inputExprWalker.unConvolvedDims.count(outputDim) &&

911 filterDims.count(outputDim)) {

912

913 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)

914 return MatchConvolutionResult::OutputDimsNotParallel;

915 allLoopDims.insert(outputDim);

916 continue;

917 }

918 if (inputExprWalker.unConvolvedDims.count(outputDim) &&

919 filterDims.count(outputDim)) {

920

921 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)

922 return MatchConvolutionResult::OutputDimsNotParallel;

923 allLoopDims.insert(outputDim);

924 continue;

925 }

926 return MatchConvolutionResult::NonConvolutionLoop;

927 }

928 for (auto filterExpr : indexingMaps[1].getResults()) {

929 int64_t filterDim = cast(filterExpr).getPosition();

930 if (outputDims.count(filterDim) &&

931 !inputExprWalker.unConvolvedDims.count(filterDim) &&

932 !inputExprWalker.convolvedDims.count(filterDim)) {

933

934 continue;

935 }

936 if (inputExprWalker.convolvedDims.count(filterDim) &&

937 !outputDims.count(filterDim)) {

938

939 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)

940 return MatchConvolutionResult::NonOutputDimNotReduction;

941 if (allLoopDims.count(filterDim))

942 return MatchConvolutionResult::NonConvolutionLoop;

943 allLoopDims.insert(filterDim);

944 continue;

945 }

946 if (inputExprWalker.unConvolvedDims.count(filterDim) &&

947 !outputDims.count(filterDim)) {

948

949 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)

950 return MatchConvolutionResult::NonOutputDimNotReduction;

951 if (allLoopDims.count(filterDim))

952 return MatchConvolutionResult::NonConvolutionLoop;

953 allLoopDims.insert(filterDim);

954 continue;

955 }

956 if (inputExprWalker.unConvolvedDims.count(filterDim) &&

957 outputDims.count(filterDim)) {

958

959 continue;

960 }

961 return MatchConvolutionResult::NonConvolutionLoop;

962 }

963

964 if (allLoopDims.size() != linalgOp.getNumLoops())

965 return MatchConvolutionResult::NonConvolutionLoop;

966

967 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())

968 return MatchConvolutionResult::EmptyConvolvedDims;

969

970 if (dimensions) {

972 linalgOp, inputExprWalker, allowEmptyConvolvedDims);

973 assert(succeeded(res) && "unexpected failure to infer convolution dims");

974 *dimensions = *res;

975 }

976

977 return MatchConvolutionResult::Success;

978 }

979

980 StringRef

982 switch (res) {

983 case MatchConvolutionResult::NotLinalgOp:

984 return "expected a LinalgOp";

985 case MatchConvolutionResult::WrongNumOperands:

986 return "expected op with 2 inputs and 1 output";

987 case MatchConvolutionResult::WrongInputIndexingMap:

988 return "unexpected input index map for convolutions";

989 case MatchConvolutionResult::NotProjectedPermutations:

990 return "expected output/filter indexing maps to be projected permutations";

991 case MatchConvolutionResult::NonConvolutionLoop:

992 return "unexpected loop dimension for convolution op";

993 case MatchConvolutionResult::OutputDimsNotParallel:

994 return "expected all iterators used to access outputs to be parallel";

995 case MatchConvolutionResult::NonOutputDimNotReduction:

996 return "expected all iterators not used to access outputs to be reduction";

997 case MatchConvolutionResult::EmptyConvolvedDims:

998 return "expected convolved dim to be non-empty";

999 case MatchConvolutionResult::Success:

1000 return "";

1001 }

1002 llvm_unreachable("unhandled MatchConvolutionResult case");

1003 }

1004

1006 bool allowEmptyConvolvedDims) {

1008 linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==

1010 }

1011

1014 if (res != MatchConvolutionResult::Success)

1016 return success();

1017 }

1018

1019

1020

1021

1022

1024 Success = 0,

1025 NotLinalgOp,

1026 WrongNumOperands,

1028 };

1029

1031 auto linalgOp = dyn_castlinalg::LinalgOp(op);

1032 if (!linalgOp)

1034 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)

1036

1037 OpOperand *value = linalgOp.getDpsInputOperand(0);

1038 if (!linalgOp.isScalar(value))

1040

1042 }

1043

1047 return op->emitError("expected a LinalgOp");

1049 return op->emitError("expected op with 1 input and 1 output");

1051 return op->emitError("expected op with scalar input");

1052

1053 return success();

1054 }

1055

1056

1057

1058

1059

1063 for (OpOperand &opOperand : getOperation()->getOpOperands()) {

1064 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)

1066 }

1067 return res;

1068 }

1069

1072 assert(!hasDynamicShape() && "expected operands to have static shapes");

1073 for (OpOperand &opOperand : getOperation()->getOpOperands())

1074 llvm::append_range(res, getShape(&opOperand));

1075 return res;

1076 }

1077

1079 AffineMap map = getLoopsToShapesMap();

1081 auto viewSizes = createFlatListOfOperandDims(b, loc);

1083 for (unsigned idx = 0; idx < numRes; ++idx) {

1084 auto result = map.getResult(idx);

1085 if (auto d = dyn_cast(result)) {

1086 if (res[d.getPosition()].offset)

1087 continue;

1088 res[d.getPosition()] =

1090 }

1091 }

1092 return res;

1093 }

1094

1095

1096

1100 : positions(std::move(positions)) {}

1101

1104 }

1105

1107 return positions.test(dimExpr.getPosition());

1108 }

1109

1111

1113

1114 private:

1115 llvm::SmallBitVector positions;

1116 };

1117

1118 static std::pair<int64_t, int64_t>

1120 int64_t inputRankSum = 0;

1121 int64_t outputRankSum = 0;

1122 for (OpOperand *input : op.getDpsInputOperands())

1123 inputRankSum += op.getRank(input);

1124 for (OpOperand &output : op.getDpsInitsMutable())

1125 outputRankSum += op.getRank(&output);

1126 return {inputRankSum, inputRankSum + outputRankSum};

1127 }

1128

1129 LogicalResult

1132

1133

1134

1135

1136

1137

1138

1139

1140

1141 AffineMap loopsToShapesMap = getLoopsToShapesMap();

1142

1143

1144

1146

1147

1148

1150 resultShapesSubMapPos.first,

1151 resultShapesSubMapPos.second - resultShapesSubMapPos.first);

1152 AffineMap resultShapesFromInputShapesMap =

1153 loopToResultsShapeMap.compose(getShapesToLoopsMap());

1154

1155

1156

1157 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());

1158 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);

1160 Location loc = getOperation()->getLoc();

1164 rewriter, loc, resultShapesFromInputShapesMap,

1165 createFlatListOfOperandDims(b, loc));

1166 int64_t pos = 0;

1168 for (OpOperand &opOperand : getDpsInitsMutable()) {

1170 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {

1171 auto shapedType = llvm::cast(opOperand.get().getType());

1172 if (!shapedType.isDynamicDim(dim)) {

1173

1174 shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));

1175 } else {

1176

1177 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])

1179 : allResultDimValues[pos];

1181 }

1182 pos++;

1183 }

1184 reifiedReturnShapes.emplace_back(std::move(shapes));

1185 }

1186 return success();

1187 }

1188

1189

1190

1191 int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {

1193 auto dpsIface = cast(*this->getOperation());

1194 if (!dpsIface.isDpsInput(opOperand))

1195 return operandNumber;

1196 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();

1197 assert(!dpsIface.isDpsInit(opOperand));

1198

1199

1200 return cast(*this->getOperation())

1201 .getNumDpsInputs() +

1202 operandNumber - start;

1203 }

1204

1206 LinalgOp linalgOp = cast(op);

1207

1208 if (!linalgOp.hasPureTensorSemantics() &&

1209 !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)

1210 return op->emitOpError("expected to have pure tensor or buffer semantics");

1211

1212

1213

1214 if (linalgOp.hasDynamicIndexingMaps())

1215 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))

1216 return failure();

1217

1218

1219 if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=

1220 linalgOp->getNumOperands())

1221 return op->emitOpError("expected the number of indexing_map (")

1222 << linalgOp.getIndexingMapsArray().size()

1223 << ") to be equal to the number of input/output operands ("

1224 << linalgOp->getNumOperands() << ")";

1225

1226

1227

1228 for (OpOperand &opOperand : linalgOp->getOpOperands()) {

1229 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);

1230

1231

1233 return op->emitOpError("unexpected symbols in indexing_map #")

1235

1236

1237 unsigned numLoops = linalgOp.getNumLoops();

1238 if (indexingMap.getNumDims() != numLoops)

1239 return op->emitOpError("expected indexing_map #")

1241 << " dim(s) to match the number of loops";

1242

1243 int64_t rank = linalgOp.getRank(&opOperand);

1244

1246 return op->emitOpError("expected operand rank (")

1247 << rank << ") to match the result rank of indexing_map #"

1250 }

1252 linalgOp.getReductionDims(redDims);

1253

1254 if (!linalgOp.getShapesToLoopsMap())

1255 return op->emitOpError("expected the shape-to-loops map to be non-null");

1256

1257

1260

1261

1262 if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {

1263 for (int64_t &range : endLoopRangeValues)

1264 range -= 1;

1265 for (OpOperand &opOperand : linalgOp->getOpOperands()) {

1266 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);

1268 indexingMap.compose(startLoopRangeValues);

1270 indexingMap.compose(endLoopRangeValues);

1272 for (auto dim : llvm::seq<int64_t>(0, shape.size())) {

1273

1274 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)

1275 continue;

1276

1277

1278

1279

1280

1281

1282

1283

1284

1285

1286

1287 int64_t inferredDimSize =

1288 std::max(startIndices[dim], endIndices[dim]) + 1;

1289 if (std::min(startIndices[dim], endIndices[dim]) < 0) {

1290 std::string mapStr;

1291 {

1292 llvm::raw_string_ostream os(mapStr);

1293 os << indexingMap;

1294 }

1296 "unexpected result less than 0 at expression #")

1297 << dim << " in " << mapStr;

1298 }

1299 if (isa(indexingMap.getResult(dim))) {

1300 if (inferredDimSize != shape[dim]) {

1301 return op->emitOpError("inferred input/output operand #")

1302 << opOperand.getOperandNumber() << " has shape's dimension #"

1303 << dim << " to be " << inferredDimSize << ", but found "

1304 << shape[dim];

1305 }

1306 } else {

1307 if (inferredDimSize > shape[dim]) {

1308 return op->emitOpError("inferred input/output operand #")

1309 << opOperand.getOperandNumber() << " has shape's dimension #"

1310 << dim << " to be greater than or equal to "

1311 << inferredDimSize << ", but found " << shape[dim];

1312 }

1313 }

1314 }

1315 }

1316 }

1317

1318

1319 if (linalgOp->getNumRegions() != 1 ||

1320 !llvm::hasSingleElement(linalgOp->getRegion(0)))

1321 return op->emitOpError("expects to have 1 region with 1 block");

1322

1323

1324

1325

1326

1327

1328

1329 Block &block = linalgOp->getRegion(0).front();

1330

1331 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())

1332 return op->emitOpError("expected as many non-induction variable region "

1333 "arguments as the number of input/output operands");

1334

1335 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {

1337 if (isa<MemRefType, RankedTensorType>(elementType))

1340 if (elementType != argType)

1341 return op->emitOpError("expected type of bb argument #")

1343 << " to match element or self type of the corresponding operand ("

1344 << elementType << ")";

1345 }

1346

1347 return success();

1348 }

static void visit(Operation *op, DenseSet< Operation * > &visited)

Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.

static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)

Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...

static Value getSourceSkipUnary(Value value)

If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...

static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)

Of the given two expressions returns one that is of type T (lhs gets preference over rhs)

static bool isPairTemplateImpl(Operation *add, Operation *mul)

Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...

static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)

static MatchFillResult isFillInterfaceImpl(Operation *op)

static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)

Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...

static bool isContractionBody(Block &block)

Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...

static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)

static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)

static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)

Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...

static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)

Infer the iterator types from the init affine map.

static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)

static ArrayRef< int64_t > getShape(Type type)

Returns the shape of the given type.

Affine binary operation expression.

AffineExpr getLHS() const

AffineExpr getRHS() const

An integer constant appearing in affine expression.

A dimensional identifier appearing in an affine expression.

unsigned getPosition() const

See documentation for AffineExprVisitorBase.

Base type for affine expression.

AffineExprKind getKind() const

Return the classification for this type.

MLIRContext * getContext() const

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

AffineMap getSliceMap(unsigned start, unsigned length) const

Returns the map consisting of length expressions starting from start.

bool isProjectedPermutation(bool allowZeroInResults=false) const

Returns true if the AffineMap represents a subset (i.e.

unsigned getNumSymbols() const

unsigned getNumDims() const

ArrayRef< AffineExpr > getResults() const

unsigned getNumResults() const

AffineExpr getResult(unsigned idx) const

AffineMap compose(AffineMap map) const

Returns the AffineMap resulting from composing this with map.

A symbolic identifier appearing in an affine expression.

Block represents an ordered list of Operations.

BlockArgument getArgument(unsigned i)

unsigned getNumArguments()

Operation * getTerminator()

Get the terminator operation of this block.

OpListType & getOperations()

IntegerAttr getIndexAttr(int64_t value)

An attribute that represents a reference to a dense integer vector or tensor object.

IRValueT get() const

Return the current value being used by this operand.

This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class helps build Operations.

This class represents a single result from folding an operation.

This class represents an operand of an operation.

unsigned getOperandNumber()

Return which operand this is in the OpOperand list of the Operation.

This class provides the API for ops that are known to be terminators.

Operation is the basic unit of execution within MLIR.

Value getOperand(unsigned idx)

bool mightHaveTrait()

Returns true if the operation might have the provided trait.

unsigned getNumOperands()

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

unsigned getNumResults()

Return the number of results held by this operation.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Variant of makeComposedFoldedAffineApply suitable for multi-result maps.

MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)

Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...

@ NotProjectedPermutations

bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())

Returns true if the block contains a contraction of the following form:

StringRef getMatchConvolutionMessage(MatchConvolutionResult res)

Returns the error message corresponding to the convolution checking return code.

bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)

Implementation of the method that check if given operands can be dropped, i.e.

MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)

Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...

LogicalResult verifyContractionInterface(Operation *op)

Verify that op conforms to ContractionOpInterface.

@ NotProjectedPermutations

@ NonOutputDimNotReduction

LogicalResult verifyFillInterface(Operation *op)

Verify that op conforms to the FillOpInterface.

StringRef getMatchContractionMessage(MatchContractionResult res)

Returns the error message corresponding to the contraction checking return code.

LogicalResult verifyStructuredOpInterface(Operation *op)

Verify that op conforms to the invariants of StructuredOpInterface.

LogicalResult verifyConvolutionInterface(Operation *op)

Verify that op conforms to the ConvolutionOpInterface.

std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)

Checks whether genericOp is semantically equivalent to a linalg.transpose.

bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)

Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.

bool isaCopyOpInterface(LinalgOp linalgOp)

Checks whether linalgOp is semantically equivalent to a linalg.copyOp.

FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)

Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...

OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)

Create one memref::DimOp or tensor::DimOp depending on the type of val.

bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)

Checks whether linalgOp conforms to ConvolutionOpInterface.

std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)

Checks whether genericOp is semantically equivalent to a linalg.broadcast.

FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)

Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...

Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)

Create one memref::DimOp or tensor::DimOp depending on the type of val.

bool isaContractionOpInterface(LinalgOp linalgOp)

Checks whether linalgOp conforms to ContractionOpInterface.

std::optional< Value > isaFillOpInterface(GenericOp genericOp)

Checks whether genericOp is semantically equivalent to a linalg.fill.

bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)

Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....

Include the generated interface declarations.

AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)

Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.

LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)

Reify the shape of the result of an operation (typically in terms of the shape of its operands).

AffineMap inversePermutation(AffineMap map)

Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...

@ Mul

RHS of mul is always a constant or a symbolic expression.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)

Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...

HasAffineDimExprVisitor(llvm::SmallBitVector positions)

bool visitDimExpr(AffineDimExpr dimExpr)

bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)

bool visitSymbolExpr(AffineSymbolExpr symbolExpr)

bool visitConstantExpr(AffineConstantExpr constExpr)

Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...

Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.

Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.