MLIR: lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp File Reference (original) (raw)

Go to the source code of this file.

Macros
#define DEBUG_TYPE "linalg-transforms"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")
#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
#define DOWNSCALE(trans)
#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
#define GET_OP_CLASSES
Functions
template<typename PatternTy , typename... Args>
static FailureOr< LinalgOp > tryApply (Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation. More...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value. More...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, Value packedHandle)
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults (TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically inferred. More...
template
static LogicalResult applyTilingToAll (RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops. More...
static Operation * replaceForAllWithNewSignature (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op. More...
static bool sameOrEquivalentIterArg (Value src, Value dst)
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is equal to 'dst' or equal to a iter arg present in a outer loop. More...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use. More...
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp, otherwise bail. More...
static Operation * cloneAndFuseFirstUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
static void printMultitileSizesTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static ParseResult parseMultitileSizesTypes (OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
template
bool isValidPackingPermutation (RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op. More...
static void printContinuousTileSizeTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)
static ParseResult parseContinuousTileSizeTypes (OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static SmallVector< OpFoldResult > normalizeUpperBounds (RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound. More...
static SmallVector< Value > denormalizeIndVar (RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with original_lb + old_iv * original_step. More...
static scf::ForallOp normalizeForallLoopOp (RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized. More...
template
DiagnosedSilenceableFailure doit (RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)

DBGS

#define DBGS ( ) (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

DBGSNL

#define DBGSNL ( ) (llvm::dbgs() << "\n")

DEBUG_TYPE

#define DEBUG_TYPE "linalg-transforms"

DOWNSCALE

| #define DOWNSCALE | ( | | trans | ) | | ----------------- | - | | ----- | - |

Value:

{ \

FailureOr res = tryApply(target); \

if (succeeded(res)) { \

results.push_back(*res); \

return DiagnosedSilenceableFailure::success(); \

} \

}

DOWNSCALE_CALL

| #define DOWNSCALE_CALL | ( | | a, | | ----------------------- | ------------------------------------------- | | -- | | | b | | | | | ) | DownscaleSizeOneWindowed2DConvolution<a, b> | | |

DOWNSCALE_NORMAL

GET_OP_CLASSES

LDBG

| #define LDBG | ( | | X | ) | LLVM_DEBUG(DBGS() << (X) << "\n") | | ------------ | - | | - | - | ------------------------------------------------------------------------------------------------ |

applyTilingToAll()

template

cloneAndFuseFirstUse()

Definition at line 998 of file LinalgTransformOps.cpp.

References mlir::OpBuilder::clone(), DBGS, diag(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::Operation::getOpResults(), mlir::detail::IROperandBase::getOwner(), mlir::Operation::isProperAncestor(), mlir::RewriterBase::modifyOpInPlace(), and mlir::OpBuilder::setInsertionPoint().

denormalizeIndVar()

doit()

isValidPackingPermutation()

template

bool isValidPackingPermutation ( RelayoutOpTy op,
ArrayRef< int64_t > permutation,
OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer
)

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op.

This is the case when thepermutationrank matches the rank expected by opandpermutationis itself a permutation vector. Return true if eitheroporpermutation` are empty to allow a simpler polymorphic implementation.

Definition at line 1781 of file LinalgTransformOps.cpp.

References mlir::isPermutationVector().

normalizeForallLoopOp()

static scf::ForallOp normalizeForallLoopOp ( RewriterBase & rewriter, scf::ForallOp loop ) static

Given a scf.forall loop return a loop op with the loop bounds normalized.

TODO: Replace this with a general utility to normalize scf.forall. At the time of writing, this wasnt done since adding this to scf dialect would disallow using of affine.apply operations due to cyclic dependencies. To avoid churn in lit tests with the change this was added with, defer that to a follow up.

Definition at line 3398 of file LinalgTransformOps.cpp.

References mlir::OpBuilder::create(), denormalizeIndVar(), mlir::Builder::getIndexAttr(), mlir::isOneInteger(), mlir::isZeroInteger(), mlir::RewriterBase::mergeBlocks(), normalizeUpperBounds(), mlir::RewriterBase::replaceOp(), and mlir::OpBuilder::setInsertionPointToStart().

normalizeUpperBounds()

parseContinuousTileSizeTypes()

static ParseResult parseContinuousTileSizeTypes ( OpAsmParser & parser, Type & targetType, Type & tileSizesType, Type & chunkSizesType ) static

parseMultitileSizesTypes()

static ParseResult parseMultitileSizesTypes ( OpAsmParser & parser, Type & targetType, Type & lowSizeType, Type & highSizeType, Type & splitPointType ) static

printContinuousTileSizeTypes()

printMultitileSizesTypes()

reifyMixedParamAndHandleResults()

When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically inferred.

If a result is a Value then it must be either a ParamType or a handle to an a constant like op.

Definition at line 179 of file LinalgTransformOps.cpp.

replaceForAllWithNewSignature()

Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op.

Definition at line 645 of file LinalgTransformOps.cpp.

References mlir::OpBuilder::create(), mlir::DominanceInfo::dominates(), mlir::detail::enumerate(), mlir::RewriterBase::eraseBlock(), mlir::Builder::getIndexAttr(), mlir::Operation::getLoc(), mlir::Operation::getResult(), mlir::Value::getUsers(), mlir::Operation::isAncestor(), mlir::Operation::isProperAncestor(), mlir::RewriterBase::replaceAllUsesWith(), mlir::RewriterBase::replaceUsesWithIf(), mlir::OpBuilder::setInsertionPoint(), and mlir::TilingResult::tiledValues.

sameOrEquivalentIterArg()

static bool sameOrEquivalentIterArg ( Value src, Value dst ) static

Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is equal to 'dst' or equal to a iter arg present in a outer loop.

To determine the second condition, this function iterates using a worklist over the enclosing loops, trying to find 'src' in any of the parent loop's iter args.

Definition at line 726 of file LinalgTransformOps.cpp.

References mlir::OpOperand::getOperandNumber(), and mlir::Block::getParentOp().

tileAndFuseFirstExtractUse()

tileAndFuseFirstExtractUseThroughContainingOpBlockArgument()

First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp, otherwise bail.

Then, find the first "extract" user of the tied block argument and tile it right before its "extract" use. The tiled op is fused under the containingOp. Return this fused op on success or nullptr if anything fails.

Definition at line 896 of file LinalgTransformOps.cpp.

References mlir::OpBuilder::clone(), DBGS, diag(), mlir::RewriterBase::eraseOp(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::OpOperand::getOperandNumber(), mlir::tensor::getOrCreateDestinations(), mlir::Value::getUsers(), mlir::IRMapping::map(), mlir::RewriterBase::modifyOpInPlace(), mlir::RewriterBase::replaceOp(), mlir::OpBuilder::setInsertionPoint(), and mlir::Operation::setOperand().

tryApply()

template<typename PatternTy , typename... Args>

static FailureOr tryApply ( Operation * operation, Args &&... args ) static

Attempts to apply the pattern specified as template argument to the given operation.

The pattern is expected to have a returningMatchAndRewrite function that returns the "main" result or failure. Returns failure if the pattern failed to apply. Extra arguments are forwarded to the pattern constructor.

Definition at line 65 of file LinalgTransformOps.cpp.

References mlir::Operation::getContext().

unpackSingleIndexResultPayloadOperations() [1/2]

Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value.

Definition at line 93 of file LinalgTransformOps.cpp.

unpackSingleIndexResultPayloadOperations() [2/2]