MLIR: lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include
24 #include
25
26 namespace mlir {
27 namespace mesh {
28 #define GEN_PASS_DEF_SHARDINGPROPAGATION
29 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
30 }
31 }
32
33 #define DEBUG_TYPE "sharding-propagation"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
35
36 using namespace mlir;
38
43 };
44
45 #ifdef LLVM_DEBUG
46
47 template
48 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
50 template <typename... Ts>
51 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
52 const std::tuple<Ts...> &t);
53 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
55
56 template <typename Stream, typename Range>
57 static Stream &printRange(Stream &stream, Range &&range) {
58 stream << "[";
59 llvm::for_each(range, [&stream](auto &v) {
60 stream << v;
61 stream << ", ";
62 });
63 return stream << "]";
64 }
65
66 template
67 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
69 return printRange(stream, vec);
70 }
71
72 [[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
74 return stream << "{empty = " << v.empty << ", mesh" << v.mesh
75 << ", shardingArray = " << v.shardingArray << "}";
76 }
77
78 template <typename Stream, typename... Ts, size_t... Is>
79 static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
80 std::index_sequence<Is...>) {
81 static_assert(sizeof...(Is) == sizeof...(Ts),
82 "Indices must have same number of elements as tuple types!");
83 static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");
84
85 stream << "{";
86 ((stream << std::get(tuple) << ", "), ...);
87 return stream << "}";
88 }
89
90 template <typename... Ts>
91 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
92 const std::tuple<Ts...> &t) {
93 return printTuple(stream, t, std::index_sequence_for<Ts...>{});
94 }
95
96 [[maybe_unused]] static llvm::raw_ostream &
98 return stream << static_cast(v);
99 }
100
101 #endif
102
103
104
105
106
107
108
109
110
115 std::vector curShardingAttrs;
116
117 std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
118 if (i == mustShardings.size()) {
119 allShardingAttrs.push_back(std::vector(curShardingAttrs));
120 return;
121 }
122
123 if (mustShardings[i]) {
124 curShardingAttrs.push_back(mustShardings[i]);
125 dfsCreateShardingAttrs(i + 1);
126 curShardingAttrs.pop_back();
127 return;
128 }
129
130 if (optionalShardings[i]) {
131 curShardingAttrs.push_back(optionalShardings[i]);
132 dfsCreateShardingAttrs(i + 1);
133 curShardingAttrs.pop_back();
134 curShardingAttrs.push_back({});
135 dfsCreateShardingAttrs(i + 1);
136 curShardingAttrs.pop_back();
137 return;
138 }
139
140 curShardingAttrs.push_back({});
141 dfsCreateShardingAttrs(i + 1);
142 curShardingAttrs.pop_back();
143 };
144
145 dfsCreateShardingAttrs(0);
146 return allShardingAttrs;
147 }
148
149
150
151
152
153
154
155
156
157
158
160 Operation *op, const std::vector &operandAndResultShardings) {
162
163 size_t operandsCount = op->getOperands().size();
164 auto operandShardings =
165 llvm::make_range(operandAndResultShardings.begin(),
166 operandAndResultShardings.begin() + operandsCount);
167 auto resultShardings =
168 llvm::make_range(operandAndResultShardings.begin() + operandsCount,
169 operandAndResultShardings.end());
170
171 for (auto [operand, sharding] :
172 llvm::zip_equal(op->getOperands(), operandShardings)) {
173 ShardOp shardOp = llvm::dyn_cast_or_null(operand.getDefiningOp());
174 if (!shardOp) {
175 continue;
176 }
177 bool needsResharding = sharding != shardOp.getSharding();
178 bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
179 if (needsResharding) {
180 if (isExplicitAnnotationForThisOp) {
181
183 }
185 }
186 }
187
188 for (auto [result, sharding] :
189 llvm::zip_equal(op->getResults(), resultShardings)) {
190 for (auto user : result.getUsers()) {
191 ShardOp shardOp = llvm::dyn_cast(user);
192 if (!shardOp) {
193 continue;
194 }
195 bool needsResharding = sharding != shardOp.getSharding();
196 bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
197 if (needsResharding) {
198 if (isExplicitAnnotationForThisOp) {
199
201 }
203 }
204 }
205 }
206
207 return res;
208 }
209
210
211
212
213
214
215
216
218 ShardingInterface shardingOp,
219 ArrayRef<std::vector> possibleOperandShardingAttrs,
220 ArrayRef<std::vector> possibleResultShardingAttrs) {
222 shardingOptionsAndReshardingRequirements;
223
226 possibleOperandShardingAttrs) {
227 FailureOr shardingOption =
228 shardingOp.getShardingOption(operandShardings, resultShardings);
229 if (failed(shardingOption) || shardingOption->empty) {
230 continue;
231 }
232
233
234
235
236
237 FailureOr<std::vector> operandAndResultShardings =
238 shardingOp.getShardingAnnotations(*shardingOption);
239 if (failed(operandAndResultShardings)) {
240 return failure();
241 }
242
243
244
245
249
250 return *shardingOption;
251 }
252
253 shardingOptionsAndReshardingRequirements.emplace_back(
254 std::move(*shardingOption), reshardingRquirement);
255 }
256 }
257
258 if (shardingOptionsAndReshardingRequirements.empty()) {
260 }
261
262 std::partial_sort(
263 shardingOptionsAndReshardingRequirements.begin(),
264 shardingOptionsAndReshardingRequirements.begin() + 1,
265 shardingOptionsAndReshardingRequirements.end(),
266 [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
267 const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
268 return std::get(a) <
269 std::get(b);
270 });
271
272 LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "
273 << shardingOptionsAndReshardingRequirements << "\n";);
274
275 return std::get(
276 shardingOptionsAndReshardingRequirements.front());
277 }
278
279
280
281
282
283
285 ShardingInterface shardingOp = llvm::dyn_cast(op);
288 llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
289 return success();
290
291 if (!shardingOp) {
292 op->emitOpError() << "sharding interface is not implemented.";
293 return failure();
294 }
295
296
297 std::vector allowConflictsResultShardings;
298 allowConflictsResultShardings.resize(op->getNumResults());
299 std::vector resultMustShardings;
302 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
304 if (failed(maybeShardAttr))
305 continue;
306 if (!maybeShardAttr->first)
307 resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
308 else
309 allowConflictsResultShardings[result.getResultNumber()] =
310 maybeShardAttr->second;
311 }
312
313
314 std::vector allowConflictsOperandShardings;
315 allowConflictsOperandShardings.resize(op->getNumOperands());
316 std::vector operandMustShardings;
319 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
321 if (failed(maybeShardAttr))
322 continue;
323
324 if (maybeShardAttr->first)
325 operandMustShardings[opOperand.getOperandNumber()] =
326 maybeShardAttr->second;
327 else
328 allowConflictsOperandShardings[opOperand.getOperandNumber()] =
329 maybeShardAttr->second;
330 }
331
332
335 allowConflictsOperandShardings);
338 allowConflictsResultShardings);
340 shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
341
342 if (failed(shardingOption)) {
343 op->emitOpError() << "fail to get sharding option.";
344 return failure();
345 }
346
347 LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");
348
349
350 if (shardingOption->empty)
351 return success();
352
353 if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
354 op->emitOpError() << "fail to set sharding annotations.";
355 return failure();
356 }
357 return success();
358 }
359
360
361
362
364 : public mesh::impl::ShardingPropagationBase {
366 FunctionOpInterface funcOp = getOperation();
368 Region ®ion = funcOp.getFunctionBody();
371 funcOp.emitOpError() << "only one block is supported!";
372 return signalPassFailure();
373 }
375
376 LLVM_DEBUG(
377 DBGS() << "print all the ops' iterator types and indexing maps in the "
378 "block.\n";
381 if (auto shardingOp = llvm::dyn_cast(&op))
382 shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
383 });
384
385
386 for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
387 if (failed(visitOp(&op, builder)))
388 return signalPassFailure();
389
390 LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
391 << funcOp << "\n");
392 LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
393
394
395 for (Operation &op : llvm::make_early_inc_range(block))
396 if (failed(visitOp(&op, builder)))
397 return signalPassFailure();
398 }
399 };
ReshardingRquirementKind getReshardingRquirementKind(Operation *op, const std::vector< MeshSharding > &operandAndResultShardings)
static LogicalResult visitOp(Operation *op, OpBuilder &builder)
static FailureOr< ShardingOption > selectShardingOption(ShardingInterface shardingOp, ArrayRef< std::vector< MeshSharding >> possibleOperandShardingAttrs, ArrayRef< std::vector< MeshSharding >> possibleResultShardingAttrs)
@ RESHARDING_FOR_EXPLICIT_ANNOTATIONS
@ NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS
static SmallVector< std::vector< MeshSharding > > getOrderedPossibleShardingAttrs(ArrayRef< MeshSharding > mustShardings, ArrayRef< MeshSharding > optionalShardings)
Block represents an ordered list of Operations.
OpListType & getOperations()
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for a sub-set of ops that are known to be constant-like.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
unsigned getNumOperands()
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
FailureOr< std::pair< bool, MeshSharding > > getMeshSharding(OpResult result)
Include the generated interface declarations.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
void runOnOperation() override
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
ShardingArray shardingArray
static ShardingOption makeEmpty()