MLIR: lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp File Reference (original) (raw)

#include <utility>
#include "[Detail/DimLvlMapParser.h](DimLvlMapParser%5F8h%5Fsource.html)"
#include "[mlir/Dialect/SparseTensor/IR/Enums.h](Enums%5F8h%5Fsource.html)"
#include "[mlir/Dialect/SparseTensor/IR/SparseTensor.h](mlir%5F2Dialect%5F2SparseTensor%5F2IR%5F2SparseTensor%5F8h%5Fsource.html)"
#include "[mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h](SparseTensorStorageLayout%5F8h%5Fsource.html)"
#include "[mlir/Dialect/SparseTensor/IR/SparseTensorType.h](SparseTensorType%5F8h%5Fsource.html)"
#include "[mlir/Dialect/Arith/IR/Arith.h](mlir%5F2Dialect%5F2Arith%5F2IR%5F2Arith%5F8h%5Fsource.html)"
#include "[mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h](BufferizableOpInterface%5F8h%5Fsource.html)"
#include "[mlir/Dialect/Complex/IR/Complex.h](Complex%5F8h%5Fsource.html)"
#include "[mlir/Dialect/Utils/StaticValueUtils.h](StaticValueUtils%5F8h%5Fsource.html)"
#include "[mlir/IR/Builders.h](Builders%5F8h%5Fsource.html)"
#include "[mlir/IR/DialectImplementation.h](DialectImplementation%5F8h%5Fsource.html)"
#include "[mlir/IR/Matchers.h](Matchers%5F8h%5Fsource.html)"
#include "[mlir/IR/OpImplementation.h](OpImplementation%5F8h%5Fsource.html)"
#include "[mlir/IR/PatternMatch.h](PatternMatch%5F8h%5Fsource.html)"
#include "llvm/ADT/Bitset.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"

Go to the source code of this file.

