MLIR: lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

23

24 namespace mlir {

25 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS

26 #include "mlir/Conversion/Passes.h.inc"

27 }

28

29 using namespace mlir;

31

32 namespace {

33

35

36 struct ArithToAMDGPUConversionPass final

37 : impl::ArithToAMDGPUConversionPassBase {

38 using impl::ArithToAMDGPUConversionPassBase<

39 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;

40

41 void runOnOperation() override;

42 };

43

44 struct ExtFOnFloat8RewritePattern final : OpRewritePatternarith::ExtFOp {

46

50

51 LogicalResult matchAndRewrite(arith::ExtFOp op,

53 };

54

55 struct TruncFToFloat8RewritePattern final : OpRewritePatternarith::TruncFOp {

56 bool saturateFP8 = false;

57 TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,

60 chipset(chipset) {}

62

63 LogicalResult matchAndRewrite(arith::TruncFOp op,

65 };

66

67 struct TruncfToFloat16RewritePattern final

69

71

72 LogicalResult matchAndRewrite(arith::TruncFOp op,

74 };

75

76 }

77

80 return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType);

82 return isa<Float8E4M3FNType, Float8E5M2Type>(elementType);

83 return false;

84 }

85

89 if (elementType.isF32())

90 return f32;

92 return rewriter.createarith::TruncFOp(loc, desType, f32);

94 return rewriter.createarith::ExtFOp(loc, desType, f32);

95 llvm_unreachable("The only 32-bit float type is f32");

96 }

97

98 LogicalResult

99 ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,

101 Type inType = op.getIn().getType();

102 auto inVecType = dyn_cast(inType);

103 if (inVecType) {

104 if (inVecType.isScalable())

105 return failure();

106 inType = inVecType.getElementType();

107 }

109 return failure();

110

112 Value in = op.getIn();

115 if (!inVecType) {

116 Value asFloat = rewriter.createamdgpu::ExtPackedFp8Op(

118 Value result = castF32To(outElemType, asFloat, loc, rewriter);

120 return success();

121 }

122 int64_t numElements = inVecType.getNumElements();

123

124 Value zero = rewriter.createarith::ConstantOp(

125 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));

126 VectorType outType = cast(op.getOut().getType());

127

128 if (inVecType.getShape().empty()) {

129 Value zerodSplat =

130 rewriter.createOrFoldvector::SplatOp(loc, outType, zero);

133 Value scalarExt =

134 rewriter.createarith::ExtFOp(loc, outElemType, scalarIn);

135 Value result = rewriter.createvector::InsertOp(loc, scalarExt, zerodSplat,

138 return success();

139 }

140

142 outType.getElementType());

143 Value result = rewriter.createOrFoldvector::SplatOp(loc, flatTy, zero);

144

145 if (inVecType.getRank() > 1) {

147 inVecType.getElementType());

148 in = rewriter.createvector::ShapeCastOp(loc, inVecType, in);

149 }

150

151 for (int64_t i = 0; i < numElements; i += 4) {

152 int64_t elemsThisOp = std::min(numElements, i + 4) - i;

153 Value inSlice = rewriter.createvector::ExtractStridedSliceOp(

154 loc, in, i, elemsThisOp, 1);

155 for (int64_t j = 0; j < elemsThisOp; j += 2) {

156 if (i + j + 1 < numElements) {

157 Value asFloats = rewriter.createamdgpu::ExtPackedFp8Op(

158 loc, extResType, inSlice, j / 2);

160 Value asType = castF32To(desType, asFloats, loc, rewriter);

161 result = rewriter.createvector::InsertStridedSliceOp(

162 loc, asType, result, i + j, 1);

163 } else {

164 Value asFloat = rewriter.createamdgpu::ExtPackedFp8Op(

165 loc, rewriter.getF32Type(), inSlice, j / 2 * 2);

166 Value asType = castF32To(outElemType, asFloat, loc, rewriter);

167 result = rewriter.createvector::InsertOp(loc, asType, result, i + j);

168 }

169 }

170 }

171

172 if (inVecType.getRank() != outType.getRank()) {

173 result = rewriter.createvector::ShapeCastOp(loc, outType, result);

174 }

175

177 return success();

178 }

179

182 if (type.isF32())

