MLIR: lib/Dialect/Mesh/Transforms/Spmdization.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

31 #include "llvm/ADT/APInt.h"

32 #include "llvm/ADT/DenseSet.h"

33 #include "llvm/ADT/STLExtras.h"

34 #include "llvm/ADT/SmallVector.h"

35 #include "llvm/Support/Casting.h"

36 #include

37 #include

38 #include

39 #include <type_traits>

40

42

43 template <typename SourceAxes, typename TargetAxes>

45 const TargetAxes &targetAxes) {

46 return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {

47 return sourceAxes.contains(targetAxis);

48 });

49 }

50

51

52

53

54

55

56

57 static std::tuple<TypedValue, MeshSharding>

64 return {sourceShard, sourceSharding};

65 }

69 using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;

70 using AxisSet = llvm::SmallDenseSet;

71 AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),

73 AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),

76 targetShardingPartialAxesSet));

78 llvm::copy_if(sourceShardingPartialAxesSet,

79 std::back_inserter(allReduceMeshAxes),

80 [&targetShardingPartialAxesSet](Axis a) {

81 return !targetShardingPartialAxesSet.contains(a);

82 });

83 if (allReduceMeshAxes.empty()) {

84 return {sourceShard, sourceSharding};

85 }

86

89 builder

90 .create(sourceShard.getLoc(), sourceShard.getType(),

91 sourceSharding.getMeshAttr().getLeafReference(),

92 allReduceMeshAxes, sourceShard,

94 .getResult());

95

97 llvm::copy_if(sourceShardingPartialAxesSet,

98 std::back_inserter(allReduceMeshAxes),

99 [&targetShardingPartialAxesSet](Axis a) {

100 return targetShardingPartialAxesSet.contains(a);

101 });

104 remainingPartialAxes, sourceSharding.getPartialType());

105 return {resultValue, resultSharding};

106 }

107

110 int64_t splitTensorAxis,

113 llvm::to_vector(sourceSharding.getSplitAxes());

114 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=

115 splitTensorAxis) {

117 }

118 auto targetSplitAxes =

119 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());

120 targetSplitAxes.push_back(splitMeshAxis);

121 targetShardingSplitAxes[splitTensorAxis] =

124 sourceSharding.getMeshAttr(), targetShardingSplitAxes,

126 }

127

128

129

130

131 static std::tuple<TypedValue, MeshSharding>

135 int64_t splitTensorAxis, MeshAxis splitMeshAxis) {

137 builder

138 .create(sourceShard, mesh,

140 splitTensorAxis)

141 .getResult());

143 builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);

144 return {targetShard, targetSharding};

145 }

146

147

148

149

150

151

152 static std::optional<std::tuple<int64_t, MeshAxis>>

155 for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();

156 ++tensorAxis) {

157 if (sourceSharding.getSplitAxes().size() > tensorAxis) {

158 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=

159 targetSharding.getSplitAxes()[tensorAxis].size()) {

160 continue;

161 }

162 if (!llvm::equal(

163 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),

164 llvm::make_range(

166 .asArrayRef()

167 .begin(),

168 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -

169 1))) {

170 continue;

171 }

172 } else {

173 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {

174 continue;

175 }

176 }

177 return std::make_tuple(

178 tensorAxis,

179 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());

180 }

181 return std::nullopt;

182 }

183

184 static std::optional<std::tuple<TypedValue, MeshSharding>>

189 if (auto detectRes =

191 auto [tensorAxis, meshAxis] = detectRes.value();

193 tensorAxis, meshAxis);

194 }

195

196 return std::nullopt;

197 }

198

199

200

201

202 static std::optional<std::tuple<int64_t, MeshAxis>>

205 for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();

206 ++tensorAxis) {

207 if (targetSharding.getSplitAxes().size() > tensorAxis) {

208 if (sourceSharding.getSplitAxes()[tensorAxis].size() !=

209 targetSharding.getSplitAxes()[tensorAxis].size() + 1)

210 continue;

211 if (!llvm::equal(

212 llvm::make_range(

214 .asArrayRef()

215 .begin(),

216 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -

217 1),

218 targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))

219 continue;

220 } else {

221 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)

222 continue;

223 }

