MLIR: lib/Dialect/Utils/ReshapeOpsUtils.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

14#include "llvm/ADT/ArrayRef.h"

15#include "llvm/ADT/SmallVector.h"

16

17#include

18#include

19

20using namespace mlir;

21

22std::optional<SmallVector>

24 ShapedType targetType) {

25 if (sourceType.getRank() > targetType.getRank())

27 targetType.getShape());

28 if (sourceType.getRank() < targetType.getRank())

30 sourceType.getShape());

31 return std::nullopt;

32}

33

34namespace {

35

36

37

38struct ReassociationIndexRange {

39

40

41

42 int64_t leftIdx = 0, rightIdx = 0;

43

44

45 LogicalResult verify() const {

46 return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();

47 }

48

49

50

51 bool isInRange(const ReassociationIndexRange &outerRange) const {

52 return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;

53 }

54

55 unsigned size() const {

56 assert(succeeded(verify()));

57 return rightIdx - leftIdx + 1;

58 }

59 bool containsSingleIndex() const { return size() == 1; }

60

61

63 getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {

64 if (rightIdx < rhs.leftIdx) {

65

66 auto jointFullIndices = getFullIndices();

67 jointFullIndices.append(rhs.getFullIndices());

68 return jointFullIndices;

69 }

71

72 int64_t leftStart = std::min(leftIdx, rhs.leftIdx);

73 int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);

74 llvm::append_range(result, llvm::seq(leftStart, leftEnd));

75

76

77 int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;

78 int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);

79 if (rightStart < rightEnd)

80 llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));

82 }

83

84

87 for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {

88 result.push_back(idx);

89 }

91 }

92};

93}

94

95

96

97

98

99

100

101static FailureOr

104 bool matchGreedily = false) {

105 const unsigned numSourceDims = sourceShape.size();

106 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};

107 std::optional resultRange = std::nullopt;

108

109 ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};

110 for (; iterationRange.isInRange(sourceShapeAsRange);

111 iterationRange.rightIdx++) {

112 int64_t sourceSize = sourceShape[iterationRange.rightIdx];

113 if (sourceSize == ShapedType::kDynamic) {

114 resultRange = iterationRange;

115 break;

116 }

117 }

118 if (!resultRange)

119 return failure();

120 if (matchGreedily)

121 resultRange->rightIdx = sourceShapeAsRange.rightIdx;

122 return *resultRange;

123}

124

125

126

127

128

129

130static FailureOr

133 bool matchGreedily = false) {

134 const unsigned numSourceDims = sourceShape.size();

135 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};

136 std::optional resultRange = std::nullopt;

137

138 ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};

139 int64_t prodOfCollapsedDims = 1;

140 while (iterationRange.isInRange(sourceShapeAsRange)) {

141 int64_t sourceSize = sourceShape[iterationRange.rightIdx];

142 if (sourceSize == ShapedType::kDynamic) {

143

144

145

146 prodOfCollapsedDims = 1;

147 iterationRange = {iterationRange.rightIdx + 1,

148 iterationRange.rightIdx + 1};

149 continue;

150 }

151 prodOfCollapsedDims *= sourceSize;

152

153

154

155 while (prodOfCollapsedDims > targetSize &&

156 !iterationRange.containsSingleIndex()) {

157 int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];

158 prodOfCollapsedDims /= frontSourceSize;

159

160 iterationRange.leftIdx++;

161 }

162

163

164 if (prodOfCollapsedDims == targetSize) {

165 resultRange = iterationRange;

166 break;

167 }

168

169 iterationRange.rightIdx++;

170 }

171 if (!resultRange)

172 return failure();

173 if (matchGreedily) {

174

175

176

177 iterationRange.rightIdx++;

178 while (iterationRange.isInRange(sourceShapeAsRange) &&

179 sourceShape[iterationRange.rightIdx] == 1) {

180 resultRange = iterationRange;

181 iterationRange.rightIdx++;

182 }

183 }

184 return *resultRange;

185}

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200static FailureOr<SmallVector>

203 unsigned numSourceDims = sourceShape.size(),

204 numTargetDims = targetShape.size();

