MLIR: lib/Dialect/Tensor/Utils/Utils.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
14
21
22 using namespace mlir;
24
29
30
31
32 assert(((resType.getNumDynamicDims() == dynOutDims.size()) ||
33 dynOutDims.empty()) &&
34 "Either none or all output dynamic dims must be specified!");
35
36
37
40
41 size_t outDimIdx = 0;
42
43 for (const auto [idx, val] : enumerate(resType.getShape())) {
44 bool isDimDynamic = ShapedType::isDynamic(val);
45 bool updatePadHigh = !isDimDynamic || !dynOutDims.empty();
46
47
48
49 if (!updatePadHigh)
50 continue;
51
52
58
60 {outDim, sourceDim});
61 }
62 return b.create(loc, resType, source, low, high, pad, nofold);
63 }
64
67 Value rankedTensor) {
68 auto tensorTy = cast(rankedTensor.getType());
70 for (const auto &en : llvm::enumerate(tensorTy.getShape())) {
71 if (en.value() == ShapedType::kDynamic)
72 dynamicDims.push_back(
73 b.createtensor::DimOp(loc, rankedTensor, en.index()));
74 }
75 return dynamicDims;
76 }
77
78 FailureOr
81 if (transposeVector.empty())
82 return rankedTensorType;
83
85 transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank()))
86 return failure();
87
90
92 RankedTensorType transposedTensorType =
93 RTTBuilder(rankedTensorType).setShape(transposedShape);
94 return transposedTensorType;
95 }
96
97 CollapseShapeOp
99 const llvm::SmallBitVector &dropDims) {
100 auto srcType = cast(src.getType());
101 int64_t rank = srcType.getRank();
102 assert(rank == static_cast<int64_t>(dropDims.size()) &&
103 "dropDims dimension does not match src tensor rank");
104 assert(llvm::all_of(
106 [&](unsigned dim) { return srcType.getShape()[dim] == 1; }) &&
107 "Dropping non unit dimension");
108
110
111
112 int64_t nextDimToGroup = 0;
113 llvm::SmallBitVector keptDims(dropDims);
114 keptDims.flip();
115 int64_t lastSetBit = keptDims.find_last();
116 for (int64_t setBit : keptDims.set_bits()) {
117
118
119
120 int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit;
121 auto seq = llvm::seq_inclusive(nextDimToGroup, upTo);
122 reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end()));
123 nextDimToGroup = setBit + 1;
124 }
125 return b.createtensor::CollapseShapeOp(loc, src, reassocMaps);
126 }
127
129 llvm::SmallBitVector droppedDims = op.getDroppedDims();
130 int64_t srcDim = 0;
131 RankedTensorType resultType = op.getDestType();
132
133
134 for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
135 if (droppedDims.test(resultDim)) {
136
137
138 if (resultType.getDimSize(resultDim) != 1)
139 return false;
140 continue;
141 }
143 {op.getSource(), srcDim}, {op.getResult(), resultDim});
144 if (failed(equalDimSize) || !*equalDimSize)
145 return false;
146 ++srcDim;
147 }
148
149 return true;
150 }
151
153 llvm::SmallBitVector droppedDims = op.getDroppedDims();
154 int64_t resultDim = 0;
155
156
157 RankedTensorType sourceType = op.getSourceType();
158 for (int64_t dim = 0, e = sourceType.getRank(); dim < e; ++dim) {
159 if (droppedDims.test(dim)) {
160
161
162 if (sourceType.getDimSize(dim) != 1)
163 return false;
164 continue;
165 }
167 {op.getSource(), dim}, {op.getResult(), resultDim});
168 if (failed(equalDimSize) || !*equalDimSize)
169 return false;
170 ++resultDim;
171 }
172
173 return true;
174 }
static void setBit(char *rawData, size_t bitPos, bool value)
Set a bit to a specific value.
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, SmallVector< Value > dynOutDims={})
SmallVector< Value > createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor)
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value src, const llvm::SmallBitVector &dropDims)
Create tensor.collapse_shape to drop unit dimensions in dropDims in tensor src.
bool isCastLikeExtractSliceOp(ExtractSliceOp op)
A tensor.extract_slice is a cast-like operation if it merely rank-reduces unit dimensions of the sour...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
FailureOr< RankedTensorType > computeTransposedType(RankedTensorType rankedTensorType, ArrayRef< int64_t > transposeVector)
Returns the transposed rankedTensorType if transposeVector is non-empty.
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.