MLIR: lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
11
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
20
21 #include
22
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25
26 using namespace mlir;
28
29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
30
31
32
33
34
35 static LogicalResult
40 auto binOpExpr = cast(expr);
44 return failure();
46 return failure();
47 return success();
48 }
50 auto binOpExpr = cast(expr);
56 dimExpr = lhs;
59 dimExpr = rhs;
60 } else {
61 return failure();
62 }
63 unsigned position = cast(dimExpr).getPosition();
64 if ((size_t)position >= seenIds.size() || seenIds[position])
65 return failure();
66 seenIds[position] = true;
67 return success();
68 }
70 unsigned position = cast(expr).getPosition();
71 if ((size_t)position >= seenIds.size() || seenIds[position])
72 return failure();
73 seenIds[position] = true;
74 return success();
75 }
76 default:
77 return failure();
78 }
79 }
80
81 static FailureOr<llvm::SmallSet<unsigned, 2>>
85 return failure();
86
87 llvm::SmallSet<unsigned, 2> positions;
89 if (it.value())
90 positions.insert((unsigned)it.index());
91 }
92 return positions;
93 }
94
95 template
99 for (const auto &v : vec) {
101 }
102 return res;
103 }
104
105
106
107
108
109 FailureOr<std::pair<bool, MeshSharding>>
111 Value val = cast(result);
112 bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
113 auto shardOp = llvm::dyn_castmesh::ShardOp(user);
114 if (!shardOp)
115 return false;
116 return !shardOp.getAnnotateForUsers();
117 });
118
119 if (anyShardedForDef) {
120
121
123 return failure();
124 auto shardOp = llvm::castmesh::ShardOp(*val.getUsers().begin());
125 return std::make_pair(false, MeshSharding(shardOp.getSharding()));
126 }
127
128 bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
129 auto shardOp = llvm::dyn_castmesh::ShardOp(user);
130 if (!shardOp)
131 return false;
132 return shardOp.getAnnotateForUsers();
133 });
134 if (anyShardedForUsers) {
137 ShardOp shardOp = llvm::dyn_cast(user);
138 if (shardOp)
139 shardOps.push_back(shardOp);
140 }
141 MeshSharding shardForDef = shardOps[0].getSharding();
142 for (size_t i = 1; i < shardOps.size(); ++i) {
143
144
145 assert(shardForDef == shardOps[i].getSharding() &&
146 "only support all shard ops have the same mesh sharding attr");
147 }
148 return std::make_pair(true, shardForDef);
149 }
150 return failure();
151 }
152
153 FailureOr<std::pair<bool, MeshSharding>>
156 if (ShardOp shardOp = val.getDefiningOp())
157 return std::make_pair(shardOp.getAnnotateForUsers(),
159
160 return failure();
161 }
162
163
164
165
166
167 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
169
170
172 if (!llvm::isa(type) && !type.isIntOrIndexOrFloat())
173 return failure();
175 if (!llvm::isa(type) && !type.isIntOrIndexOrFloat())
176 return failure();
177
178
180 if (maps.empty())
181 return failure();
184 if (numOperands + numResults != maps.size())
185 return failure();
186
188 auto resultType = dyn_cast(result.getType());
189 if (!resultType)
190 return failure();
191 AffineMap map = maps[numOperands + result.getResultNumber()];
193 return failure();
194 }
195 }
196
197 return success();
198 }
199
200
201
202
203
204 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
205 os << "print loop types and indexing maps for: \n";
206 getOperation()->print(os);
207 os << "\n";
208 os << "loop types: [";
209 for (utils::IteratorType type : getLoopIteratorTypes()) {
210 os << stringifyEnum(type) << " ";
211 }
212 os << "]\n";
213 os << "indexing maps: \n";
214 for (AffineMap map : getIndexingMaps())
215 os << map << "\n";
216 os << "\n";
217 }
218
219
220
221
222
223 namespace {
224
225
226 static LogicalResult fillShardingOption(Operation *op,
230 unsigned loopIdx) {
231 if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
232 (!shardingOption.shardingArray[loopIdx].empty() &&
233 shardingOption.shardingArray[loopIdx] != meshAxes)) {
234 LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
235 << loopIdx << "\n");
236 return failure();
237 }
238 for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
239 if (i == loopIdx)
240 continue;
241
242 for (MeshAxis axis : meshAxes) {
243 if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
244 LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
245 << axis << " duplicate");
246 return failure();
247 }
248 }
249 }
250 if (mesh)
251 shardingOption.mesh = mesh;
252 if (shardingOption.shardingArray[loopIdx].empty())
253 shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
254 meshAxes.end());
255 return success();
256 }
257
258 }
259
260 FailureOr
264 ShardingInterface shardingOp = llvm::cast(op);
266
267 if (failed(shardingOp.verifyShardingInterfaceImpl()))
268 return op->emitOpError() << "invalid sharding interface implementation";
270 shardingOp.getLoopIteratorTypes();
273 shardingOption.shardingArray.resize(loopTypes.size());
275 llvm::SmallSet<unsigned, 4> visitedLoopIndices;
276 bool anyShardingInResultsOrOperands = false;
277
278
279 for (auto shardingIt : llvm::enumerate(resultShardings)) {
281 if (!shardAttr)
282 continue;
283 AffineMap map = maps[numOperands + shardingIt.index()];
284 anyShardingInResultsOrOperands = true;
287 } else {
288
289
290
294 auto dim = cast(expr);
295 unsigned index = dim.getPosition();
296 visitedLoopIndices.insert(index);
297 if (failed(fillShardingOption(op, shardingOption,
299 return failure();
300 }
301 }
302
303
304
306 if (!partialAxes.empty()) {
307 if (!partialMeshAxes.empty())
308 return op->emitOpError() << "at most one result with partial axes is "
309 "supported at present";
310 partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
311
312
313 for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
315 visitedLoopIndices.insert(loopIdx);
316 }
317 }
318 }
319
320
321 for (auto shardingIt : llvm::enumerate(operandShardings)) {
323 if (!shardAttr)
324 continue;
325
326 anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty();
327 AffineMap map = maps[shardingIt.index()];
329
330
331
332
333
334
338 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
340 if (failed(loopIndices))
342 << "operand's affine expression is restricted to const_i * "
343 "dim_i + const_j + dim_j + ...";
344 if (loopIndices->empty())
345 continue;
346 if (loopIndices->size() == 1) {
347 unsigned loopIdx = *loopIndices->begin();
348 visitedLoopIndices.insert(loopIdx);
349 if (failed(fillShardingOption(op, shardingOption,
350 shardAttr.getMeshAttr(), axes, loopIdx)))
351 return failure();
352 }
353
354
355
356 if (loopIndices->size() > 1) {
357 bool seenLoopIndices = false;
358 for (unsigned loopIdx : *loopIndices) {
359 if (visitedLoopIndices.contains(loopIdx)) {
360 seenLoopIndices = true;
361 break;
362 }
363 }
364 if (!seenLoopIndices)
366 << "the operand " << shardingIt.index()
367 << " has multiple loop indices in a dimension, but none of "
368 "them could be found in the exactly specified annotation "
369 "of op results or operands.";
370 }
371 }
372 }
373
374
375 if (!partialMeshAxes.empty()) {
376 bool anyNonEmptyReductionLoop = llvm::any_of(
378 SmallVector &subArray = it.value();
379 int64_t idx = it.index();
380 return isReductionLoop(loopTypes[idx]) && !subArray.empty();
381 });
382 if (!anyNonEmptyReductionLoop) {
383 bool filled = false;
384 for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
386 std::ignore = fillShardingOption(op, shardingOption, nullptr,
387 partialMeshAxes, idx);
388 filled = true;
389 break;
390 }
391 }
392 if (!filled)
393 return op->emitOpError() << "no matched reduction loop found for the "
394 "result's partial type";
395 }
396 }
398 if (!anyShardingInResultsOrOperands)
399 shardingOption.empty = true;
400 return shardingOption;
401 }
402
403
407 auto resultType = cast(result.getType());
410
411
414
415
416 auto dim = cast(expr);
417 unsigned loopIdx = dim.getPosition();
418 if (loopIdx < shardingOption.shardingArray.size())
419 splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
420 }
421
422
423
425 size_t reductionLoopKindsIdx = 0;
426 for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
427 utils::IteratorType iType = std::get<0>(it);
429 ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
430 ++reductionLoopKindsIdx;
431 if (!partialAxes.empty())
432 assert(partialType == curPartialType &&
433 "Only one reduction type is supported");
434 partialType = curPartialType;
436 partialAxes.append(axis);
437 }
438 }
439
443 partialAxes, partialType);
444 }
445
449 Value operandValue = opOperand.get();
450 auto operandType = dyn_cast(operandValue.getType());
451 if (!operandType) {
454 return failure();
455 }
456
457 if (operandType.getRank() == 0) {
459 }
463 int64_t idx = it.index();
465 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
467 if (failed(loopIndices))
468 return failure();
470 for (unsigned loopIdx : *loopIndices) {
471 if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
473 shardedLoopIndices.push_back(loopIdx);
474 }
475
476 if (shardedLoopIndices.size() > 1)
477 return failure();
478 if (shardedLoopIndices.size() == 1) {
479 splitAxes[idx].append(
480 shardingOption.shardingArray[shardedLoopIndices[0]]);
481 }
482 }
483
486 shardingOption.mesh,
488 }
489
490 FailureOr<std::vector>
493 std::vector res;
494
495 ShardingInterface shardingOp = llvm::cast(op);
497 shardingOp.getLoopIteratorTypes();
499 shardingOp.getReductionLoopIteratorKinds();
502
504 FailureOr shardingAttr = getSharding(
505 opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
506 if (failed(shardingAttr))
507 return failure();
508 res.push_back(*shardingAttr);
509 }
510
512 res.push_back(getSharding(result, shardingOption,
513 maps[numOperands + result.getResultNumber()],
514 loopTypes, reductionKinds));
515 }
516
517 return res;
518 }
519
520
521
522
523
524
525
532 getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds);
534
535 return success();
536 }
537
538
539
543
544 FailureOr sharding =
545 getSharding(opOperand, shardingOption, map);
546 if (failed(sharding)) {
547 return failure();
548 }
551
552 return success();
553 }
554
557 assert(!shardingOption.empty && shardingOption.mesh);
558
559 ShardingInterface shardingOp = llvm::cast(op);
561 shardingOp.getLoopIteratorTypes();
563 shardingOp.getReductionLoopIteratorKinds();
566
567
569 if (failed(addShardOp(b, result, shardingOption,
570 maps[numOperands + result.getResultNumber()],
571 loopTypes, reductionKinds)))
572 return failure();
573 }
574
575
577 if (failed(addShardOp(b, opOperand, shardingOption,
578 maps[opOperand.getOperandNumber()])))
579 return failure();
580 }
581
582 return success();
583 }
584
585 #ifndef NDEBUG
586 static bool
589 if (isa(value.getType())) {
591 }
592
593 return !sharding;
594 }
595
596 template <typename ValueRange, typename MeshShardingRage>
597 static bool
599 MeshShardingRage &&shardings) {
600 if (std::size(values) != std::size(shardings)) {
601 return false;
602 }
603 return llvm::all_of(
604 llvm::zip_equal(std::forward(values),
605 std::forward(shardings)),
606 [](auto valueAndSharding) {
608 std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
609 });
610 }
611 #endif
612
618 assert(spmdizedOperands.size() == operandShardings.size());
620 operandShardings));
622 resultShardings));
623
624 builder.clone(op, spmdizationMap);
625 }
626
630 &meshAxesAssignmentForLoopIterators) {
631 AffineDimExpr affineDimExpr = cast(indexingExpr);
632 unsigned loopIteratorIdx = affineDimExpr.getPosition();
633 if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
634 assert(llvm::equal(meshAxesAssignmentForTensorAxis,
635 *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
636 } else {
637 meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
638 llvm::to_vector(meshAxesAssignmentForTensorAxis);
639 }
640 }
641
648 meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
649 std::vector operatorAndResultShardings;
650 operatorAndResultShardings.reserve(operandShardings.size() +
651 resultShardings.size());
652 llvm::append_range(operatorAndResultShardings, operandShardings);
653 for (auto [sharding, affineMap] :
654 llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
655 if (!sharding) {
656 continue;
657 }
658 for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
659 llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
661 meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
662 meshAxisAssignmentForLoopIterators);
663 }
664
665 for (unsigned i = sharding.getSplitAxes().size();
666 i < affineMap.getNumResults(); ++i) {
668 {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
669 }
670 }
671
673 llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
675 if (!axes) {
677 };
678 return std::move(*axes);
679 });
680 return res;
681 }
682
686 for (auto [loopIteratorType, meshAxisAssignment] :
687 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
688 if (loopIteratorType == utils::IteratorType::reduction &&
689 !meshAxisAssignment.empty()) {
690 return true;
691 }
692 }
693 return false;
694 }
695
700 for (auto [loopIteratorType, meshAxisAssignment] :
701 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
702 if (loopIteratorType == utils::IteratorType::reduction) {
703 llvm::append_range(meshAxes, meshAxisAssignment);
704 }
705 }
706 return meshAxes;
707 }
708
714
715 Operation *newOp = builder.clone(op, spmdizationMap);
716
717 for (auto [oldResult, newResult, sharding] :
720 newResult.getType(),
721 getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
722 }
723 }
SmallVector< MeshAxesAttr > fromArrayOfVector(MLIRContext *ctxt, const SmallVector< SmallVector< T >> &vec)
static void updateMeshAxisAssignmentForLoopIterators(ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, MeshShardingRage &&shardings)
static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshSharding sharding)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
unsigned getNumOperands()
operand_type_range getOperandTypes()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef.
::mlir::FlatSymbolRefAttr getMeshAttr() const
ArrayRef< MeshAxesAttr > getSplitAxes() const
ArrayRef< MeshAxis > getPartialAxes() const
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
mesh::MeshSharding MeshSharding
FailureOr< std::vector< MeshSharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder, ShardOp &newShardOp)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
bool isReductionLoop(utils::IteratorType iType)
bool isFullReplication(MeshSharding sharding)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
FailureOr< std::pair< bool, MeshSharding > > getMeshSharding(OpResult result)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
ShardingArray shardingArray