183 return value;

185 return rewriter.createarith::ExtFOp(loc, rewriter.getF32Type(), value);

187 return rewriter.createarith::TruncFOp(loc, rewriter.getF32Type(), value);

188 llvm_unreachable("The only 32-bit float type is f32");

189 }

190

191

192

193

194

195

197 Type outElemType, Value source) {

199 const llvm::fltSemantics &sourceSem =

201 const llvm::fltSemantics &targetSem =

202 cast(outElemType).getFloatSemantics();

203

204 APFloat min = APFloat::getLargest(targetSem, true);

205 APFloat max = APFloat::getLargest(targetSem, false);

206 bool ignoredLosesInfo = false;

207

208

209

210 (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);

211 (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);

212

215

217 rewriter, loc, sourceType,

218 APFloat::getInf(sourceSem, false));

220 rewriter, loc, sourceType, APFloat::getInf(sourceSem, true));

222 loc, arith::CmpFPredicate::OEQ, source, inf);

224 loc, arith::CmpFPredicate::OEQ, source, negInf);

226 loc, arith::CmpFPredicate::UNO, source, source);

227 Value isNonFinite = rewriter.createarith::OrIOp(

228 loc, rewriter.createarith::OrIOp(loc, isInf, isNegInf), isNan);

229

230 Value clampedBelow = rewriter.createarith::MaximumFOp(loc, source, minCst);

231 Value clamped = rewriter.createarith::MinimumFOp(loc, clampedBelow, maxCst);

233 rewriter.createarith::SelectOp(loc, isNonFinite, source, clamped);

234 return res;

235 }

236

237 LogicalResult

238 TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,

240

241 if (op.getRoundingmodeAttr())

242 return failure();

243 Type outType = op.getOut().getType();

244 auto outVecType = dyn_cast(outType);

245 if (outVecType) {

246 if (outVecType.isScalable())

247 return failure();

248 outType = outVecType.getElementType();

249 }

250 auto inType = dyn_cast(getElementTypeOrSelf(op.getIn().getType()));

251 if (inType && inType.getWidth() <= 8 && saturateFP8)

252

253 return failure();

254

256 return failure();

257

259 Value in = op.getIn();

261 if (saturateFP8)

262 in = clampInput(rewriter, loc, outElemType, in);

263 auto inVectorTy = dyn_cast(in.getType());

264 VectorType truncResType = VectorType::get(4, outElemType);

265 if (!inVectorTy) {

267 Value asF8s = rewriter.createamdgpu::PackedTrunc2xFp8Op(

268 loc, truncResType, asFloat, nullptr, 0,

269 nullptr);

270 Value result = rewriter.createvector::ExtractOp(loc, asF8s, 0);

272 return success();

273 }

274

275 int64_t numElements = outVecType.getNumElements();

276 Value zero = rewriter.createarith::ConstantOp(

277 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));

278 if (outVecType.getShape().empty()) {

281

282 Value scalarTrunc =

283 rewriter.createarith::TruncFOp(loc, outElemType, scalarIn);

284 Value result = rewriter.createvector::InsertOp(loc, scalarTrunc, zero,

287 return success();

288 }

289

291 outVecType.getElementType());

292 Value result = rewriter.createOrFoldvector::SplatOp(loc, flatTy, zero);

293

294 if (inVectorTy.getRank() > 1) {

296 inVectorTy.getElementType());

297 in = rewriter.createvector::ShapeCastOp(loc, inVectorTy, in);

298 }

299

300 for (int64_t i = 0; i < numElements; i += 4) {

301 int64_t elemsThisOp = std::min(numElements, i + 4) - i;

302 Value thisResult = nullptr;

303 for (int64_t j = 0; j < elemsThisOp; j += 2) {

304 Value elemA = rewriter.createvector::ExtractOp(loc, in, i + j);

306 Value asFloatB = nullptr;

307 if (j + 1 < elemsThisOp) {

308 Value elemB = rewriter.createvector::ExtractOp(loc, in, i + j + 1);

309 asFloatB = castToF32(elemB, loc, rewriter);

310 }

311 thisResult = rewriter.createamdgpu::PackedTrunc2xFp8Op(

312 loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);

313 }

314 if (elemsThisOp < 4)

315 thisResult = rewriter.createvector::ExtractStridedSliceOp(

316 loc, thisResult, 0, elemsThisOp, 1);

317 result = rewriter.createvector::InsertStridedSliceOp(loc, thisResult,

318 result, i, 1);

319 }