205 assert(numSourceDims > numTargetDims);

206 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};

207

209 reassocRanges.reserve(numTargetDims);

210

211

212

213 std::optional<int64_t> prevTargetSize = std::nullopt;

214 for (unsigned targetDimIdx = 0, sourceDimIdx = 0;

215 targetDimIdx < numTargetDims; ++targetDimIdx) {

216 int64_t targetSize = targetShape[targetDimIdx];

217

218

219 bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;

220 FailureOr sourceRange;

221 if (targetSize == ShapedType::kDynamic) {

223 sourceShape, sourceDimIdx, shouldMatchGreedily);

224 } else {

226 sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);

227 }

228

229

230 if (failed(sourceRange) || failed(sourceRange->verify()) ||

231 !sourceRange->isInRange(sourceShapeAsRange))

232 return failure();

233 if (sourceRange->leftIdx > sourceDimIdx) {

234

235

236 if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)

237 return failure();

238 reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;

239 }

240

241

242 prevTargetSize = targetSize;

243 sourceDimIdx = sourceRange->rightIdx + 1;

244 reassocRanges.push_back(*sourceRange);

245 }

246

247

248

249 if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)

250 return failure();

251 return reassocRanges;

252}

253

254

255

256static FailureOr<SmallVector>

259 bool iterateRightToLeft) {

260 if (!iterateRightToLeft)

262

263

264

265

266

267 std::vector<int64_t> sourceToReverse = sourceShape.vec(),

268 targetToReverse = targetShape.vec();

269 std::reverse(sourceToReverse.begin(), sourceToReverse.end());

270 std::reverse(targetToReverse.begin(), targetToReverse.end());

271 auto invertedRanges =

273 if (failed(invertedRanges))

274 return failure();

276 unsigned numSourceDims = sourceShape.size();

277

278

279 for (auto &range : rangesToInvert) {

280 int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;

281 range.leftIdx = numSourceDims - 1 - invRightIdx;

282 range.rightIdx = numSourceDims - 1 - invLeftIdx;

283 }

284

285

286 std::reverse(rangesToInvert.begin(), rangesToInvert.end());

287 return rangesToInvert;

288}

289

290std::optional<SmallVector>

293 unsigned numSourceDims = sourceShape.size(),

294 numTargetDims = targetShape.size();

295

296

297

298

299 if (numSourceDims <= numTargetDims)

300 return std::nullopt;

301

302

303

304 if (numTargetDims == 0) {

305 for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;

306 ++sourceDimIdx) {

307 int64_t sourceSize = sourceShape[sourceDimIdx];

308 if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)

309 return std::nullopt;

310 }

312 }

313

314

315 FailureOr<SmallVector> maybeForwardRanges =

317 if (failed(maybeForwardRanges))

318 return std::nullopt;

319 auto &ranges = *maybeForwardRanges;

320

321

322

323

324

325

326

327

328 FailureOr<SmallVector> maybeReverseRanges =

330 true);

331 if (failed(maybeReverseRanges))

332 return std::nullopt;

333 auto &reverseRanges = *maybeReverseRanges;

334

335 if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)

336 return std::nullopt;

337

338

339

341 for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;

342 ++targetDimIdx) {

343 ReassociationIndexRange &range = ranges[targetDimIdx];

344 ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];

345

347 range.getNonOverlappingIndicesWith(reverseRange);

348

349

350 for (int64_t sourceDimIdx : nonMatchingIndices) {

351 if (sourceShape[sourceDimIdx] != 1)

352 return std::nullopt;

353 }

354 reassociationMap[targetDimIdx] = range.getFullIndices();

355 }

356 return reassociationMap;

357}

358

359std::optional<SmallVector>

365

366

367 if (producerReassociations.size() == consumerReassociations.size())

368 return std::nullopt;

369 if (producerReassociations.size() < consumerReassociations.size())

370 std::swap(producerReassociations, consumerReassociations);

371

372

373

374 if (consumerReassociations.empty())

375 return composedIndices;

376

377 size_t consumerDims =

378 llvm::accumulate(consumerReassociations, size_t(0),

380 return all + indices.size();

381 });

