MLIR: lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

18 #include "llvm/ADT/STLExtras.h"

19 #include "llvm/ADT/SmallVector.h"

20 #include "llvm/ADT/iterator_range.h"

21 #include "llvm/Support/Debug.h"

22 #include "llvm/Support/raw_ostream.h"

23 #include

24 #include

25

26 namespace mlir {

27 namespace mesh {

28 #define GEN_PASS_DEF_SHARDINGPROPAGATION

29 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"

30 }

31 }

32

33 #define DEBUG_TYPE "sharding-propagation"

34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")

35

36 using namespace mlir;

38

43 };

44

45 #ifdef LLVM_DEBUG

46

47 template

48 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,

50 template <typename... Ts>

51 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,

52 const std::tuple<Ts...> &t);

53 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,

55

56 template <typename Stream, typename Range>

57 static Stream &printRange(Stream &stream, Range &&range) {

58 stream << "[";

59 llvm::for_each(range, [&stream](auto &v) {

60 stream << v;

61 stream << ", ";

62 });

63 return stream << "]";

64 }

65

66 template

67 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,

69 return printRange(stream, vec);

70 }

71

72 [[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,

74 return stream << "{empty = " << v.empty << ", mesh" << v.mesh

75 << ", shardingArray = " << v.shardingArray << "}";

76 }

77

78 template <typename Stream, typename... Ts, size_t... Is>

79 static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,

80 std::index_sequence<Is...>) {

81 static_assert(sizeof...(Is) == sizeof...(Ts),

82 "Indices must have same number of elements as tuple types!");

83 static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");

84

85 stream << "{";

86 ((stream << std::get(tuple) << ", "), ...);

87 return stream << "}";

88 }

89

90 template <typename... Ts>

91 static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,

92 const std::tuple<Ts...> &t) {

93 return printTuple(stream, t, std::index_sequence_for<Ts...>{});

94 }

95

96 [[maybe_unused]] static llvm::raw_ostream &

98 return stream << static_cast(v);

99 }

100

101 #endif

102

103

104

105

106

107

108

109

110

115 std::vector curShardingAttrs;

116

117 std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {

118 if (i == mustShardings.size()) {

119 allShardingAttrs.push_back(std::vector(curShardingAttrs));

120 return;

121 }

122

123 if (mustShardings[i]) {

124 curShardingAttrs.push_back(mustShardings[i]);

125 dfsCreateShardingAttrs(i + 1);

126 curShardingAttrs.pop_back();

127 return;

128 }

129

130 if (optionalShardings[i]) {

131 curShardingAttrs.push_back(optionalShardings[i]);

132 dfsCreateShardingAttrs(i + 1);

133 curShardingAttrs.pop_back();

134 curShardingAttrs.push_back({});

135 dfsCreateShardingAttrs(i + 1);

136 curShardingAttrs.pop_back();

137 return;

138 }

139

140 curShardingAttrs.push_back({});

141 dfsCreateShardingAttrs(i + 1);

142 curShardingAttrs.pop_back();

143 };

144

145 dfsCreateShardingAttrs(0);

146 return allShardingAttrs;

147 }

148

149

150

151

152

153

154

155

156

157

158

160 Operation *op, const std::vector &operandAndResultShardings) {

162

163 size_t operandsCount = op->getOperands().size();

164 auto operandShardings =

165 llvm::make_range(operandAndResultShardings.begin(),

166 operandAndResultShardings.begin() + operandsCount);

167 auto resultShardings =

168 llvm::make_range(operandAndResultShardings.begin() + operandsCount,

169 operandAndResultShardings.end());

170

171 for (auto [operand, sharding] :

172 llvm::zip_equal(op->getOperands(), operandShardings)) {

173 ShardOp shardOp = llvm::dyn_cast_or_null(operand.getDefiningOp());

174 if (!shardOp) {

175 continue;

176 }

177 bool needsResharding = sharding != shardOp.getSharding();

178 bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();

179 if (needsResharding) {

180 if (isExplicitAnnotationForThisOp) {

181

183 }

185 }

186 }

187

188 for (auto [result, sharding] :

189 llvm::zip_equal(op->getResults(), resultShardings)) {

190 for (auto user : result.getUsers()) {

191 ShardOp shardOp = llvm::dyn_cast(user);

192 if (!shardOp) {

193 continue;

194 }

195 bool needsResharding = sharding != shardOp.getSharding();

196 bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();

197 if (needsResharding) {

198 if (isExplicitAnnotationForThisOp) {

199

201 }

203 }

204 }

205 }

206

207 return res;

208 }

209

210

211

212

213

214

215

216

218 ShardingInterface shardingOp,

219 ArrayRef<std::vector> possibleOperandShardingAttrs,

220 ArrayRef<std::vector> possibleResultShardingAttrs) {

222 shardingOptionsAndReshardingRequirements;

223

226 possibleOperandShardingAttrs) {

227 FailureOr shardingOption =

228 shardingOp.getShardingOption(operandShardings, resultShardings);

229 if (failed(shardingOption) || shardingOption->empty) {

230 continue;

231 }

232

233

234

235

236

237 FailureOr<std::vector> operandAndResultShardings =

238 shardingOp.getShardingAnnotations(*shardingOption);

239 if (failed(operandAndResultShardings)) {

240 return failure();

241 }

242

243

244

245

249

250 return *shardingOption;

251 }

