MLIR: lib/Dialect/Utils/ReshapeOpsUtils.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

14 #include "llvm/ADT/ArrayRef.h"

15 #include "llvm/ADT/SmallVector.h"

16 #include "llvm/Support/LogicalResult.h"

17

18 #include

19 #include

20

21 using namespace mlir;

22

23 std::optional<SmallVector>

25 ShapedType targetType) {

26 if (sourceType.getRank() > targetType.getRank())

28 targetType.getShape());

29 if (sourceType.getRank() < targetType.getRank())

31 sourceType.getShape());

32 return std::nullopt;

33 }

34

35 namespace {

36

37

38

39 struct ReassociationIndexRange {

40

41

42

43 int64_t leftIdx = 0, rightIdx = 0;

44

45

46 LogicalResult verify() const {

47 return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();

48 }

49

50

51

52 bool isInRange(const ReassociationIndexRange &outerRange) const {

53 return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;

54 }

55

56 unsigned size() const {

57 assert(succeeded(verify()));

58 return rightIdx - leftIdx + 1;

59 }

60 bool containsSingleIndex() const { return size() == 1; }

61

62

64 getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {

65 if (rightIdx < rhs.leftIdx) {

66

67 auto jointFullIndices = getFullIndices();

68 jointFullIndices.append(rhs.getFullIndices());

69 return jointFullIndices;

70 }

72

73 int64_t leftStart = std::min(leftIdx, rhs.leftIdx);

74 int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);

75 llvm::append_range(result, llvm::seq(leftStart, leftEnd));

76

77

78 int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;

79 int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);

80 if (rightStart < rightEnd)

81 llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));

82 return result;

83 }

84

85

88 for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {

89 result.push_back(idx);

90 }

91 return result;

92 }

93 };

94 }

95

96

97

98

99

100

101

102 static FailureOr

104 int64_t sourceStartIdx,

105 bool matchGreedily = false) {

106 const unsigned numSourceDims = sourceShape.size();

107 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};

108 std::optional resultRange = std::nullopt;

109

110 ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};

111 for (; iterationRange.isInRange(sourceShapeAsRange);

112 iterationRange.rightIdx++) {

113 int64_t sourceSize = sourceShape[iterationRange.rightIdx];

114 if (sourceSize == ShapedType::kDynamic) {

115 resultRange = iterationRange;

116 break;

117 }

118 }

119 if (!resultRange)

120 return failure();

121 if (matchGreedily)

122 resultRange->rightIdx = sourceShapeAsRange.rightIdx;

123 return *resultRange;

124 }

125

126

127

128

129

130

131 static FailureOr

133 int64_t sourceStartIdx, int64_t targetSize,

134 bool matchGreedily = false) {

135 const unsigned numSourceDims = sourceShape.size();

136 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};

137 std::optional resultRange = std::nullopt;

138

139 ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};

140 int64_t prodOfCollapsedDims = 1;

141 while (iterationRange.isInRange(sourceShapeAsRange)) {

142 int64_t sourceSize = sourceShape[iterationRange.rightIdx];

143 if (sourceSize == ShapedType::kDynamic) {

144

145

146

147 prodOfCollapsedDims = 1;

148 iterationRange = {iterationRange.rightIdx + 1,

149 iterationRange.rightIdx + 1};

150 continue;

151 }

152 prodOfCollapsedDims *= sourceSize;

153

154

155

156 while (prodOfCollapsedDims > targetSize &&

157 !iterationRange.containsSingleIndex()) {

158 int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];

159 prodOfCollapsedDims /= frontSourceSize;

160

161 iterationRange.leftIdx++;

162 }

163

164

165 if (prodOfCollapsedDims == targetSize) {

166 resultRange = iterationRange;

167 break;

168 }

169

170 iterationRange.rightIdx++;

171 }

172 if (!resultRange)

173 return failure();