320

321 if (inVectorTy.getRank() != outVecType.getRank()) {

322 result = rewriter.createvector::ShapeCastOp(loc, outVecType, result);

323 }

324

326 return success();

327 }

328

329 LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(

331 Type outType = op.getOut().getType();

333 auto outVecType = dyn_cast(outType);

334 if (outVecType) {

335 if (outVecType.isScalable())

336 return failure();

337 outType = outVecType.getElementType();

338 }

339 if (!(outType.isF16() && inputType.isF32()))

340 return failure();

341

343 Value in = op.getIn();

345 VectorType truncResType = VectorType::get(2, outElemType);

346 auto inVectorTy = dyn_cast(in.getType());

347

348

349 if (!inVectorTy) {

350 auto sourceB = rewriter.createLLVM::PoisonOp(loc, rewriter.getF32Type());

352 rewriter.createROCDL::CvtPkRtz(loc, truncResType, in, sourceB);

353 Value result = rewriter.createvector::ExtractOp(loc, asF16s, 0);

355 return success();

356 }

357 int64_t numElements = outVecType.getNumElements();

359 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));

360 Value result = rewriter.createOrFoldvector::SplatOp(loc, outVecType, zero);

361

362 if (inVectorTy.getRank() > 1) {

364 inVectorTy.getElementType());

365 in = rewriter.createvector::ShapeCastOp(loc, inVectorTy, in);

366 }

367

368

369

370 for (int64_t i = 0; i < numElements; i += 2) {

371 int64_t elemsThisOp = std::min(numElements, i + 2) - i;

372 Value thisResult = nullptr;

373 Value elemA = rewriter.createvector::ExtractOp(loc, in, i);

375

376 if (elemsThisOp == 2) {

377 elemB = rewriter.createvector::ExtractOp(loc, in, i + 1);

378 }

379

380 thisResult =

381 rewriter.createROCDL::CvtPkRtz(loc, truncResType, elemA, elemB);

382

383

384 thisResult = rewriter.createvector::ExtractStridedSliceOp(

385 loc, thisResult, 0, elemsThisOp, 1);

386 result = rewriter.createvector::InsertStridedSliceOp(loc, thisResult,

387 result, i, 1);

388 }

389

390 if (inVectorTy.getRank() != outVecType.getRank()) {

391 result = rewriter.createvector::ShapeCastOp(loc, outVecType, result);

392 }

393

395 return success();

396 }

397

400 bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {

401

402 if (convertFP8Arithmetic) {

403 patterns.add(patterns.getContext(), chipset);

404 patterns.add(patterns.getContext(),

405 saturateFP8Truncf, chipset);

406 }

407 if (allowPackedF16Rtz)

408 patterns.add(patterns.getContext());

409 }

410

411 void ArithToAMDGPUConversionPass::runOnOperation() {

416 if (failed(maybeChipset)) {

418 return signalPassFailure();

419 }

420

421 bool convertFP8Arithmetic =

424 patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,

425 *maybeChipset);

427 return signalPassFailure();

428 }

constexpr Chipset kGfx942

static Value castF32To(Type desType, Value f32, Location loc, PatternRewriter &rewriter)

static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)

static bool isSupportedF8(Type elementType, Chipset chipset)

static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)

static MLIRContext * getContext(OpFoldResult val)

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)

FloatAttr getFloatAttr(Type type, double value)

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

Operation is the basic unit of execution within MLIR.

MLIRContext * getContext()

Return the context this operation is associated with.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

bool hasOcpFp8(const Chipset &chipset)

void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool convertFP8Arithmetic, bool saturateFP8Truncf, bool allowPackedF16Rtz, amdgpu::Chipset chipset)

Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...

Include the generated interface declarations.

Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)

Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...

LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)

Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...

Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.

static FailureOr< Chipset > parse(StringRef name)

Parses the chipset version string and returns the chipset on success, and failure otherwise.

Eliminates variable at the specified position using Fourier-Motzkin variable elimination.