MLIR: include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9 #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_

10 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_

11

16

17 namespace mlir {

18

19 class Operation;

20 class IRMapping;

21 class SymbolTableCollection;

22

23 namespace mesh {

24

25

26

27

29 ArrayRef operandShardings,

30 ArrayRef resultShardings,

31 ArrayRefutils::IteratorType loopIteratorTypes,

32 ArrayRef indexingMaps);

33

35 ArrayRefutils::IteratorType loopIteratorTypes,

36 ArrayRef<SmallVector> meshAxisAssignmentForLoopIterators);

37

38

40 ArrayRefutils::IteratorType loopIteratorTypes,

41 ArrayRef<SmallVector> meshAxisAssignmentForLoopIterators);

42

43

44

46 ArrayRef spmdizedOperands,

47 ArrayRef operandShardings,

48 ArrayRef resultShardings,

49 IRMapping &spmdizationMap,

50 SymbolTableCollection &symbolTable,

51 OpBuilder &builder);

52

53

54

55 template

57 : public ShardingInterface::ExternalModel<

58 IndependentParallelIteratorDomainShardingInterface, Op> {

63 populateIteratorTypes(t, iterTypes);

64 }

66 populateIteratorTypes(t, iterTypes);

67 }

68 return iterTypes;

69 }

70

72

74 }

75

83 resultShardings, spmdizationMap,

84 symbolTable, builder);

85 return success();

86 }

87

88 private:

89 void

90 populateIteratorTypes(Type t,

92 RankedTensorType rankedTensorType = dyn_cast(t);

93 if (!rankedTensorType) {

94 return;

95 }

96

97 iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());

98 for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {

99 iterTypes.push_back(utils::IteratorType::parallel);

100 }

101 }

102 };

103

104

105 template

107 : public ShardingInterface::ExternalModel<

108 ElementwiseShardingInterface, ElemwiseOp> {

111 auto type = dyn_cast(val.getType());

112 if (!type)

113 return {};

115 utils::IteratorType::parallel);

116 return types;

117 }

118

122 auto type = dyn_cast(val.getType());

123 if (!type)

124 return {};

125 int64_t rank = type.getRank();

129 return maps;

130 }

131

139 resultShardings, spmdizationMap,

140 symbolTable, builder);

141 return success();

142 }

143 };

144

145 }

146 }

147

148 #endif

static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)

Returns an AffineMap with 'numDims' identity result dim exprs.

This is a utility class for mapping one set of IR entities to another.

MLIRContext is the top-level object for a collection of MLIR operations.

This class helps build Operations.

Operation is the basic unit of execution within MLIR.

Value getOperand(unsigned idx)

MLIRContext * getContext()

Return the context this operation is associated with.

unsigned getNumOperands()

operand_type_range getOperandTypes()

result_type_range getResultTypes()

unsigned getNumResults()

Return the number of results held by this operation.

This class represents a collection of SymbolTables.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)

SmallVector< SmallVector< MeshAxis > > ShardingArray

SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)

ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)

void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)

Include the generated interface declarations.

SmallVector< AffineMap > getIndexingMaps(Operation *op) const

LogicalResult spmdize(Operation *op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const

SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *op) const

SmallVector< AffineMap > getIndexingMaps(Operation *op) const

LogicalResult spmdize(Operation *op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const

SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *operation) const