MLIR: lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

11

16 #include "llvm/ADT/ArrayRef.h"

17 #include "llvm/ADT/STLExtras.h"

18 #include "llvm/ADT/SmallSet.h"

19 #include "llvm/Support/Debug.h"

20

21 #include

22

23 #define DEBUG_TYPE "sharding-interface"

24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")

25

26 using namespace mlir;

28

29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"

30

31

32

33

34

35 static LogicalResult

40 auto binOpExpr = cast(expr);

44 return failure();

46 return failure();

47 return success();

48 }

50 auto binOpExpr = cast(expr);

56 dimExpr = lhs;

59 dimExpr = rhs;

60 } else {

61 return failure();

62 }

63 unsigned position = cast(dimExpr).getPosition();

64 if ((size_t)position >= seenIds.size() || seenIds[position])

65 return failure();

66 seenIds[position] = true;

67 return success();

68 }

70 unsigned position = cast(expr).getPosition();

71 if ((size_t)position >= seenIds.size() || seenIds[position])

72 return failure();

73 seenIds[position] = true;

74 return success();

75 }

76 default:

77 return failure();

78 }

79 }

80

81 static FailureOr<llvm::SmallSet<unsigned, 2>>

85 return failure();

86

87 llvm::SmallSet<unsigned, 2> positions;

89 if (it.value())

90 positions.insert((unsigned)it.index());

91 }

92 return positions;

93 }

94

95 template

99 for (const auto &v : vec) {

101 }

102 return res;

103 }

104

105

106

107

108

109 FailureOr<std::pair<bool, MeshSharding>>

111 Value val = cast(result);

112 bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {

113 auto shardOp = llvm::dyn_castmesh::ShardOp(user);

114 if (!shardOp)

115 return false;

116 return !shardOp.getAnnotateForUsers();

117 });

118

119 if (anyShardedForDef) {

120

121

123 return failure();

124 auto shardOp = llvm::castmesh::ShardOp(*val.getUsers().begin());

125 return std::make_pair(false, MeshSharding(shardOp.getSharding()));

126 }

127

128 bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {

129 auto shardOp = llvm::dyn_castmesh::ShardOp(user);

130 if (!shardOp)

131 return false;

132 return shardOp.getAnnotateForUsers();

133 });

134 if (anyShardedForUsers) {

137 ShardOp shardOp = llvm::dyn_cast(user);

138 if (shardOp)

139 shardOps.push_back(shardOp);

140 }

141 MeshSharding shardForDef = shardOps[0].getSharding();

142 for (size_t i = 1; i < shardOps.size(); ++i) {

143

144

145 assert(shardForDef == shardOps[i].getSharding() &&

146 "only support all shard ops have the same mesh sharding attr");

147 }

148 return std::make_pair(true, shardForDef);

149 }

150 return failure();

151 }

152

153 FailureOr<std::pair<bool, MeshSharding>>

156 if (ShardOp shardOp = val.getDefiningOp())

157 return std::make_pair(shardOp.getAnnotateForUsers(),

159

160 return failure();

161 }

162

163

164

165

166

167 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {

169

170

172 if (!llvm::isa(type) && !type.isIntOrIndexOrFloat())

173 return failure();

175 if (!llvm::isa(type) && !type.isIntOrIndexOrFloat())

176 return failure();

177

178

180 if (maps.empty())

181 return failure();

184 if (numOperands + numResults != maps.size())

185 return failure();

186

188 auto resultType = dyn_cast(result.getType());

189 if (!resultType)

190 return failure();

191 AffineMap map = maps[numOperands + result.getResultNumber()];

193 return failure();

194 }

195 }

196

197 return success();

198 }

199

200

201

202

203

204 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {

205 os << "print loop types and indexing maps for: \n";

206 getOperation()->print(os);

207 os << "\n";

208 os << "loop types: [";

209 for (utils::IteratorType type : getLoopIteratorTypes()) {

210 os << stringifyEnum(type) << " ";

211 }

212 os << "]\n";

213 os << "indexing maps: \n";

214 for (AffineMap map : getIndexingMaps())

215 os << map << "\n";

216 os << "\n";

217 }

218

219

220

221

222

