MLIR: lib/Dialect/Shard/Transforms/Partition.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

29#include "llvm/ADT/STLExtras.h"

30#include "llvm/ADT/SmallVector.h"

31#include "llvm/Support/Casting.h"

32#include

33#include

34#include

35

37

38template <typename SourceAxes, typename TargetAxes>

40 const TargetAxes &targetAxes) {

41 return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {

42 return sourceAxes.contains(targetAxis);

43 });

44}

45

51 llvm::to_vector(sourceSharding.getSplitAxes());

52 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=

53 splitTensorAxis) {

55 }

56 auto targetSplitAxes =

57 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());

58 targetSplitAxes.push_back(splitGridAxis);

59 targetShardingSplitAxes[splitTensorAxis] =

62}

63

64

65

66

67static std::tuple<TypedValue, Sharding>

73 AllSliceOp::create(builder, sourceShard, grid,

75 .getResult();

77 builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);

78 return {targetShard, targetSharding};

79}

80

81

82

83

84

85

86static std::optional<std::tuple<int64_t, GridAxis>>

89 for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();

90 ++tensorAxis) {

91 if (sourceSharding.getSplitAxes().size() > tensorAxis) {

92 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=

93 targetSharding.getSplitAxes()[tensorAxis].size()) {

94 continue;

95 }

96 if (!llvm::equal(

97 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),

98 llvm::make_range(

100 .asArrayRef()

101 .begin(),

102 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -

103 1))) {

104 continue;

105 }

106 } else {

107 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {

108 continue;

109 }

110 }

111 return std::make_tuple(

112 tensorAxis,

113 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());

114 }

115 return std::nullopt;

116}

117

118static std::optional<std::tuple<TypedValue, Sharding>>

122 if (auto detectRes =

124 auto [tensorAxis, gridAxis] = detectRes.value();

126 tensorAxis, gridAxis);

127 }

128

129 return std::nullopt;

130}

131

132

133

134

135static std::optional<std::tuple<int64_t, GridAxis>>

138 for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();

139 ++tensorAxis) {

140 if (targetSharding.getSplitAxes().size() > tensorAxis) {

141 if (sourceSharding.getSplitAxes()[tensorAxis].size() !=

142 targetSharding.getSplitAxes()[tensorAxis].size() + 1)

143 continue;

144 if (!llvm::equal(

145 llvm::make_range(

147 .asArrayRef()

148 .begin(),

149 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -

150 1),

151 targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))

152 continue;

153 } else {

154 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)

155 continue;

156 }

157 return std::make_tuple(

158 tensorAxis,

159 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());

160 }

161 return std::nullopt;

162}

163

166 int64_t splitTensorAxis) {

168 llvm::to_vector(sourceSharding.getSplitAxes());

169 assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >

170 splitTensorAxis);

171 auto targetSplitAxes =

172 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());

173

174 targetSplitAxes.pop_back();

175 targetShardingSplitAxes[splitTensorAxis] =

178}

179

181 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {

183 targetShape[splitTensorAxis] =

184 gatherDimension(targetShape[splitTensorAxis], splitCount);

185 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());

186}

187

191 GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) {

194

198 sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);

199 Value allGatherResult = AllGatherOp::create(

200 builder,

201 RankedTensorType::get(allGatherResultShape.getShape(),

202 allGatherResultShape.getElementType()),

204 APInt(64, splitTensorAxis));

205 ShapedType targetShape =

206 shardShapedType(sourceUnshardedShape, grid, targetSharding);

208 tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();

209 return {targetShard, targetSharding};

210}

211

212static std::optional<std::tuple<TypedValue, Sharding>>

215 ShapedType sourceUnshardedShape,

217 if (auto detectRes =

219 auto [tensorAxis, gridAxis] = detectRes.value();

221 sourceUnshardedShape, sourceShard, grid,

222 tensorAxis, gridAxis);

223 }

224

225 return std::nullopt;

226}

227

228

229

230

231

232

233static std::optional<std::tuple<int64_t, int64_t, GridAxis>>

