MLIR: lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp File Reference (original) (raw)
Go to the source code of this file.
Macros | |
---|---|
#define | DEBUG_TYPE "linalg-transforms" |
#define | DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
#define | DBGSNL() (llvm::dbgs() << "\n") |
#define | LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") |
#define | DOWNSCALE(trans) |
#define | DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b> |
#define | DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b)) |
#define | GET_OP_CLASSES |
Functions | |
---|---|
template<typename PatternTy , typename... Args> | |
static FailureOr< LinalgOp > | tryApply (Operation *operation, Args &&...args) |
Attempts to apply the pattern specified as template argument to the given operation. More... | |
static DiagnosedSilenceableFailure | unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs) |
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value. More... | |
static DiagnosedSilenceableFailure | unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, Value packedHandle) |
static DiagnosedSilenceableFailure | reifyMixedParamAndHandleResults (TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified) |
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically inferred. More... | |
template | |
static LogicalResult | applyTilingToAll (RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn) |
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops. More... | |
static Operation * | replaceForAllWithNewSignature (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes) |
Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op. More... | |
static bool | sameOrEquivalentIterArg (Value src, Value dst) |
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is equal to 'dst' or equal to a iter arg present in a outer loop. More... | |
static std::tuple< SmallVector< Operation * >, Operation * > | tileAndFuseFirstExtractUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) |
Find the first "extract" user of producerOp and tile it right before its use. More... | |
static SmallVector< Operation * > | tileAndFuseFirstExtractUseThroughContainingOpBlockArgument (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) |
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp, otherwise bail. More... | |
static Operation * | cloneAndFuseFirstUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) |
static void | printMultitileSizesTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type) |
static ParseResult | parseMultitileSizesTypes (OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType) |
template | |
bool | isValidPackingPermutation (RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer) |
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op. More... | |
static void | printContinuousTileSizeTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type) |
static ParseResult | parseContinuousTileSizeTypes (OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType) |
static SmallVector< OpFoldResult > | normalizeUpperBounds (RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps) |
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound. More... | |
static SmallVector< Value > | denormalizeIndVar (RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps) |
When a loop is normalized, the uses of the induction variable within the loop need to replaced with original_lb + old_iv * original_step. More... | |
static scf::ForallOp | normalizeForallLoopOp (RewriterBase &rewriter, scf::ForallOp loop) |
Given a scf.forall loop return a loop op with the loop bounds normalized. More... | |
template | |
DiagnosedSilenceableFailure | doit (RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state) |
◆ DBGS
#define DBGS | ( | ) | (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
---|
◆ DBGSNL
#define DBGSNL | ( | ) | (llvm::dbgs() << "\n") |
---|
◆ DEBUG_TYPE
#define DEBUG_TYPE "linalg-transforms"
◆ DOWNSCALE
| #define DOWNSCALE | ( | | trans | ) | | ----------------- | - | | ----- | - |
Value:
{ \
FailureOr res = tryApply(target); \
if (succeeded(res)) { \
results.push_back(*res); \
return DiagnosedSilenceableFailure::success(); \
} \
}
◆ DOWNSCALE_CALL
| #define DOWNSCALE_CALL | ( | | a, | | ----------------------- | ------------------------------------------- | | -- | | | b | | | | | ) | DownscaleSizeOneWindowed2DConvolution<a, b> | | |
◆ DOWNSCALE_NORMAL
◆ GET_OP_CLASSES
◆ LDBG
| #define LDBG | ( | | X | ) | LLVM_DEBUG(DBGS() << (X) << "\n") | | ------------ | - | | - | - | ------------------------------------------------------------------------------------------------ |
◆ applyTilingToAll()
template
◆ cloneAndFuseFirstUse()
Definition at line 998 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::clone(), DBGS, diag(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::Operation::getOpResults(), mlir::detail::IROperandBase::getOwner(), mlir::Operation::isProperAncestor(), mlir::RewriterBase::modifyOpInPlace(), and mlir::OpBuilder::setInsertionPoint().
◆ denormalizeIndVar()
◆ doit()
◆ isValidPackingPermutation()
template
bool isValidPackingPermutation | ( | RelayoutOpTy | op, |
---|---|---|---|
ArrayRef< int64_t > | permutation, | ||
OuterOrInnerPerm | outerOrInnerPerm = OuterOrInnerPerm::Outer | ||
) |
Return true if permutation
is a valid permutation of the outer_dims_perm
(case OuterOrInnerPerm::Outer) or inner_dims_pos
(OuterOrInnerPerm::Inner) of the tensor.pack
or tensor.unpack
op.
This is the case when thepermutationrank matches the rank expected by
opand
permutationis itself a permutation vector. Return true if either
opor
permutation` are empty to allow a simpler polymorphic implementation.
Definition at line 1781 of file LinalgTransformOps.cpp.
References mlir::isPermutationVector().
◆ normalizeForallLoopOp()
static scf::ForallOp normalizeForallLoopOp ( RewriterBase & rewriter, scf::ForallOp loop ) | static |
---|
Given a scf.forall
loop return a loop op with the loop bounds normalized.
TODO: Replace this with a general utility to normalize scf.forall
. At the time of writing, this wasnt done since adding this to scf
dialect would disallow using of affine.apply
operations due to cyclic dependencies. To avoid churn in lit tests with the change this was added with, defer that to a follow up.
Definition at line 3398 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::create(), denormalizeIndVar(), mlir::Builder::getIndexAttr(), mlir::isOneInteger(), mlir::isZeroInteger(), mlir::RewriterBase::mergeBlocks(), normalizeUpperBounds(), mlir::RewriterBase::replaceOp(), and mlir::OpBuilder::setInsertionPointToStart().
◆ normalizeUpperBounds()
◆ parseContinuousTileSizeTypes()
static ParseResult parseContinuousTileSizeTypes ( OpAsmParser & parser, Type & targetType, Type & tileSizesType, Type & chunkSizesType ) | static |
---|
◆ parseMultitileSizesTypes()
static ParseResult parseMultitileSizesTypes ( OpAsmParser & parser, Type & targetType, Type & lowSizeType, Type & highSizeType, Type & splitPointType ) | static |
---|
◆ printContinuousTileSizeTypes()
◆ printMultitileSizesTypes()
◆ reifyMixedParamAndHandleResults()
When possible, converts each OpFoldResult
in mixedResult
to an integer if the value can be statically inferred.
If a result is a Value
then it must be either a ParamType
or a handle to an a constant like op.
Definition at line 179 of file LinalgTransformOps.cpp.
◆ replaceForAllWithNewSignature()
Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op.
Definition at line 645 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::create(), mlir::DominanceInfo::dominates(), mlir::detail::enumerate(), mlir::RewriterBase::eraseBlock(), mlir::Builder::getIndexAttr(), mlir::Operation::getLoc(), mlir::Operation::getResult(), mlir::Value::getUsers(), mlir::Operation::isAncestor(), mlir::Operation::isProperAncestor(), mlir::RewriterBase::replaceAllUsesWith(), mlir::RewriterBase::replaceUsesWithIf(), mlir::OpBuilder::setInsertionPoint(), and mlir::TilingResult::tiledValues.
◆ sameOrEquivalentIterArg()
static bool sameOrEquivalentIterArg ( Value src, Value dst ) | static |
---|
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is equal to 'dst' or equal to a iter arg present in a outer loop.
To determine the second condition, this function iterates using a worklist over the enclosing loops, trying to find 'src' in any of the parent loop's iter args.
Definition at line 726 of file LinalgTransformOps.cpp.
References mlir::OpOperand::getOperandNumber(), and mlir::Block::getParentOp().
◆ tileAndFuseFirstExtractUse()
◆ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument()
First, find the first "scf::ForallOp" user of producerOp
and ensure it is exactly the containingOp
, otherwise bail.
Then, find the first "extract" user of the tied block argument and tile it right before its "extract" use. The tiled op is fused under the containingOp
. Return this fused op on success or nullptr if anything fails.
Definition at line 896 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::clone(), DBGS, diag(), mlir::RewriterBase::eraseOp(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::OpOperand::getOperandNumber(), mlir::tensor::getOrCreateDestinations(), mlir::Value::getUsers(), mlir::IRMapping::map(), mlir::RewriterBase::modifyOpInPlace(), mlir::RewriterBase::replaceOp(), mlir::OpBuilder::setInsertionPoint(), and mlir::Operation::setOperand().
◆ tryApply()
template<typename PatternTy , typename... Args>
static FailureOr tryApply ( Operation * operation, Args &&... args ) | static |
---|
Attempts to apply the pattern specified as template argument to the given operation.
The pattern is expected to have a returningMatchAndRewrite
function that returns the "main" result or failure. Returns failure if the pattern failed to apply. Extra arguments are forwarded to the pattern constructor.
Definition at line 65 of file LinalgTransformOps.cpp.
References mlir::Operation::getContext().
◆ unpackSingleIndexResultPayloadOperations() [1/2]
Assuming that ofr
is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value.
Definition at line 93 of file LinalgTransformOps.cpp.