MLIR: lib/Dialect/ArmSME/IR/Utils.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
15
17
21 }
22
27 }
28
30 if ((vType.getRank() != 2) || !vType.allDimsScalable())
31 return false;
32
33 auto elemType = vType.getElementType();
35 return false;
36
38 if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
39 return false;
40
41 return true;
42 }
43
44 std::optional getSMETileType(VectorType type) {
46 return {};
47 switch (type.getElementTypeBitWidth()) {
48 case 8:
49 return ArmSMETileType::ZAB;
50 case 16:
51 return ArmSMETileType::ZAH;
52 case 32:
53 return ArmSMETileType::ZAS;
54 case 64:
55 return ArmSMETileType::ZAD;
56 case 128:
57 return ArmSMETileType::ZAQ;
58 default:
59 llvm_unreachable("unknown SME tile type");
60 }
61 }
62
64 auto tileOp = llvm::dyn_cast(op);
65 if (!tileOp)
66 return success();
67 auto tileId = tileOp.getTileId();
68 if (!tileId)
69 return success();
70 if (!tileId.getType().isSignlessInteger(32))
71 return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
72 return success();
73 }
74
81 loc, llvm::cast(initTile.getType()).getDimSize(0));
82 auto vscale =
85 auto numTileSlices =
86 rewriter.createarith::MulIOp(loc, minTileSlices, vscale);
87 auto forOp = rewriter.createscf::ForOp(loc, lowerBound, numTileSlices, step,
91 makeLoopBody(rewriter, loc, forOp.getInductionVar(),
92 forOp.getRegionIterArg(0));
93 rewriter.createscf::YieldOp(loc, nextTile);
94 return forOp;
95 }
96
98 if (vType.getRank() != 2 || !vType.allDimsScalable())
99 return false;
100
101 auto elementType = vType.getElementType();
103 return false;
104
106
107 int64_t vectorRows = vType.getDimSize(0);
108 int64_t vectorCols = vType.getDimSize(1);
109
110 return (vectorRows > minNumElts || vectorCols > minNumElts) &&
111 vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
112 }
113
116 return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
117 }
118
120 FunctionOpInterface function) {
122 function->walk([&](Operation *op) {
123 auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op);
125 worklist.push_back(armSMEOp);
126 });
127 while (!worklist.empty()) {
128 Operation *op = worklist.pop_back_val();
130 continue;
132 if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
133 worklist.push_back(armSMEOp);
134 }
136 }
137 }
138
140 return tileOp && tileOp->getNumResults() == 1 &&
141 tileOp->getNumOperands() == 0 && isPure(tileOp);
142 }
143
145 for (Value result : tileOp->getResults()) {
147 return true;
148 }
149 return false;
150 }
151
153 if (!tileOp)
154 return nullptr;
155 auto isTileOperandType = [](OpOperand &operand) {
157 };
158 assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
159 "expected at most one tile operand");
161 llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
162 if (tileOperand == tileOp->getOpOperands().end())
163 return nullptr;
164 return tileOperand;
165 }
166
168
169 return static_cast<unsigned>(typeA) <= static_cast<unsigned>(typeB);
170 }
171
172 }
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
bool isValidSMETileElementType(Type type)
Returns true if type is a valid element type for an SME tile or false otherwise.
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
LogicalResult verifyOperationHasValidTileId(Operation *)
Verifies the tile ID (if set) on this tile operation is valid.
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
constexpr unsigned MinStreamingVectorLengthInBits
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...