236 for (size_t sourceTensorAxis = 0;

237 sourceTensorAxis < sourceSharding.getSplitAxes().size();

238 ++sourceTensorAxis) {

239 for (size_t targetTensorAxis = 0;

240 targetTensorAxis < targetSharding.getSplitAxes().size();

241 ++targetTensorAxis) {

242 if (sourceTensorAxis == targetTensorAxis)

243 continue;

244 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||

245 targetSharding.getSplitAxes()[targetTensorAxis].empty() ||

246 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=

247 targetSharding.getSplitAxes()[targetTensorAxis]

248 .asArrayRef()

249 .back())

250 continue;

251 if (!llvm::equal(

252 llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]

253 .asArrayRef()

254 .begin(),

255 sourceSharding.getSplitAxes()[sourceTensorAxis]

256 .asArrayRef()

257 .end() -

258 1),

259 llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]

260 .asArrayRef()

261 .begin(),

262 targetSharding.getSplitAxes()[targetTensorAxis]

263 .asArrayRef()

264 .end() -

265 1)))

266 continue;

267 return std::make_tuple(

268 sourceTensorAxis, targetTensorAxis,

269 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());

270 }

271 }

272 return std::nullopt;

273}

274

277 int64_t sourceTensorAxis,

278 int64_t targetTensorAxis) {

280 llvm::to_vector(sourceSharding.getSplitAxes());

281 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=

282 targetTensorAxis) {

284 }

285

286 auto sourceSplitAxes =

287 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());

288 assert(!sourceSplitAxes.empty());

289 auto gridAxis = sourceSplitAxes.back();

290 sourceSplitAxes.pop_back();

291 targetShardingSplitAxes[sourceTensorAxis] =

293

294 auto targetSplitAxes =

295 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());

296 targetSplitAxes.push_back(gridAxis);

297 targetShardingSplitAxes[targetTensorAxis] =

299

301}

302

305 int64_t sourceTensorAxis,

306 int64_t targetTensorAxis) {

308 targetShape[sourceTensorAxis] =

309 gatherDimension(targetShape[sourceTensorAxis], splitCount);

310 targetShape[targetTensorAxis] =

311 shardDimension(targetShape[targetTensorAxis], splitCount);

312 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());

313}

314

315static std::tuple<TypedValue, Sharding>

318 ShapedType sourceUnshardedShape,

320 int64_t sourceTensorAxis,

324

326 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);

328 sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis,

329 targetTensorAxis);

330 Value allToAllResult = AllToAllOp::create(

331 builder,

332 RankedTensorType::get(allToAllResultShape.getShape(),

333 allToAllResultShape.getElementType()),

335 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));

336 ShapedType targetShape =

337 shardShapedType(sourceUnshardedShape, grid, targetSharding);

339 tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();

340 return {targetShard, targetSharding};

341}

342

343static std::optional<std::tuple<TypedValue, Sharding>>

347 ShapedType sourceUnshardedShape,

349 if (auto detectRes =

351 auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value();

353 builder, grid, sourceSharding, sourceUnshardedShape, sourceShard,

354 sourceTensorAxis, targetTensorAxis, gridAxis);

355 }

356

357 return std::nullopt;

358}

359

360

361

362

363

364static std::optional<std::tuple<TypedValue, Sharding>>

367 ShapedType sourceUnshardedShape,

369

370

371 if (!sourceSharding.equalSplitAxes(targetSharding) ||

375 return std::nullopt;

376 }

377

380 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());

381 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&

382 ShapedType::isStaticShape(tgtHaloSizes) &&

383 sourceShard.getType().hasStaticShape()) &&

384 "dynamic shapes/halos are not supported yet for shard-partition");

385 auto rank = sourceShard.getType().getRank();

386 auto splitAxes = sourceSharding.getSplitAxes();

388 strides(rank, 1), outShape(sourceShard.getType().getShape()),

389 coreShape(sourceShard.getType().getShape());

390

391

392

393 for (auto i = 0u; i < rank; ++i) {

394 if (i < splitAxes.size() && !splitAxes[i].empty()) {

395 if (!srcHaloSizes.empty()) {

396 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];

397 srcCoreOffs[i] = srcHaloSizes[i * 2];

398 }

399 tgtCoreOffs[i] = tgtHaloSizes[i * 2];

400 outShape[i] =

401 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];