174 if (matchGreedily) {

175

176

177

178 iterationRange.rightIdx++;

179 while (iterationRange.isInRange(sourceShapeAsRange) &&

180 sourceShape[iterationRange.rightIdx] == 1) {

181 resultRange = iterationRange;

182 iterationRange.rightIdx++;

183 }

184 }

185 return *resultRange;

186 }

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201 static FailureOr<SmallVector>

204 unsigned numSourceDims = sourceShape.size(),

205 numTargetDims = targetShape.size();

206 assert(numSourceDims > numTargetDims);

207 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};

208

210 reassocRanges.reserve(numTargetDims);

211

212

213

214 std::optional<int64_t> prevTargetSize = std::nullopt;

215 for (unsigned targetDimIdx = 0, sourceDimIdx = 0;

216 targetDimIdx < numTargetDims; ++targetDimIdx) {

217 int64_t targetSize = targetShape[targetDimIdx];

218

219

220 bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;

221 FailureOr sourceRange;

222 if (targetSize == ShapedType::kDynamic) {

224 sourceShape, sourceDimIdx, shouldMatchGreedily);

225 } else {

227 sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);

228 }

229

230

231 if (failed(sourceRange) || failed(sourceRange->verify()) ||

232 !sourceRange->isInRange(sourceShapeAsRange))

233 return failure();

234 if (sourceRange->leftIdx > sourceDimIdx) {

235

236

237 if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)

238 return failure();

239 reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;

240 }

241

242

243 prevTargetSize = targetSize;

244 sourceDimIdx = sourceRange->rightIdx + 1;

245 reassocRanges.push_back(*sourceRange);

246 }

247

248

249

250 if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)

251 return failure();

252 return reassocRanges;

253 }

254

255

256

257 static FailureOr<SmallVector>

260 bool iterateRightToLeft) {

261 if (!iterateRightToLeft)

263

264

265

266

267

268 std::vector<int64_t> sourceToReverse = sourceShape.vec(),

269 targetToReverse = targetShape.vec();

270 std::reverse(sourceToReverse.begin(), sourceToReverse.end());

271 std::reverse(targetToReverse.begin(), targetToReverse.end());

272 auto invertedRanges =

274 if (failed(invertedRanges))

275 return failure();

277 unsigned numSourceDims = sourceShape.size();

278

279

280 for (auto &range : rangesToInvert) {

281 int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;

282 range.leftIdx = numSourceDims - 1 - invRightIdx;

283 range.rightIdx = numSourceDims - 1 - invLeftIdx;

284 }

285

286

287 std::reverse(rangesToInvert.begin(), rangesToInvert.end());

288 return rangesToInvert;

289 }

290

291 std::optional<SmallVector>

294 unsigned numSourceDims = sourceShape.size(),

295 numTargetDims = targetShape.size();

296

297

298

299

300 if (numSourceDims <= numTargetDims)

301 return std::nullopt;

302

303

304

305 if (numTargetDims == 0) {

306 for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;

307 ++sourceDimIdx) {

308 int64_t sourceSize = sourceShape[sourceDimIdx];

309 if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)

310 return std::nullopt;

311 }

313 }

314

315

316 FailureOr<SmallVector> maybeForwardRanges =

318 if (failed(maybeForwardRanges))

319 return std::nullopt;

320 auto &ranges = *maybeForwardRanges;

321

322

323

324

325

326

327

328

329 FailureOr<SmallVector> maybeReverseRanges =

331 true);

332 if (failed(maybeReverseRanges))

333 return std::nullopt;

334 auto &reverseRanges = *maybeReverseRanges;

335

336 if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)

337 return std::nullopt;

338

339

340

342 for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;

343 ++targetDimIdx) {

344 ReassociationIndexRange &range = ranges[targetDimIdx];

345 ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];

346

348 range.getNonOverlappingIndicesWith(reverseRange);

349

350

351 for (int64_t sourceDimIdx : nonMatchingIndices) {

352 if (sourceShape[sourceDimIdx] != 1)

353 return std::nullopt;

354 }

