LLVM: include/llvm/IR/MatrixBuilder.h Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14#ifndef LLVM_IR_MATRIXBUILDER_H
15#define LLVM_IR_MATRIXBUILDER_H
16
26
27namespace llvm {
28
32
35 Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
36
37 std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
39 assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
40 "One of the operands must be a matrix (embedded in a vector)");
41 if (LHS->getType()->isVectorTy() && ->getType()->isVectorTy()) {
43 "LHS Assumed to be fixed width");
44 RHS = B.CreateVectorSplat(
46 "scalar.splat");
47 } else if (->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
49 "RHS Assumed to be fixed width");
50 LHS = B.CreateVectorSplat(
52 "scalar.splat");
53 }
55 }
56
57public:
59
60
61
62
63
64
65
67 Value *Stride, bool IsVolatile, unsigned Rows,
68 unsigned Columns, const Twine &Name = "") {
70
71 Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
72 B.getInt32(Columns)};
73 Type *OverloadedTypes[] = {RetType, Stride->getType()};
74
76 getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
77
81 Call->addParamAttr(0, AlignAttr);
83 }
84
85
86
87
88
90 Value *Stride, bool IsVolatile,
91 unsigned Rows, unsigned Columns,
92 const Twine &Name = "") {
94 Stride, B.getInt1(IsVolatile),
95 B.getInt32(Rows), B.getInt32(Columns)};
97
99 getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
100
104 Call->addParamAttr(1, AlignAttr);
106 }
107
108
109
111 unsigned Columns, const Twine &Name = "") {
113 auto *ReturnType =
115
116 Type *OverloadedTypes[] = {ReturnType};
117 Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
119 getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
120
122 }
123
124
125
127 unsigned LHSColumns, unsigned RHSColumns,
128 const Twine &Name = "") {
131
132 auto *ReturnType =
134
135 Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
136 B.getInt32(RHSColumns)};
137 Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
138
140 getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
142 }
143
144
145
147 Value *ColumnIdx, unsigned NumRows) {
148 return B.CreateInsertElement(
150 B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
151 ColumnIdx->getType(), NumRows)),
152 RowIdx));
153 }
154
155
156
158 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
159 if (LHS->getType()->isVectorTy() && ->getType()->isVectorTy()) {
161 "LHS Assumed to be fixed width");
162 RHS = B.CreateVectorSplat(
164 "scalar.splat");
165 } else if (->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
167 "RHS Assumed to be fixed width");
168 LHS = B.CreateVectorSplat(
170 "scalar.splat");
171 }
172
174 ->getElementType()
175 ->isFloatingPointTy()
178 }
179
180
181
183 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
184 if (LHS->getType()->isVectorTy() && ->getType()->isVectorTy()) {
186 "LHS Assumed to be fixed width");
187 RHS = B.CreateVectorSplat(
189 "scalar.splat");
190 } else if (->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
192 "RHS Assumed to be fixed width");
193 LHS = B.CreateVectorSplat(
195 "scalar.splat");
196 }
197
199 ->getElementType()
200 ->isFloatingPointTy()
203 }
204
205
206
208 std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
209 if (LHS->getType()->getScalarType()->isFloatingPointTy())
210 return B.CreateFMul(LHS, RHS);
211 return B.CreateMul(LHS, RHS);
212 }
213
214
215
217 assert(LHS->getType()->isVectorTy() && ->getType()->isVectorTy());
219 "LHS Assumed to be fixed width");
222 RHS, "scalar.splat");
224 ->getElementType()
225 ->isFloatingPointTy()
227 : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
228 }
229
230
232 Twine const &Name = "") {
235 auto *Cmp = B.CreateICmpULT(Idx, NumElts);
238 else
239 B.CreateAssumption(Cmp);
240 }
241
242
243
245 Twine const &Name = "") {
249 RowIdx = B.CreateZExt(RowIdx, IntTy);
250 ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
251 Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
252 return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
253 }
254};
255
256}
257
258#endif
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file contains the declarations for the subclasses of Constant, which represent the different fla...
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
Functions, function parameters, and return types can have attributes to indicate how they should be t...
static LLVM_ABI Attribute getWithAlignment(LLVMContext &Context, Align Alignment)
Return a uniquified Attribute object that has the specific alignment set.
This class represents a function call, abstracting a target machine's calling convention.
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Common base class shared among various IRBuilders.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Value * CreateScalarMultiply(Value *LHS, Value *RHS)
Multiply matrix LHS with scalar RHS or scalar LHS with matrix RHS.
Definition MatrixBuilder.h:207
Value * CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned)
Divide matrix LHS by scalar RHS.
Definition MatrixBuilder.h:216
Value * CreateSub(Value *LHS, Value *RHS)
Subtract matrixes LHS and RHS.
Definition MatrixBuilder.h:182
MatrixBuilder(IRBuilderBase &Builder)
Definition MatrixBuilder.h:58
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
Definition MatrixBuilder.h:110
CallInst * CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix store.
Definition MatrixBuilder.h:89
Value * CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, Value *ColumnIdx, unsigned NumRows)
Insert a single element NewVal into Matrix at indices (RowIdx, ColumnIdx).
Definition MatrixBuilder.h:146
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Definition MatrixBuilder.h:126
void CreateIndexAssumption(Value *Idx, unsigned NumElements, Twine const &Name="")
Create an assumption that Idx is less than NumElements.
Definition MatrixBuilder.h:231
Value * CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, Twine const &Name="")
Compute the index to access the element at (RowIdx, ColumnIdx) from a matrix with NumRows embedded in...
Definition MatrixBuilder.h:244
CallInst * CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a column major, strided matrix load.
Definition MatrixBuilder.h:66
Value * CreateAdd(Value *LHS, Value *RHS)
Add matrixes LHS and RHS.
Definition MatrixBuilder.h:157
A Module instance is used to store all the information related to an LLVM module.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI Function * getOrInsertDeclaration(Module *M, ID id, ArrayRef< Type * > Tys={})
Look up the Function declaration of the intrinsic id in the Module M.
This is an optimization pass for GlobalISel generic memory operations.
bool isa(const From &Val)
isa - Return true if the parameter to the template is an instance of one of the template type argu...
decltype(auto) cast(const From &Val)
cast - Return the argument parameter cast to the specified type.
This struct is a compact representation of a valid (non-zero power of two) alignment.