224 return std::make_tuple(

225 tensorAxis,

226 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());

227 }

228 return std::nullopt;

229 }

230

233 int64_t splitTensorAxis) {

235 llvm::to_vector(sourceSharding.getSplitAxes());

236 assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >

237 splitTensorAxis);

238 auto targetSplitAxes =

239 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());

240

241 targetSplitAxes.pop_back();

242 targetShardingSplitAxes[splitTensorAxis] =

245 sourceSharding.getMeshAttr(), targetShardingSplitAxes,

247 }

248

250 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {

252 targetShape[splitTensorAxis] =

253 gatherDimension(targetShape[splitTensorAxis], splitCount);

254 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());

255 }

256

257 static std::tuple<TypedValue, MeshSharding>

260 ShapedType sourceUnshardedShape,

262 int64_t splitTensorAxis, MeshAxis splitMeshAxis) {

265

269 sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);

270 Value allGatherResult = builder.create(

272 allGatherResultShape.getElementType()),

274 APInt(64, splitTensorAxis));

275 ShapedType targetShape =

276 shardShapedType(sourceUnshardedShape, mesh, targetSharding);

278 builder.createtensor::CastOp(targetShape, allGatherResult).getResult());

279 return {targetShard, targetSharding};

280 }

281

282 static std::optional<std::tuple<TypedValue, MeshSharding>>

286 ShapedType sourceUnshardedShape,

288 if (auto detectRes =

290 auto [tensorAxis, meshAxis] = detectRes.value();

292 sourceUnshardedShape, sourceShard, mesh,

293 tensorAxis, meshAxis);

294 }

295

296 return std::nullopt;

297 }

298

299

300

301

302

303

304 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>

307 for (size_t sourceTensorAxis = 0;

308 sourceTensorAxis < sourceSharding.getSplitAxes().size();

309 ++sourceTensorAxis) {

310 for (size_t targetTensorAxis = 0;

311 targetTensorAxis < targetSharding.getSplitAxes().size();

312 ++targetTensorAxis) {

313 if (sourceTensorAxis == targetTensorAxis)

314 continue;

315 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||

316 targetSharding.getSplitAxes()[targetTensorAxis].empty() ||

317 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=

318 targetSharding.getSplitAxes()[targetTensorAxis]

319 .asArrayRef()

320 .back())

321 continue;

322 if (!llvm::equal(

323 llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]

324 .asArrayRef()

325 .begin(),

326 sourceSharding.getSplitAxes()[sourceTensorAxis]

327 .asArrayRef()

328 .end() -

329 1),

330 llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]

331 .asArrayRef()

332 .begin(),

333 targetSharding.getSplitAxes()[targetTensorAxis]

334 .asArrayRef()

335 .end() -

336 1)))

337 continue;

338 return std::make_tuple(

339 sourceTensorAxis, targetTensorAxis,

340 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());

341 }

342 }

343 return std::nullopt;

344 }

345

348 int64_t sourceTensorAxis,

349 int64_t targetTensorAxis) {

351 llvm::to_vector(sourceSharding.getSplitAxes());

352 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=

353 targetTensorAxis) {

355 }

356

357 auto sourceSplitAxes =

358 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());

359 assert(!sourceSplitAxes.empty());

360 auto meshAxis = sourceSplitAxes.back();

361 sourceSplitAxes.pop_back();

362 targetShardingSplitAxes[sourceTensorAxis] =

364

365 auto targetSplitAxes =

366 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());

367 targetSplitAxes.push_back(meshAxis);

368 targetShardingSplitAxes[targetTensorAxis] =

370

372 sourceSharding.getMeshAttr(), targetShardingSplitAxes,

374 }

375

377 int64_t splitCount,

378 int64_t sourceTensorAxis,

379 int64_t targetTensorAxis) {

381 targetShape[sourceTensorAxis] =

382 gatherDimension(targetShape[sourceTensorAxis], splitCount);

383 targetShape[targetTensorAxis] =

384 shardDimension(targetShape[targetTensorAxis], splitCount);

385 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());

386 }

387

388 static std::tuple<TypedValue, MeshSharding>

391 ShapedType sourceUnshardedShape,

393 int64_t sourceTensorAxis,