355 reassociationMap[targetDimIdx] = range.getFullIndices();

356 }

357 return reassociationMap;

358 }

359

360 std::optional<SmallVector>

366

367

368 if (producerReassociations.size() == consumerReassociations.size())

369 return std::nullopt;

370 if (producerReassociations.size() < consumerReassociations.size())

371 std::swap(producerReassociations, consumerReassociations);

372

373

374

375 if (consumerReassociations.empty())

376 return composedIndices;

377

378 size_t consumerDims = std::accumulate(

379 consumerReassociations.begin(), consumerReassociations.end(), 0,

381 return all + indices.size();

382 });

383 if (producerReassociations.size() != consumerDims)

384 return std::nullopt;

385

388 for (int64_t consumerIndex : consumerIndices) {

389 llvm::append_range(reassociations, producerReassociations[consumerIndex]);

390 }

391 composedIndices.push_back(std::move(reassociations));

392 }

393 return composedIndices;

394 }

395

400 for (const auto &indices : reassociationIndices) {

402 reassociationMap.reserve(indices.size());

403 for (int64_t index : indices)

405 reassociationMaps.push_back(std::move(reassociationMap));

406 }

407 return reassociationMaps;

408 }

409

410 template

412 unsigned pos = 0;

413 for (const auto &exprs : exprArrays) {

414 for (auto expr : exprs) {

416 if (auto d = dyn_cast(e))

417 pos = std::max(pos, d.getPosition());

418 });

419 }

420 }

421 return pos;

422 }

423

427 llvm::to_vector<4>(llvm::map_range(

430 }));

432 }

433

437 for (const auto &exprs : reassociationExprs) {

439 indices.reserve(exprs.size());

440 for (const auto &expr : exprs)

441 indices.push_back(cast(expr).getPosition());

442 reassociationIndices.push_back(indices);

443 }

444 return reassociationIndices;

445 }

446

449 unsigned maxDim = getMaxPosOfType(reassociation);

450 assert(getMaxPosOfType(reassociation) == 0 &&

451 "Expected symbol-less expressions");

453 maps.reserve(reassociation.size());

454 for (const auto &exprs : reassociation) {

455 assert(!exprs.empty());

457 }

458 return maps;

459 }

460

462 int *invalidIndex) {

463 if (reassociation.empty())

464 return true;

465 unsigned nDims = reassociation[0].getNumDims();

466 unsigned nextExpectedDim = 0;

468 auto m = it.value();

469 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {

470 if (invalidIndex)

471 *invalidIndex = it.index();

472 return false;

473 }

474 for (auto e : m.getResults()) {

475 auto d = dyn_cast(e);

476 if (!d || d.getPosition() != nextExpectedDim++) {

477 if (invalidIndex)

478 *invalidIndex = it.index();

479 return false;

480 }

481 }

482 }

483 if (nextExpectedDim != nDims) {

484 if (invalidIndex)

485 *invalidIndex = reassociation.size() - 1;

486 return false;

487 }

488 return true;

489 }

490

495 unsigned expandedDimStart = 0;

496 for (const auto &map : llvm::enumerate(reassociationMaps)) {

497 bool foundDynamicShape = false;

498 int64_t linearizedStaticShape = 1;

499

501 expandedShape.slice(expandedDimStart, map.value().size()))) {

502 if (ShapedType::isDynamic(dim.value()))

503 foundDynamicShape = true;

504 else

505 linearizedStaticShape *= dim.value();

506 }

507 if (foundDynamicShape) {

508 if (!ShapedType::isDynamic(collapsedShape[map.index()])) {

510 "expected dimension " + Twine(map.index()) +

511 " of collapsed type to be dynamic since one or more of the "

512 "corresponding dimensions in the expanded type is dynamic");

513 }

514 } else {

515 if (collapsedShape[map.index()] != linearizedStaticShape) {

516 return emitError("expected dimension " + Twine(map.index()) +

517 " of collapsed type to be static value of " +

518 Twine(linearizedStaticShape));

519 }

520 }

