MLIR: lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

17 #include "llvm/Support/FormatVariadic.h"

18

19 using namespace mlir;

20 using namespace sparse_tensor;

21

22

23

24

25

26

30 for (auto type : types) {

31

34 continue;

35 }

36 hasAnnotation = true;

37

38

47 auto rtp = cast(t);

48 if (!directOut) {

50 if (extraTypes)

51 extraTypes->push_back(rtp);

52 }

54 }

55 return true;

56 });

57 }

58 }

59

60

64 bool directOut) {

65 unsigned idx = 0;

66 for (auto type : types) {

67

69 toVals.push_back(fromVals[idx++]);

70 continue;

71 }

72

73 auto rtp = cast(type);

78 if (!isIn)

79 inputs.push_back(fromVals[idx++]);

80

81

88 if (isIn) {

89 inputs.push_back(fromVals[idx++]);

90 } else if (directOut) {

93 mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],

94 lv);

96 mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],

97 lv);

98 else

99 mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);

100 toVals.push_back(mem);

101 } else {

102 ShapedType rtp = cast(t);

104 inputs.push_back(extraVals[extra++]);

105 retTypes.push_back(rtp);

107 }

108 }

109 return true;

110 });

111

112 if (isIn) {

113

114 auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);

115 toVals.push_back(a.getResult());

116 } else if (!directOut) {

117

118

119 unsigned len = retTypes.size();

120 retTypes.append(cntTypes);

121 auto d =

122 builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);

123 for (unsigned i = 0; i < len; i++)

124 toVals.push_back(d.getResult(i));

125 }

126 }

127 }

128

129

130

131

132

133 namespace {

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165 struct SparseFuncAssembler : public OpRewritePatternfunc::FuncOp {

167

168 SparseFuncAssembler(MLIRContext *context, bool dO)

170

171 LogicalResult matchAndRewrite(func::FuncOp funcOp,

173

174 if (funcOp.isPrivate())

175 return failure();

176

177

181 bool hasAnnotation = false;

182 convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr,

183 false);

184 convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes,

185 directOut);

186

187

188 if (!hasAnnotation)

189 return failure();

190

191

192 auto orgName = funcOp.getName();

193 std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();

194 funcOp.setName(wrapper);

195 funcOp.setPrivate();

196

197

198 Location loc = funcOp.getLoc();

199 ModuleOp modOp = funcOp->getParentOfType();

200 MLIRContext *context = modOp.getContext();

201 OpBuilder moduleBuilder(modOp.getBodyRegion());

202 unsigned extra = inputTypes.size();

203 inputTypes.append(extraTypes);

204 auto func = moduleBuilder.createfunc::FuncOp(

205 loc, orgName, FunctionType::get(context, inputTypes, outputTypes));

206 func.setPublic();

207

208

210 Block *body = func.addEntryBlock();

212

213

216 ValueRange(), inputs, 0, true, directOut);

217

218

219

221 auto call = rewriter.createfunc::CallOp(loc, funcOp.getResultTypes(), org,

222 inputs);

223

224

226 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),

227 body->getArguments(), outputs, extra, false, directOut);

228 rewriter.createfunc::ReturnOp(loc, outputs);

229

230

231 if (funcOp->getAttrOfType(

232 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {

233 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),

235 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());

236 }

237 return success();

238 }

239

240 private:

241 const bool directOut;

242 };

243

244 }

245

246

247

248

249

251 bool directOut) {

252 patterns.add(patterns.getContext(), directOut);

253 }

union mlir::linalg::@1203::ArityGroupAndKind::Kind kind

static void convTypes(bool &hasAnnotation, TypeRange types, SmallVectorImpl< Type > &convTypes, SmallVectorImpl< Type > *extraTypes, bool directOut)

static void convVals(OpBuilder &builder, Location loc, TypeRange types, ValueRange fromVals, ValueRange extraVals, SmallVectorImpl< Value > &toVals, unsigned extra, bool isIn, bool directOut)

Block represents an ordered list of Operations.

BlockArgListType getArguments()

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setAttr(StringAttr name, Attribute value)

If the an attribute exists with the specified name, change it to the new value.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

A wrapper around RankedTensorType, which has three goals:

void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)

unsigned FieldIndex

The type of field indices.

uint64_t Level

The type of level identifiers and level-ranks.

SparseTensorEncodingAttr getSparseTensorEncoding(Type type)

Convenience method to get a sparse encoding attribute from a type.

SparseTensorFieldKind

===-------------------------------------------------------------------—===// The sparse tensor storag...

Include the generated interface declarations.

const FrozenRewritePatternSet & patterns

void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...

This enum defines all the sparse representations supportable by the SparseTensor dialect.