MLIR: lib/Conversion/GPUCommon/OpToFuncCallLowering.h Source File (original) (raw)

1

2

3

4

5

6

7

8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_

9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_

10

16

17 namespace mlir {

18

19 namespace {

20

21 template

22 using has_get_fastmath_t = decltype(std::declval().getFastmath());

23 }

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54 template

56 public:

65

66 LogicalResult

69 using LLVM::LLVMFuncOp;

70

71 static_assert(

73 "expected single result op");

74

75 bool isResultBool = op->getResultTypes().front().isInteger(1);

77 SourceOp>::value) {

78 assert(op->getNumOperands() > 0 &&

79 "expected op to take at least one operand");

80 assert((op->getResultTypes().front() == op->getOperand(0).getType() ||

81 isResultBool) &&

82 "expected op with same operand and result types");

83 }

84

85 if (!op->template getParentOfType()) {

87 op, "expected op to be within a function region");

88 }

89

91 for (Value operand : adaptor.getOperands())

92 castedOperands.push_back(maybeCast(operand, rewriter));

93

94 Type castedOperandType = castedOperands.front().getType();

95

96

97 Type resultType =

98 isResultBool ? rewriter.getIntegerType(32) : castedOperandType;

100 StringRef funcName = getFunctionName(castedOperandType, op);

101 if (funcName.empty())

102 return failure();

103

105 auto callOp =

106 rewriter.createLLVM::CallOp(op->getLoc(), funcOp, castedOperands);

107

108 if (resultType == adaptor.getOperands().front().getType()) {

109 rewriter.replaceOp(op, {callOp.getResult()});

110 return success();

111 }

112

113

114

115

116

117 if (isResultBool) {

118 Value zero = rewriter.createLLVM::ConstantOp(

121 Value truncated = rewriter.createLLVM::ICmpOp(

122 op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero);

123 rewriter.replaceOp(op, {truncated});

124 return success();

125 }

126

127 assert(callOp.getResult().getType().isF32() &&

128 "only f32 types are supposed to be truncated back");

129 Value truncated = rewriter.createLLVM::FPTruncOp(

130 op->getLoc(), adaptor.getOperands().front().getType(),

131 callOp.getResult());

132 rewriter.replaceOp(op, {truncated});

133 return success();

134 }

135

138 if (!isa<Float16Type, BFloat16Type>(type))

139 return operand;

140

141

142 if (f16Func.empty() && isa(type))

143 return operand;

144

145 return rewriter.createLLVM::FPExtOp(

147 }

148

152 }

153

156 using LLVM::LLVMFuncOp;

157

159 auto funcOp =

160 SymbolTable::lookupNearestSymbolFrom(op, funcAttr);

161 if (funcOp)

162 return funcOp;

163

164 auto parentFunc = op->getParentOfType();

165 assert(parentFunc && "expected there to be a parent function");

167 return b.create(op->getLoc(), funcName, funcType);

168 }

169

171 bool useApprox = false;

172 if constexpr (llvm::is_detected<has_get_fastmath_t, SourceOp>::value) {

173 arith::FastMathFlags flag = op.getFastmath();

174 useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&

176 }

177

178 if (isa(type))

180 if (isa(type)) {

181 if (useApprox)

184 }

185 if (isa(type))

187

190 return "";

191 }

192

198 };

199

200 }

201

202 #endif

IntegerAttr getI32IntegerAttr(int32_t value)

IntegerType getIntegerType(unsigned width)

MLIRContext * getContext() const

This class implements a pattern rewriter for use with ConversionPatterns.

void replaceOp(Operation *op, ValueRange newValues) override

Replace the given operation with the new values.

Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...

Conversion from types to the LLVM IR dialect.

This class helps build Operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class provides return value APIs for ops that are known to have a single result.

This class provides verification for ops that are known to have the same operand and result type.

Operation is the basic unit of execution within MLIR.

MLIRContext * getContext()

Return the context this operation is associated with.

Location getLoc()

The source location the operation was defined or derived from.

OpTy getParentOfType()

Return the closest surrounding parent operation that is of type 'OpTy'.

This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isInteger() const

Return true if this is an integer type (with the specified width).

This class provides an abstraction over the different types of ranges over Values.

type_range getTypes() const

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Location getLoc() const

Return the location of this value.

Include the generated interface declarations.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...

const std::string f64Func

const std::string f32ApproxFunc

StringRef getFunctionName(Type type, SourceOp op) const

const std::string f32Func

LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, Operation *op) const

const std::string f16Func

const std::string i32Func

LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override

Methods that operate on the SourceOp type.

OpToFuncCallLowering(const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func, StringRef i32Func="", PatternBenefit benefit=1)

Type getFunctionType(Type resultType, ValueRange operands) const

Value maybeCast(Value operand, PatternRewriter &rewriter) const