Namespaces
Macros
#define GET_ATTRDEF_CLASSES
#define GET_TYPEDEF_CLASSES
#define GET_ATTRDEF_LIST
#define GET_TYPEDEF_LIST
#define GET_OP_LIST
#define GET_OP_CLASSES
Functions
static mlir::ParseResult parseLevelRange (AsmParser &parser, Level &lvlLo, Level &lvlHi)
Parses a level range in the form "$lo `to` hi"orsimply"hi" or simply "hi"orsimply"lo" if hi−hi - hilo = 1. More...
static void printLevelRange (AsmPrinter &p, Level lo, Level hi)
Prints a level range in the form "$lo `to` hi"orsimply"hi" or simply "hi"orsimply"lo" if hi−hi - hilo = 1. More...
llvm::hash_code mlir::sparse_tensor::hash_value (LevelType lt)
static constexpr bool acceptBitWidth (unsigned bitWidth)
static SmallVector< Size > getSparseFieldShape (const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t >> dimShape)
static ParseResult parseOptionalStaticSlice (int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier (SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well as other variants) lead to the same storage specifier type, and stripping irrelevant fields that do not alter the sparse tensor memory layout. More...
static LogicalResult lvlIsInBounds (Level lvl, Value tensor)
static LogicalResult isMatchingWidth (Value mem, unsigned width)
static LogicalResult verifySparsifierGetterSetter (StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static Type getFieldElemType (SparseTensorType stt, SparseTensorFieldKind kind)
static LogicalResult verifyPackUnPack (Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
template
static LogicalResult inferSparseBufferType (ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
template
static SetStorageSpecifierOp getSpecifierSetDef (SpecifierOp op)
template
static LogicalResult verifyNumBlockArgs (T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseLevelRange (OpAsmParser &parser, IntegerAttr &lvlLoAttr, IntegerAttr &lvlHiAttr)
Parses a level range in the form "$lo `to` hi"orsimply"hi" or simply "hi"orsimply"lo" if hi−hi - hilo = 1. More...
static void printLevelRange (OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, IntegerAttr lvlHi)
Prints a level range in the form "$lo `to` hi"orsimply"hi" or simply "hi"orsimply"lo" if hi−hi - hilo = 1. More...
static ParseResult parseOptionalDefinedList (OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to annotate that the corresponding value is not defined (e.g., to represent an undefined coordinate in the sparse iteration space). More...
static void printOptionalDefinedList (OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static ParseResult parseUsedCoordList (OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static ParseResult parseSparseIterateLoop (OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static ParseResult parseSparseCoIterateLoop (OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static void printInitializationList (OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of (inner = outer, inner2 = outer2, <...>) where 'inner' values are assumed to be region arguments and 'outer' values are regular SSA values. More...
template
static LogicalResult verifySparseLoopOp (SparseLoopOp op)
Variables
static constexpr Level kInvalidLevel = -1u
static constexpr Level kInvalidFieldIndex = -1u
static constexpr FieldIndex kDataFieldStartingIdx = 0

GET_ATTRDEF_CLASSES

#define GET_ATTRDEF_CLASSES

GET_ATTRDEF_LIST

GET_OP_CLASSES

GET_OP_LIST

GET_TYPEDEF_CLASSES

#define GET_TYPEDEF_CLASSES

GET_TYPEDEF_LIST

acceptBitWidth()

static constexpr bool acceptBitWidth ( unsigned bitWidth) staticconstexpr

getFieldElemType()

getNormalizedEncodingForSpecifier()

static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier ( SparseTensorEncodingAttr enc) static

getSparseFieldShape()

static SmallVector<Size> getSparseFieldShape ( const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t >> dimShape ) static

getSpecifierSetDef()

template

static SetStorageSpecifierOp getSpecifierSetDef ( SpecifierOp op) static

inferSparseBufferType()

template

isMatchingWidth()

static LogicalResult isMatchingWidth ( Value mem, unsigned width ) static

lvlIsInBounds()

static LogicalResult lvlIsInBounds ( Level lvl, Value tensor ) static

parseLevelRange() [1/2]

parseLevelRange() [2/2]

static ParseResult parseLevelRange ( OpAsmParser & parser, IntegerAttr & lvlLoAttr, IntegerAttr & lvlHiAttr ) static

parseOptionalDefinedList()

Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to annotate that the corresponding value is not defined (e.g., to represent an undefined coordinate in the sparse iteration space).

Definition at line 2150 of file SparseTensorDialect.cpp.

parseOptionalStaticSlice()

static ParseResult parseOptionalStaticSlice ( int64_t & result, AsmParser & parser ) static

parseSparseCoIterateLoop()

parseSparseIterateLoop()

parseUsedCoordList()

printInitializationList()

Prints the initialization list in the form of (inner = outer, inner2 = outer2, <...>) where 'inner' values are assumed to be region arguments and 'outer' values are regular SSA values.

Definition at line 2501 of file SparseTensorDialect.cpp.

printLevelRange() [1/2]

printLevelRange() [2/2]

static void printLevelRange ( OpAsmPrinter & p, Operation * , IntegerAttr lvlLo, IntegerAttr lvlHi ) static

printOptionalDefinedList()

verifyNumBlockArgs()

template

static LogicalResult verifyNumBlockArgs ( T * op, Region & region, const char * regionName, TypeRange inputTypes, Type outputType ) static

verifyPackUnPack()

Definition at line 1295 of file SparseTensorDialect.cpp.

References mlir::Operation::emitError(), mlir::sparse_tensor::StorageLayout::foreachField(), mlir::sparse_tensor::SparseTensorType::getAoSCOOStart(), mlir::sparse_tensor::SparseTensorType::getEncoding(), getFieldElemType(), mlir::sparse_tensor::SparseTensorType::getLvlRank(), mlir::sparse_tensor::SparseTensorType::getLvlType(), mlir::sparse_tensor::StorageLayout::getNumDataFields(), mlir::sparse_tensor::SparseTensorType::hasEncoding(), mlir::sparse_tensor::SparseTensorType::hasStaticDimShape(), mlir::sparse_tensor::StorageSpec, and mlir::sparse_tensor::ValMemRef.

verifySparseLoopOp()

template

static LogicalResult verifySparseLoopOp ( SparseLoopOp op) static

verifySparsifierGetterSetter()

static LogicalResult verifySparsifierGetterSetter ( StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation * op ) static

kDataFieldStartingIdx

kInvalidFieldIndex

constexpr Level kInvalidFieldIndex = -1u staticconstexpr

kInvalidLevel

constexpr Level kInvalidLevel = -1u staticconstexpr