382 if (producerReassociations.size() != consumerDims)

383 return std::nullopt;

384

387 for (int64_t consumerIndex : consumerIndices) {

388 llvm::append_range(reassociations, producerReassociations[consumerIndex]);

389 }

390 composedIndices.push_back(std::move(reassociations));

391 }

392 return composedIndices;

393}

394

399 for (const auto &indices : reassociationIndices) {

401 reassociationMap.reserve(indices.size());

404 reassociationMaps.push_back(std::move(reassociationMap));

405 }

406 return reassociationMaps;

407}

408

409template

411 unsigned pos = 0;

412 for (const auto &exprs : exprArrays) {

413 for (auto expr : exprs) {

415 if (auto d = dyn_cast(e))

416 pos = std::max(pos, d.getPosition());

417 });

418 }

419 }

420 return pos;

421}

422

426 llvm::to_vector<4>(llvm::map_range(

428 return cast(b.getI64ArrayAttr(indices));

429 }));

430 return b.getArrayAttr(reassociationAttr);

431}

432

436 for (const auto &exprs : reassociationExprs) {

438 indices.reserve(exprs.size());

439 for (const auto &expr : exprs)

440 indices.push_back(cast(expr).getPosition());

441 reassociationIndices.push_back(indices);

442 }

443 return reassociationIndices;

444}

445

450 "Expected symbol-less expressions");

452 maps.reserve(reassociation.size());

453 for (const auto &exprs : reassociation) {

454 assert(!exprs.empty());

456 }

457 return maps;

458}

459

461 int *invalidIndex) {

462 if (reassociation.empty())

463 return true;

464 unsigned nDims = reassociation[0].getNumDims();

465 unsigned nextExpectedDim = 0;

466 for (const auto &it : llvm::enumerate(reassociation)) {

467 auto m = it.value();

468 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {

469 if (invalidIndex)

470 *invalidIndex = it.index();

471 return false;

472 }

473 for (auto e : m.getResults()) {

474 auto d = dyn_cast(e);

475 if (!d || d.getPosition() != nextExpectedDim++) {

476 if (invalidIndex)

477 *invalidIndex = it.index();

478 return false;

479 }

480 }

481 }

482 if (nextExpectedDim != nDims) {

483 if (invalidIndex)

484 *invalidIndex = reassociation.size() - 1;

485 return false;

486 }

487 return true;

488}

489

490LogicalResult mlir::reshapeLikeShapesAreCompatible(

494 unsigned expandedDimStart = 0;

495 for (const auto &map : llvm::enumerate(reassociationMaps)) {

496 bool foundDynamicShape = false;

497 int64_t linearizedStaticShape = 1;

498

499 for (const auto &dim : llvm::enumerate(

500 expandedShape.slice(expandedDimStart, map.value().size()))) {

501 if (ShapedType::isDynamic(dim.value()))

502 foundDynamicShape = true;

503 else

504 linearizedStaticShape *= dim.value();

505 }

506 if (foundDynamicShape) {

507 if (ShapedType::isStatic(collapsedShape[map.index()])) {

509 "expected dimension " + Twine(map.index()) +

510 " of collapsed type to be dynamic since one or more of the "

511 "corresponding dimensions in the expanded type is dynamic");

512 }

513 } else {

514 if (collapsedShape[map.index()] != linearizedStaticShape) {

515 return emitError("expected dimension " + Twine(map.index()) +

516 " of collapsed type to be static value of " +

517 Twine(linearizedStaticShape));

518 }

519 }

520 expandedDimStart += map.value().size();

521 }

523}

524

525bool mlir::hasNonIdentityLayout(Type type) {

526 if (auto memrefType = dyn_cast(type))

527 return !memrefType.getLayout().isIdentity();

528 return false;

529}

530

531llvm::SmallBitVector

534 assert(sliceParams.size() == sliceInputShape.size() &&

535 "only supports non rank-reducing case");

536 llvm::SmallBitVector mask(sliceInputShape.size());

537 unsigned idx = 0;

538 for (const auto &[offset, size, stride] : sliceParams) {

542 (!strideConst || *strideConst != 1) ||

543 (!offsetConst || *offsetConst != 0);

544 idx++;

545 }

