MLIR: lib/Dialect/Utils/IndexingUtils.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

15 #include "llvm/ADT/STLExtras.h"

16 #include

17 #include

18

19 using namespace mlir;

20

21 template

23 ExprType unit) {

24 if (sizes.empty())

25 return {};

27 for (int64_t r = static_cast<int64_t>(strides.size()) - 2; r >= 0; --r)

28 strides[r] = strides[r + 1] * sizes[r + 1];

29 return strides;

30 }

31

32 template

35

36 if (v1.empty() && v2.empty())

37 return {};

39 for (auto it : llvm::zip_equal(v1, v2))

40 result.push_back(std::get<0>(it) * std::get<1>(it));

41 return result;

42 }

43

44 template

46 ExprType zero) {

47 assert(offsets.size() == basis.size());

48 ExprType linearIndex = zero;

49 for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)

50 linearIndex = linearIndex + offsets[idx] * basis[idx];

51 return linearIndex;

52 }

53

54 template <typename ExprType, typename DivOpTy>

57 DivOpTy divOp) {

58 int64_t rank = strides.size();

60 for (int64_t r = 0; r < rank; ++r) {

61 offsets[r] = divOp(linearIndex, strides[r]);

62 linearIndex = linearIndex % strides[r];

63 }

64 return offsets;

65 }

66

67

68

69

70

72 assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&

73 "sizes must be nonnegative");

74 int64_t unit = 1;

76 }

77

81 }

82

84 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&

85 "basis must be nonnegative");

86 if (basis.empty())

87 return 0;

88 return std::accumulate(basis.begin(), basis.end(), 1, std::plus<int64_t>());

89 }

90

92 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&

93 "basis must be nonnegative");

94 if (basis.empty())

95 return 1;

96 return std::accumulate(basis.begin(), basis.end(), 1,

97 std::multiplies<int64_t>());

98 }

99

101 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&

102 "basis must be nonnegative");

103 int64_t zero = 0;

105 }

106

109 assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) &&

110 "strides must be nonnegative");

112 [](int64_t e1, int64_t e2) { return e1 / e2; });

113 }

114

115 std::optional<SmallVector<int64_t>>

117 if (shape.size() < subShape.size())

118 return std::nullopt;

119 assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) &&

120 "shape must be nonnegative");

121 assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) &&

122 "subShape must be nonnegative");

123

124

125 std::vector<int64_t> result;

126 result.reserve(shape.size());

127 for (auto [size, subSize] :

128 llvm::zip(llvm::reverse(shape), llvm::reverse(subShape))) {

129

130 if (size % subSize != 0)

131 return std::nullopt;

132 result.push_back(size / subSize);

133 }

134

135

136 int commonSize = subShape.size();

137 std::copy(shape.rbegin() + commonSize, shape.rend(),

138 std::back_inserter(result));

139

141 }

142

143

144

145

146

148 if (sizes.empty())

149 return {};

152 }

153

157 }

158

160 if (basis.empty())

162 return std::accumulate(basis.begin(), basis.end(),

164 std::plus());

165 }

166

168 if (basis.empty())

170 return std::accumulate(basis.begin(), basis.end(),

172 std::multiplies());

173 }

174

179 }

180

183

185 }

186

190 linearIndex, strides,

192 }

193

198 }

199

200

201

202

203

206 assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) &&

207 "permutation must be non-negative");

210 inversion[pos.value()] = pos.index();

211 }

212 return inversion;

213 }

214

216 for (auto i : llvm::seq<int64_t>(0, permutation.size()))

217 if (permutation[i] != i)

218 return false;

219 return true;

220 }

221

223 llvm::SmallDenseSet<int64_t, 4> seenVals;

224 for (auto val : interchange) {

225 if (val < 0 || static_cast<uint64_t>(val) >= interchange.size())

226 return false;

227 if (seenVals.count(val))

228 return false;

229 seenVals.insert(val);

230 }

231 return seenVals.size() == interchange.size();

232 }

233

239 for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) {

240 res[desiredPos] = pos;

241 seen.insert(pos);

242 }

243 int64_t nextPos = 0;

244 for (int64_t &entry : res) {

245 if (entry != -1)

246 continue;

247 while (seen.contains(nextPos))

248 ++nextPos;

249 entry = nextPos;

250 ++nextPos;

251 }

252 return res;

253 }

254

257 assert(inputPerm.size() >= dropPositions.size() &&

258 "expect inputPerm size large than position to drop");

260 unsigned permSize = inputPerm.size();

261 for (unsigned inputIndex = 0; inputIndex < permSize; ++inputIndex) {

262 int64_t targetIndex = inputPerm[inputIndex];

263 bool shouldDrop = false;

264 unsigned dropSize = dropPositions.size();

265 for (unsigned dropIndex = 0; dropIndex < dropSize; dropIndex++) {

266 if (dropPositions[dropIndex] == inputPerm[inputIndex]) {

267 shouldDrop = true;

268 break;

269 }

270 if (dropPositions[dropIndex] < inputPerm[inputIndex]) {

271 targetIndex--;

272 }

273 }

274 if (!shouldDrop) {

275 res.push_back(targetIndex);

276 }

277 }

