MLIR: lib/Conversion/GPUCommon/OpToFuncCallLowering.h Source File (original) (raw)
1
2
3
4
5
6
7
8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
10
16
17 namespace mlir {
18
19 namespace {
20
21 template
22 using has_get_fastmath_t = decltype(std::declval().getFastmath());
23 }
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54 template
56 public:
65
66 LogicalResult
69 using LLVM::LLVMFuncOp;
70
71 static_assert(
73 "expected single result op");
74
75 bool isResultBool = op->getResultTypes().front().isInteger(1);
77 SourceOp>::value) {
78 assert(op->getNumOperands() > 0 &&
79 "expected op to take at least one operand");
80 assert((op->getResultTypes().front() == op->getOperand(0).getType() ||
81 isResultBool) &&
82 "expected op with same operand and result types");
83 }
84
85 if (!op->template getParentOfType()) {
87 op, "expected op to be within a function region");
88 }
89
91 for (Value operand : adaptor.getOperands())
92 castedOperands.push_back(maybeCast(operand, rewriter));
93
94 Type castedOperandType = castedOperands.front().getType();
95
96
97 Type resultType =
98 isResultBool ? rewriter.getIntegerType(32) : castedOperandType;
100 StringRef funcName = getFunctionName(castedOperandType, op);
101 if (funcName.empty())
102 return failure();
103
105 auto callOp =
106 rewriter.createLLVM::CallOp(op->getLoc(), funcOp, castedOperands);
107
108 if (resultType == adaptor.getOperands().front().getType()) {
109 rewriter.replaceOp(op, {callOp.getResult()});
110 return success();
111 }
112
113
114
115
116
117 if (isResultBool) {
118 Value zero = rewriter.createLLVM::ConstantOp(
121 Value truncated = rewriter.createLLVM::ICmpOp(
122 op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero);
123 rewriter.replaceOp(op, {truncated});
124 return success();
125 }
126
127 assert(callOp.getResult().getType().isF32() &&
128 "only f32 types are supposed to be truncated back");
129 Value truncated = rewriter.createLLVM::FPTruncOp(
130 op->getLoc(), adaptor.getOperands().front().getType(),
131 callOp.getResult());
132 rewriter.replaceOp(op, {truncated});
133 return success();
134 }
135
138 if (!isa<Float16Type, BFloat16Type>(type))
139 return operand;
140
141
142 if (.empty() && isa(type))
143 return operand;
144
145 return rewriter.createLLVM::FPExtOp(
147 }
148
152 }
153
156 using LLVM::LLVMFuncOp;
157
159 auto funcOp =
160 SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
161 if (funcOp)
162 return funcOp;
163
164 auto parentFunc = op->getParentOfType();
165 assert(parentFunc && "expected there to be a parent function");
167 return b.create(op->getLoc(), funcName, funcType);
168 }
169
171 bool useApprox = false;
172 if constexpr (llvm::is_detected<has_get_fastmath_t, SourceOp>::value) {
173 arith::FastMathFlags flag = op.getFastmath();
174 useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
176 }
177
178 if (isa(type))
180 if (isa(type)) {
181 if (useApprox)
184 }
185 if (isa(type))
187
190 return "";
191 }
192
198 };
199
200 }
201
202 #endif
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Conversion from types to the LLVM IR dialect.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class provides return value APIs for ops that are known to have a single result.
This class provides verification for ops that are known to have the same operand and result type.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
const std::string f64Func
const std::string f32ApproxFunc
StringRef getFunctionName(Type type, SourceOp op) const
const std::string f32Func
LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, Operation *op) const
const std::string f16Func
const std::string i32Func
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
OpToFuncCallLowering(const LLVMTypeConverter &lowering, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func, StringRef i32Func="", PatternBenefit benefit=1)
Type getFunctionType(Type resultType, ValueRange operands) const
Value maybeCast(Value operand, PatternRewriter &rewriter) const