521 expandedDimStart += map.value().size();

522 }

523 return success();

524 }

525

527 if (auto memrefType = dyn_cast(type))

528 return !memrefType.getLayout().isIdentity();

529 return false;

530 }

531

532 llvm::SmallBitVector

535 assert(sliceParams.size() == sliceInputShape.size() &&

536 "only supports non rank-reducing case");

537 llvm::SmallBitVector mask(sliceInputShape.size());

538 unsigned idx = 0;

539 for (const auto &[offset, size, stride] : sliceParams) {

543 (!strideConst || *strideConst != 1) ||

544 (!offsetConst || *offsetConst != 0);

545 idx++;

546 }

547 return mask;

548 }

549

552 llvm::SmallBitVector result(reassociationIndices.size());

553 for (const auto &it : llvm::enumerate(reassociationIndices))

554 result[it.index()] = it.value().size() > 1;

555 return result;

556 }

557

560 unsigned loopIdx = 0;

564 offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());

565 for (const auto &it : llvm::enumerate(reassociationIndices)) {

566

567

568

569 if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {

570 llvm::append_range(

571 offsetsSizesAndStrides,

572 llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range {

574 }));

575 continue;

576 }

577

578

579

580

581 if (linearizedDimensions[it.index()]) {

582 llvm::append_range(offsetsSizesAndStrides,

583 llvm::map_range(it.value(), [&](int64_t idx) -> Range {

584 return {zeroAttr, collapseShapeInputShape[idx],

585 oneAttr};

586 }));

587 continue;

588 }

589

590

591 offsetsSizesAndStrides.push_back(sliceParams[it.index()]);

592 }

593 return offsetsSizesAndStrides;

594 }

595

597 SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,

602 insertParams.reserve(linearizedDimensions.size());

603 unsigned loopIdx = 0;

604 for (unsigned i = 0; i < linearizedDimensions.size(); i++) {

605 if (linearizedDimensions[i] && slicedDimensions[i]) {

606 insertParams.push_back(Range{tileIndices[loopIdx++], one, one});

607 continue;

608 }

609 insertParams.push_back(Range{zero, sliceParams[i].size, one});

610 }

611 return insertParams;

612 }

613

614

615

616

619

620 std::optional<int64_t> dimIndex;

621 if (indices.size() < 2)

622 return std::nullopt;

623 for (int64_t idx : indices) {

624 if (shape[idx] != 1) {

625 if (dimIndex != std::nullopt)

626 return std::nullopt;

627 dimIndex = idx;

628 }

629 }

630 return dimIndex;

631 }

632

633

634

635

637 RankedTensorType sourceType,

640 for (const auto &indices : reassociationIndices)

641 trivialSegments.push_back(

643 return trivialSegments;

644 }

645

646

647

648 static FailureOr<SmallVector<std::optional<int64_t>>>

650 RankedTensorType sourceType,

654 if (!llvm::any_of(trivialSegments, [](const std::optional<int64_t> &idx) {

655 return idx.has_value();

656 }))

657 return failure();

658 return trivialSegments;

659 }

660

661 FailureOr

662 mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(

663 RankedTensorType sourceType,

665 FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments =

667 reassociationIndices);

668 if (failed(trivialSegments))

669 return failure();

670

671

673 for (const auto &[nonUnitDim, indices] :

674 llvm::zip(*trivialSegments, reassociationIndices)) {

675 if (nonUnitDim) {

676 sliceShape.push_back(sourceType.getDimSize(*nonUnitDim));

677 continue;

678 }

679 llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {

680 return sourceType.getDimSize(idx);

681 }));

682 }

683 auto sliceType =

685

686

687 if (sliceShape.size() == reassociationIndices.size())

688 return CollapseShapeRankReducingSliceSimplificationInfo{sliceType,

689 std::nullopt};

690

691

692

693

696 int64_t groupIdx = 0;