546 return mask;

547}

548

551 llvm::SmallBitVector result(reassociationIndices.size());

552 for (const auto &it : llvm::enumerate(reassociationIndices))

553 result[it.index()] = it.value().size() > 1;

555}

556

559 unsigned loopIdx = 0;

560 auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);

561 auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);

563 offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());

564 for (const auto &it : llvm::enumerate(reassociationIndices)) {

565

566

567

568 if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {

569 llvm::append_range(

570 offsetsSizesAndStrides,

571 llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range {

573 }));

574 continue;

575 }

576

577

578

579

580 if (linearizedDimensions[it.index()]) {

581 llvm::append_range(offsetsSizesAndStrides,

582 llvm::map_range(it.value(), [&](int64_t idx) -> Range {

583 return {zeroAttr, collapseShapeInputShape[idx],

584 oneAttr};

585 }));

586 continue;

587 }

588

589

590 offsetsSizesAndStrides.push_back(sliceParams[it.index()]);

591 }

592 return offsetsSizesAndStrides;

593}

594

596SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,

598 auto one = IntegerAttr::get(IndexType::get(ctx), 1);

599 auto zero = IntegerAttr::get(IndexType::get(ctx), 0);

601 insertParams.reserve(linearizedDimensions.size());

602 unsigned loopIdx = 0;

603 for (unsigned i = 0; i < linearizedDimensions.size(); i++) {

604 if (linearizedDimensions[i] && slicedDimensions[i]) {

605 insertParams.push_back(Range{tileIndices[loopIdx++], one, one});

606 continue;

607 }

608 insertParams.push_back(Range{zero, sliceParams[i].size, one});

609 }

610 return insertParams;

611}

612

613

614

615

618

619 std::optional<int64_t> dimIndex;

621 return std::nullopt;

623 if (shape[idx] != 1) {

624 if (dimIndex != std::nullopt)

625 return std::nullopt;

626 dimIndex = idx;

627 }

628 }

629 return dimIndex;

630}

631

632

633

634

636 RankedTensorType sourceType,

639 for (const auto &indices : reassociationIndices)

640 trivialSegments.push_back(

642 return trivialSegments;

643}

644

645

646

647static FailureOr<SmallVector<std::optional<int64_t>>>

649 RankedTensorType sourceType,

653 if (!llvm::any_of(trivialSegments, [](const std::optional<int64_t> &idx) {

654 return idx.has_value();

655 }))

656 return failure();

657 return trivialSegments;

658}

659

660FailureOr

661mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(

662 RankedTensorType sourceType,

664 FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments =

666 reassociationIndices);

667 if (failed(trivialSegments))

668 return failure();

669

670

672 for (const auto &[nonUnitDim, indices] :

673 llvm::zip(*trivialSegments, reassociationIndices)) {

674 if (nonUnitDim) {

675 sliceShape.push_back(sourceType.getDimSize(*nonUnitDim));

676 continue;

677 }

678 llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {

679 return sourceType.getDimSize(idx);

680 }));

681 }

682 auto sliceType =

683 RankedTensorType::get(sliceShape, sourceType.getElementType());

684

685

686 if (sliceShape.size() == reassociationIndices.size())

687 return CollapseShapeRankReducingSliceSimplificationInfo{sliceType,

688 std::nullopt};

689

690

691

692

696 for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {

697 reassociation.push_back(dimIdx);

698 if ((*trivialSegments)[groupIdx] ||

699 reassociation.size() == reassociationIndices[groupIdx].size()) {

700 newReassociationIndices.push_back(reassociation);

701 reassociation.clear();

702 groupIdx++;

703 }

704 }

705

706 return CollapseShapeRankReducingSliceSimplificationInfo{

707 sliceType, newReassociationIndices};

708}

709

710PackingMetadata mlir::computePackingMetadata(int64_t packedRank,

712 PackingMetadata res;

713 res.insertPositions.reserve(innerDimPos.size());

714

715

716

717

718

719

720

721

722

723

724

725

726

727

729 for (int64_t pos : innerDimPos) {

730 int64_t numInsertedBefore = llvm::count_if(

731 innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });

732 res.insertPositions.push_back(pos + numInsertedBefore + offset);

733 }

