MLIR: lib/Dialect/Utils/IndexingUtils.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
15 #include "llvm/ADT/STLExtras.h"
16 #include
17 #include
18
19 using namespace mlir;
20
21 template
23 ExprType unit) {
24 if (sizes.empty())
25 return {};
27 for (int64_t r = static_cast<int64_t>(strides.size()) - 2; r >= 0; --r)
28 strides[r] = strides[r + 1] * sizes[r + 1];
29 return strides;
30 }
31
32 template
35
36 if (v1.empty() && v2.empty())
37 return {};
39 for (auto it : llvm::zip_equal(v1, v2))
40 result.push_back(std::get<0>(it) * std::get<1>(it));
41 return result;
42 }
43
44 template
46 ExprType zero) {
47 assert(offsets.size() == basis.size());
48 ExprType linearIndex = zero;
49 for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
50 linearIndex = linearIndex + offsets[idx] * basis[idx];
51 return linearIndex;
52 }
53
54 template <typename ExprType, typename DivOpTy>
57 DivOpTy divOp) {
58 int64_t rank = strides.size();
60 for (int64_t r = 0; r < rank; ++r) {
61 offsets[r] = divOp(linearIndex, strides[r]);
62 linearIndex = linearIndex % strides[r];
63 }
64 return offsets;
65 }
66
67
68
69
70
72 assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
73 "sizes must be nonnegative");
74 int64_t unit = 1;
76 }
77
81 }
82
84 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
85 "basis must be nonnegative");
86 if (basis.empty())
87 return 0;
88 return std::accumulate(basis.begin(), basis.end(), 1, std::plus<int64_t>());
89 }
90
92 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
93 "basis must be nonnegative");
94 if (basis.empty())
95 return 1;
96 return std::accumulate(basis.begin(), basis.end(), 1,
97 std::multiplies<int64_t>());
98 }
99
101 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
102 "basis must be nonnegative");
103 int64_t zero = 0;
105 }
106
109 assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) &&
110 "strides must be nonnegative");
112 [](int64_t e1, int64_t e2) { return e1 / e2; });
113 }
114
115 std::optional<SmallVector<int64_t>>
117 if (shape.size() < subShape.size())
118 return std::nullopt;
119 assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) &&
120 "shape must be nonnegative");
121 assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) &&
122 "subShape must be nonnegative");
123
124
125 std::vector<int64_t> result;
126 result.reserve(shape.size());
127 for (auto [size, subSize] :
128 llvm::zip(llvm::reverse(shape), llvm::reverse(subShape))) {
129
130 if (size % subSize != 0)
131 return std::nullopt;
132 result.push_back(size / subSize);
133 }
134
135
136 int commonSize = subShape.size();
137 std::copy(shape.rbegin() + commonSize, shape.rend(),
138 std::back_inserter(result));
139
141 }
142
143
144
145
146
148 if (sizes.empty())
149 return {};
152 }
153
157 }
158
160 if (basis.empty())
162 return std::accumulate(basis.begin(), basis.end(),
164 std::plus());
165 }
166
168 if (basis.empty())
170 return std::accumulate(basis.begin(), basis.end(),
172 std::multiplies());
173 }
174
179 }
180
183
185 }
186
190 linearIndex, strides,
192 }
193
198 }
199
200
201
202
203
206 assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) &&
207 "permutation must be non-negative");
210 inversion[pos.value()] = pos.index();
211 }
212 return inversion;
213 }
214
216 for (auto i : llvm::seq<int64_t>(0, permutation.size()))
217 if (permutation[i] != i)
218 return false;
219 return true;
220 }
221
223 llvm::SmallDenseSet<int64_t, 4> seenVals;
224 for (auto val : interchange) {
225 if (val < 0 || static_cast<uint64_t>(val) >= interchange.size())
226 return false;
227 if (seenVals.count(val))
228 return false;
229 seenVals.insert(val);
230 }
231 return seenVals.size() == interchange.size();
232 }
233
239 for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) {
240 res[desiredPos] = pos;
241 seen.insert(pos);
242 }
243 int64_t nextPos = 0;
244 for (int64_t &entry : res) {
245 if (entry != -1)
246 continue;
247 while (seen.contains(nextPos))
248 ++nextPos;
249 entry = nextPos;
250 ++nextPos;
251 }
252 return res;
253 }
254
257 assert(inputPerm.size() >= dropPositions.size() &&
258 "expect inputPerm size large than position to drop");
260 unsigned permSize = inputPerm.size();
261 for (unsigned inputIndex = 0; inputIndex < permSize; ++inputIndex) {
262 int64_t targetIndex = inputPerm[inputIndex];
263 bool shouldDrop = false;
264 unsigned dropSize = dropPositions.size();
265 for (unsigned dropIndex = 0; dropIndex < dropSize; dropIndex++) {
266 if (dropPositions[dropIndex] == inputPerm[inputIndex]) {
267 shouldDrop = true;
268 break;
269 }
270 if (dropPositions[dropIndex] < inputPerm[inputIndex]) {
271 targetIndex--;
272 }
273 }
274 if (!shouldDrop) {
275 res.push_back(targetIndex);
276 }
277 }
278 return res;
279 }
280
283 unsigned dropBack) {
284 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
285 auto range = arrayAttr.getAsRange();
287 res.reserve(arrayAttr.size() - dropFront - dropBack);
288 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
289 it != eit; ++it)
290 res.push_back((*it).getValue().getSExtValue());
291 return res;
292 }
293
294
296 assert(val && "Invalid value");
297 if (auto attr = dyn_cast(val)) {
298 return attr.getContext();
299 }
300 return cast(val).getContext();
301 }
302
303 std::pair<AffineExpr, SmallVector>
307 assert(strides.size() == indices.size());
308 auto sourceRank = static_cast<unsigned>(strides.size());
309
310
313
316 values[0] = sourceOffset;
317
318 for (unsigned i = 0; i < sourceRank; ++i) {
319
321
322
323 unsigned baseIdxForDim = 1 + 2 * i;
324 unsigned subOffsetForDim = baseIdxForDim;
325 unsigned origStrideForDim = baseIdxForDim + 1;
326 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
327 values[subOffsetForDim] = indices[i];
328 values[origStrideForDim] = origStride;
329 }
330
331 return {expr, values};
332 }
333
334 std::pair<AffineExpr, SmallVector>
340 }
341
342
343
344
345
346
348 unsigned paddedSize) {
349 assert(tileShape.size() <= paddedSize &&
350 "expected tileShape to <= paddedSize");
351 if (tileShape.size() == paddedSize)
352 return to_vector(tileShape);
354 llvm::append_range(result, tileShape);
355 return result;
356 }
357
363 sliceStrides(shape.size()) {
364
365 std::optional<SmallVector<int64_t>> shapeRatio =
367 assert(shapeRatio && shapeRatio->size() == shape.size() &&
368 "target shape does not evenly divide the original shape");
369 assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() &&
370 "expected loop order to be a permutation of rank equal to outer "
371 "shape");
372
376 }
377
379 int64_t linearIndex) const {
381 delinearize(linearIndex, sliceStrides), inverseLoopOrder);
383 }
384
390 delinearize(linearIndex, sliceStrides), inverseLoopOrder);
393 }
void dropFront(int64_t arr[N], int64_t *res)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
static SmallVector< int64_t > padTileShapeToSize(ArrayRef< int64_t > tileShape, unsigned paddedSize)
Apply left-padding by 1 to the tile shape if required.
SmallVector< ExprType > computeElementwiseMulImpl(ArrayRef< ExprType > v1, ArrayRef< ExprType > v2)
SmallVector< ExprType > computeSuffixProductImpl(ArrayRef< ExprType > sizes, ExprType unit)
ExprType linearizeImpl(ArrayRef< ExprType > offsets, ArrayRef< ExprType > basis, ExprType zero)
SmallVector< ExprType > delinearizeImpl(ExprType linearIndex, ArrayRef< ExprType > strides, DivOpTy divOp)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
MLIRContext * getContext() const
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
MLIRContext * getContext() const
This class provides an abstraction over the different types of ranges over Values.
SmallVector< int64_t > getStaticTileOffsets(int64_t linearIndex) const
TileOffsetRangeImpl(ArrayRef< int64_t > shape, ArrayRef< int64_t > tileShape, ArrayRef< int64_t > loopOrder)
SmallVector< AffineExpr > getDynamicTileOffsets(AffineExpr linearIndex) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
SmallVector< AffineExpr > getAffineConstantExprs(ArrayRef< int64_t > constants, MLIRContext *context)
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.