MLIR: lib/Dialect/Utils/StaticValueUtils.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

12 #include "llvm/ADT/APSInt.h"

13 #include "llvm/ADT/STLExtras.h"

14 #include "llvm/Support/MathExtras.h"

15

16 namespace mlir {

17

19

21

22 std::tuple<SmallVector, SmallVector,

23 SmallVector>

26 offsets.reserve(ranges.size());

27 sizes.reserve(ranges.size());

28 strides.reserve(ranges.size());

29 for (const auto &[offset, size, stride] : ranges) {

30 offsets.push_back(offset);

31 sizes.push_back(size);

32 strides.push_back(stride);

33 }

34 return std::make_tuple(offsets, sizes, strides);

35 }

36

37

38

39

40

41

42

46 auto v = llvm::dyn_cast_if_present(ofr);

47 if (!v) {

48 APInt apInt = cast(cast(ofr)).getValue();

49 staticVec.push_back(apInt.getSExtValue());

50 return;

51 }

52 dynamicVec.push_back(v);

53 staticVec.push_back(ShapedType::kDynamic);

54 }

55

56 std::pair<int64_t, OpFoldResult>

58 int64_t tileSizeForShape =

60

62 (tileSizeForShape != ShapedType::kDynamic)

64 : tileSizeOfr;

65

66 return std::pair<int64_t, OpFoldResult>(tileSizeForShape,

67 tileSizeOfrSimplified);

68 }

69

75 }

76

77

78

80 if (!val)

84 return attr;

85 return val;

86 }

87

88

89

91 return llvm::to_vector(

93 }

94

95

98 res.reserve(arrayAttr.size());

100 res.push_back(a);

101 return res;

102 }

103

106 }

107

110 return llvm::to_vector(llvm::map_range(

112 }

113

114

116

117 if (auto val = llvm::dyn_cast_if_present(ofr)) {

118 APSInt intVal;

120 return intVal.getSExtValue();

121 return std::nullopt;

122 }

123

124 Attribute attr = llvm::dyn_cast_if_present(ofr);

125 if (auto intAttr = dyn_cast_or_null(attr))

126 return intAttr.getValue().getSExtValue();

127 return std::nullopt;

128 }

129

130 std::optional<SmallVector<int64_t>>

132 bool failed = false;

135 if (!cv.has_value())

136 failed = true;

137 return cv.value_or(0);

138 });

139 if (failed)

140 return std::nullopt;

141 return res;

142 }

143

146 }

147

149 return llvm::all_of(

151 }

152

155 if (ofrs.size() != values.size())

156 return false;

158 return constOfrs && llvm::equal(constOfrs.value(), values);

159 }

160

161

162

163

164

167 if (cst1 && cst2 && *cst1 == *cst2)

168 return true;

169 auto v1 = llvm::dyn_cast_if_present(ofr1),

170 v2 = llvm::dyn_cast_if_present(ofr2);

171 return v1 && v1 == v2;

172 }

173

176 if (ofrs1.size() != ofrs2.size())

177 return false;

178 for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))

180 return false;

181 return true;

182 }

183

184

185

186

191 res.reserve(staticValues.size());

192 unsigned numDynamic = 0;

193 unsigned count = static_cast<unsigned>(staticValues.size());

194 for (unsigned idx = 0; idx < count; ++idx) {

195 int64_t value = staticValues[idx];

196 res.push_back(ShapedType::isDynamic(value)

200 }

201 return res;

202 }

206 }

207

208

209

210 std::pair<SmallVector<int64_t>, SmallVector>

214 for (const auto &it : mixedValues) {

215 if (auto attr = dyn_cast(it)) {

216 staticValues.push_back(cast(attr).getInt());

217 } else {

218 staticValues.push_back(ShapedType::kDynamic);

219 dynamicValues.push_back(cast(it));

220 }

221 }

222 return {staticValues, dynamicValues};

223 }

224

225

226 template <typename K, typename V>

227 static SmallVector

230 if (keys.empty())

232 assert(keys.size() == values.size() && "unexpected mismatching sizes");

233 auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));

234 llvm::sort(indices,

235 [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });

237 res.reserve(values.size());

238 for (int64_t i = 0, e = indices.size(); i < e; ++i)

239 res.push_back(values[indices[i]]);

240 return res;

241 }

242

243 SmallVector

247 }

248

249 SmallVector

253 }

254

255 SmallVector<int64_t>

259 }

260

261

262

265 if (lb == ub)

266 return 0;

267

269 if (!lbConstant)

270 return std::nullopt;

272 if (!ubConstant)

273 return std::nullopt;

275 if (!stepConstant)

276 return std::nullopt;

277

278 return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);

279 }

280

282 return llvm::none_of(sizesOrOffsets, [](int64_t value) {

283 return !ShapedType::isDynamic(value) && value < 0;

284 });

285 }

286

288 return llvm::none_of(strides, [](int64_t value) {

289 return !ShapedType::isDynamic(value) && value == 0;

290 });

291 }

292

294 bool onlyNonNegative, bool onlyNonZero) {

295 bool valuesChanged = false;

297 if (isa(ofr))

298 continue;

301

303 continue;

305 continue;

306 ofr = attr;

307 valuesChanged = true;

308 }

309 }

310 return success(valuesChanged);

311 }

312

313 LogicalResult

316 false);

317 }

318

321 true);

322 }

323

324 }

Attributes are known-constant values of operations.

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getIndexAttr(int64_t value)

MLIRContext * getContext() const

MLIRContext is the top-level object for a collection of MLIR operations.

This class represents a single result from folding an operation.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

int compare(const Fraction &x, const Fraction &y)

Three-way comparison between two fractions.

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

bool isConstantIntValue(OpFoldResult ofr, int64_t value)

Return true if ofr is constant integer equal to value.

detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...

bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)

Return true if all of ofrs are constant integers equal to the corresponding value in values.

OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)

Convert int64_t to integer attributes of index type and return them as OpFoldResult.

std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)

Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)

Returns "success" when any of the elements in strides is a constant value.

bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)

Return true if all of ofrs are constant integers equal to value.

bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)

Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.

bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)

Helper function to check whether the passed in sizes or offsets are valid.

std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)

Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.

bool isZeroInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 0.

bool hasValidStrides(SmallVector< int64_t > strides)

Helper function to check whether the passed in strides are valid.

void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)

Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...

std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)

Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.

static SmallVector< V > getValuesSortedByKeyImpl(ArrayRef< K > keys, ArrayRef< V > values, llvm::function_ref< bool(K, K)> compare)

Helper to sort values according to matching keys.

bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)

Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)

Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)

Given OpFoldResult representing dim size value (*), generates a pair of sizes:

std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)

If all ofrs are constant integers or IntegerAttrs, return the integers.

SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)

Helper to sort values according to matching keys.

LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)

Returns "success" when any of the elements in offsetsOrSizes is a constant value.

bool isOneInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 1.

LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)

Returns "success" when any of the elements in ofrs is a constant value.

Eliminates variable at the specified position using Fourier-Motzkin variable elimination.