394 int64_t targetTensorAxis, MeshAxis meshAxis) {

397

399 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);

401 sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,

402 targetTensorAxis);

403 Value allToAllResult = builder.create(

405 allToAllResultShape.getElementType()),

407 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));

408 ShapedType targetShape =

409 shardShapedType(sourceUnshardedShape, mesh, targetSharding);

411 builder.createtensor::CastOp(targetShape, allToAllResult).getResult());

412 return {targetShard, targetSharding};

413 }

414

415 static std::optional<std::tuple<TypedValue, MeshSharding>>

419 ShapedType sourceUnshardedShape,

421 if (auto detectRes =

423 auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();

425 builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,

426 sourceTensorAxis, targetTensorAxis, meshAxis);

427 }

428

429 return std::nullopt;

430 }

431

432

433

434

435

436 static std::optional<std::tuple<TypedValue, MeshSharding>>

440 ShapedType sourceUnshardedShape,

442

443

450 return std::nullopt;

451 }

452

455 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());

456 assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&

457 !ShapedType::isDynamicShape(tgtHaloSizes) &&

458 sourceShard.getType().hasStaticShape()) &&

459 "dynamic shapes/halos are not supported yet for mesh-spmdization");

460 auto rank = sourceShard.getType().getRank();

461 auto splitAxes = sourceSharding.getSplitAxes();

463 strides(rank, 1), outShape(sourceShard.getType().getShape()),

464 coreShape(sourceShard.getType().getShape());

465

466

467

468 for (auto i = 0u; i < rank; ++i) {

469 if (i < splitAxes.size() && !splitAxes[i].empty()) {

470 if (!srcHaloSizes.empty()) {

471 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];

472 srcCoreOffs[i] = srcHaloSizes[i * 2];

473 }

474 tgtCoreOffs[i] = tgtHaloSizes[i * 2];

475 outShape[i] =

476 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];

477 }

478 }

479

480

482 auto initVal = builder.createtensor::EmptyOp(

483 sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());

484 auto core = builder.createtensor::ExtractSliceOp(

485 sourceShard.getLoc(),

487 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);

488 auto initOprnd = builder.createtensor::InsertSliceOp(

489 sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,

490 coreShape, strides);

491

492

493 auto updateHaloResult =

494 builder

495 .create(

496 sourceShard.getLoc(),

498 sourceShard.getType().getElementType()),

499 initOprnd, mesh.getSymName(),

504 .getResult();

506 targetSharding);

507 }

508

509

510

511

517 assert(sourceShard.getType() ==

518 shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));

519 [[maybe_unused]] ShapedType targetShardType =

520 shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);

521 assert(sourceShard.getType().getRank() == targetShardType.getRank());

522 assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");

523

524 auto [reducedSourceShard, reducedSourceSharding] =

526 sourceShard);

527

528 if (reducedSourceSharding == targetSharding) {

529 return reducedSourceShard;

530 }

531

534 if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&

536 reducedSourceSharding.getStaticHaloSizes().empty() &&

539 builder, mesh, reducedSourceSharding, targetSharding,

540 sourceUnshardedValue.getType(), reducedSourceShard)) {

541 std::tie(targetShard, actualTargetSharding) = tryRes.value();

543 builder, mesh, reducedSourceSharding, targetSharding,

544 reducedSourceShard)) {

545 std::tie(targetShard, actualTargetSharding) = tryRes.value();

547 builder, mesh, reducedSourceSharding, targetSharding,

548 sourceUnshardedValue.getType(), reducedSourceShard)) {

549 std::tie(targetShard, actualTargetSharding) = tryRes.value();

550 }

551 }

552 assert(targetShard && "Did not find any pattern to apply.");

553 assert(actualTargetSharding == targetSharding);

554 assert(targetShard.getType() == targetShardType);

555 return targetShard;

556 }

557

563

564 if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&

566 return sourceShard;

567 }

568

569

570

572 builder, mesh, sourceSharding, targetSharding,

573 sourceUnshardedValue.getType(), sourceShard)) {

574 return std::get<0>(tryRes.value());

575 }

576

577

578

579

580 return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,

581 sourceUnshardedValue, sourceShard);

582 }

583

