MLIR: lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp File Reference (original ) (raw )#include <utility>
#include "[Detail/DimLvlMapParser.h](DimLvlMapParser%5F8h%5Fsource.html)"
#include "[mlir/Dialect/SparseTensor/IR/Enums.h](Enums%5F8h%5Fsource.html)"
#include "[mlir/Dialect/SparseTensor/IR/SparseTensor.h](mlir%5F2Dialect%5F2SparseTensor%5F2IR%5F2SparseTensor%5F8h%5Fsource.html)"
#include "[mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h](SparseTensorStorageLayout%5F8h%5Fsource.html)"
#include "[mlir/Dialect/SparseTensor/IR/SparseTensorType.h](SparseTensorType%5F8h%5Fsource.html)"
#include "[mlir/Dialect/Arith/IR/Arith.h](mlir%5F2Dialect%5F2Arith%5F2IR%5F2Arith%5F8h%5Fsource.html)"
#include "[mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h](BufferizableOpInterface%5F8h%5Fsource.html)"
#include "[mlir/Dialect/Complex/IR/Complex.h](Complex%5F8h%5Fsource.html)"
#include "[mlir/Dialect/Utils/StaticValueUtils.h](StaticValueUtils%5F8h%5Fsource.html)"
#include "[mlir/IR/Builders.h](Builders%5F8h%5Fsource.html)"
#include "[mlir/IR/DialectImplementation.h](DialectImplementation%5F8h%5Fsource.html)"
#include "[mlir/IR/Matchers.h](Matchers%5F8h%5Fsource.html)"
#include "[mlir/IR/OpImplementation.h](OpImplementation%5F8h%5Fsource.html)"
#include "[mlir/IR/PatternMatch.h](PatternMatch%5F8h%5Fsource.html)"
#include "llvm/ADT/Bitset.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
Go to the source code of this file.
Functions
static mlir::ParseResult
parseLevelRange (AsmParser &parser, Level &lvlLo, Level &lvlHi)
Parses a level range in the form "$lo `to` hi"orsimply"hi" or simply " hi " ors im pl y " lo" if hi−hi - hi − lo = 1. More...
static void
printLevelRange (AsmPrinter &p, Level lo, Level hi)
Prints a level range in the form "$lo `to` hi"orsimply"hi" or simply " hi " ors im pl y " lo" if hi−hi - hi − lo = 1. More...
llvm::hash_code
mlir::sparse_tensor::hash_value (LevelType lt)
static constexpr bool
acceptBitWidth (unsigned bitWidth)
static SmallVector < Size >
getSparseFieldShape (const SparseTensorEncodingAttr enc, std::optional< ArrayRef < int64_t >> dimShape)
static ParseResult
parseOptionalStaticSlice (int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr
getNormalizedEncodingForSpecifier (SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well as other variants) lead to the same storage specifier type, and stripping irrelevant fields that do not alter the sparse tensor memory layout. More...
static LogicalResult
lvlIsInBounds (Level lvl, Value tensor)
static LogicalResult
isMatchingWidth (Value mem, unsigned width)
static LogicalResult
verifySparsifierGetterSetter (StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue < StorageSpecifierType > md, Operation *op)
static Type
getFieldElemType (SparseTensorType stt, SparseTensorFieldKind kind )
static LogicalResult
verifyPackUnPack (Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
template
static LogicalResult
inferSparseBufferType (ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl < mlir::Type > &ret)
template
static SetStorageSpecifierOp
getSpecifierSetDef (SpecifierOp op)
template
static LogicalResult
verifyNumBlockArgs (T *op, Region ®ion, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult
parseLevelRange (OpAsmParser &parser, IntegerAttr &lvlLoAttr, IntegerAttr &lvlHiAttr)
Parses a level range in the form "$lo `to` hi"orsimply"hi" or simply " hi " ors im pl y " lo" if hi−hi - hi − lo = 1. More...
static void
printLevelRange (OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, IntegerAttr lvlHi)
Prints a level range in the form "$lo `to` hi"orsimply"hi" or simply " hi " ors im pl y " lo" if hi−hi - hi − lo = 1. More...
static ParseResult
parseOptionalDefinedList (OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl < OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max (), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to annotate that the corresponding value is not defined (e.g., to represent an undefined coordinate in the sparse iteration space). More...
static void
printOptionalDefinedList (OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static ParseResult
parseUsedCoordList (OpAsmParser &parser, OperationState &state, SmallVectorImpl < OpAsmParser::Argument > &coords)
static ParseResult
parseSparseIterateLoop (OpAsmParser &parser, OperationState &state, SmallVectorImpl < OpAsmParser::Argument > &iterators, SmallVectorImpl < OpAsmParser::Argument > &blockArgs)
static ParseResult
parseSparseCoIterateLoop (OpAsmParser &parser, OperationState &state, SmallVectorImpl < Value > &spacesVals, SmallVectorImpl < OpAsmParser::Argument > &blockArgs)
static void
printInitializationList (OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of (inner = outer, inner2 = outer2, <...>) where 'inner' values are assumed to be region arguments and 'outer' values are regular SSA values. More...
template
static LogicalResult
verifySparseLoopOp (SparseLoopOp op)
◆ GET_ATTRDEF_CLASSES#define GET_ATTRDEF_CLASSES
◆ GET_ATTRDEF_LIST◆ GET_OP_CLASSES◆ GET_OP_LIST◆ GET_TYPEDEF_CLASSES#define GET_TYPEDEF_CLASSES
◆ GET_TYPEDEF_LIST◆ acceptBitWidth()
static constexpr bool acceptBitWidth ( unsigned bitWidth )
staticconstexpr
◆ getFieldElemType()◆ getNormalizedEncodingForSpecifier()
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier ( SparseTensorEncodingAttr enc )
static
◆ getSparseFieldShape()
static SmallVector <Size > getSparseFieldShape ( const SparseTensorEncodingAttr enc , std::optional< ArrayRef < int64_t >> dimShape )
static
◆ getSpecifierSetDef()template
static SetStorageSpecifierOp getSpecifierSetDef ( SpecifierOp op )
static
◆ inferSparseBufferType()template
◆ isMatchingWidth()
static LogicalResult isMatchingWidth ( Value mem , unsigned width )
static
◆ lvlIsInBounds()
static LogicalResult lvlIsInBounds ( Level lvl , Value tensor )
static
◆ parseLevelRange() [1/2]◆ parseLevelRange() [2/2]
static ParseResult parseLevelRange ( OpAsmParser & parser , IntegerAttr & lvlLoAttr , IntegerAttr & lvlHiAttr )
static
◆ parseOptionalDefinedList()Parses a list of optional
defined list in the form of "(%val0, _, %val1, ...)", where _
is used to annotate that the corresponding value is not defined (e.g., to represent an undefined coordinate in the sparse iteration space).
Definition at line 2150 of file SparseTensorDialect.cpp .
◆ parseOptionalStaticSlice()
static ParseResult parseOptionalStaticSlice ( int64_t & result , AsmParser & parser )
static
◆ parseSparseCoIterateLoop()◆ parseSparseIterateLoop()◆ parseUsedCoordList()◆ printInitializationList()Prints the initialization list in the form of (inner = outer, inner2 = outer2, <...>) where 'inner' values are assumed to be region arguments and 'outer' values are regular SSA values.
Definition at line 2501 of file SparseTensorDialect.cpp .
◆ printLevelRange() [1/2]◆ printLevelRange() [2/2]
static void printLevelRange ( OpAsmPrinter & p , Operation * , IntegerAttr lvlLo , IntegerAttr lvlHi )
static
◆ printOptionalDefinedList()◆ verifyNumBlockArgs()template
static LogicalResult verifyNumBlockArgs ( T * op , Region & region , const char * regionName , TypeRange inputTypes , Type outputType )
static
◆ verifyPackUnPack()Definition at line 1295 of file SparseTensorDialect.cpp .
References mlir::Operation::emitError() , mlir::sparse_tensor::StorageLayout::foreachField() , mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() , mlir::sparse_tensor::SparseTensorType::getEncoding() , getFieldElemType() , mlir::sparse_tensor::SparseTensorType::getLvlRank() , mlir::sparse_tensor::SparseTensorType::getLvlType() , mlir::sparse_tensor::StorageLayout::getNumDataFields() , mlir::sparse_tensor::SparseTensorType::hasEncoding() , mlir::sparse_tensor::SparseTensorType::hasStaticDimShape() , mlir::sparse_tensor::StorageSpec , and mlir::sparse_tensor::ValMemRef .
◆ verifySparseLoopOp()template
static LogicalResult verifySparseLoopOp ( SparseLoopOp op )
static
◆ verifySparsifierGetterSetter()
static LogicalResult verifySparsifierGetterSetter ( StorageSpecifierKind mdKind , std::optional< Level > lvl , TypedValue < StorageSpecifierType > md , Operation * op )
static
◆ kDataFieldStartingIdx◆ kInvalidFieldIndex
constexpr Level kInvalidFieldIndex = -1u
staticconstexpr
◆ kInvalidLevel
constexpr Level kInvalidLevel = -1u
staticconstexpr