MLIR: lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

25 #include "llvm/Support/MathExtras.h"

26

28 #define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS

29 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"

30 }

31

32 using namespace mlir;

34

35

36

37

38

39

40

41

42

43

44

45

46

48 PatternRewriter &rewriter, VectorTransferOpInterface xferOp,

49 bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {

50 if (!xferOp.getMask())

51 return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");

52

53

54

55

57 if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(

58 &broadcastedDims))

59 return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");

60

61 auto memRefType = dyn_cast(xferOp.getShapedType());

62 if (!memRefType)

64

65 Attribute addrSpace = memRefType.getMemorySpace();

66 if (!isa_and_nonnullamdgpu::AddressSpaceAttr(addrSpace))

68

69 if (dyn_castamdgpu::AddressSpaceAttr(addrSpace).getValue() !=

70 amdgpu::AddressSpace::FatRawBuffer)

71 return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");

72

73

74 if (!memRefType.isLastDimUnitStride())

75 return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");

76

77 if (memRefType.getElementTypeBitWidth() < 8)

78 return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");

79

80

81

84 for (unsigned i : broadcastedDims)

85 unbroadcastedVectorShape[i] = 1;

86 unbroadcastedVectorType = xferOp.getVectorType().cloneWith(

87 unbroadcastedVectorShape, xferOp.getVectorType().getElementType());

88 requiresBroadcasting = !broadcastedDims.empty();

89

90

91

92 auto memrefElTy = memRefType.getElementType();

93 if (isa(memrefElTy) && memrefElTy != unbroadcastedVectorType)

94 return rewriter.notifyMatchFailure(xferOp, "incompatible element type");

95

96

97 if (!isa(memrefElTy) &&

98 memrefElTy != xferOp.getVectorType().getElementType())

99 return rewriter.notifyMatchFailure(xferOp, "non-matching element type");

100

101

102 if (xferOp.hasOutOfBoundsDim())

103 return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");

104

105 if (xferOp.getVectorType().getRank() != 1)

106

108 xferOp, "vector type is not rank 1, can't create masked load, needs "

109 "VectorToSCF");

110

111 return success();

112 }

113

115 vector::TransferReadOp readOp,

116 bool requiresBroadcasting,

117 VectorType unbroadcastedVectorType) {

118 Value fill = builder.createvector::SplatOp(loc, unbroadcastedVectorType,

119 readOp.getPadding());

120 Value load = builder.createvector::LoadOp(

121 loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices());

122 Value res = builder.createarith::SelectOp(loc, unbroadcastedVectorType,

123 readOp.getMask(), load, fill);

124

125 if (requiresBroadcasting) {

126 res = builder.createvector::BroadcastOp(loc, readOp.getVectorType(), res);

127 }

128 return res;

129 }

130

132 "amdgpu.buffer_transfer_read_needs_mask";

133

134 namespace {

135

136 struct TransferReadLowering final : OpRewritePatternvector::TransferReadOp {

138

139 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,

142 return failure();

143

144 bool requiresBroadcasting = false;

145 VectorType unbroadcastedVectorType;

147 unbroadcastedVectorType))) {

148 return failure();

149 }

150

151 Location loc = readOp.getLoc();

152 Value src = readOp.getBase();

153

154 VectorType vectorType = readOp.getVectorType();

155 int64_t vectorSize = vectorType.getNumElements();

156 int64_t elementBitWidth = vectorType.getElementTypeBitWidth();

158

159 auto stridedMetadata =

160 rewriter.creatememref::ExtractStridedMetadataOp(loc, src);

162 stridedMetadata.getConstifiedMixedStrides();

164 OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();

167 std::tie(linearizedInfo, linearizedIndices) =

169 elementBitWidth, offset, sizes,

170 strides, indices);

171

172

173 Value vectorSizeOffset =

174 rewriter.createarith::ConstantIndexOp(loc, vectorSize);

175 Value linearIndex =

179 Value delta = rewriter.createarith::SubIOp(loc, totalSize, linearIndex);

180

181

182 Value isOutofBounds = rewriter.createarith::CmpIOp(

183 loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);

184

185

186 Value elementsPerWord = rewriter.createarith::ConstantIndexOp(

188 Value isNotWordAligned = rewriter.createarith::CmpIOp(

189 loc, arith::CmpIPredicate::ne,

190 rewriter.createarith::RemUIOp(loc, delta, elementsPerWord),

191 rewriter.createarith::ConstantIndexOp(loc, 0));

192

193

194

195

196

197 Value ifCondition =

198 rewriter.createarith::AndIOp(loc, isOutofBounds, isNotWordAligned);

199

201 Operation *read = builder.clone(*readOp.getOperation());

204 builder.createscf::YieldOp(loc, readResult);

205 };

206

209 builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);

210 rewriter.createscf::YieldOp(loc, res);

211 };

212

213 auto ifOp =

214 rewriter.createscf::IfOp(loc, ifCondition, thenBuilder, elseBuilder);

215

216 rewriter.replaceOp(readOp, ifOp);

217

218 return success();

219 }

220 };

221

222 }

223

227 }

228

230 : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<

231 AmdgpuTransferReadToLoadPass> {

236 return signalPassFailure();

237 }

238 }

239 };

static MLIRContext * getContext(OpFoldResult val)

static std::optional< VectorShape > vectorShape(Type type)

static LogicalResult transferPreconditions(PatternRewriter &rewriter, VectorTransferOpInterface xferOp, bool &requiresBroadcasting, VectorType &unbroadcastedVectorType)

This pattern supports lowering of: vector.transfer_read to a combination of vector....

static constexpr char kTransferReadNeedsMask[]

static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, vector::TransferReadOp readOp, bool requiresBroadcasting, VectorType unbroadcastedVectorType)

Attributes are known-constant values of operations.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class helps build Operations.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents a single result from folding an operation.

Operation is the basic unit of execution within MLIR.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

void setAttr(StringAttr name, Attribute value)

If the an attribute exists with the specified name, change it to the new value.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns)

llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)

Divides the known min value of the numerator by the denominator and rounds the result up to the next ...

std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})

Include the generated interface declarations.

LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)

Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...

const FrozenRewritePatternSet & patterns

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

void runOnOperation() override

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...

For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...

OpFoldResult linearizedSize