MLIR: lib/Analysis/Presburger/PWMAFunction.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/ADT/STLFunctionalExtras.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Support/raw_ostream.h"
18 #include
19 #include
20 #include
21
22 using namespace mlir;
23 using namespace presburger;
24
25 void MultiAffineFunction::assertIsConsistent() const {
28 "Inconsistent number of output columns");
31 "Inconsistent number of non-division variables in divs");
33 "Inconsistent number of output rows");
35 "Inconsistent number of divisions.");
36 assert(divs.hasAllReprs() && "All divisions should have a representation");
37 }
38
39
40
41
44 assert(vecA.size() == vecB.size() &&
45 "Cannot subtract vectors of differing lengths!");
47 result.reserve(vecA.size());
48 for (unsigned i = 0, e = vecA.size(); i < e; ++i)
49 result.emplace_back(vecA[i] - vecB[i]);
50 return result;
51 }
52
55 for (const Piece &piece : pieces)
57 return domain;
58 }
59
62 os << "Division Representation:\n";
64 os << "Output:\n";
66 }
67
69
73 "Point has incorrect dimensionality!");
74
76
79
80
81 pointHomogenous.reserve(pointHomogenous.size() + divValues.size());
82 for (const std::optional &divVal : divValues)
83 pointHomogenous.emplace_back(*divVal);
84
85
86
87
88
89 pointHomogenous.emplace_back(1);
93 return result;
94 }
95
98 "Spaces should be compatible for equality check.");
100 }
101
105 "Spaces should be compatible for equality check.");
108
111
112 return restrictedThis.isEqual(restrictedOther);
113 }
114
118 "Spaces should be compatible for equality check.");
121 return isEqual(other, IntegerPolyhedron(disjunct));
122 });
123 }
124
126 assert(end <= getNumOutputs() && "Invalid range");
127
128 if (start >= end)
129 return;
130
132 output.removeRows(start, end - start);
133 }
134
136 assert(space.isCompatible(other.space) && "Functions should be compatible");
137
140
142
144 for (unsigned i = 0; i < nDivs; ++i) {
145
146 std::fill(div.begin(), div.end(), 0);
147
149 div.begin());
150
153 }
154
157
158 auto merge = [&](unsigned i, unsigned j) {
159
160 if (i >= j)
161 return false;
162
163
164
165 if (j < nDivs)
166 return false;
167
168
170 other.output.addToColumn(divOffset + i, divOffset + j, 1);
172 return true;
173 };
174
176
177 unsigned newDivs = other.divs.getNumDivs() - nDivs;
178
180 output.insertColumns(divOffset + nDivs, newDivs);
181 divs = other.divs;
182
183
184 assertIsConsistent();
185 other.assertIsConsistent();
186 }
187
192 "Output space of funcs should be compatible");
193
194
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
221 1 + 2 * resultSpace.getNumLocalVars(),
223 resultSpace.getNumVars() + 1, resultSpace);
224
225
226 for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) {
233 }
234
235 for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) {
236
239
240
241 switch (comp) {
243
244
245
246
248 break;
250
251
252
253
255 break;
260 assert(false && "Not implemented case");
261 }
262
263
265
266
268
270 }
271
272 return result;
273 }
274
275
276
279 return false;
280
282 return false;
283
284
285
286
287
288 return llvm::all_of(this->pieces, [&other](const Piece &pieceA) {
289 return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) {
290 PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain);
291 return pieceA.output.isEqual(pieceB.output, commonDomain);
292 });
293 });
294 }
295
297 assert(piece.isConsistent() && "Piece should be consistent");
299 "Piece should be disjoint from the function");
300 pieces.emplace_back(piece);
301 }
302
306 for (const Piece &piece : pieces) {
307 os << "Domain of piece:\n";
308 piece.domain.print(os);
309 os << "Output of piece\n";
310 piece.output.print(os);
311 }
312 }
313
315
320 "Ranges of functions should be same.");
322 "Space is not compatible.");
323
324
325
326
327
328
329
330
331
332
333
334
336 for (const Piece &pieceA : pieces) {
338 for (const Piece &pieceB : func.pieces) {
340
341
342
343 result.addPiece({better, pieceB.output});
344 dom = dom.subtract(better);
345 }
346
347
348
349
350
351
352
353
354
355
356 result.addPiece({dom, pieceA.output});
357 }
358
359
361 for (const Piece &pieceB : func.pieces)
362 result.addPiece({pieceB.domain.subtract(dom), pieceB.output});
363
364 return result;
365 }
366
367
368
369
370 template
375
376 return result;
377 }
378
380 return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::LT>);
381 }
382
384 return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::GT>);
385 }
386
389 "Spaces should be compatible for subtraction.");
390
393 for (unsigned i = 0, e = getNumOutputs(); i < e; ++i)
395
396
397 assertIsConsistent();
398 }
399
400
401
402
406 "All divisions in divs should have a representation");
408 "Relation and divs should have the same number of vars");
410 "Relation and divs should have the same number of local vars");
411
412 for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) {
417 }
418 }
419
421
422
426
428
429
431
432
433
435 for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) {
436
437
438
439
440
442
444
447
450
451
452
454
456 }
457
458 return result;
459 }
460
463 for (Piece &piece : pieces)
464 piece.output.removeOutputs(start, end);
465 }
466
467 std::optional<SmallVector<DynamicAPInt, 8>>
470
471 for (const Piece &piece : pieces)
472 if (piece.domain.containsPoint(point))
473 return piece.output.valueAt(point);
474 return std::nullopt;
475 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static SmallVector< DynamicAPInt, 8 > subtractExprs(ArrayRef< DynamicAPInt > vecA, ArrayRef< DynamicAPInt > vecB)
static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA, const PWMAFunction::Piece &pieceB)
A tiebreak function which breaks ties by comparing the outputs lexicographically based on the given c...
static void addDivisionConstraints(IntegerRelation &rel, const DivisionRepr &divs)
Adds division constraints corresponding to local variables, given a relation and division representat...
Class storing division representation of local variables of a constraint system.
void removeDuplicateDivs(llvm::function_ref< bool(unsigned i, unsigned j)> merge)
Removes duplicate divisions.
unsigned getNumNonDivs() const
unsigned getNumVars() const
unsigned getDivOffset() const
void print(raw_ostream &os) const
unsigned getNumDivs() const
SmallVector< std::optional< DynamicAPInt >, 4 > divValuesAt(ArrayRef< DynamicAPInt > point) const
DynamicAPInt & getDenom(unsigned i)
void insertDiv(unsigned pos, ArrayRef< DynamicAPInt > dividend, const DynamicAPInt &divisor)
void setDiv(unsigned i, ArrayRef< DynamicAPInt > dividend, const DynamicAPInt &divisor)
MutableArrayRef< DynamicAPInt > getDividend(unsigned i)
An IntegerPolyhedron represents the set of points from a PresburgerSpace that satisfy a list of affin...
An IntegerRelation represents the set of points from a PresburgerSpace that satisfy a list of affine ...
unsigned getVarKindEnd(VarKind kind) const
Return the index at Which the specified kind of vars ends.
void addBound(BoundType type, unsigned pos, const DynamicAPInt &value)
Adds a constant bound for the specified variable.
virtual unsigned insertVar(VarKind kind, unsigned pos, unsigned num=1)
Insert num variables of the specified kind at position pos.
unsigned getNumVars() const
void intersectDomain(const IntegerPolyhedron &poly)
Intersect the given poly with the domain in-place.
bool isEqual(const IntegerRelation &other) const
Return whether this and other are equal.
void addEquality(ArrayRef< DynamicAPInt > eq)
Adds an equality from the coefficients specified in eq.
unsigned getNumLocalVars() const
unsigned getNumCols() const
Returns the number of columns in the constraint system.
void removeInequality(unsigned pos)
void addInequality(ArrayRef< DynamicAPInt > inEq)
Adds an inequality (>= 0) from the coefficients specified in inEq.
unsigned getNumInequalities() const
unsigned getVarKindOffset(VarKind kind) const
Return the index at which the specified kind of vars starts.
unsigned getNumRows() const
void removeColumn(unsigned pos)
void addToColumn(unsigned sourceColumn, unsigned targetColumn, const T &scale)
Add scale multiples of the source column to the target column.
void print(raw_ostream &os) const
Print the matrix.
void insertColumns(unsigned pos, unsigned count)
Insert columns having positions pos, pos + 1, ...
unsigned getNumColumns() const
SmallVector< T, 8 > postMultiplyWithColumn(ArrayRef< T > colVec) const
The given vector is interpreted as a column vector v.
void addToRow(unsigned sourceRow, unsigned targetRow, const T &scale)
Add scale multiples of the source row to the target row.
void removeRows(unsigned pos, unsigned count)
Remove the rows having positions pos, pos + 1, ...
This class represents a multi-affine function with the domain as Z^d, where d is the number of domain...
void subtract(const MultiAffineFunction &other)
void removeOutputs(unsigned start, unsigned end)
Remove the specified range of outputs.
void print(raw_ostream &os) const
unsigned getNumDivs() const
const PresburgerSpace & getSpace() const
Get the space of this function.
PresburgerSpace getDomainSpace() const
Get the domain/output space of the function.
PresburgerSet getLexSet(OrderingKind comp, const MultiAffineFunction &other) const
Return the set of domain points where the output of this and other are ordered lexicographically acco...
ArrayRef< DynamicAPInt > getOutputExpr(unsigned i) const
Get the i^th output expression.
unsigned getNumOutputs() const
SmallVector< DynamicAPInt, 8 > valueAt(ArrayRef< DynamicAPInt > point) const
unsigned getNumDomainVars() const
void mergeDivs(MultiAffineFunction &other)
Given a MAF other, merges division variables such that both functions have the union of the division ...
unsigned getNumSymbolVars() const
IntegerRelation getAsRelation() const
Get this function as a relation.
bool isEqual(const MultiAffineFunction &other) const
Return whether the this and other are equal when the domain is restricted to domain.
This class represents a piece-wise MultiAffineFunction.
const PresburgerSpace & getSpace() const
void addPiece(const Piece &piece)
unsigned getNumDomainVars() const
void print(raw_ostream &os) const
PWMAFunction unionLexMax(const PWMAFunction &func)
unsigned getNumPieces() const
void removeOutputs(unsigned start, unsigned end)
Remove the specified range of outputs.
unsigned getNumOutputs() const
PWMAFunction unionLexMin(const PWMAFunction &func)
Return a function defined on the union of the domains of this and func, such that when only one of th...
std::optional< SmallVector< DynamicAPInt, 8 > > valueAt(ArrayRef< DynamicAPInt > point) const
Return the output of the function at the given point.
PresburgerSet getDomain() const
Return the domain of this piece-wise MultiAffineFunction.
PresburgerSpace getDomainSpace() const
Get the domain/output space of the function.
unsigned getNumSymbolVars() const
bool isEqual(const PWMAFunction &other) const
Return whether this and other are equal as PWMAFunctions, i.e.
void unionInPlace(const IntegerRelation &disjunct)
Mutate this set, turning it into the union of this set and the given disjunct.
ArrayRef< IntegerRelation > getAllDisjuncts() const
Return a reference to the list of disjuncts.
PresburgerSet intersect(const PresburgerRelation &set) const
static PresburgerSet getEmpty(const PresburgerSpace &space)
Return an empty set of the specified type that contains no points.
PresburgerSpace is the space of all possible values of a tuple of integer valued variables/variables.
unsigned getNumRangeVars() const
unsigned getNumSymbolVars() const
void removeVarRange(VarKind kind, unsigned varStart, unsigned varLimit)
Removes variables of the specified kind in the column range [varStart, varLimit).
unsigned getNumVars() const
unsigned getNumLocalVars() const
unsigned getNumDomainVars() const
bool isCompatible(const PresburgerSpace &other) const
Returns true if both the spaces are compatible i.e.
void print(llvm::raw_ostream &os) const
PresburgerSpace getSpaceWithoutLocals() const
Get the space without local variables.
static PresburgerSpace getRelationSpace(unsigned numDomain=0, unsigned numRange=0, unsigned numSymbols=0, unsigned numLocals=0)
unsigned insertVar(VarKind kind, unsigned pos, unsigned num=1)
Insert num variables of the specified kind at position pos.
OrderingKind
Enum representing a binary comparison operator: equal, not equal, less than, less than or equal,...
SmallVector< DynamicAPInt, 8 > getDivUpperBound(ArrayRef< DynamicAPInt > dividend, const DynamicAPInt &divisor, unsigned localVarIdx)
If q is defined to be equal to expr floordiv d, this equivalent to saying that q is an integer and q ...
SmallVector< DynamicAPInt, 8 > getDivLowerBound(ArrayRef< DynamicAPInt > dividend, const DynamicAPInt &divisor, unsigned localVarIdx)
Include the generated interface declarations.
MultiAffineFunction output
bool isConsistent() const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.