697 for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {

698 reassociation.push_back(dimIdx);

699 if ((*trivialSegments)[groupIdx] ||

700 reassociation.size() == reassociationIndices[groupIdx].size()) {

701 newReassociationIndices.push_back(reassociation);

702 reassociation.clear();

703 groupIdx++;

704 }

705 }

706

707 return CollapseShapeRankReducingSliceSimplificationInfo{

708 sliceType, newReassociationIndices};

709 }

710

711 PackingMetadata mlir::computePackingMetadata(int64_t packedRank,

713 PackingMetadata res;

714 res.insertPositions.reserve(innerDimPos.size());

715

716

717

718

719

720

721

722

723

724

725

726

727

728

729 int64_t offset = 1;

730 for (int64_t pos : innerDimPos) {

731 int64_t numInsertedBefore = llvm::count_if(

732 innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });

733 res.insertPositions.push_back(pos + numInsertedBefore + offset);

734 }

735

737 res.insertPositions.end());

738 res.reassociations.reserve(packedRank);

739 for (int64_t i = 1; i <= packedRank; ++i) {

740 res.outerPositions.push_back(i - 1);

741 if (!posSet.contains(i)) {

743 continue;

744 }

746 ++i;

747 }

748 return res;

749 }

750

753 std::optional cst) {

754 if (source && source.isSplat() && result.hasStaticShape() &&

757

758 return {};

759 }

static MLIRContext * getContext(OpFoldResult val)

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)

unsigned getMaxPosOfType(ArrayRef< ReassociationExprs > exprArrays)

static FailureOr< ReassociationIndexRange > findReassociationRangeForSize(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, int64_t targetSize, bool matchGreedily=false)

Starting from sourceStartIdx, searches sourceShape for the first sequence of static dimensions such t...

static SmallVector< std::optional< int64_t > > getCollapseShapeTrivialSegments(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)

static FailureOr< ReassociationIndexRange > findReassociationRangeForDynamicDim(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, bool matchGreedily=false)

Starting from sourceStartIdx, searches sourceShape for the first sequence that can be collapsed into ...

static std::optional< int64_t > getUniqueNonUnitDim(ArrayRef< int64_t > indices, ArrayRef< int64_t > shape)

Returns the index of the only non-unit dimension among indices of shape, if such a dimension exists a...

static FailureOr< SmallVector< ReassociationIndexRange > > findReassociationRangesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)

Attempts to find a valid collapsing reassociation of sourceShape into targetShape through a simple tr...

static FailureOr< SmallVector< std::optional< int64_t > > > canCollapseShapeBeSimplifiedByRankReducingSlice(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)

Returns true if any of the segments of the reassociation indices for a collapsing reshape can be simp...

Base type for affine expression.

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

Attributes are known-constant values of operations.

This class is a general helper class for creating context-global objects like types,...

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)

An attribute that represents a reference to a dense vector or tensor object.

std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const

Return the splat value for this attribute.

DenseElementsAttr resizeSplat(ShapedType newType)

Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...

bool isSplat() const

Returns true if this attribute corresponds to a splat, i.e.

MLIRContext is the top-level object for a collection of MLIR operations.

This class represents a single result from folding an operation.

Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

Include the generated interface declarations.

llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)

The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...

bool hasNonIdentityLayout(Type type)

Returns true iff the type is a MemRefType and has a non-identity layout.

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)

Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)

Constructs affine maps out of Array<Array>.

SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)

Convert Array<Array> to Array<Array<int64_t>>.

LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)

Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...

std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)

Return the reassociations maps to use to reshape given the source type and the target type when possi...

std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)

Returns the reassociation maps to collapse sourceShape to targetShape if possible.

ArrayRef< int64_t > ReassociationIndicesRef

SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)

Convert reassociation indices to affine expressions.

bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)

Return true if the reassociation specification is valid, false otherwise.

std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)

Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)

Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...

AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)

These free functions allow clients of the API to not use classes in detail.

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)

Wraps a list of reassociations in an ArrayAttr.

Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...