MLIR: lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

13 #include "llvm/ADT/STLForwardCompat.h"

14 #include "llvm/Support/ErrorHandling.h"

15 #include "llvm/Support/MathExtras.h"

16 #include

17

18 using namespace mlir;

20

21

23

25

27

29 unsigned width = IndexType::kInternalStorageBitWidth;

31 APInt(width, umax));

32 }

33

34 namespace {

35 enum class LaunchDims : uint32_t { Block = 0, Grid = 1 };

36 }

37

38

39

40

41

42

43

45 switch (dim) {

46 case Dimension::x:

47 return dims.x;

48 case Dimension::y:

49 return dims.y;

50 case Dimension::z:

51 return dims.z;

52 }

53 llvm_unreachable("All dimension enum cases handled above");

54 }

55

56 static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }

57

58 static std::optional<uint64_t>

61 switch (dims) {

62 case LaunchDims::Block:

63 bounds = func.getKnownBlockSizeAttr();

64 break;

65 case LaunchDims::Grid:

66 bounds = func.getKnownGridSizeAttr();

67 break;

68 }

69 if (!bounds)

70 return std::nullopt;

71 if (bounds.size() < static_cast<uint32_t>(dim))

72 return std::nullopt;

73 return zext(bounds[static_cast<uint32_t>(dim)]);

74 }

75

77 StringRef attrName,

78 Dimension dim) {

79 auto bounds = func.getOperation()->getAttrOfType<DenseI32ArrayAttr>(attrName);

80 if (!bounds)

81 return std::nullopt;

82 if (bounds.size() < static_cast<uint32_t>(dim))

83 return std::nullopt;

84 return zext(bounds[static_cast<uint32_t>(dim)]);

85 }

86

87 template

89 Dimension dim = op.getDimension();

90 if (auto launch = op->template getParentOfType()) {

92 switch (type) {

93 case LaunchDims::Block:

94 bounds = launch.getBlockSizeOperandValues();

95 break;

96 case LaunchDims::Grid:

97 bounds = launch.getGridSizeOperandValues();

98 break;

99 }

101 APInt value;

103 return value.getZExtValue();

104 }

105

106 if (auto gpuFunc = op->template getParentOfType()) {

108 if (inherentAttr)

109 return inherentAttr;

110 }

111 if (auto func = op->template getParentOfType()) {

112 StringRef attrName;

113 switch (type) {

114 case LaunchDims::Block:

115 attrName = GPUDialect::KnownBlockSizeAttrHelper::getNameStr();

116 break;

117 case LaunchDims::Grid:

118 attrName = GPUDialect::KnownGridSizeAttrHelper::getNameStr();

119 break;

120 }

122 if (discardableAttr)

123 return discardableAttr;

124 }

125 return std::nullopt;

126 }

127

132 max = specified->getZExtValue();

134 }

135

140 max = specified->getZExtValue();

142 }

143

148 max = specified->getZExtValue();

150 }

151

156 max = specified->getZExtValue();

158 }

159

162 std::optional<uint64_t> knownVal =

164 if (knownVal)

165 return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));

166 ;

169 max = specified->getZExtValue();

171 }

172

176 if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Grid))

177 max = fromContext.value();

179 max = specified->getZExtValue();

181 }

182

185 std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);

186 if (knownVal)

187 return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));

190 max = specified->getZExtValue();

192 }

193

197 if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Block))

198 max = fromContext.value();

200 max = specified->getZExtValue();

202 }

203

208 max = specified->getZExtValue();

210 }

211

216 max = specified->getZExtValue();

218 }

219

223 return setResultRange(getResult(),

224 getIndexRange(0, specified->getZExtValue() - 1ULL));

225

226 uint64_t blockDimMax =

228 uint64_t gridDimMax =

230 setResultRange(getResult(),

231 getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));

232 }

233

238 max = specified->getZExtValue();

240 }

241

246 max = specified->getZExtValue();

248 }

249

253 Value idxResult) {

254 if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)

255 return;

258 setResultRange(dimResult, dimRange);

261 setResultRange(idxResult, idxRange);

262 };

263

264 argRanges = argRanges.drop_front(getAsyncDependencies().size());

267 setRange(argRanges[0], gridDims.x, blockIds.x);

268 setRange(argRanges[1], gridDims.y, blockIds.y);

269 setRange(argRanges[2], gridDims.z, blockIds.z);

271 KernelDim3 threadIds = getThreadIds();

272 setRange(argRanges[3], blockDims.x, threadIds.x);

273 setRange(argRanges[4], blockDims.y, threadIds.y);

274 setRange(argRanges[5], blockDims.z, threadIds.z);

275 }

static std::optional< int64_t > getUpperBound(Value iv)

Gets the constant upper bound on an affine.for iv.

static Value valueByDim(KernelDim3 dims, Dimension dim)

If the operation op is in a context that is annotated with maximum launch dimensions (a launch op wit...

static constexpr uint64_t kMaxClusterDim

static std::optional< uint64_t > getKnownLaunchDim(Op op, LaunchDims type)

static constexpr uint64_t kMaxDim

static uint64_t zext(uint32_t arg)

static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax)

static constexpr uint64_t kMaxSubgroupSize

static std::optional< uint64_t > getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim)

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

Block represents an ordered list of Operations.

A set of arbitrary-precision integers representing bounds on a given integer value.

static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)

Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...

ConstantIntRanges intersection(const ConstantIntRanges &other) const

Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...

const APInt & umax() const

The maximum value of an integer when it is interpreted as unsigned.

const APInt & umin() const

The minimum value of an integer when it is interpreted as unsigned.

This provides public APIs that all operations should have.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)

Given the dimToLvl map, returns the block sizes in a vector.

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...

Utility class for the GPU dialect to represent triples of Values accessible through ....