585 ShardOp target,

587 assert(source.getResult() == target.getSrc());

588 auto sourceSharding = source.getSharding();

589 auto targetSharding = target.getSharding();

591 return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,

593 sourceShardValue);

594 }

595

597 ShardOp target,

600 MeshOp srcMesh = getMesh(source, symbolTableCollection);

601 assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));

602 return reshard(builder, srcMesh, source, target, sourceShardValue);

603 }

604

606 registry.insert<mesh::MeshDialect, tensor::TensorDialect>();

607 }

608

609 #define GEN_PASS_DEF_SPMDIZATION

610 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"

611

613

614

615

616

621 llvm::transform(

622 block.getArguments(), std::back_inserter(res),

624 auto rankedTensorArg = dyn_cast<TypedValue>(arg);

625 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {

626 return arg.getType();

627 }

628

629 assert(rankedTensorArg.hasOneUse());

631 ShardOp shardOp = llvm::dyn_cast(useOp);

632 assert(shardOp);

633 MeshOp mesh = getMesh(shardOp, symbolTableCollection);

634 return cast(shardShapedType(rankedTensorArg.getType(), mesh,

635 shardOp.getSharding()));

636 });

637 return res;

638 }

639

645 ShardingInterface shardingInterface = llvm::dyn_cast(op);

646 if (!shardingInterface) {

647

648

650 resultShardings, spmdizationMap,

651 symbolTableCollection, builder);

652 } else {

653 if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,

654 resultShardings, spmdizationMap,

655 symbolTableCollection, builder))) {

656 return failure();

657 }

658 }

659

660 assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {

661 return spmdizationMap.contains(result);

662 }));

663

664 return success();

665 }

666

667

668

670 std::vector res;

672 llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {

673 TypedValue rankedTensor =

674 dyn_cast<TypedValue>(operand);

675 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {

676 return MeshSharding();

677 }

678

680 assert(definingOp);

681 ShardOp shardOp = llvm::cast(definingOp);

683 });

684 return res;

685 }

686

687

688

690 std::vector res;

692 llvm::transform(

694 if (!result.hasOneUse() || result.use_empty()) {

695 return MeshSharding();

696 }

699 if (!rankedTensor) {

701 }

703 ShardOp shardOp = llvm::dyn_cast(userOp);

704 if (shardOp) {

706 }

707 if (rankedTensor.getType().getRank() == 0) {

708

709

710

712 if (auto sharding = operand.getDefiningOp()) {

714 }

715 }

716 }

718 });

719 return res;

720 }

721

722 static LogicalResult

726 Value targetSpmdValue;

727

728

729

730 ShardOp srcShardOp =

731 dyn_cast_or_null(shardOp.getSrc().getDefiningOp());

732 if (!srcShardOp) {

733 targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());

734 } else {

735

737 cast<TypedValue>(spmdizationMap.lookup(srcShardOp));

738 targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,

739 symbolTableCollection);

740 }

741

742 assert(!spmdizationMap.contains(shardOp.getResult()));

743 spmdizationMap.map(shardOp.getResult(), targetSpmdValue);

744 return success();

745 }

746

747 static LogicalResult

751 if (isa(op)) {

752 return success();

753 }

754 if (auto getShardingOp = dyn_cast(op)) {

755 auto shardOp = getShardingOp.getSource().getDefiningOp();

756 if (!shardOp) {

757 return op.emitError("expected a shard op as source of get_sharding");

758 }

759 auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());

760 spmdizationMap.map(op.getResult(0), newSharding->getResult(0));

761 return success();

762 }

763

764 ShardOp shardOp = llvm::dyn_cast(op);

765 if (shardOp) {

766 return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,

767 builder);

768 }

769

771 llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),

772 [&spmdizationMap](Value operand) {

773 assert(spmdizationMap.contains(operand));

774 return spmdizationMap.lookup(operand);

775 });

778 symbolTableCollection, builder);

779 }

780

784

786 llvm::transform(block.getArguments(), std::back_inserter(argLocations),

791 for (auto [unshardedBlockArg, spmdizedBlockArg] :

793 spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);

794 }

795

799 if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,

800 builder))) {

801 return failure();

802 }

803 }