252

253 shardingOptionsAndReshardingRequirements.emplace_back(

254 std::move(*shardingOption), reshardingRquirement);

255 }

256 }

257

258 if (shardingOptionsAndReshardingRequirements.empty()) {

260 }

261

262 std::partial_sort(

263 shardingOptionsAndReshardingRequirements.begin(),

264 shardingOptionsAndReshardingRequirements.begin() + 1,

265 shardingOptionsAndReshardingRequirements.end(),

266 [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,

267 const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {

268 return std::get(a) <

269 std::get(b);

270 });

271

272 LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "

273 << shardingOptionsAndReshardingRequirements << "\n";);

274

275 return std::get(

276 shardingOptionsAndReshardingRequirements.front());

277 }

278

279

280

281

282

283

285 ShardingInterface shardingOp = llvm::dyn_cast(op);

288 llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))

289 return success();

290

291 if (!shardingOp) {

292 op->emitOpError() << "sharding interface is not implemented.";

293 return failure();

294 }

295

296

297 std::vector allowConflictsResultShardings;

298 allowConflictsResultShardings.resize(op->getNumResults());

299 std::vector resultMustShardings;

302 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =

304 if (failed(maybeShardAttr))

305 continue;

306 if (!maybeShardAttr->first)

307 resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;

308 else

309 allowConflictsResultShardings[result.getResultNumber()] =

310 maybeShardAttr->second;

311 }

312

313

314 std::vector allowConflictsOperandShardings;

315 allowConflictsOperandShardings.resize(op->getNumOperands());

316 std::vector operandMustShardings;

319 FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =

321 if (failed(maybeShardAttr))

322 continue;

323

324 if (maybeShardAttr->first)

325 operandMustShardings[opOperand.getOperandNumber()] =

326 maybeShardAttr->second;

327 else

328 allowConflictsOperandShardings[opOperand.getOperandNumber()] =

329 maybeShardAttr->second;

330 }

331

332

335 allowConflictsOperandShardings);

338 allowConflictsResultShardings);

340 shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);

341

342 if (failed(shardingOption)) {

343 op->emitOpError() << "fail to get sharding option.";

344 return failure();

345 }

346

347 LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");

348

349

350 if (shardingOption->empty)

351 return success();

352

353 if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {

354 op->emitOpError() << "fail to set sharding annotations.";

355 return failure();

356 }

357 return success();

358 }

359

360

361

362

364 : public mesh::impl::ShardingPropagationBase {

366 FunctionOpInterface funcOp = getOperation();

368 Region &region = funcOp.getFunctionBody();

371 funcOp.emitOpError() << "only one block is supported!";

372 return signalPassFailure();

373 }

375

376 LLVM_DEBUG(

377 DBGS() << "print all the ops' iterator types and indexing maps in the "

378 "block.\n";

381 if (auto shardingOp = llvm::dyn_cast(&op))

382 shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());

383 });

384

385

386 for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))

387 if (failed(visitOp(&op, builder)))

388 return signalPassFailure();

389

390 LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"

391 << funcOp << "\n");

392 LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));

393

394

395 for (Operation &op : llvm::make_early_inc_range(block))

396 if (failed(visitOp(&op, builder)))

397 return signalPassFailure();

398 }

399 };

ReshardingRquirementKind getReshardingRquirementKind(Operation *op, const std::vector< MeshSharding > &operandAndResultShardings)

static LogicalResult visitOp(Operation *op, OpBuilder &builder)

static FailureOr< ShardingOption > selectShardingOption(ShardingInterface shardingOp, ArrayRef< std::vector< MeshSharding >> possibleOperandShardingAttrs, ArrayRef< std::vector< MeshSharding >> possibleResultShardingAttrs)

@ RESHARDING_FOR_EXPLICIT_ANNOTATIONS

@ NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS

static SmallVector< std::vector< MeshSharding > > getOrderedPossibleShardingAttrs(ArrayRef< MeshSharding > mustShardings, ArrayRef< MeshSharding > optionalShardings)

Block represents an ordered list of Operations.

OpListType & getOperations()

MLIRContext is the top-level object for a collection of MLIR operations.

This class helps build Operations.

This class represents an operand of an operation.

This is a value defined by a result of an operation.

This class provides the API for a sub-set of ops that are known to be constant-like.

This class provides the API for ops that are known to be terminators.

Operation is the basic unit of execution within MLIR.

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

unsigned getNumOperands()

MutableArrayRef< OpOperand > getOpOperands()

operand_range getOperands()

Returns an iterator on the underlying Value's.

result_range getResults()

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

unsigned getNumResults()

Return the number of results held by this operation.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

bool hasOneBlock()

Return true if this region has exactly one block.

FailureOr< std::pair< bool, MeshSharding > > getMeshSharding(OpResult result)

Include the generated interface declarations.

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)

void runOnOperation() override

Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...

ShardingArray shardingArray

static ShardingOption makeEmpty()