MLIR: lib/Conversion/LLVMCommon/Pattern.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
15
16 using namespace mlir;
17
18
19
20
21
23 StringRef rootOpName, MLIRContext *context,
25 : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
26
30 }
31
34 }
35
38 }
39
43 }
44
47 }
48
51 }
52
55 Type resultType,
56 int64_t value) {
57 return builder.createLLVM::ConstantOp(loc, resultType,
59 }
60
64 LLVM::GEPNoWrapFlags noWrapFlags) const {
66 memRefDesc, indices, noWrapFlags);
67 }
68
69
70
72 MemRefType type) const {
73 if (!type.getLayout().isIdentity())
74 return false;
76 }
77
80 if (failed(addressSpace))
81 return {};
83 }
84
90 "layout maps must have been normalized away");
91 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
92 static_cast<ssize_t>(dynamicSizes.size()) &&
93 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
94
95 sizes.reserve(memRefType.getRank());
96 unsigned dynamicIndex = 0;
98 for (int64_t size : memRefType.getShape()) {
99 sizes.push_back(
100 size == ShapedType::kDynamic
101 ? dynamicSizes[dynamicIndex++]
103 }
104
105
106 int64_t stride = 1;
108 strides.resize(memRefType.getRank());
109 for (auto i = memRefType.getRank(); i-- > 0;) {
110 strides[i] = runningStride;
111
112 int64_t staticSize = memRefType.getShape()[i];
113 bool useSizeAsStride = stride == 1;
114 if (staticSize == ShapedType::kDynamic)
115 stride = ShapedType::kDynamic;
116 if (stride != ShapedType::kDynamic)
117 stride *= staticSize;
118
119 if (useSizeAsStride)
120 runningStride = sizes[i];
121 else if (stride == ShapedType::kDynamic)
122 runningStride =
123 rewriter.createLLVM::MulOp(loc, runningStride, sizes[i]);
124 else
126 }
127 if (sizeInBytes) {
128
131 Value nullPtr = rewriter.createLLVM::ZeroOp(loc, elementPtrType);
132 Value gepPtr = rewriter.createLLVM::GEPOp(
133 loc, elementPtrType, elementType, nullPtr, runningStride);
134 size = rewriter.createLLVM::PtrToIntOp(loc, getIndexType(), gepPtr);
135 } else {
136 size = runningStride;
137 }
138 }
139
142
143
144
145
146
149 auto nullPtr = rewriter.createLLVM::ZeroOp(loc, convertedPtrType);
150 auto gep = rewriter.createLLVM::GEPOp(loc, convertedPtrType, llvmType,
153 }
154
158 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
159 static_cast<ssize_t>(dynamicSizes.size()) &&
160 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
161
163 Value numElements = memRefType.getRank() == 0
165 : nullptr;
166 unsigned dynamicIndex = 0;
167
168
169 for (int64_t staticSize : memRefType.getShape()) {
170 if (numElements) {
172 staticSize == ShapedType::kDynamic
173 ? dynamicSizes[dynamicIndex++]
175 numElements = rewriter.createLLVM::MulOp(loc, numElements, size);
176 } else {
177 numElements =
178 staticSize == ShapedType::kDynamic
179 ? dynamicSizes[dynamicIndex++]
181 }
182 }
183 return numElements;
184 }
185
186
188 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
193
194
195 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
196
197
198 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
199
200
202 memRefDescriptor.setOffset(
204
205
207 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
208
209
211 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
212
213 return memRefDescriptor;
214 }
215
219 assert(origTypes.size() == operands.size() &&
220 "expected as may original types as operands");
221
222
225 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
226 if (auto memRefType = dyn_cast(origTypes[i])) {
227 unrankedMemrefs.emplace_back(operands[i]);
228 FailureOr addressSpace =
230 if (failed(addressSpace))
231 return failure();
232 unrankedAddressSpaces.emplace_back(*addressSpace);
233 }
234 }
235
236 if (unrankedMemrefs.empty())
237 return success();
238
239
242 unrankedMemrefs, unrankedAddressSpaces,
243 sizes);
244
245
247
248
249 auto module = builder.getInsertionPoint()->getParentOfType();
250 FailureOrLLVM::LLVMFuncOp freeFunc, mallocFunc;
251 if (toDynamic) {
253 if (failed(mallocFunc))
254 return failure();
255 }
256 if (!toDynamic) {
258 if (failed(freeFunc))
259 return failure();
260 }
261
262 unsigned unrankedMemrefPos = 0;
263 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
264 Type type = origTypes[i];
265 if (!isa(type))
266 continue;
267 Value allocationSize = sizes[unrankedMemrefPos++];
269
270
272 toDynamic
273 ? builder
274 .createLLVM::CallOp(loc, mallocFunc.value(), allocationSize)
275 .getResult()
278 allocationSize,
279 0);
281 builder.createLLVM::MemcpyOp(loc, memory, source, allocationSize, false);
282 if (!toDynamic)
283 builder.createLLVM::CallOp(loc, freeFunc.value(), source);
284
285
286
287
288
290 if (!descriptorType)
291 return failure();
292 auto updatedDesc =
294 Value rank = desc.rank(builder, loc);
295 updatedDesc.setRank(builder, loc, rank);
296 updatedDesc.setMemRefDescPtr(builder, loc, memory);
297
298 operands[i] = updatedDesc;
299 }
300
301 return success();
302 }
303
304
305
306
307
309 IntegerOverflowFlags overflowFlags) {
310 if (auto iface = dyn_cast(op))
311 iface.setOverflowFlags(overflowFlags);
312 }
313
314
315
320 IntegerOverflowFlags overflowFlags) {
322
324 if (numResults != 0) {
325 resultTypes.push_back(
327 if (!resultTypes.back())
328 return failure();
329 }
330
331
334 resultTypes, targetAttrs);
335
337
338
339 if (numResults == 0)
340 return rewriter.eraseOp(op), success();
341 if (numResults == 1)
343
344
345
347 results.reserve(numResults);
348 for (unsigned i = 0; i < numResults; ++i) {
349 results.push_back(rewriter.createLLVM::ExtractValueOp(
351 }
353 return success();
354 }
355
359 auto loc = op->getLoc();
360
361 if (!llvm::all_of(operands, [](Value value) {
363 }))
364 return failure();
365
367 Type resType;
368 if (numResults != 0)
370
371 auto callIntrOp = rewriter.createLLVM::CallIntrinsicOp(
372 loc, resType, rewriter.getStringAttr(intrinsic), operands);
373
375
376 if (numResults <= 1) {
377
378 rewriter.replaceOp(op, callIntrOp);
379 return success();
380 }
381
382
383
385 results.reserve(numResults);
386 Value intrRes = callIntrOp.getResults();
387 for (unsigned i = 0; i < numResults; ++i)
388 results.push_back(rewriter.createLLVM::ExtractValueOp(loc, intrRes, i));
390
391 return success();
392 }
393
397
398 auto vec = cast(type);
399 assert(!vec.isScalable() && "scalable vectors are not supported");
400 return vec.getNumElements() * getBitWidth(vec.getElementType());
401 }
402
404 int32_t value) {
406 return builder.createLLVM::ConstantOp(loc, i32, value);
407 }
408
412 if (srcType == dstType)
413 return {src};
414
415 unsigned srcBitWidth = getBitWidth(srcType);
416 unsigned dstBitWidth = getBitWidth(dstType);
417 if (srcBitWidth == dstBitWidth) {
418 Value cast = builder.createLLVM::BitcastOp(loc, dstType, src);
419 return {cast};
420 }
421
422 if (dstBitWidth > srcBitWidth) {
423 auto smallerInt = builder.getIntegerType(srcBitWidth);
424 if (srcType != smallerInt)
425 src = builder.createLLVM::BitcastOp(loc, smallerInt, src);
426
427 auto largerInt = builder.getIntegerType(dstBitWidth);
428 Value res = builder.createLLVM::ZExtOp(loc, largerInt, src);
429 return {res};
430 }
431 assert(srcBitWidth % dstBitWidth == 0 &&
432 "src bit width must be a multiple of dst bit width");
433 int64_t numElements = srcBitWidth / dstBitWidth;
435
436 src = builder.createLLVM::BitcastOp(loc, vecType, src);
437
439 for (auto i : llvm::seq(numElements)) {
441 Value elem = builder.createLLVM::ExtractElementOp(loc, src, idx);
442 res.emplace_back(elem);
443 }
444
445 return res;
446 }
447
449 Type dstType) {
450 assert(!src.empty() && "src range must not be empty");
451 if (src.size() == 1) {
452 Value res = src.front();
453 if (res.getType() == dstType)
454 return res;
455
457 unsigned dstBitWidth = getBitWidth(dstType);
458 if (dstBitWidth < srcBitWidth) {
459 auto largerInt = builder.getIntegerType(srcBitWidth);
460 if (res.getType() != largerInt)
461 res = builder.createLLVM::BitcastOp(loc, largerInt, res);
462
463 auto smallerInt = builder.getIntegerType(dstBitWidth);
464 res = builder.createLLVM::TruncOp(loc, smallerInt, res);
465 }
466
467 if (res.getType() != dstType)
468 res = builder.createLLVM::BitcastOp(loc, dstType, res);
469
470 return res;
471 }
472
473 int64_t numElements = src.size();
475 Value res = builder.createLLVM::PoisonOp(loc, srcType);
478 res = builder.createLLVM::InsertElementOp(loc, srcType, res, elem, idx);
479 }
480
481 if (res.getType() != dstType)
482 res = builder.createLLVM::BitcastOp(loc, dstType, res);
483
484 return res;
485 }
486
489 MemRefType type, Value memRefDesc,
491 LLVM::GEPNoWrapFlags noWrapFlags) {
492 auto [strides, offset] = type.getStridesAndOffset();
493
495
496
497
498
499 Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type);
500
501 LLVM::IntegerOverflowFlags intOverflowFlags =
502 LLVM::IntegerOverflowFlags::none;
503 if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
504 intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
505 }
506 if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
507 intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
508 }
509
512 for (int i = 0, e = indices.size(); i < e; ++i) {
513 Value increment = indices[i];
514 if (strides[i] != 1) {
516 ShapedType::isDynamic(strides[i])
517 ? memRefDescriptor.stride(builder, loc, i)
518 : builder.createLLVM::ConstantOp(
519 loc, indexType, builder.getIndexAttr(strides[i]));
520 increment =
521 builder.createLLVM::MulOp(loc, increment, stride, intOverflowFlags);
522 }
523 index = index ? builder.createLLVM::AddOp(loc, index, increment,
524 intOverflowFlags)
525 : increment;
526 }
527
529 return index ? builder.createLLVM::GEPOp(
530 loc, elementPtrType,
531 converter.convertType(type.getElementType()), base, index,
532 noWrapFlags)
533 : base;
534 }
static Value createI32Constant(OpBuilder &builder, Location loc, int32_t value)
static unsigned getBitWidth(Type type)
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Base class for the conversion patterns.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
const LLVMTypeConverter * getTypeConverter() const
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref type is convertible to LLVM and has an identity layout map.
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Conversion from types to the LLVM IR dialect.
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags)
Handle generically setting flags as native properties on LLVM operations.
LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic, ValueRange operands, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
Replaces the given operation "op" with a call to an LLVM intrinsic with the specified name "intrinsic...
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp)
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...