MLIR: lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

22 #include "llvm/Support/Debug.h"

23 #include "llvm/Support/FormatVariadic.h"

24 #include

25 #include

26

27 using namespace mlir;

28

29

32 for (const auto &vals : values)

33 llvm::append_range(result, vals);

34 return result;

35 }

36

37 FailureOr

39 auto layout = llvm::dyn_cast_if_present(tdescTy.getLayout());

40

41

42 if (!layout || !layout.isSgLayout())

43 return failure();

44

47 auto tdescShape = tdescTy.getShape();

48 auto elementType = tdescTy.getElementType();

49

50

51

52

53 auto sgSize = std::accumulate(laneLayout.begin(), laneLayout.end(), 1,

54 std::multiplies<int64_t>());

55

56

57 auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();

58 if (scatterAttr) {

59 auto chunkSize = scatterAttr.getChunkSize().getInt();

60

61

62 assert(tdescShape[0] == laneLayout[0] &&

63 "tensor descriptor shape is not distributable");

65 }

66

67

68

69 int64_t tensorSize = 1;

70 for (auto [tdescDim, laneDim, laneDataDim] :

71 llvm::zip_equal(tdescShape, laneLayout, laneData)) {

72 assert((tdescDim % (laneDim * laneDataDim) == 0) &&

73 "tensor descriptor shape is not distributable");

74 tensorSize *= tdescDim;

75 }

76

77 tensorSize *= tdescTy.getArrayLength();

78

79 return VectorType::get({tensorSize / sgSize}, elementType);

80 }

81

82 FailureOr

84 xegpu::LayoutAttr layout) {

85 int64_t rank = originalType.getRank();

86

87 if (rank < 1 || rank > 3)

88 return failure();

90

91

92 int arrayLength = 1;

93 if (rank == 3) {

94 arrayLength = shape[0];

95 shape = shape.drop_front();

96 }

98 shape, originalType.getElementType(), arrayLength,

99 true,

100 xegpu::MemorySpace::Global, layout);

102 }

103

105 const StringRef prefix("layout_operand_");

106 unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();

107 return llvm::formatv("{0}{1}", prefix, idx).str();

108 }

109

111 const StringRef prefix = "layout_result_";

112 return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();

113 }

114

116 if (!value)

117 return nullptr;

118

119 if (auto tdescTy =

120 dyn_cast_if_presentxegpu::TensorDescType(value.getType()))

121 return tdescTy.getLayoutAttr();

122

123 if (auto result = dyn_cast(value)) {

124 Operation *defOp = result.getDefiningOp();

125 assert(defOp && "result must have a defining op");

126

127

128 if (auto loadNd = dyn_castxegpu::LoadNdOp(defOp))

130

132 if (defOp->hasAttr(layoutName))

133 return defOp->getAttrOfTypexegpu::LayoutAttr(layoutName);

134 }

135

136 if (auto arg = dyn_cast(value)) {

137 auto parentOp = arg.getOwner()->getParentOp();

138 if (auto loop = dyn_cast(parentOp)) {

139 OpOperand *tiedInit = loop.getTiedLoopInit(arg);

141 }

142 }

143

144 return nullptr;

145 }

146

150 if (op->hasAttr(layoutName))

151 return op->getAttrOfTypexegpu::LayoutAttr(layoutName);

153 }

154

155 template <typename T, typename>

157 Operation *owner = operandOrResult.getOwner();

159 if (layout && !owner->hasAttrOfType(name))

160 owner->setAttr(name, layout);

161 }

162

163

164 template void

165 xegpu::setLayoutAttrmlir::OpResult(const mlir::OpResult &result,

166 const mlir::xegpu::LayoutAttr layout);

167

168

169 template void

170 xegpu::setLayoutAttrmlir::OpOperand(const mlir::OpOperand &operand,

171 const mlir::xegpu::LayoutAttr layout);

172

177 auto layout = getLayoutImpl(opr.get());

178 setLayoutAttr(opr, layout);

179 }

181 auto layout = getLayoutImpl(result);

182 setLayoutAttr(result, layout);

183 }

184 });

185 }

186

190 auto vecTy = dyn_cast(value.getType());

191 if (!vecTy)

192 return {value};

193

196 return {value};

197

201 result.push_back(builder.createvector::ExtractStridedSliceOp(

202 loc, value, offsets, shape, staticStrides));

203 }

204

205 return result;

206 }

207

211 VectorType inputTy = dyn_cast(values[0].getType());

212 assert(llvm::all_of(values.getTypes(),

213 [&](Type type) { return type == inputTy; }) &&

214 "values must be of the same VectorType");

215

216 Type elemTy = inputTy.getElementType();

218

220 auto zeroAttr = builder.getZeroAttr(elemTy);

221 Value result = builder.createarith::ConstantOp(

223

224 for (auto [src, offsets] :

227 result = builder.createvector::InsertStridedSliceOp(

228 loc, src, result, offsets, staticStrides);

229 }

230 return result;

231 }

232

236

239 return builder.create(loc, type, inputs)

240 .getResult(0);

241 };

242

243 {

248 });

251

253 target.addLegalOp();

254

257 target);

259 }

260

