LLVM: include/llvm/IR/MatrixBuilder.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14#ifndef LLVM_IR_MATRIXBUILDER_H

15#define LLVM_IR_MATRIXBUILDER_H

16

26

27namespace llvm {

28

32

35 Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }

36

37 std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,

39 assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&

40 "One of the operands must be a matrix (embedded in a vector)");

41 if (LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {

43 "LHS Assumed to be fixed width");

44 RHS = B.CreateVectorSplat(

46 "scalar.splat");

47 } else if (LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {

49 "RHS Assumed to be fixed width");

50 LHS = B.CreateVectorSplat(

52 "scalar.splat");

53 }

55 }

56

57public:

59

60

61

62

63

64

65

67 Value *Stride, bool IsVolatile, unsigned Rows,

68 unsigned Columns, const Twine &Name = "") {

70

71 Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),

72 B.getInt32(Columns)};

73 Type *OverloadedTypes[] = {RetType, Stride->getType()};

74

76 getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);

77

81 Call->addParamAttr(0, AlignAttr);

83 }

84

85

86

87

88

90 Value *Stride, bool IsVolatile,

91 unsigned Rows, unsigned Columns,

92 const Twine &Name = "") {

94 Stride, B.getInt1(IsVolatile),

95 B.getInt32(Rows), B.getInt32(Columns)};

97

99 getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);

100

104 Call->addParamAttr(1, AlignAttr);

106 }

107

108

109

111 unsigned Columns, const Twine &Name = "") {

113 auto *ReturnType =

115

116 Type *OverloadedTypes[] = {ReturnType};

117 Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};

119 getModule(), Intrinsic::matrix_transpose, OverloadedTypes);

120

122 }

123

124

125

127 unsigned LHSColumns, unsigned RHSColumns,

128 const Twine &Name = "") {

131

132 auto *ReturnType =

134

135 Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),

136 B.getInt32(RHSColumns)};

137 Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};

138

140 getModule(), Intrinsic::matrix_multiply, OverloadedTypes);

142 }

143

144

145

147 Value *ColumnIdx, unsigned NumRows) {

148 return B.CreateInsertElement(

150 B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(

151 ColumnIdx->getType(), NumRows)),

152 RowIdx));

153 }

154

155

156

158 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());

159 if (LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {

161 "LHS Assumed to be fixed width");

162 RHS = B.CreateVectorSplat(

164 "scalar.splat");

165 } else if (LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {

167 "RHS Assumed to be fixed width");

168 LHS = B.CreateVectorSplat(

170 "scalar.splat");

171 }

172

174 ->getElementType()

175 ->isFloatingPointTy()

176 ? B.CreateFAdd(LHS, RHS)

177 : B.CreateAdd(LHS, RHS);

178 }

179

180

181

183 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());

184 if (LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {

186 "LHS Assumed to be fixed width");

187 RHS = B.CreateVectorSplat(

189 "scalar.splat");

190 } else if (LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {

192 "RHS Assumed to be fixed width");

193 LHS = B.CreateVectorSplat(

195 "scalar.splat");

196 }

197

199 ->getElementType()

200 ->isFloatingPointTy()

201 ? B.CreateFSub(LHS, RHS)

202 : B.CreateSub(LHS, RHS);

203 }

204

205

206

208 std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);

209 if (LHS->getType()->getScalarType()->isFloatingPointTy())

210 return B.CreateFMul(LHS, RHS);

211 return B.CreateMul(LHS, RHS);

212 }

213

214

215

217 assert(LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy());

219 "LHS Assumed to be fixed width");

222 RHS, "scalar.splat");

224 ->getElementType()

225 ->isFloatingPointTy()

226 ? B.CreateFDiv(LHS, RHS)

227 : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));

228 }

229

230

232 Twine const &Name = "") {

235 auto *Cmp = B.CreateICmpULT(Idx, NumElts);

238 else

239 B.CreateAssumption(Cmp);

240 }

241

242

243

245 Twine const &Name = "") {

249 RowIdx = B.CreateZExt(RowIdx, IntTy);

250 ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);

251 Value *NumRowsV = B.getIntN(MaxWidth, NumRows);

252 return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);

253 }

254};

255

256}

257

258#endif

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

This file contains the declarations for the subclasses of Constant, which represent the different fla...

const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]

Functions, function parameters, and return types can have attributes to indicate how they should be t...

static LLVM_ABI Attribute getWithAlignment(LLVMContext &Context, Align Alignment)

Return a uniquified Attribute object that has the specific alignment set.

This class represents a function call, abstracting a target machine's calling convention.

static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)

FunctionType * getFunctionType() const

Returns the FunctionType for me.

Common base class shared among various IRBuilders.

static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)

This static method is the primary way of constructing an IntegerType.

Value * CreateScalarMultiply(Value *LHS, Value *RHS)

Multiply matrix LHS with scalar RHS or scalar LHS with matrix RHS.

Definition MatrixBuilder.h:207

Value * CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned)

Divide matrix LHS by scalar RHS.

Definition MatrixBuilder.h:216

Value * CreateSub(Value *LHS, Value *RHS)

Subtract matrixes LHS and RHS.

Definition MatrixBuilder.h:182

MatrixBuilder(IRBuilderBase &Builder)

Definition MatrixBuilder.h:58

CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")

Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.

Definition MatrixBuilder.h:110

CallInst * CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")

Create a column major, strided matrix store.

Definition MatrixBuilder.h:89

Value * CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, Value *ColumnIdx, unsigned NumRows)

Insert a single element NewVal into Matrix at indices (RowIdx, ColumnIdx).

Definition MatrixBuilder.h:146

CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")

Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.

Definition MatrixBuilder.h:126

void CreateIndexAssumption(Value *Idx, unsigned NumElements, Twine const &Name="")

Create an assumption that Idx is less than NumElements.

Definition MatrixBuilder.h:231

Value * CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, Twine const &Name="")

Compute the index to access the element at (RowIdx, ColumnIdx) from a matrix with NumRows embedded in...

Definition MatrixBuilder.h:244

CallInst * CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name="")

Create a column major, strided matrix load.

Definition MatrixBuilder.h:66

Value * CreateAdd(Value *LHS, Value *RHS)

Add matrixes LHS and RHS.

Definition MatrixBuilder.h:157

A Module instance is used to store all the information related to an LLVM module.

Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...

The instances of the Type class are immutable: once they are created, they are never changed.

LLVMContext & getContext() const

Return the LLVMContext in which this type was uniqued.

LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY

If this is a vector type, return the getPrimitiveSizeInBits value for the element type.

LLVM Value Representation.

Type * getType() const

All values are typed, get the type of this value.

LLVM_ABI Function * getOrInsertDeclaration(Module *M, ID id, ArrayRef< Type * > Tys={})

Look up the Function declaration of the intrinsic id in the Module M.

This is an optimization pass for GlobalISel generic memory operations.

bool isa(const From &Val)

isa - Return true if the parameter to the template is an instance of one of the template type argu...

decltype(auto) cast(const From &Val)

cast - Return the argument parameter cast to the specified type.

This struct is a compact representation of a valid (non-zero power of two) alignment.