804

805 return success();

806 }

807

808 static LogicalResult

811 OpBuilder builder(op.getFunctionBody());

812

813

814

816 for (Block &b : op.getBlocks()) {

817 if (llvm::any_of(b.getOperations(),

818 [](Operation &op) { return isa(op); })) {

819 originalBlocks.push_back(&b);

820 }

821 }

822

823 for (Block *block : originalBlocks) {

824 if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,

825 builder))) {

826 return failure();

827 }

828 }

829

830 for (Block *block : originalBlocks) {

831 block->erase();

832 }

833

834

835

837 for (Block &block : op.getFunctionBody()) {

838 if (block.empty()) {

839 continue;

840 }

841

843 returnOp = &block.back();

844 break;

845 }

846 }

847 if (returnOp) {

849 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),

851 }

852

853 return success();

854 }

855

856 namespace {

857

858 struct Spmdization : public impl::SpmdizationBase {

859 void runOnOperation() override {

862 if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,

863 symbolTableCollection))) {

864 return signalPassFailure();

865 }

866 }

867

868 void getDependentDialects(DialectRegistry &registry) const override {

870 registry.insertmesh::MeshDialect();

871 }

872 };

873

874 }

875

876 }

This class represents an argument of a Block.

Block represents an ordered list of Operations.

Region * getParent() const

Provide a 'getParent' method for ilist_node_with_parent methods.

OpListType & getOperations()

BlockArgListType getArguments()

MLIRContext * getContext() const

The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.

This is a utility class for mapping one set of IR entities to another.

auto lookup(T from) const

Lookup a mapped value within the map.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

bool contains(T from) const

Checks to see if a mapping for 'from' exists.

ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...

OpTy create(Args &&...args)

Create an operation of specific op type at the current insertion point and location.

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

void setInsertionPointAfterValue(Value val)

Sets the insertion point to the node after the specified value.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This is a value defined by a result of an operation.

Operation is the basic unit of execution within MLIR.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

unsigned getNumOperands()

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

operand_type_range getOperandTypes()

operand_range getOperands()

Returns an iterator on the underlying Value's.

user_range getUsers()

Returns a range of all users.

result_range getResults()

unsigned getNumResults()

Return the number of results held by this operation.

This class represents a collection of SymbolTables.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

user_range getUsers() const

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)

Builder from ArrayRef.

bool equalSplitAndPartialAxes(const MeshSharding &rhs) const

ArrayRef< int64_t > getStaticShardedDimsOffsets() const

::mlir::FlatSymbolRefAttr getMeshAttr() const

bool equalHaloSizes(const MeshSharding &rhs) const

ArrayRef< MeshAxesAttr > getSplitAxes() const

ReductionKind getPartialType() const

ArrayRef< MeshAxis > getPartialAxes() const

ArrayRef< Value > getDynamicHaloSizes() const

ArrayRef< int64_t > getStaticHaloSizes() const

static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})

mesh::MeshSharding MeshSharding

static LogicalResult spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)

static std::tuple< TypedValue< ShapedType >, MeshSharding > handlePartialAxesDuringResharding(OpBuilder &builder, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)

SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)

static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)

void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)

static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)

static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)

static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)

static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)

static std::vector< MeshSharding > getResultShardings(Operation &op)

static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)

static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)

int64_t gatherDimension(int64_t dimSize, int64_t shardCount)

static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis)

static std::tuple< TypedValue< ShapedType >, MeshSharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)

ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)

mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)

int64_t shardDimension(int64_t dimSize, int64_t shardCount)

TypedValue< ShapedType > reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)

bool isFullReplication(MeshSharding sharding)

static TypedValue< ShapedType > reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)

static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)

void reshardingRegisterDependentDialects(DialectRegistry &registry)

static std::tuple< TypedValue< ShapedType >, MeshSharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)

static std::vector< MeshSharding > getOperandShardings(Operation &op)

static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)

static LogicalResult spmdizeOperation(Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)

static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)

static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)

static std::tuple< TypedValue< ShapedType >, MeshSharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)

TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)

static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)

static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)

std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue

If Ty is mlir::Type this will select Value instead of having a wrapper around it.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

This trait indicates that a terminator operation is "return-like".