261 {

262

263

264 op->walk([](UnrealizedConversionCastOp castOp) {

265 if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)

267

268 Value input = castOp.getInputs()[0];

269 Value result = castOp.getResults()[0];

270 auto inputTy = dyn_cast(input.getType());

271 auto resultTy = dyn_cast(result.getType());

272

273

274 if (!inputTy || !resultTy)

276

278 if (!layout)

280

281 RankedTensorType newTy = resultTy.cloneWithEncoding(layout);

283

284

286 if (auto loop = dyn_cast(use.getOwner())) {

287 BlockArgument arg = loop.getTiedLoopRegionIterArg(&use);

289 }

290

291

292 if (auto whileOp = dyn_castscf::WhileOp(use.getOwner())) {

293 unsigned idx = use.getOperandNumber();

294 BlockArgument arg = whileOp.getAfterArguments()[idx];

296 }

297 }

299 });

300

301

302 op->walk([](scf::YieldOp yieldOp) {

305 unsigned idx = r.getResultNumber();

306 Type resultTy = r.getType();

307 Type yieldTy = yieldOp.getResults()[idx].getType();

308 if (isa(resultTy) && yieldTy != resultTy)

309 r.setType(yieldTy);

310 }

311 });

312 }

313

314 {

315

316

317

318

319

320

321 class UnrealizedConversionCastOpPattern

325

326 mlir::LogicalResult

327 matchAndRewrite(mlir::UnrealizedConversionCastOp op,

328 OneToNOpAdaptor adaptor,

330 auto inputs = op.getOperands();

331 auto outputs = op.getOutputs();

332

333 if (inputs.size() != 1 || outputs.size() != 1)

334 return failure();

335

336 auto inputTy = inputs[0].getType();

337 auto outputTy = outputs[0].getType();

338

339 if (isa(inputTy) && isa(outputTy)) {

341 return success();

342 }

343

344 if (isa(inputTy) && isa(outputTy)) {

346 auto newOp = rewriter.create(

347 op.getLoc(), outputTy, values);

349 return success();

350 }

351 return failure();

352 }

353 };

354

358 return builder.create(loc, type, inputs)

359 .getResults();

360 });

361

364 [](UnrealizedConversionCastOp op) {

365 auto isTensorTy = [](Type type) {

366 return isa(type);

367 };

368 return llvm::none_of(op->getOperandTypes(), isTensorTy) &&

370 });

372 patterns.insert(context);

374 target);

376 }

377 }

This class represents an argument of a Block.

TypedAttr getZeroAttr(Type type)

This class implements a pattern rewriter for use with ConversionPatterns.

void replaceOp(Operation *op, ValueRange newValues) override

Replace the given operation with the new values.

void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)

Replace the given operation with the new value ranges.

This class describes a specific conversion target.

void addLegalOp(OperationName op)

Register the given operations as legal.

void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)

Register the given operation as dynamically legal and set the dynamic legalization callback to the on...

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

IRValueT get() const

Return the current value being used by this operand.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

This class helps build Operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...

This class represents an operand of an operation.

This is a value defined by a result of an operation.

unsigned getResultNumber() const

Returns the number of this result.

Operation is the basic unit of execution within MLIR.

AttrClass getAttrOfType(StringAttr name)

bool hasAttrOfType(NameT &&name)

bool hasAttr(StringAttr name)

Return true if the operation has an attribute with the provided name, false otherwise.

std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)

Walk the operation by calling the callback for each nested operation (including this one),...

MLIRContext * getContext()

Return the context this operation is associated with.

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

void setAttr(StringAttr name, Attribute value)

If the an attribute exists with the specified name, change it to the new value.

operand_type_range getOperandTypes()

MutableArrayRef< OpOperand > getOpOperands()

result_type_range getResultTypes()

result_range getOpResults()

A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...

void addConversion(FnT &&callback)

Register a conversion function.

void addSourceMaterialization(FnT &&callback)

All of the following materializations require function objects that are convertible to the following ...

void addTargetMaterialization(FnT &&callback)

This method registers a materialization that will be called when converting a value to a target type ...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

type_range getTypes() const

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

void setType(Type newType)

Mutate the type of this Value to be of the specified type.

Type getType() const

Return the type of this value.

use_range getUses() const

Returns a range of all uses, which is useful for iterating over all uses.

static WalkResult advance()

Operation * getOwner() const

Return the owner of this operand.

void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)

Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...

Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)

Create a vector of shape from a set of values using vector.insert_stride_slice.

LayoutAttr getLayoutAttr(const Value value)

Retrieves the LayoutAttr associated with a given Value.

void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout)

Sets the LayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictionary attri...

std::string getLayoutName(const OpOperand &operand)

Return the attribute name for the OpOperand to attach LayoutAttr.

void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)

Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...

void setLayoutAttrs(Operation *op, function_ref< LayoutAttr(Value)> getLayoutImpl)

Set the LayoutAttr for each OpOperand and OpResult of the given operation.

SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)

Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.

SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)

Flatten a set of ValueRange into a single SmallVector

FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)

If tensor descriptor has a layout attribute it is used in SIMT mode.

Include the generated interface declarations.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

const FrozenRewritePatternSet & patterns

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)

Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.

LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())

Below we define several entry points for operation conversion.