402 }

403 }

404

405

407 auto initVal =

408 tensor::EmptyOp::create(builder, sourceShard.getLoc(), outShape,

409 sourceShard.getType().getElementType());

410 auto core = tensor::ExtractSliceOp::create(

411 builder, sourceShard.getLoc(),

412 RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),

413 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);

414 auto initOprnd = tensor::InsertSliceOp::create(

415 builder, sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,

416 tgtCoreOffs, coreShape, strides);

417

418

419 auto updateHaloResult =

420 UpdateHaloOp::create(

421 builder, sourceShard.getLoc(),

422 RankedTensorType::get(outShape,

423 sourceShard.getType().getElementType()),

424 initOprnd, grid.getSymName(),

425 GridAxesArrayAttr::get(builder.getContext(),

429 .getResult();

431 targetSharding);

432}

433

434

435

436

442 assert(sourceShard.getType() ==

443 shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));

444 [[maybe_unused]] ShapedType targetShardType =

445 shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding);

446 assert(sourceShard.getType().getRank() == targetShardType.getRank());

447 assert(grid.getRank() == 1 && "Only 1D grides are currently supported.");

448

449 if (sourceSharding == targetSharding) {

450 return sourceShard;

451 }

452

454 Sharding actualTargetSharding;

460 builder, grid, sourceSharding, targetSharding,

461 sourceUnshardedValue.getType(), sourceShard)) {

462 std::tie(targetShard, actualTargetSharding) = tryRes.value();

463 } else if (auto tryRes =

465 targetSharding, sourceShard)) {

466 std::tie(targetShard, actualTargetSharding) = tryRes.value();

468 builder, grid, sourceSharding, targetSharding,

469 sourceUnshardedValue.getType(), sourceShard)) {

470 std::tie(targetShard, actualTargetSharding) = tryRes.value();

471 }

472 }

473 assert(targetShard && "Did not find any pattern to apply.");

474 assert(actualTargetSharding == targetSharding);

475 assert(targetShard.getType() == targetShardType);

476 return targetShard;

477}

478

483

484 if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&

486 return sourceShard;

487 }

488

489

490

492 builder, grid, sourceSharding, targetSharding,

493 sourceUnshardedValue.getType(), sourceShard)) {

494 return std::get<0>(tryRes.value());

495 }

496

497

498

499

500 return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding,

501 sourceUnshardedValue, sourceShard);

502}

503

507 assert(source.getResult() == target.getSrc());

508 auto sourceSharding = source.getSharding();

509 auto targetSharding = target.getSharding();

511 return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,

512 source.getSrc(), sourceShardValue);

513}

514

519 GridOp srcGrid = getGrid(source, symbolTableCollection);

520 assert(srcGrid && srcGrid == getGrid(target, symbolTableCollection));

521 return reshard(builder, srcGrid, source, target, sourceShardValue);

522}

523

525 registry.insert<shard::ShardDialect, tensor::TensorDialect>();

526}

527

528#define GEN_PASS_DEF_PARTITION

529#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"

530

532

533

534

535

540 llvm::transform(

541 block.getArguments(), std::back_inserter(res),

543 auto rankedTensorArg = dyn_cast<TypedValue>(arg);

544 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {

545 return arg.getType();

546 }

547

548 assert(rankedTensorArg.hasOneUse());

550 ShardOp shardOp = llvm::dyn_cast(useOp);

551 assert(shardOp);

552 GridOp grid = getGrid(shardOp, symbolTableCollection);

553 return cast(shardShapedType(rankedTensorArg.getType(), grid,

554 shardOp.getSharding()));

555 });

556 return res;

557}

558

559static LogicalResult

565 ShardingInterface shardingInterface = llvm::dyn_cast(op);

566 if (!shardingInterface) {

567

568

570 resultShardings, partitionMap,

571 symbolTableCollection, builder);

572 } else {

573 if (failed(shardingInterface.partition(

574 partitionedOperands, operandShardings, resultShardings,

575 partitionMap, symbolTableCollection, builder))) {

576 return failure();

577 }

578 }

