MLIR: lib/Dialect/Shard/Transforms/Partition.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Casting.h"
32#include
33#include
34#include
35
37
38template <typename SourceAxes, typename TargetAxes>
40 const TargetAxes &targetAxes) {
41 return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
42 return sourceAxes.contains(targetAxis);
43 });
44}
45
51 llvm::to_vector(sourceSharding.getSplitAxes());
52 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
53 splitTensorAxis) {
55 }
56 auto targetSplitAxes =
57 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
58 targetSplitAxes.push_back(splitGridAxis);
59 targetShardingSplitAxes[splitTensorAxis] =
62}
63
64
65
66
67static std::tuple<TypedValue, Sharding>
73 AllSliceOp::create(builder, sourceShard, grid,
75 .getResult();
77 builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
78 return {targetShard, targetSharding};
79}
80
81
82
83
84
85
86static std::optional<std::tuple<int64_t, GridAxis>>
89 for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
90 ++tensorAxis) {
91 if (sourceSharding.getSplitAxes().size() > tensorAxis) {
92 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
93 targetSharding.getSplitAxes()[tensorAxis].size()) {
94 continue;
95 }
96 if (!llvm::equal(
97 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
98 llvm::make_range(
100 .asArrayRef()
101 .begin(),
102 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
103 1))) {
104 continue;
105 }
106 } else {
107 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
108 continue;
109 }
110 }
111 return std::make_tuple(
112 tensorAxis,
113 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
114 }
115 return std::nullopt;
116}
117
118static std::optional<std::tuple<TypedValue, Sharding>>
122 if (auto detectRes =
124 auto [tensorAxis, gridAxis] = detectRes.value();
126 tensorAxis, gridAxis);
127 }
128
129 return std::nullopt;
130}
131
132
133
134
135static std::optional<std::tuple<int64_t, GridAxis>>
138 for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
139 ++tensorAxis) {
140 if (targetSharding.getSplitAxes().size() > tensorAxis) {
141 if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
142 targetSharding.getSplitAxes()[tensorAxis].size() + 1)
143 continue;
144 if (!llvm::equal(
145 llvm::make_range(
147 .asArrayRef()
148 .begin(),
149 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
150 1),
151 targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
152 continue;
153 } else {
154 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
155 continue;
156 }
157 return std::make_tuple(
158 tensorAxis,
159 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
160 }
161 return std::nullopt;
162}
163
166 int64_t splitTensorAxis) {
168 llvm::to_vector(sourceSharding.getSplitAxes());
169 assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
170 splitTensorAxis);
171 auto targetSplitAxes =
172 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
173
174 targetSplitAxes.pop_back();
175 targetShardingSplitAxes[splitTensorAxis] =
178}
179
181 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
183 targetShape[splitTensorAxis] =
184 gatherDimension(targetShape[splitTensorAxis], splitCount);
185 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
186}
187
191 GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) {
194
198 sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
199 Value allGatherResult = AllGatherOp::create(
200 builder,
201 RankedTensorType::get(allGatherResultShape.getShape(),
202 allGatherResultShape.getElementType()),
204 APInt(64, splitTensorAxis));
205 ShapedType targetShape =
206 shardShapedType(sourceUnshardedShape, grid, targetSharding);
208 tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
209 return {targetShard, targetSharding};
210}
211
212static std::optional<std::tuple<TypedValue, Sharding>>
215 ShapedType sourceUnshardedShape,
217 if (auto detectRes =
219 auto [tensorAxis, gridAxis] = detectRes.value();
221 sourceUnshardedShape, sourceShard, grid,
222 tensorAxis, gridAxis);
223 }
224
225 return std::nullopt;
226}
227
228
229
230
231
232
233static std::optional<std::tuple<int64_t, int64_t, GridAxis>>
236 for (size_t sourceTensorAxis = 0;
237 sourceTensorAxis < sourceSharding.getSplitAxes().size();
238 ++sourceTensorAxis) {
239 for (size_t targetTensorAxis = 0;
240 targetTensorAxis < targetSharding.getSplitAxes().size();
241 ++targetTensorAxis) {
242 if (sourceTensorAxis == targetTensorAxis)
243 continue;
244 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
245 targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
246 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
247 targetSharding.getSplitAxes()[targetTensorAxis]
248 .asArrayRef()
249 .back())
250 continue;
251 if (!llvm::equal(
252 llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
253 .asArrayRef()
254 .begin(),
255 sourceSharding.getSplitAxes()[sourceTensorAxis]
256 .asArrayRef()
257 .end() -
258 1),
259 llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
260 .asArrayRef()
261 .begin(),
262 targetSharding.getSplitAxes()[targetTensorAxis]
263 .asArrayRef()
264 .end() -
265 1)))
266 continue;
267 return std::make_tuple(
268 sourceTensorAxis, targetTensorAxis,
269 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
270 }
271 }
272 return std::nullopt;
273}
274
277 int64_t sourceTensorAxis,
278 int64_t targetTensorAxis) {
280 llvm::to_vector(sourceSharding.getSplitAxes());
281 while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
282 targetTensorAxis) {
284 }
285
286 auto sourceSplitAxes =
287 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
288 assert(!sourceSplitAxes.empty());
289 auto gridAxis = sourceSplitAxes.back();
290 sourceSplitAxes.pop_back();
291 targetShardingSplitAxes[sourceTensorAxis] =
293
294 auto targetSplitAxes =
295 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
296 targetSplitAxes.push_back(gridAxis);
297 targetShardingSplitAxes[targetTensorAxis] =
299
301}
302
305 int64_t sourceTensorAxis,
306 int64_t targetTensorAxis) {
308 targetShape[sourceTensorAxis] =
309 gatherDimension(targetShape[sourceTensorAxis], splitCount);
310 targetShape[targetTensorAxis] =
311 shardDimension(targetShape[targetTensorAxis], splitCount);
312 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
313}
314
315static std::tuple<TypedValue, Sharding>
318 ShapedType sourceUnshardedShape,
320 int64_t sourceTensorAxis,
324
326 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
328 sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis,
329 targetTensorAxis);
330 Value allToAllResult = AllToAllOp::create(
331 builder,
332 RankedTensorType::get(allToAllResultShape.getShape(),
333 allToAllResultShape.getElementType()),
335 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
336 ShapedType targetShape =
337 shardShapedType(sourceUnshardedShape, grid, targetSharding);
339 tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
340 return {targetShard, targetSharding};
341}
342
343static std::optional<std::tuple<TypedValue, Sharding>>
347 ShapedType sourceUnshardedShape,
349 if (auto detectRes =
351 auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value();
353 builder, grid, sourceSharding, sourceUnshardedShape, sourceShard,
354 sourceTensorAxis, targetTensorAxis, gridAxis);
355 }
356
357 return std::nullopt;
358}
359
360
361
362
363
364static std::optional<std::tuple<TypedValue, Sharding>>
367 ShapedType sourceUnshardedShape,
369
370
371 if (!sourceSharding.equalSplitAxes(targetSharding) ||
375 return std::nullopt;
376 }
377
380 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
381 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
382 ShapedType::isStaticShape(tgtHaloSizes) &&
383 sourceShard.getType().hasStaticShape()) &&
384 "dynamic shapes/halos are not supported yet for shard-partition");
385 auto rank = sourceShard.getType().getRank();
386 auto splitAxes = sourceSharding.getSplitAxes();
388 strides(rank, 1), outShape(sourceShard.getType().getShape()),
389 coreShape(sourceShard.getType().getShape());
390
391
392
393 for (auto i = 0u; i < rank; ++i) {
394 if (i < splitAxes.size() && !splitAxes[i].empty()) {
395 if (!srcHaloSizes.empty()) {
396 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
397 srcCoreOffs[i] = srcHaloSizes[i * 2];
398 }
399 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
400 outShape[i] =
401 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
402 }
403 }
404
405
407 auto initVal =
408 tensor::EmptyOp::create(builder, sourceShard.getLoc(), outShape,
409 sourceShard.getType().getElementType());
410 auto core = tensor::ExtractSliceOp::create(
411 builder, sourceShard.getLoc(),
412 RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
413 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
414 auto initOprnd = tensor::InsertSliceOp::create(
415 builder, sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
416 tgtCoreOffs, coreShape, strides);
417
418
419 auto updateHaloResult =
420 UpdateHaloOp::create(
421 builder, sourceShard.getLoc(),
422 RankedTensorType::get(outShape,
423 sourceShard.getType().getElementType()),
424 initOprnd, grid.getSymName(),
425 GridAxesArrayAttr::get(builder.getContext(),
429 .getResult();
431 targetSharding);
432}
433
434
435
436
442 assert(sourceShard.getType() ==
443 shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
444 [[maybe_unused]] ShapedType targetShardType =
445 shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding);
446 assert(sourceShard.getType().getRank() == targetShardType.getRank());
447 assert(grid.getRank() == 1 && "Only 1D grides are currently supported.");
448
449 if (sourceSharding == targetSharding) {
450 return sourceShard;
451 }
452
454 Sharding actualTargetSharding;
460 builder, grid, sourceSharding, targetSharding,
461 sourceUnshardedValue.getType(), sourceShard)) {
462 std::tie(targetShard, actualTargetSharding) = tryRes.value();
463 } else if (auto tryRes =
465 targetSharding, sourceShard)) {
466 std::tie(targetShard, actualTargetSharding) = tryRes.value();
468 builder, grid, sourceSharding, targetSharding,
469 sourceUnshardedValue.getType(), sourceShard)) {
470 std::tie(targetShard, actualTargetSharding) = tryRes.value();
471 }
472 }
473 assert(targetShard && "Did not find any pattern to apply.");
474 assert(actualTargetSharding == targetSharding);
475 assert(targetShard.getType() == targetShardType);
476 return targetShard;
477}
478
483
484 if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
486 return sourceShard;
487 }
488
489
490
492 builder, grid, sourceSharding, targetSharding,
493 sourceUnshardedValue.getType(), sourceShard)) {
494 return std::get<0>(tryRes.value());
495 }
496
497
498
499
500 return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding,
501 sourceUnshardedValue, sourceShard);
502}
503
507 assert(source.getResult() == target.getSrc());
508 auto sourceSharding = source.getSharding();
509 auto targetSharding = target.getSharding();
511 return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
512 source.getSrc(), sourceShardValue);
513}
514
519 GridOp srcGrid = getGrid(source, symbolTableCollection);
520 assert(srcGrid && srcGrid == getGrid(target, symbolTableCollection));
521 return reshard(builder, srcGrid, source, target, sourceShardValue);
522}
523
525 registry.insert<shard::ShardDialect, tensor::TensorDialect>();
526}
527
528#define GEN_PASS_DEF_PARTITION
529#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
530
532
533
534
535
540 llvm::transform(
541 block.getArguments(), std::back_inserter(res),
543 auto rankedTensorArg = dyn_cast<TypedValue>(arg);
544 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
545 return arg.getType();
546 }
547
548 assert(rankedTensorArg.hasOneUse());
550 ShardOp shardOp = llvm::dyn_cast(useOp);
551 assert(shardOp);
552 GridOp grid = getGrid(shardOp, symbolTableCollection);
553 return cast(shardShapedType(rankedTensorArg.getType(), grid,
554 shardOp.getSharding()));
555 });
556 return res;
557}
558
559static LogicalResult
565 ShardingInterface shardingInterface = llvm::dyn_cast(op);
566 if (!shardingInterface) {
567
568
570 resultShardings, partitionMap,
571 symbolTableCollection, builder);
572 } else {
573 if (failed(shardingInterface.partition(
574 partitionedOperands, operandShardings, resultShardings,
575 partitionMap, symbolTableCollection, builder))) {
576 return failure();
577 }
578 }
579
581 return partitionMap.contains(result);
582 }));
583
585}
586
587
588
590 std::vector res;
592 llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
593 TypedValue rankedTensor =
594 dyn_cast<TypedValue>(operand);
595 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
596 return Sharding();
597 }
598
600 assert(definingOp);
601 ShardOp shardOp = llvm::cast(definingOp);
602 return Sharding(shardOp.getSharding());
603 });
604 return res;
605}
606
607
608
610 std::vector res;
612 llvm::transform(
614 if (!result.hasOneUse() || result.use_empty()) {
615 return Sharding();
616 }
619 if (!rankedTensor) {
621 }
623 ShardOp shardOp = llvm::dyn_cast(userOp);
624 if (shardOp) {
625 return Sharding(shardOp.getSharding());
626 }
627 if (rankedTensor.getType().getRank() == 0) {
628
629
630
632 if (auto sharding = operand.getDefiningOp()) {
633 return Sharding(sharding.getGridAttr());
634 }
635 }
636 }
638 });
639 return res;
640}
641
642static LogicalResult
646 Value targetPartitionValue;
647
648
649
650 ShardOp srcShardOp = shardOp.getSrc().getDefiningOp();
651 if (!srcShardOp) {
652 targetPartitionValue = partitionMap.lookup(shardOp.getSrc());
653 } else {
654
656 cast<TypedValue>(partitionMap.lookup(srcShardOp));
657 targetPartitionValue = reshard(builder, srcShardOp, shardOp,
658 srcPartitionValue, symbolTableCollection);
659 }
660
661 assert(!partitionMap.contains(shardOp.getResult()));
662 partitionMap.map(shardOp.getResult(), targetPartitionValue);
664}
665
666static LogicalResult
670 if (isa(op)) {
672 }
673 if (auto getShardingOp = dyn_cast(op)) {
674 auto shardOp = getShardingOp.getSource().getDefiningOp();
675 if (!shardOp) {
676 return op.emitError("expected a shard op as source of get_sharding");
677 }
678 auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
679 partitionMap.map(op.getResult(0), newSharding->getResult(0));
681 }
682
683 ShardOp shardOp = llvm::dyn_cast(op);
684 if (shardOp) {
685 return partitionOperation(shardOp, partitionMap, symbolTableCollection,
686 builder);
687 }
688
690 llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
691 [&partitionMap](Value operand) {
692 assert(partitionMap.contains(operand));
693 return partitionMap.lookup(operand);
694 });
697 symbolTableCollection, builder);
698}
699
700static LogicalResult
704
706 llvm::transform(block.getArguments(), std::back_inserter(argLocations),
711 for (auto [unshardedBlockArg, partitionedBlockArg] :
713 partitionMap.map(unshardedBlockArg, partitionedBlockArg);
714 }
715
719 if (failed(partitionOperation(op, partitionMap, symbolTableCollection,
720 builder))) {
721 return failure();
722 }
723 }
724
726}
727
728static LogicalResult
731 OpBuilder builder(op.getFunctionBody());
732
733
734
736 for (Block &b : op.getBlocks()) {
737 if (llvm::any_of(b.getOperations(),
738 [](Operation &op) { return isa(op); })) {
739 originalBlocks.push_back(&b);
740 }
741 }
742
743 for (Block *block : originalBlocks) {
744 if (failed(partitionBlock(*block, partitionMap, symbolTableCollection,
745 builder))) {
746 return failure();
747 }
748 }
749
750 for (Block *block : originalBlocks) {
751 block->erase();
752 }
753
754
755
757 for (Block &block : op.getFunctionBody()) {
758 if (block.empty()) {
759 continue;
760 }
761
763 returnOp = &block.back();
764 break;
765 }
766 }
767 if (returnOp) {
768 op.setType(FunctionType::get(
769 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
771 }
772
774}
775
776namespace {
777
778struct Partition : public impl::PartitionBase {
779 void runOnOperation() override {
782 if (failed(partitionFuncOp(getOperation(), partitionMap,
783 symbolTableCollection))) {
784 return signalPassFailure();
785 }
786 }
787
788 void getDependentDialects(DialectRegistry ®istry) const override {
790 registry.insertshard::ShardDialect();
791 }
792};
793
794}
795
796}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
BlockArgListType getArguments()
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
::mlir::FlatSymbolRefAttr getGridAttr() const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
static std::optional< std::tuple< int64_t, int64_t, GridAxis > > detectMoveLastSplitAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
Definition Partition.cpp:234
static std::tuple< TypedValue< ShapedType >, Sharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
Definition Partition.cpp:68
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
static std::tuple< TypedValue< ShapedType >, Sharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis)
Definition Partition.cpp:316
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
Definition Partition.cpp:180
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::tuple< TypedValue< ShapedType >, Sharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
Definition Partition.cpp:188
static std::optional< std::tuple< int64_t, GridAxis > > detectUnsplitLastAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
Definition Partition.cpp:136
static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
Definition Partition.cpp:537
static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
Definition Partition.cpp:729
static TypedValue< ShapedType > reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
Definition Partition.cpp:438
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
Definition Partition.cpp:344
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
Definition Partition.cpp:39
bool isFullReplication(Sharding sharding)
static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
Definition Partition.cpp:701
static std::vector< Sharding > getOperandShardings(Operation &op)
Definition Partition.cpp:589
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard)
Definition Partition.cpp:119
DenseMap< Value, Value > UnshardedToShardedValueMap
Definition Partition.cpp:531
static std::vector< Sharding > getResultShardings(Operation &op)
Definition Partition.cpp:609
static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
Definition Partition.cpp:275
static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis)
Definition Partition.cpp:164
static std::optional< std::tuple< int64_t, GridAxis > > detectSplitLastAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
Definition Partition.cpp:87
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
Definition Partition.cpp:504
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
Definition Partition.cpp:365
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
Definition Partition.cpp:213
static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis)
Definition Partition.cpp:46
void reshardingRegisterDependentDialects(DialectRegistry ®istry)
Definition Partition.cpp:524
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
Definition Partition.cpp:303
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
Definition Partition.cpp:560
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
This trait indicates that a terminator operation is "return-like".