MLIR: lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
20
22 #define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
23 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
24 }
25
26 using namespace mlir;
28
29 namespace {
30 struct AmdgpuEmulateAtomicsPass
31 : public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
32 AmdgpuEmulateAtomicsPass> {
33 using AmdgpuEmulateAtomicsPassBase<
34 AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
35 void runOnOperation() override;
36 };
37
38 template <typename AtomicOp, typename ArithOp>
39 struct RawBufferAtomicByCasPattern : public OpConversionPattern {
41 using Adaptor = typename AtomicOp::Adaptor;
42
43 LogicalResult
44 matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
46 };
47 }
48
49 namespace {
50 enum class DataArgAction : unsigned char {
51 Duplicate,
52 Drop,
53 };
54 }
55
56
57
58
59
60
63 DataArgAction action) {
64 newAttrs.reserve(attrs.size());
66 if (attr.getName().getValue() != "operandSegmentSizes") {
67 newAttrs.push_back(attr);
68 continue;
69 }
70 auto segmentAttr = cast(attr.getValue());
71 MLIRContext *context = segmentAttr.getContext();
73 switch (action) {
74 case DataArgAction::Drop:
76 context, segmentAttr.asArrayRef().drop_front());
77 break;
78 case DataArgAction::Duplicate: {
81 newVals.push_back(oldVals[0]);
82 newVals.append(oldVals.begin(), oldVals.end());
84 break;
85 }
86 }
87 newAttrs.push_back(NamedAttribute(attr.getName(), newSegments));
88 }
89 }
90
91
92
95 auto vectorType = dyn_cast(val.getType());
96 if (!vectorType)
97 return val;
98
99 int64_t bitwidth =
100 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
103 Value bitcast = rewriter.createvector::BitCastOp(loc, allBitsVecType, val);
104 Value scalar = rewriter.createvector::ExtractOp(loc, bitcast, 0);
105 return scalar;
106 }
107
108 template <typename AtomicOp, typename ArithOp>
109 LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
110 AtomicOp atomicOp, Adaptor adaptor,
112 Location loc = atomicOp.getLoc();
113
115 ValueRange operands = adaptor.getOperands();
116 Value data = operands.take_front()[0];
117 ValueRange invariantArgs = operands.drop_front();
119
122 Value initialLoad =
123 rewriter.create(loc, dataType, invariantArgs, loadAttrs);
125 Block *afterAtomic =
127 Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
128
130 rewriter.createcf::BranchOp(loc, loopBlock, initialLoad);
131
133 Value prevLoad = loopBlock->getArgument(0);
134 Value operated = rewriter.create(loc, data, prevLoad);
135 dataType = operated.getType();
136
140 cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
141 Value atomicRes = rewriter.create(
142 loc, dataType, cmpswapArgs, cmpswapAttrs);
143
144
145
146
147
148
151 if (auto floatDataTy = dyn_cast(dataType)) {
153 prevLoadForCompare =
154 rewriter.createarith::BitcastOp(loc, equivInt, prevLoad);
155 atomicResForCompare =
156 rewriter.createarith::BitcastOp(loc, equivInt, atomicRes);
157 }
158 Value canLeave = rewriter.createarith::CmpIOp(
159 loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
160 rewriter.createcf::CondBranchOp(loc, canLeave, afterAtomic, ValueRange{},
161 loopBlock, atomicRes);
162 rewriter.eraseOp(atomicOp);
163 return success();
164 }
165
168
170 target.addIllegalOp();
171 }
172
175 [](RawBufferAtomicFaddOp op) -> bool {
177 return !isa<Float16Type, BFloat16Type>(elemType);
178 });
179 }
180
182 if (chipset >= Chipset(9, 0, 0xa)) {
183
184
186 [](RawBufferAtomicFmaxOp op) -> bool {
187 return op.getValue().getType().isF64();
188 });
189 } else {
190 target.addIllegalOp();
191 }
192
193
194 if (chipset < Chipset(9, 5, 0)) {
196 [](RawBufferAtomicFaddOp op) -> bool {
198 return !isa(elemType);
199 });
200 }
201 }
203 RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
204 RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
205 RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
206 RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
208 }
209
210 void AmdgpuEmulateAtomicsPass::runOnOperation() {
212 FailureOr maybeChipset = Chipset::parse(chipset);
213 if (failed(maybeChipset)) {
214 emitError(op->getLoc(), "Invalid chipset name: " + chipset);
215 return signalPassFailure();
216 }
217
221 target.markUnknownOpDynamicallyLegal(
222 [](Operation *op) -> bool { return true; });
223
226 return signalPassFailure();
227 }
static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, Value val)
static void patchOperandSegmentSizes(ArrayRef< NamedAttribute > attrs, SmallVectorImpl< NamedAttribute > &newAttrs, DataArgAction action)
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef.
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.