223 namespace {

224

225

226 static LogicalResult fillShardingOption(Operation *op,

230 unsigned loopIdx) {

231 if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||

232 (!shardingOption.shardingArray[loopIdx].empty() &&

233 shardingOption.shardingArray[loopIdx] != meshAxes)) {

234 LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "

235 << loopIdx << "\n");

236 return failure();

237 }

238 for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {

239 if (i == loopIdx)

240 continue;

241

242 for (MeshAxis axis : meshAxes) {

243 if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {

244 LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "

245 << axis << " duplicate");

246 return failure();

247 }

248 }

249 }

250 if (mesh)

251 shardingOption.mesh = mesh;

252 if (shardingOption.shardingArray[loopIdx].empty())

253 shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),

254 meshAxes.end());

255 return success();

256 }

257

258 }

259

260 FailureOr

264 ShardingInterface shardingOp = llvm::cast(op);

266

267 if (failed(shardingOp.verifyShardingInterfaceImpl()))

268 return op->emitOpError() << "invalid sharding interface implementation";

270 shardingOp.getLoopIteratorTypes();

273 shardingOption.shardingArray.resize(loopTypes.size());

275 llvm::SmallSet<unsigned, 4> visitedLoopIndices;

276 bool anyShardingInResultsOrOperands = false;

277

278

279 for (auto shardingIt : llvm::enumerate(resultShardings)) {

281 if (!shardAttr)

282 continue;

283 AffineMap map = maps[numOperands + shardingIt.index()];

284 anyShardingInResultsOrOperands = true;

287 } else {

288

289

290

294 auto dim = cast(expr);

295 unsigned index = dim.getPosition();

296 visitedLoopIndices.insert(index);

297 if (failed(fillShardingOption(op, shardingOption,

299 return failure();

300 }

301 }

302

303

304

306 if (!partialAxes.empty()) {

307 if (!partialMeshAxes.empty())

308 return op->emitOpError() << "at most one result with partial axes is "

309 "supported at present";

310 partialMeshAxes.append(partialAxes.begin(), partialAxes.end());

311

312

313 for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {

315 visitedLoopIndices.insert(loopIdx);

316 }

317 }

318 }

319

320

321 for (auto shardingIt : llvm::enumerate(operandShardings)) {

323 if (!shardAttr)

324 continue;

325

326 anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty();

327 AffineMap map = maps[shardingIt.index()];

329

330

331

332

333

334

338 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =

340 if (failed(loopIndices))

342 << "operand's affine expression is restricted to const_i * "

343 "dim_i + const_j + dim_j + ...";

344 if (loopIndices->empty())

345 continue;

346 if (loopIndices->size() == 1) {

347 unsigned loopIdx = *loopIndices->begin();

348 visitedLoopIndices.insert(loopIdx);

349 if (failed(fillShardingOption(op, shardingOption,

350 shardAttr.getMeshAttr(), axes, loopIdx)))

351 return failure();

352 }

353

354

355

356 if (loopIndices->size() > 1) {

357 bool seenLoopIndices = false;

358 for (unsigned loopIdx : *loopIndices) {

359 if (visitedLoopIndices.contains(loopIdx)) {

360 seenLoopIndices = true;

361 break;

362 }

363 }

364 if (!seenLoopIndices)

366 << "the operand " << shardingIt.index()

367 << " has multiple loop indices in a dimension, but none of "

368 "them could be found in the exactly specified annotation "

369 "of op results or operands.";

370 }

371 }

372 }

373

374

375 if (!partialMeshAxes.empty()) {

376 bool anyNonEmptyReductionLoop = llvm::any_of(

378 SmallVector &subArray = it.value();

379 int64_t idx = it.index();

380 return isReductionLoop(loopTypes[idx]) && !subArray.empty();

381 });

382 if (!anyNonEmptyReductionLoop) {

383 bool filled = false;

384 for (size_t idx = 0; idx < loopTypes.size(); ++idx) {

386 std::ignore = fillShardingOption(op, shardingOption, nullptr,

387 partialMeshAxes, idx);

388 filled = true;

389 break;

390 }

391 }

392 if (!filled)

393 return op->emitOpError() << "no matched reduction loop found for the "

394 "result's partial type";

395 }

396 }

398 if (!anyShardingInResultsOrOperands)

399 shardingOption.empty = true;

400 return shardingOption;