579

581 return partitionMap.contains(result);

582 }));

583

585}

586

587

588

590 std::vector res;

592 llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {

593 TypedValue rankedTensor =

594 dyn_cast<TypedValue>(operand);

595 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {

596 return Sharding();

597 }

598

600 assert(definingOp);

601 ShardOp shardOp = llvm::cast(definingOp);

602 return Sharding(shardOp.getSharding());

603 });

604 return res;

605}

606

607

608

610 std::vector res;

612 llvm::transform(

614 if (!result.hasOneUse() || result.use_empty()) {

615 return Sharding();

616 }

619 if (!rankedTensor) {

621 }

623 ShardOp shardOp = llvm::dyn_cast(userOp);

624 if (shardOp) {

625 return Sharding(shardOp.getSharding());

626 }

627 if (rankedTensor.getType().getRank() == 0) {

628

629

630

632 if (auto sharding = operand.getDefiningOp()) {

633 return Sharding(sharding.getGridAttr());

634 }

635 }

636 }

638 });

639 return res;

640}

641

642static LogicalResult

646 Value targetPartitionValue;

647

648

649

650 ShardOp srcShardOp = shardOp.getSrc().getDefiningOp();

651 if (!srcShardOp) {

652 targetPartitionValue = partitionMap.lookup(shardOp.getSrc());

653 } else {

654

656 cast<TypedValue>(partitionMap.lookup(srcShardOp));

657 targetPartitionValue = reshard(builder, srcShardOp, shardOp,

658 srcPartitionValue, symbolTableCollection);

659 }

660

661 assert(!partitionMap.contains(shardOp.getResult()));

662 partitionMap.map(shardOp.getResult(), targetPartitionValue);

664}

665

666static LogicalResult

670 if (isa(op)) {

672 }

673 if (auto getShardingOp = dyn_cast(op)) {

674 auto shardOp = getShardingOp.getSource().getDefiningOp();

675 if (!shardOp) {

676 return op.emitError("expected a shard op as source of get_sharding");

677 }

678 auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());

679 partitionMap.map(op.getResult(0), newSharding->getResult(0));

681 }

682

683 ShardOp shardOp = llvm::dyn_cast(op);

684 if (shardOp) {

685 return partitionOperation(shardOp, partitionMap, symbolTableCollection,

686 builder);

687 }

688

690 llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),

691 [&partitionMap](Value operand) {

692 assert(partitionMap.contains(operand));

693 return partitionMap.lookup(operand);

694 });

697 symbolTableCollection, builder);

698}

699

700static LogicalResult

704

706 llvm::transform(block.getArguments(), std::back_inserter(argLocations),

711 for (auto [unshardedBlockArg, partitionedBlockArg] :

713 partitionMap.map(unshardedBlockArg, partitionedBlockArg);

714 }

715

719 if (failed(partitionOperation(op, partitionMap, symbolTableCollection,

720 builder))) {

721 return failure();

722 }

723 }

724

726}

727

728static LogicalResult

731 OpBuilder builder(op.getFunctionBody());

732

733

734

736 for (Block &b : op.getBlocks()) {

737 if (llvm::any_of(b.getOperations(),

738 [](Operation &op) { return isa(op); })) {

739 originalBlocks.push_back(&b);

740 }

741 }

742

743 for (Block *block : originalBlocks) {

744 if (failed(partitionBlock(*block, partitionMap, symbolTableCollection,

745 builder))) {

746 return failure();

747 }

748 }

749

750 for (Block *block : originalBlocks) {

751 block->erase();

752 }

753

754

755

757 for (Block &block : op.getFunctionBody()) {

758 if (block.empty()) {

759 continue;

760 }

761

763 returnOp = &block.back();

764 break;

765 }

766 }

767 if (returnOp) {

768 op.setType(FunctionType::get(

769 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),

771 }

772

774}

775

