#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"">

MLIR: lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp File Reference (original) (raw)

#include "[mlir/Dialect/Mesh/Interfaces/ShardingInterface.h](ShardingInterface%5F8h%5Fsource.html)"
#include "[mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h](Mesh%5F2Interfaces%5F2ShardingInterfaceImpl%5F8h%5Fsource.html)"
#include "[mlir/Dialect/Mesh/IR/MeshOps.h](MeshOps%5F8h%5Fsource.html)"
#include "[mlir/IR/AffineMap.h](mlir%5F2IR%5F2AffineMap%5F8h%5Fsource.html)"
#include "[mlir/IR/IRMapping.h](IRMapping%5F8h%5Fsource.html)"
#include "[mlir/Support/LLVM.h](mlir%5F2Support%5F2LLVM%5F8h%5Fsource.html)"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
#include <utility>
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"

Go to the source code of this file.

Functions
static LogicalResult checkOperandAffineExprRecursively (AffineExpr expr, SmallVectorImpl< bool > &seenIds)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr (AffineExpr expr, unsigned numDims)
template
SmallVector< MeshAxesAttr > fromArrayOfVector (MLIRContext *ctxt, const SmallVector< SmallVector< T >> &vec)
MeshSharding getSharding (OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
static FailureOr< MeshSharding > getSharding (OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map)
static LogicalResult addShardOp (OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
static LogicalResult addShardOp (OpBuilder &b, OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map)
static bool isValueCompatibleWithFullReplicationSharding (Value value, MeshSharding sharding)
template<typename ValueRange , typename MeshShardingRage >
static bool areValuesCompatibleWithFullReplicationShardings (ValueRange &&values, MeshShardingRage &&shardings)
static void updateMeshAxisAssignmentForLoopIterators (ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)

DBGS

#define DBGS ( ) (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")

DEBUG_TYPE

#define DEBUG_TYPE "sharding-interface"

addShardOp() [1/2]

addShardOp() [2/2]

areValuesCompatibleWithFullReplicationShardings()

template<typename ValueRange , typename MeshShardingRage >

static bool areValuesCompatibleWithFullReplicationShardings ( ValueRange && values, MeshShardingRage && shardings ) static

checkOperandAffineExpr()

static FailureOr<llvm::SmallSet<unsigned, 2> > checkOperandAffineExpr ( AffineExpr expr, unsigned numDims ) static

checkOperandAffineExprRecursively()

fromArrayOfVector()

getSharding() [1/2]

Definition at line 446 of file ShardingInterface.cpp.

References checkOperandAffineExpr(), mlir::detail::enumerate(), fromArrayOfVector(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::mesh::MeshSharding::get(), mlir::Value::getContext(), mlir::AffineMap::getNumDims(), mlir::AffineMap::getResults(), mlir::Value::getType(), mlir::Type::isIntOrIndexOrFloat(), mlir::mesh::ShardingOption::mesh, mlir::mesh::removeTrailingEmptySubArray(), and mlir::mesh::ShardingOption::shardingArray.

getSharding() [2/2]

Definition at line 404 of file ShardingInterface.cpp.

References mlir::detail::enumerate(), fromArrayOfVector(), mlir::mesh::MeshSharding::get(), mlir::Value::getContext(), mlir::AffineMap::getResults(), mlir::Value::getType(), mlir::mesh::isReductionLoop(), mlir::mesh::ShardingOption::mesh, mlir::mesh::removeTrailingEmptySubArray(), and mlir::mesh::ShardingOption::shardingArray.

Referenced by addShardOp(), mlir::mesh::detail::defaultGetShardingAnnotations(), and mlir::mesh::getMeshSharding().

isValueCompatibleWithFullReplicationSharding()

static bool isValueCompatibleWithFullReplicationSharding ( Value value, MeshSharding sharding ) static

updateMeshAxisAssignmentForLoopIterators()