401 }

402

403

407 auto resultType = cast(result.getType());

410

411

414

415

416 auto dim = cast(expr);

417 unsigned loopIdx = dim.getPosition();

418 if (loopIdx < shardingOption.shardingArray.size())

419 splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);

420 }

421

422

423

425 size_t reductionLoopKindsIdx = 0;

426 for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {

427 utils::IteratorType iType = std::get<0>(it);

429 ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];

430 ++reductionLoopKindsIdx;

431 if (!partialAxes.empty())

432 assert(partialType == curPartialType &&

433 "Only one reduction type is supported");

434 partialType = curPartialType;

436 partialAxes.append(axis);

437 }

438 }

439

443 partialAxes, partialType);

444 }

445

449 Value operandValue = opOperand.get();

450 auto operandType = dyn_cast(operandValue.getType());

451 if (!operandType) {

454 return failure();

455 }

456

457 if (operandType.getRank() == 0) {

459 }

463 int64_t idx = it.index();

465 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =

467 if (failed(loopIndices))

468 return failure();

470 for (unsigned loopIdx : *loopIndices) {

471 if ((size_t)loopIdx < shardingOption.shardingArray.size() &&

473 shardedLoopIndices.push_back(loopIdx);

474 }

475

476 if (shardedLoopIndices.size() > 1)

477 return failure();

478 if (shardedLoopIndices.size() == 1) {

479 splitAxes[idx].append(

480 shardingOption.shardingArray[shardedLoopIndices[0]]);

481 }

482 }

483

486 shardingOption.mesh,

488 }

489

490 FailureOr<std::vector>

493 std::vector res;

494

495 ShardingInterface shardingOp = llvm::cast(op);

497 shardingOp.getLoopIteratorTypes();

499 shardingOp.getReductionLoopIteratorKinds();

502

504 FailureOr shardingAttr = getSharding(

505 opOperand, shardingOption, maps[opOperand.getOperandNumber()]);

506 if (failed(shardingAttr))

507 return failure();

508 res.push_back(*shardingAttr);

509 }

510

512 res.push_back(getSharding(result, shardingOption,

513 maps[numOperands + result.getResultNumber()],

514 loopTypes, reductionKinds));

515 }

516

517 return res;

518 }

519

520

521

522

523

524

525

532 getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds);

534

535 return success();

536 }

537

538

539

543

544 FailureOr sharding =

545 getSharding(opOperand, shardingOption, map);

546 if (failed(sharding)) {

547 return failure();

548 }

551

552 return success();

553 }

554

557 assert(!shardingOption.empty && shardingOption.mesh);

558

559 ShardingInterface shardingOp = llvm::cast(op);

561 shardingOp.getLoopIteratorTypes();

563 shardingOp.getReductionLoopIteratorKinds();

566

567

569 if (failed(addShardOp(b, result, shardingOption,

570 maps[numOperands + result.getResultNumber()],

571 loopTypes, reductionKinds)))

572 return failure();

573 }

574

575

577 if (failed(addShardOp(b, opOperand, shardingOption,

578 maps[opOperand.getOperandNumber()])))

579 return failure();

580 }

581

582 return success();

583 }

584

585 #ifndef NDEBUG

586 static bool

589 if (isa(value.getType())) {

591 }

592

593 return !sharding;

594 }

595

596 template <typename ValueRange, typename MeshShardingRage>

597 static bool

599 MeshShardingRage &&shardings) {

600 if (std::size(values) != std::size(shardings)) {

601 return false;

602 }

603 return llvm::all_of(

604 llvm::zip_equal(std::forward(values),

605 std::forward(shardings)),

606 [](auto valueAndSharding) {

608 std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));

609 });

610 }

611 #endif

612

618 assert(spmdizedOperands.size() == operandShardings.size());

620 operandShardings));

622 resultShardings));

623

624 builder.clone(op, spmdizationMap);

625 }

626

630 &meshAxesAssignmentForLoopIterators) {

631 AffineDimExpr affineDimExpr = cast(indexingExpr);

632 unsigned loopIteratorIdx = affineDimExpr.getPosition();

633 if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {

634 assert(llvm::equal(meshAxesAssignmentForTensorAxis,

635 *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));

636 } else {

637 meshAxesAssignmentForLoopIterators[loopIteratorIdx] =

638 llvm::to_vector(meshAxesAssignmentForTensorAxis);

639 }

640 }