776namespace {

777

778struct Partition : public impl::PartitionBase {

779 void runOnOperation() override {

782 if (failed(partitionFuncOp(getOperation(), partitionMap,

783 symbolTableCollection))) {

784 return signalPassFailure();

785 }

786 }

787

788 void getDependentDialects(DialectRegistry &registry) const override {

790 registry.insertshard::ShardDialect();

791 }

792};

793

794}

795

796}

b

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...

This class represents an argument of a Block.

Block represents an ordered list of Operations.

Region * getParent() const

Provide a 'getParent' method for ilist_node_with_parent methods.

OpListType & getOperations()

BlockArgListType getArguments()

MLIRContext * getContext() const

The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.

This is a utility class for mapping one set of IR entities to another.

auto lookup(T from) const

Lookup a mapped value within the map.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

bool contains(T from) const

Checks to see if a mapping for 'from' exists.

ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...

Location getLoc() const

Accessors for the implied location.

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

void setInsertionPointAfterValue(Value val)

Sets the insertion point to the node after the specified value.

This is a value defined by a result of an operation.

Operation is the basic unit of execution within MLIR.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

unsigned getNumOperands()

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

operand_type_range getOperandTypes()

operand_range getOperands()

Returns an iterator on the underlying Value's.

user_range getUsers()

Returns a range of all users.

result_range getResults()

unsigned getNumResults()

Return the number of results held by this operation.

This class represents a collection of SymbolTables.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)

static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})

bool equalSplitAxes(const Sharding &rhs) const

ArrayRef< int64_t > getStaticHaloSizes() const

::mlir::FlatSymbolRefAttr getGridAttr() const

ArrayRef< Value > getDynamicHaloSizes() const

ArrayRef< int64_t > getStaticShardedDimsOffsets() const

ArrayRef< GridAxesAttr > getSplitAxes() const

bool equalHaloSizes(const Sharding &rhs) const

static std::optional< std::tuple< int64_t, int64_t, GridAxis > > detectMoveLastSplitAxisInResharding(Sharding sourceSharding, Sharding targetSharding)

Definition Partition.cpp:234

static std::tuple< TypedValue< ShapedType >, Sharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)

Definition Partition.cpp:68

ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)

static std::tuple< TypedValue< ShapedType >, Sharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis)

Definition Partition.cpp:316

static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)

Definition Partition.cpp:180

void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)

static std::tuple< TypedValue< ShapedType >, Sharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)

Definition Partition.cpp:188

static std::optional< std::tuple< int64_t, GridAxis > > detectUnsplitLastAxisInResharding(Sharding sourceSharding, Sharding targetSharding)

Definition Partition.cpp:136

static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)

Definition Partition.cpp:537

static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)

Definition Partition.cpp:729

static TypedValue< ShapedType > reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)

Definition Partition.cpp:438

static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)

Definition Partition.cpp:344

static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)

Definition Partition.cpp:39

bool isFullReplication(Sharding sharding)

static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)

Definition Partition.cpp:701

static std::vector< Sharding > getOperandShardings(Operation &op)

Definition Partition.cpp:589

static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard)

Definition Partition.cpp:119

DenseMap< Value, Value > UnshardedToShardedValueMap

Definition Partition.cpp:531

static std::vector< Sharding > getResultShardings(Operation &op)

Definition Partition.cpp:609

static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)

Definition Partition.cpp:275

static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis)

Definition Partition.cpp:164

static std::optional< std::tuple< int64_t, GridAxis > > detectSplitLastAxisInResharding(Sharding sourceSharding, Sharding targetSharding)

Definition Partition.cpp:87

int64_t shardDimension(int64_t dimSize, int64_t shardCount)

TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)

Definition Partition.cpp:504

static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)

Definition Partition.cpp:365

static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)

Definition Partition.cpp:213

static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis)

Definition Partition.cpp:46

void reshardingRegisterDependentDialects(DialectRegistry &registry)

Definition Partition.cpp:524

static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)

Definition Partition.cpp:303

shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)

static LogicalResult partitionOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)

Definition Partition.cpp:560

int64_t gatherDimension(int64_t dimSize, int64_t shardCount)

std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue

If Ty is mlir::Type this will select Value instead of having a wrapper around it.

llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap

This trait indicates that a terminator operation is "return-like".