278 return res;

279 }

280

283 unsigned dropBack) {

284 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");

285 auto range = arrayAttr.getAsRange();

287 res.reserve(arrayAttr.size() - dropFront - dropBack);

288 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;

289 it != eit; ++it)

290 res.push_back((*it).getValue().getSExtValue());

291 return res;

292 }

293

294

296 assert(val && "Invalid value");

297 if (auto attr = dyn_cast(val)) {

298 return attr.getContext();

299 }

300 return cast(val).getContext();

301 }

302

303 std::pair<AffineExpr, SmallVector>

307 assert(strides.size() == indices.size());

308 auto sourceRank = static_cast<unsigned>(strides.size());

309

310

313

316 values[0] = sourceOffset;

317

318 for (unsigned i = 0; i < sourceRank; ++i) {

319

321

322

323 unsigned baseIdxForDim = 1 + 2 * i;

324 unsigned subOffsetForDim = baseIdxForDim;

325 unsigned origStrideForDim = baseIdxForDim + 1;

326 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];

327 values[subOffsetForDim] = indices[i];

328 values[origStrideForDim] = origStride;

329 }

330

331 return {expr, values};

332 }

333

334 std::pair<AffineExpr, SmallVector>

340 }

341

342

343

344

345

346

348 unsigned paddedSize) {

349 assert(tileShape.size() <= paddedSize &&

350 "expected tileShape to <= paddedSize");

351 if (tileShape.size() == paddedSize)

352 return to_vector(tileShape);

354 llvm::append_range(result, tileShape);

355 return result;

356 }

357

363 sliceStrides(shape.size()) {

364

365 std::optional<SmallVector<int64_t>> shapeRatio =

367 assert(shapeRatio && shapeRatio->size() == shape.size() &&

368 "target shape does not evenly divide the original shape");

369 assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() &&

370 "expected loop order to be a permutation of rank equal to outer "

371 "shape");

372

376 }

377

379 int64_t linearIndex) const {

381 delinearize(linearIndex, sliceStrides), inverseLoopOrder);

383 }

384

390 delinearize(linearIndex, sliceStrides), inverseLoopOrder);

393 }

void dropFront(int64_t arr[N], int64_t *res)

static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)

Copies the given number of bytes from src to dst pointers.

static MLIRContext * getContext(OpFoldResult val)

static SmallVector< int64_t > padTileShapeToSize(ArrayRef< int64_t > tileShape, unsigned paddedSize)

Apply left-padding by 1 to the tile shape if required.

SmallVector< ExprType > computeElementwiseMulImpl(ArrayRef< ExprType > v1, ArrayRef< ExprType > v2)

SmallVector< ExprType > computeSuffixProductImpl(ArrayRef< ExprType > sizes, ExprType unit)

ExprType linearizeImpl(ArrayRef< ExprType > offsets, ArrayRef< ExprType > basis, ExprType zero)

SmallVector< ExprType > delinearizeImpl(ExprType linearIndex, ArrayRef< ExprType > strides, DivOpTy divOp)

Base type for affine expression.

AffineExpr floorDiv(uint64_t v) const

MLIRContext * getContext() const

MLIRContext is the top-level object for a collection of MLIR operations.

This class represents a single result from folding an operation.

MLIRContext * getContext() const

This class provides an abstraction over the different types of ranges over Values.

SmallVector< int64_t > getStaticTileOffsets(int64_t linearIndex) const

TileOffsetRangeImpl(ArrayRef< int64_t > shape, ArrayRef< int64_t > tileShape, ArrayRef< int64_t > loopOrder)

SmallVector< AffineExpr > getDynamicTileOffsets(AffineExpr linearIndex) const

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

Include the generated interface declarations.

OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)

Convert int64_t to integer attributes of index type and return them as OpFoldResult.

SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)

Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.

std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)

Compute linear index from provided strides and indices, assuming strided layout.

SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)

SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)

SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)

Given the strides together with a linear index in the dimension space, return the vector-space offset...

int64_t computeProduct(ArrayRef< int64_t > basis)

Self-explicit.

bool isIdentityPermutation(ArrayRef< int64_t > permutation)

Returns true if permutation is an identity permutation.

SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)

Return a permutation vector of size permSize that would result in moving positions into desiredPositi...

SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)

Helper to return a subset of arrayAttr as a vector of int64_t.

SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)

Given a set of sizes, return the suffix product.

int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)

Return the number of elements of basis (i.e.

AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)

Return the linearized index of 'offsets' w.r.t.

SmallVector< AffineExpr > getAffineConstantExprs(ArrayRef< int64_t > constants, MLIRContext *context)

std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)

Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.

void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)

Apply the permutation defined by permutation to inVec.

int64_t computeSum(ArrayRef< int64_t > basis)

Self-explicit.

SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)

Returns a permutation vector that drop the input dims in dropPositions from inputPerm.

bool isPermutationVector(ArrayRef< int64_t > interchange)

Method to check if an interchange vector is a permutation.

void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)

SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)

Helper method to apply to inverse a permutation.