MLIR: lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include
25 #include
26
27 using namespace mlir;
28
29
32 for (const auto &vals : values)
33 llvm::append_range(result, vals);
34 return result;
35 }
36
37 FailureOr
39 auto layout = llvm::dyn_cast_if_present(tdescTy.getLayout());
40
41
42 if (!layout || !layout.isSgLayout())
43 return failure();
44
47 auto tdescShape = tdescTy.getShape();
48 auto elementType = tdescTy.getElementType();
49
50
51
52
53 auto sgSize = std::accumulate(laneLayout.begin(), laneLayout.end(), 1,
54 std::multiplies<int64_t>());
55
56
57 auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
58 if (scatterAttr) {
59 auto chunkSize = scatterAttr.getChunkSize().getInt();
60
61
62 assert(tdescShape[0] == laneLayout[0] &&
63 "tensor descriptor shape is not distributable");
65 }
66
67
68
69 int64_t tensorSize = 1;
70 for (auto [tdescDim, laneDim, laneDataDim] :
71 llvm::zip_equal(tdescShape, laneLayout, laneData)) {
72 assert((tdescDim % (laneDim * laneDataDim) == 0) &&
73 "tensor descriptor shape is not distributable");
74 tensorSize *= tdescDim;
75 }
76
77 tensorSize *= tdescTy.getArrayLength();
78
79 return VectorType::get({tensorSize / sgSize}, elementType);
80 }
81
82 FailureOr
84 xegpu::LayoutAttr layout) {
85 int64_t rank = originalType.getRank();
86
87 if (rank < 1 || rank > 3)
88 return failure();
90
91
92 int arrayLength = 1;
93 if (rank == 3) {
94 arrayLength = shape[0];
95 shape = shape.drop_front();
96 }
98 shape, originalType.getElementType(), arrayLength,
99 true,
100 xegpu::MemorySpace::Global, layout);
102 }
103
105 const StringRef prefix("layout_operand_");
106 unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
107 return llvm::formatv("{0}{1}", prefix, idx).str();
108 }
109
111 const StringRef prefix = "layout_result_";
112 return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
113 }
114
116 if (!value)
117 return nullptr;
118
119 if (auto tdescTy =
120 dyn_cast_if_presentxegpu::TensorDescType(value.getType()))
121 return tdescTy.getLayoutAttr();
122
123 if (auto result = dyn_cast(value)) {
124 Operation *defOp = result.getDefiningOp();
125 assert(defOp && "result must have a defining op");
126
127
128 if (auto loadNd = dyn_castxegpu::LoadNdOp(defOp))
130
132 if (defOp->hasAttr(layoutName))
133 return defOp->getAttrOfTypexegpu::LayoutAttr(layoutName);
134 }
135
136 if (auto arg = dyn_cast(value)) {
137 auto parentOp = arg.getOwner()->getParentOp();
138 if (auto loop = dyn_cast(parentOp)) {
139 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
141 }
142 }
143
144 return nullptr;
145 }
146
150 if (op->hasAttr(layoutName))
151 return op->getAttrOfTypexegpu::LayoutAttr(layoutName);
153 }
154
155 template <typename T, typename>
157 Operation *owner = operandOrResult.getOwner();
159 if (layout && !owner->hasAttrOfType(name))
160 owner->setAttr(name, layout);
161 }
162
163
164 template void
165 xegpu::setLayoutAttrmlir::OpResult(const mlir::OpResult &result,
166 const mlir::xegpu::LayoutAttr layout);
167
168
169 template void
170 xegpu::setLayoutAttrmlir::OpOperand(const mlir::OpOperand &operand,
171 const mlir::xegpu::LayoutAttr layout);
172
177 auto layout = getLayoutImpl(opr.get());
178 setLayoutAttr(opr, layout);
179 }
181 auto layout = getLayoutImpl(result);
182 setLayoutAttr(result, layout);
183 }
184 });
185 }
186
190 auto vecTy = dyn_cast(value.getType());
191 if (!vecTy)
192 return {value};
193
196 return {value};
197
201 result.push_back(builder.createvector::ExtractStridedSliceOp(
202 loc, value, offsets, shape, staticStrides));
203 }
204
205 return result;
206 }
207
211 VectorType inputTy = dyn_cast(values[0].getType());
212 assert(llvm::all_of(values.getTypes(),
213 [&](Type type) { return type == inputTy; }) &&
214 "values must be of the same VectorType");
215
216 Type elemTy = inputTy.getElementType();
218
220 auto zeroAttr = builder.getZeroAttr(elemTy);
221 Value result = builder.createarith::ConstantOp(
223
224 for (auto [src, offsets] :
227 result = builder.createvector::InsertStridedSliceOp(
228 loc, src, result, offsets, staticStrides);
229 }
230 return result;
231 }
232
236
239 return builder.create(loc, type, inputs)
240 .getResult(0);
241 };
242
243 {
248 });
251
253 target.addLegalOp();
254
257 target);
259 }
260
261 {
262
263
264 op->walk([](UnrealizedConversionCastOp castOp) {
265 if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
267
268 Value input = castOp.getInputs()[0];
269 Value result = castOp.getResults()[0];
270 auto inputTy = dyn_cast(input.getType());
271 auto resultTy = dyn_cast(result.getType());
272
273
274 if (!inputTy || !resultTy)
276
278 if (!layout)
280
281 RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
283
284
286 if (auto loop = dyn_cast(use.getOwner())) {
287 BlockArgument arg = loop.getTiedLoopRegionIterArg(&use);
289 }
290
291
292 if (auto whileOp = dyn_castscf::WhileOp(use.getOwner())) {
293 unsigned idx = use.getOperandNumber();
294 BlockArgument arg = whileOp.getAfterArguments()[idx];
296 }
297 }
299 });
300
301
302 op->walk([](scf::YieldOp yieldOp) {
305 unsigned idx = r.getResultNumber();
306 Type resultTy = r.getType();
307 Type yieldTy = yieldOp.getResults()[idx].getType();
308 if (isa(resultTy) && yieldTy != resultTy)
309 r.setType(yieldTy);
310 }
311 });
312 }
313
314 {
315
316
317
318
319
320
321 class UnrealizedConversionCastOpPattern
325
326 mlir::LogicalResult
327 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
328 OneToNOpAdaptor adaptor,
330 auto inputs = op.getOperands();
331 auto outputs = op.getOutputs();
332
333 if (inputs.size() != 1 || outputs.size() != 1)
334 return failure();
335
336 auto inputTy = inputs[0].getType();
337 auto outputTy = outputs[0].getType();
338
339 if (isa(inputTy) && isa(outputTy)) {
341 return success();
342 }
343
344 if (isa(inputTy) && isa(outputTy)) {
346 auto newOp = rewriter.create(
347 op.getLoc(), outputTy, values);
349 return success();
350 }
351 return failure();
352 }
353 };
354
358 return builder.create(loc, type, inputs)
359 .getResults();
360 });
361
364 [](UnrealizedConversionCastOp op) {
365 auto isTensorTy = [](Type type) {
366 return isa(type);
367 };
368 return llvm::none_of(op->getOperandTypes(), isTensorTy) &&
370 });
372 patterns.insert(context);
374 target);
376 }
377 }
This class represents an argument of a Block.
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttrOfType(NameT &&name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_type_range getOperandTypes()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
result_range getOpResults()
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
LayoutAttr getLayoutAttr(const Value value)
Retrieves the LayoutAttr associated with a given Value.
void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout)
Sets the LayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictionary attri...
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach LayoutAttr.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
void setLayoutAttrs(Operation *op, function_ref< LayoutAttr(Value)> getLayoutImpl)
Set the LayoutAttr for each OpOperand and OpResult of the given operation.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.