MLIR: lib/Dialect/Tensor/Utils/Utils.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

14

21

22 using namespace mlir;

24

29

30

31

32 assert(((resType.getNumDynamicDims() == dynOutDims.size()) ||

33 dynOutDims.empty()) &&

34 "Either none or all output dynamic dims must be specified!");

35

36

37

40

41 size_t outDimIdx = 0;

42

43 for (const auto [idx, val] : enumerate(resType.getShape())) {

44 bool isDimDynamic = ShapedType::isDynamic(val);

45 bool updatePadHigh = !isDimDynamic || !dynOutDims.empty();

46

47

48

49 if (!updatePadHigh)

50 continue;

51

52

58

60 {outDim, sourceDim});

61 }

62 return b.create(loc, resType, source, low, high, pad, nofold);

63 }

64

67 Value rankedTensor) {

68 auto tensorTy = cast(rankedTensor.getType());

70 for (const auto &en : llvm::enumerate(tensorTy.getShape())) {

71 if (en.value() == ShapedType::kDynamic)

72 dynamicDims.push_back(

73 b.createtensor::DimOp(loc, rankedTensor, en.index()));

74 }

75 return dynamicDims;

76 }

77

78 FailureOr

81 if (transposeVector.empty())

82 return rankedTensorType;

83

85 transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank()))

86 return failure();

87

90

92 RankedTensorType transposedTensorType =

93 RTTBuilder(rankedTensorType).setShape(transposedShape);

94 return transposedTensorType;

95 }

96

97 CollapseShapeOp

99 const llvm::SmallBitVector &dropDims) {

100 auto srcType = cast(src.getType());

101 int64_t rank = srcType.getRank();

102 assert(rank == static_cast<int64_t>(dropDims.size()) &&

103 "dropDims dimension does not match src tensor rank");

104 assert(llvm::all_of(

106 [&](unsigned dim) { return srcType.getShape()[dim] == 1; }) &&

107 "Dropping non unit dimension");

108

110

111

112 int64_t nextDimToGroup = 0;

113 llvm::SmallBitVector keptDims(dropDims);

114 keptDims.flip();

115 int64_t lastSetBit = keptDims.find_last();

116 for (int64_t setBit : keptDims.set_bits()) {

117

118

119

120 int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit;

121 auto seq = llvm::seq_inclusive(nextDimToGroup, upTo);

122 reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end()));

123 nextDimToGroup = setBit + 1;

124 }

125 return b.createtensor::CollapseShapeOp(loc, src, reassocMaps);

126 }

127

129 llvm::SmallBitVector droppedDims = op.getDroppedDims();

130 int64_t srcDim = 0;

131 RankedTensorType resultType = op.getDestType();

132

133

134 for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {

135 if (droppedDims.test(resultDim)) {

136

137

138 if (resultType.getDimSize(resultDim) != 1)

139 return false;

140 continue;

141 }

143 {op.getSource(), srcDim}, {op.getResult(), resultDim});

144 if (failed(equalDimSize) || !*equalDimSize)

145 return false;

146 ++srcDim;

147 }

148

149 return true;

150 }

151

153 llvm::SmallBitVector droppedDims = op.getDroppedDims();

154 int64_t resultDim = 0;

155

156

157 RankedTensorType sourceType = op.getSourceType();

158 for (int64_t dim = 0, e = sourceType.getRank(); dim < e; ++dim) {

159 if (droppedDims.test(dim)) {

160

161

162 if (sourceType.getDimSize(dim) != 1)

163 return false;

164 continue;

165 }

167 {op.getSource(), dim}, {op.getResult(), resultDim});

168 if (failed(equalDimSize) || !*equalDimSize)

169 return false;

170 ++resultDim;

171 }

172

173 return true;

174 }

static void setBit(char *rawData, size_t bitPos, bool value)

Set a bit to a specific value.

Base type for affine expression.

IntegerAttr getIndexAttr(int64_t value)

MLIRContext * getContext() const

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class helps build Operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents a single result from folding an operation.

This is a builder type that keeps local references to arguments.

Builder & setShape(ArrayRef< int64_t > newShape)

static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)

Compute whether the given variables are equal.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, SmallVector< Value > dynOutDims={})

SmallVector< Value > createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor)

bool isCastLikeInsertSliceOp(InsertSliceOp op)

A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...

CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value src, const llvm::SmallBitVector &dropDims)

Create tensor.collapse_shape to drop unit dimensions in dropDims in tensor src.

bool isCastLikeExtractSliceOp(ExtractSliceOp op)

A tensor.extract_slice is a cast-like operation if it merely rank-reduces unit dimensions of the sour...

OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)

Return the dimension of the given tensor value.

FailureOr< RankedTensorType > computeTransposedType(RankedTensorType rankedTensorType, ArrayRef< int64_t > transposeVector)

Returns the transposed rankedTensorType if transposeVector is non-empty.

Include the generated interface declarations.

void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to DimExpr at positions: [0 .

void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)

Apply the permutation defined by permutation to inVec.

SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)

Returns a permutation vector that drop the input dims in dropPositions from inputPerm.

bool isPermutationVector(ArrayRef< int64_t > interchange)

Method to check if an interchange vector is a permutation.