MLIR: lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

17

18using namespace mlir;

19

20

21

22template

24 return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&

25 op.hasPureBufferSemantics() &&

27 .isLastDimUnitStride();

28}

29

30

31

33 if (auto transferWrite = dyn_castvector::TransferWriteOp(write))

35

36 return isavector::StoreOp(write);

37}

38

39

40

42 if (auto transferRead = dyn_castvector::TransferReadOp(read))

44

45 return isavector::LoadOp(read);

46}

47

48namespace {

49

50struct TransferMask {

51 vector::CreateMaskOp createMaskOp;

52 SmallVector<int64_t> extractPosition;

53};

54}

55

56

57

59 auto transferRead = dyn_castvector::TransferReadOp(loadOp);

60 if (!transferRead || !transferRead.getMask())

61 return TransferMask{{}, {}};

62 assert(transferRead.getMask().getType().getRank() == 1 &&

63 "expected 1-D mask");

64

65

66 if (auto maskOp =

67 transferRead.getMask().getDefiningOpvector::CreateMaskOp())

68 return TransferMask{maskOp, {}};

69

70

71 if (auto extractOp =

72 transferRead.getMask().getDefiningOpvector::ExtractOp())

73 if (auto maskOp =

74 extractOp.getSource().getDefiningOpvector::CreateMaskOp())

75 return TransferMask{maskOp,

77

78

79 return failure();

80}

81

82

85 FailureOr transferMask = getMaskOp(readOp);

86 assert(succeeded(transferMask) && "invalid transfer mask");

87

88

89 if (!transferMask->createMaskOp)

91

92

93 if (transferMask->extractPosition.empty()) {

94 assert(transferMask->createMaskOp.getNumOperands() == 1 &&

95 "expected single operand");

96 return transferMask->createMaskOp.getOperand(0);

97 }

98

99

100

101

102

103 assert(transferMask->createMaskOp.getVectorType().getRank() -

104 transferMask->extractPosition.size() ==

105 1 &&

106 "expected N-D -> (N-1)-D extract");

108

109 for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,

110 transferMask->createMaskOp->getOperands())) {

112 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt,

114 if (!cond) {

115 cond = cmp;

116 continue;

117 }

118 cond = arith::AndIOp::create(b, loc, cmp, cond);

119 }

120 return arith::SelectOp::create(

121 b, loc, cond, transferMask->createMaskOp->getOperands().back(),

123}

124

125

127 VectorType vecType) {

128 assert(vecType.getRank() == 1 && "expected 1-D vector");

129 constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};

130

131

132 bool supportedCopySize = false;

133 int64_t numElements = vecType.getNumElements();

134 Type elementType = vecType.getElementType();

135 for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) {

136 if (alignmentInBytes * 8 ==

138 supportedCopySize = true;

139 break;

140 }

141 }

142 if (!supportedCopySize)

143 return false;

144

145

146

147

148

149 return true;

150}

151

153 bool bypassL1) {

154 llvm::SmallSetVector<Operation *, 16> copyToSharedMem;

155

156

158

160 return;

162 if (cast(vectorVal.getType()).getRank() != 1)

163 return;

165 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(

166 cast(storeBase.getType())))

167 return;

168

169

172 return;

174

175 if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(

176 cast(loadBase.getType())))

177 return;

178

179

180 if (auto transferRead = dyn_castvector::TransferReadOp(readOp)) {

181 if (Value mask = transferRead.getMask()) {

184 return;

186 return;

187 }

188 }

189

190

191

192 VectorType vecType = cast(vectorVal.getType());

193

195 vecType) ||

197 vecType))

198 return;

199

200 copyToSharedMem.insert(writeOp);

201 return;

202 });

203

204 while (!copyToSharedMem.empty()) {

205

207 Operation *writeOp = *copyToSharedMem.begin();

208 copyToSharedMem.remove(writeOp);

209 group.push_back(writeOp);

211

212

213 while ((nextNode = nextNode->getNextNode())) {

214

215 auto memInterface = dyn_cast(nextNode);

216 if (memInterface && memInterface.hasNoEffect() &&

218 continue;

219

220 if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) {

223 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(

224 cast(memrefOperand.getType()))) {

225 continue;

226 }

227 }

228 if (copyToSharedMem.count(nextNode)) {

229

230 copyToSharedMem.remove(nextNode);

231 group.push_back(nextNode);

232 continue;

233 }

234

235 break;

236 }

237

238

240 for (Operation *writeOp : group) {

243 auto vectorType = cast(vectorVal.getType());

244 int64_t numElements = vectorType.getNumElements();

248 Value numReadElements =

250 auto dstMemref = cast(storeBase.getType());

252 (dstMemref.getElementTypeBitWidth() * numElements) / 8;

253

254 Value token = nvgpu::DeviceAsyncCopyOp::create(

255 rewriter, writeOp->getLoc(),

256 nvgpu::DeviceAsyncTokenType::get(op->getContext()),

258 loadBase,

260 rewriter.getIndexAttr(numElements),

261 numReadElements,

262 bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr()

263 : UnitAttr());

264 tokens.push_back(token);

265 }

266

267

268 Value groupToken = nvgpu::DeviceAsyncCreateGroupOp::create(

269 rewriter, op->getLoc(),

270 nvgpu::DeviceAsyncTokenType::get(op->getContext()), tokens);

271 nvgpu::DeviceAsyncWaitOp::create(rewriter, op->getLoc(), groupToken,

272 nullptr);

273

275 rewriter.eraseOp(writeOp);

276 }

277}

static bool isContiguousStore(Operation *write)

Return "true" if the given op is a contiguous and suitable vector.transfer_write or vector....

Definition CreateAsyncGroups.cpp:32

static bool isContiguousXferOp(OpTy op)

Return "true" if the given vector transfer op is contiguous and suitable for replacement with an asyn...

Definition CreateAsyncGroups.cpp:23

static FailureOr< TransferMask > getMaskOp(Operation *loadOp)

If the given vector load op has a mask that is defined by vector.create_mask, return that op.

Definition CreateAsyncGroups.cpp:58

static bool resultsInSupportedAsyncCopy(MemRefType memrefType, VectorType vecType)

Return "true" if the conversion to async copy is supported by "async copy".

Definition CreateAsyncGroups.cpp:126

static bool isContiguousRead(Operation *read)

Return "true" if the given op is a contiguous and suitable vector.transfer_read or vector....

Definition CreateAsyncGroups.cpp:41

static Value buildNumReadElements(OpBuilder &b, Location loc, Operation *readOp)

Build an SSA value that represents the number of read elements.

Definition CreateAsyncGroups.cpp:83

b

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...

IntegerAttr getIndexAttr(int64_t value)

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class helps build Operations.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

This trait indicates that the memory effects of an operation includes the effects of operations neste...

Operation is the basic unit of execution within MLIR.

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

Location getLoc()

The source location the operation was defined or derived from.

void remove()

Remove the operation from its parent block, but don't delete it.

std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)

Walk the operation by calling the callback for each nested operation (including this one),...

MLIRContext * getContext()

Return the context this operation is associated with.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)

Value getMemrefOperand(Operation *op)

Get the memref that is loaded from/stored into by the given load/store operation.

Value getValueStored(Operation *op)

Get the value that is stored by the given store operation.

Operation::operand_range getIndices(Operation *op)

Get the indices that the given load/store operation is operating on.

void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)

Convert global->shared vector transfers to async device copies.

Definition CreateAsyncGroups.cpp:152

Include the generated interface declarations.

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.