734

736 res.insertPositions.end());

737 res.reassociations.reserve(packedRank);

738 for (int64_t i = 1; i <= packedRank; ++i) {

739 res.outerPositions.push_back(i - 1);

740 if (!posSet.contains(i)) {

742 continue;

743 }

745 ++i;

746 }

747 return res;

748}

749

752 std::optional cst) {

753 if (source && source.isSplat() && result.hasStaticShape() &&

756

757 return {};

758}

b

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...

static FailureOr< ReassociationIndexRange > findReassociationRangeForDynamicDim(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, bool matchGreedily=false)

Starting from sourceStartIdx, searches sourceShape for the first sequence that can be collapsed into ...

Definition ReshapeOpsUtils.cpp:102

static SmallVector< std::optional< int64_t > > getCollapseShapeTrivialSegments(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)

Definition ReshapeOpsUtils.cpp:635

static std::optional< int64_t > getUniqueNonUnitDim(ArrayRef< int64_t > indices, ArrayRef< int64_t > shape)

Returns the index of the only non-unit dimension among indices of shape, if such a dimension exists a...

Definition ReshapeOpsUtils.cpp:616

static unsigned getMaxPosOfType(ArrayRef< ReassociationExprs > exprArrays)

Definition ReshapeOpsUtils.cpp:410

static FailureOr< SmallVector< std::optional< int64_t > > > canCollapseShapeBeSimplifiedByRankReducingSlice(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)

Returns true if any of the segments of the reassociation indices for a collapsing reshape can be simp...

Definition ReshapeOpsUtils.cpp:648

static FailureOr< ReassociationIndexRange > findReassociationRangeForSize(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, int64_t targetSize, bool matchGreedily=false)

Starting from sourceStartIdx, searches sourceShape for the first sequence of static dimensions such t...

Definition ReshapeOpsUtils.cpp:131

static FailureOr< SmallVector< ReassociationIndexRange > > findReassociationRangesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)

Attempts to find a valid collapsing reassociation of sourceShape into targetShape through a simple tr...

Definition ReshapeOpsUtils.cpp:201

Base type for affine expression.

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

Attributes are known-constant values of operations.

This class is a general helper class for creating context-global objects like types,...

An attribute that represents a reference to a dense vector or tensor object.

DenseElementsAttr resizeSplat(ShapedType newType)

Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...

std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const

Return the splat value for this attribute.

bool isSplat() const

Returns true if this attribute corresponds to a splat, i.e.

MLIRContext is the top-level object for a collection of MLIR operations.

This class represents a single result from folding an operation.

Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Include the generated interface declarations.

llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)

The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...

Definition ReshapeOpsUtils.cpp:532

ArrayRef< int64_t > ReassociationIndicesRef

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)

Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.

llvm::DenseSet< ValueT, ValueInfoT > DenseSet

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)

Constructs affine maps out of Array<Array>.

Definition ReshapeOpsUtils.cpp:447

SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)

Convert Array<Array> to Array<Array<int64_t>>.

Definition ReshapeOpsUtils.cpp:433

std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)

Return the reassociations maps to use to reshape given the source type and the target type when possi...

Definition ReshapeOpsUtils.cpp:23

std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)

Returns the reassociation maps to collapse sourceShape to targetShape if possible.

Definition ReshapeOpsUtils.cpp:291

SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)

Convert reassociation indices to affine expressions.

Definition ReshapeOpsUtils.cpp:396

bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)

Return true if the reassociation specification is valid, false otherwise.

Definition ReshapeOpsUtils.cpp:460

std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)

Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...

Definition ReshapeOpsUtils.cpp:360

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)

Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...

Definition ReshapeOpsUtils.cpp:549

SmallVector< int64_t, 2 > ReassociationIndices

AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)

These free functions allow clients of the API to not use classes in detail.

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)

Wraps a list of reassociations in an ArrayAttr.

Definition ReshapeOpsUtils.cpp:423

llvm::function_ref< Fn > function_ref

Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...