641

648 meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());

649 std::vector operatorAndResultShardings;

650 operatorAndResultShardings.reserve(operandShardings.size() +

651 resultShardings.size());

652 llvm::append_range(operatorAndResultShardings, operandShardings);

653 for (auto [sharding, affineMap] :

654 llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {

655 if (!sharding) {

656 continue;

657 }

658 for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :

659 llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {

661 meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,

662 meshAxisAssignmentForLoopIterators);

663 }

664

665 for (unsigned i = sharding.getSplitAxes().size();

666 i < affineMap.getNumResults(); ++i) {

668 {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);

669 }

670 }

671

673 llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),

675 if (!axes) {

677 };

678 return std::move(*axes);

679 });

680 return res;

681 }

682

686 for (auto [loopIteratorType, meshAxisAssignment] :

687 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {

688 if (loopIteratorType == utils::IteratorType::reduction &&

689 !meshAxisAssignment.empty()) {

690 return true;

691 }

692 }

693 return false;

694 }

695

700 for (auto [loopIteratorType, meshAxisAssignment] :

701 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {

702 if (loopIteratorType == utils::IteratorType::reduction) {

703 llvm::append_range(meshAxes, meshAxisAssignment);

704 }

705 }

706 return meshAxes;

707 }

708

714

715 Operation *newOp = builder.clone(op, spmdizationMap);

716

717 for (auto [oldResult, newResult, sharding] :

720 newResult.getType(),

721 getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));

722 }

723 }

SmallVector< MeshAxesAttr > fromArrayOfVector(MLIRContext *ctxt, const SmallVector< SmallVector< T >> &vec)

static void updateMeshAxisAssignmentForLoopIterators(ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)

static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, MeshShardingRage &&shardings)

static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshSharding sharding)

static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)

static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)

static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)

MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)

A dimensional identifier appearing in an affine expression.

unsigned getPosition() const

Base type for affine expression.

AffineExprKind getKind() const

Return the classification for this type.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

bool isProjectedPermutation(bool allowZeroInResults=false) const

Returns true if the AffineMap represents a subset (i.e.

unsigned getNumDims() const

ArrayRef< AffineExpr > getResults() const

A symbol reference with a reference path containing a single element.

This is a utility class for mapping one set of IR entities to another.

IRValueT get() const

Return the current value being used by this operand.

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

This class represents an operand of an operation.

This is a value defined by a result of an operation.

Operation is the basic unit of execution within MLIR.

unsigned getNumOperands()

operand_type_range getOperandTypes()

MutableArrayRef< OpOperand > getOpOperands()

result_type_range getResultTypes()

operand_range getOperands()

Returns an iterator on the underlying Value's.

result_range getResults()

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

unsigned getNumResults()

Return the number of results held by this operation.

This class represents a collection of SymbolTables.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isIntOrIndexOrFloat() const

Return true if this is an integer (of any signedness), index, or float type.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

MLIRContext * getContext() const

Utility to get the associated MLIRContext that this value is defined in.

Type getType() const

Return the type of this value.

user_range getUsers() const

bool hasOneUse() const

Returns true if this value has exactly one use.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)

Builder from ArrayRef.

::mlir::FlatSymbolRefAttr getMeshAttr() const

ArrayRef< MeshAxesAttr > getSplitAxes() const

ArrayRef< MeshAxis > getPartialAxes() const

static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

mesh::ReductionKind ReductionKind

mesh::MeshSharding MeshSharding

FailureOr< std::vector< MeshSharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)

LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)

FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings)

bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)

SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)

void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)

mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)

void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)

Type shardType(Type type, MeshOp mesh, MeshSharding sharding)

void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder, ShardOp &newShardOp)

ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)

bool isReductionLoop(utils::IteratorType iType)

bool isFullReplication(MeshSharding sharding)

void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)

FailureOr< std::pair< bool, MeshSharding > > getMeshSharding(OpResult result)

void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)

Include the generated interface declarations.

@ Mul

RHS of mul is always a constant or a symbolic expression.

@ DimId

Dimensional identifier.

@ Constant

Constant integer.

ShardingArray shardingArray