MLIR: lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

20

22 #define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS

23 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"

24 }

25

26 using namespace mlir;

28

29 namespace {

30 struct AmdgpuEmulateAtomicsPass

31 : public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<

32 AmdgpuEmulateAtomicsPass> {

33 using AmdgpuEmulateAtomicsPassBase<

34 AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;

35 void runOnOperation() override;

36 };

37

38 template <typename AtomicOp, typename ArithOp>

39 struct RawBufferAtomicByCasPattern : public OpConversionPattern {

41 using Adaptor = typename AtomicOp::Adaptor;

42

43 LogicalResult

44 matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,

46 };

47 }

48

49 namespace {

50 enum class DataArgAction : unsigned char {

51 Duplicate,

52 Drop,

53 };

54 }

55

56

57

58

59

60

63 DataArgAction action) {

64 newAttrs.reserve(attrs.size());

66 if (attr.getName().getValue() != "operandSegmentSizes") {

67 newAttrs.push_back(attr);

68 continue;

69 }

70 auto segmentAttr = cast(attr.getValue());

71 MLIRContext *context = segmentAttr.getContext();

73 switch (action) {

74 case DataArgAction::Drop:

76 context, segmentAttr.asArrayRef().drop_front());

77 break;

78 case DataArgAction::Duplicate: {

81 newVals.push_back(oldVals[0]);

82 newVals.append(oldVals.begin(), oldVals.end());

84 break;

85 }

86 }

87 newAttrs.push_back(NamedAttribute(attr.getName(), newSegments));

88 }

89 }

90

91

92

95 auto vectorType = dyn_cast(val.getType());

96 if (!vectorType)

97 return val;

98

99 int64_t bitwidth =

100 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();

103 Value bitcast = rewriter.createvector::BitCastOp(loc, allBitsVecType, val);

104 Value scalar = rewriter.createvector::ExtractOp(loc, bitcast, 0);

105 return scalar;

106 }

107

108 template <typename AtomicOp, typename ArithOp>

109 LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(

110 AtomicOp atomicOp, Adaptor adaptor,

112 Location loc = atomicOp.getLoc();

113

115 ValueRange operands = adaptor.getOperands();

116 Value data = operands.take_front()[0];

117 ValueRange invariantArgs = operands.drop_front();

119

122 Value initialLoad =

123 rewriter.create(loc, dataType, invariantArgs, loadAttrs);

125 Block *afterAtomic =

127 Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});

128

130 rewriter.createcf::BranchOp(loc, loopBlock, initialLoad);

131

133 Value prevLoad = loopBlock->getArgument(0);

134 Value operated = rewriter.create(loc, data, prevLoad);

135 dataType = operated.getType();

136

140 cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());

141 Value atomicRes = rewriter.create(

142 loc, dataType, cmpswapArgs, cmpswapAttrs);

143

144

145

146

147

148

151 if (auto floatDataTy = dyn_cast(dataType)) {

153 prevLoadForCompare =

154 rewriter.createarith::BitcastOp(loc, equivInt, prevLoad);

155 atomicResForCompare =

156 rewriter.createarith::BitcastOp(loc, equivInt, atomicRes);

157 }

158 Value canLeave = rewriter.createarith::CmpIOp(

159 loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);

160 rewriter.createcf::CondBranchOp(loc, canLeave, afterAtomic, ValueRange{},

161 loopBlock, atomicRes);

162 rewriter.eraseOp(atomicOp);

163 return success();

164 }

165

168

170 target.addIllegalOp();

171 }

172

175 [](RawBufferAtomicFaddOp op) -> bool {

177 return !isa<Float16Type, BFloat16Type>(elemType);

178 });

179 }

180

182 if (chipset >= Chipset(9, 0, 0xa)) {

183

184

186 [](RawBufferAtomicFmaxOp op) -> bool {

187 return op.getValue().getType().isF64();

188 });

189 } else {

190 target.addIllegalOp();

191 }

192

193

194 if (chipset < Chipset(9, 5, 0)) {

196 [](RawBufferAtomicFaddOp op) -> bool {

198 return !isa(elemType);

199 });

200 }

201 }

203 RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,

204 RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,

205 RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,

206 RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(

208 }

209

210 void AmdgpuEmulateAtomicsPass::runOnOperation() {

212 FailureOr maybeChipset = Chipset::parse(chipset);

213 if (failed(maybeChipset)) {

214 emitError(op->getLoc(), "Invalid chipset name: " + chipset);

215 return signalPassFailure();

216 }

217

221 target.markUnknownOpDynamicallyLegal(

222 [](Operation *op) -> bool { return true; });

223

226 return signalPassFailure();

227 }

static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, Value val)

static void patchOperandSegmentSizes(ArrayRef< NamedAttribute > attrs, SmallVectorImpl< NamedAttribute > &newAttrs, DataArgAction action)

static MLIRContext * getContext(OpFoldResult val)

Block represents an ordered list of Operations.

IntegerType getIntegerType(unsigned width)

This class implements a pattern rewriter for use with ConversionPatterns.

void eraseOp(Operation *op) override

PatternRewriter hook for erasing a dead operation.

This class describes a specific conversion target.

void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)

Register the given operation as dynamically legal and set the dynamic legalization callback to the on...

void addIllegalOp(OperationName op)

Register the given operation as illegal, i.e.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

NamedAttribute represents a combination of a name and an Attribute value.

Block::iterator getInsertionPoint() const

Returns the current insertion point of the builder.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

Block * getInsertionBlock() const

Return the block the current insertion point belongs to.

OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...

Operation is the basic unit of execution within MLIR.

Location getLoc()

The source location the operation was defined or derived from.

Block * splitBlock(Block *block, Block::iterator before)

Split the operations starting at "before" (inclusive) out of the given block into a new block,...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)

Builder from ArrayRef.

void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset)

Include the generated interface declarations.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())

Below we define several entry points for operation conversion.

Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.

static FailureOr< Chipset > parse(StringRef name)

Parses the chipset version string and returns the chipset on success, and failure otherwise.