MLIR: lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

23

24 namespace mlir {

25 #define GEN_PASS_DEF_SPARSEASSEMBLER

26 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP

27 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE

28 #define GEN_PASS_DEF_SPARSIFICATIONPASS

29 #define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF

30 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH

31 #define GEN_PASS_DEF_LOWERFOREACHTOSCF

32 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS

33 #define GEN_PASS_DEF_SPARSETENSORCODEGEN

34 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE

35 #define GEN_PASS_DEF_SPARSEVECTORIZATION

36 #define GEN_PASS_DEF_SPARSEGPUCODEGEN

37 #define GEN_PASS_DEF_STAGESPARSEOPERATIONS

38 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM

39 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"

40 }

41

42 using namespace mlir;

44

45 namespace {

46

47

48

49

50

51 struct SparseAssembler : public impl::SparseAssemblerBase {

52 SparseAssembler() = default;

53 SparseAssembler(const SparseAssembler &pass) = default;

54 SparseAssembler(bool dO) { directOut = dO; }

55

56 void runOnOperation() override {

61 }

62 };

63

64 struct SparseReinterpretMap

65 : public impl::SparseReinterpretMapBase {

66 SparseReinterpretMap() = default;

67 SparseReinterpretMap(const SparseReinterpretMap &pass) = default;

68 SparseReinterpretMap(const SparseReinterpretMapOptions &options) {

70 }

71

72 void runOnOperation() override {

77 }

78 };

79

80 struct PreSparsificationRewritePass

81 : public impl::PreSparsificationRewriteBase {

82 PreSparsificationRewritePass() = default;

83 PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =

84 default;

85

86 void runOnOperation() override {

91 }

92 };

93

94 struct SparsificationPass

95 : public impl::SparsificationPassBase {

96 SparsificationPass() = default;

97 SparsificationPass(const SparsificationPass &pass) = default;

99 parallelization = options.parallelizationStrategy;

100 sparseEmitStrategy = options.sparseEmitStrategy;

101 enableRuntimeLibrary = options.enableRuntimeLibrary;

102 }

103

104 void runOnOperation() override {

106

108 enableRuntimeLibrary);

109

112 scf::ForOp::getCanonicalizationPatterns(patterns, ctx);

114 }

115 };

116

117 struct StageSparseOperationsPass

118 : public impl::StageSparseOperationsBase {

119 StageSparseOperationsPass() = default;

120 StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;

121 void runOnOperation() override {

126 }

127 };

128

129 struct LowerSparseOpsToForeachPass

130 : public impl::LowerSparseOpsToForeachBase {

131 LowerSparseOpsToForeachPass() = default;

132 LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =

133 default;

134 LowerSparseOpsToForeachPass(bool enableRT, bool convert) {

135 enableRuntimeLibrary = enableRT;

136 enableConvert = convert;

137 }

138

139 void runOnOperation() override {

143 enableConvert);

145 }

146 };

147

148 struct LowerForeachToSCFPass

149 : public impl::LowerForeachToSCFBase {

150 LowerForeachToSCFPass() = default;

151 LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;

152

153 void runOnOperation() override {

158 }

159 };

160

161 struct LowerSparseIterationToSCFPass

162 : public impl::LowerSparseIterationToSCFBase<

163 LowerSparseIterationToSCFPass> {

164 LowerSparseIterationToSCFPass() = default;

165 LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =

166 default;

167

168 void runOnOperation() override {

173

174

175 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,

176 memref::MemRefDialect, scf::SCFDialect,

177 sparse_tensor::SparseTensorDialect>();

178 target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,

179 IterateOp>();

180 target.addLegalOp();

182

185 signalPassFailure();

186 }

187 };

188

189 struct SparseTensorConversionPass

190 : public impl::SparseTensorConversionPassBase {

191 SparseTensorConversionPass() = default;

192 SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;

193

194 void runOnOperation() override {

199

200 target.addIllegalDialect();

201

202

203

204 target.addDynamicallyLegalOpfunc::FuncOp([&](func::FuncOp op) {

206 });

207 target.addDynamicallyLegalOpfunc::CallOp([&](func::CallOp op) {

209 });

210 target.addDynamicallyLegalOpfunc::ReturnOp([&](func::ReturnOp op) {

211 return converter.isLegal(op.getOperandTypes());

212 });

213 target.addDynamicallyLegalOptensor::DimOp([&](tensor::DimOp op) {

214 return converter.isLegal(op.getOperandTypes());

215 });

216 target.addDynamicallyLegalOptensor::CastOp([&](tensor::CastOp op) {

217 return converter.isLegal(op.getSource().getType()) &&

218 converter.isLegal(op.getDest().getType());

219 });

220 target.addDynamicallyLegalOptensor::ExpandShapeOp(

221 [&](tensor::ExpandShapeOp op) {

222 return converter.isLegal(op.getSrc().getType()) &&

223 converter.isLegal(op.getResult().getType());

224 });

225 target.addDynamicallyLegalOptensor::CollapseShapeOp(

226 [&](tensor::CollapseShapeOp op) {

227 return converter.isLegal(op.getSrc().getType()) &&

228 converter.isLegal(op.getResult().getType());

229 });

230 target.addDynamicallyLegalOpbufferization::AllocTensorOp(

231 [&](bufferization::AllocTensorOp op) {

232 return converter.isLegal(op.getType());

233 });

234 target.addDynamicallyLegalOpbufferization::DeallocTensorOp(

235 [&](bufferization::DeallocTensorOp op) {

236 return converter.isLegal(op.getTensor().getType());

237 });

238

239

240 target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,

241 linalg::YieldOp, tensor::ExtractOp,

242 tensor::FromElementsOp>();

243 target.addLegalDialect<

244 arith::ArithDialect, bufferization::BufferizationDialect,

245 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();

246

247

248 populateFunctionOpInterfaceTypeConversionPatternfunc::FuncOp(patterns,

249 converter);

252 target);

256 signalPassFailure();

257 }

258 };

259

260 struct SparseTensorCodegenPass

261 : public impl::SparseTensorCodegenBase {

262 SparseTensorCodegenPass() = default;

263 SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;

264 SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {

265 createSparseDeallocs = createDeallocs;

266 enableBufferInitialization = enableInit;

267 }

268

269 void runOnOperation() override {

274

275 target.addIllegalDialect();

276 target.addLegalOp();

277 target.addLegalOp();

278

279 target.addLegalOp();

280 target.addLegalOp();

281 target.addLegalOp();

282

283 target.addLegalOptensor::FromElementsOp();

284

285

286

287

288 target.addDynamicallyLegalOpfunc::FuncOp([&](func::FuncOp op) {

290 });

291 target.addDynamicallyLegalOpfunc::CallOp([&](func::CallOp op) {

293 });

294 target.addDynamicallyLegalOpfunc::ReturnOp([&](func::ReturnOp op) {

295 return converter.isLegal(op.getOperandTypes());

296 });

297 target.addDynamicallyLegalOpbufferization::AllocTensorOp(

298 [&](bufferization::AllocTensorOp op) {

299 return converter.isLegal(op.getType());

300 });

301 target.addDynamicallyLegalOpbufferization::DeallocTensorOp(

302 [&](bufferization::DeallocTensorOp op) {

303 return converter.isLegal(op.getTensor().getType());

304 });

305

306

307 target.addLegalOp<linalg::FillOp, linalg::YieldOp>();

308 target.addLegalDialect<

309 arith::ArithDialect, bufferization::BufferizationDialect,

310 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();

311 target.addLegalOp();

312

313 populateFunctionOpInterfaceTypeConversionPatternfunc::FuncOp(patterns,

314 converter);

316 target);

318 converter, patterns, createSparseDeallocs, enableBufferInitialization);

321 signalPassFailure();

322 }

323 };

324

325 struct SparseBufferRewritePass

326 : public impl::SparseBufferRewriteBase {

327 SparseBufferRewritePass() = default;

328 SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;

329 SparseBufferRewritePass(bool enableInit) {

330 enableBufferInitialization = enableInit;

331 }

332

333 void runOnOperation() override {

338 }

339 };

340

341 struct SparseVectorizationPass

342 : public impl::SparseVectorizationBase {

343 SparseVectorizationPass() = default;

344 SparseVectorizationPass(const SparseVectorizationPass &pass) = default;

345 SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {

346 vectorLength = vl;

347 enableVLAVectorization = vla;

348 enableSIMDIndex32 = sidx32;

349 }

350

351 void runOnOperation() override {

352 if (vectorLength == 0)

353 return signalPassFailure();

357 patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);

360 }

361 };

362

363 struct SparseGPUCodegenPass

364 : public impl::SparseGPUCodegenBase {

365 SparseGPUCodegenPass() = default;

366 SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;

367 SparseGPUCodegenPass(unsigned nT, bool enableRT) {

368 numThreads = nT;

369 enableRuntimeLibrary = enableRT;

370 }

371

372 void runOnOperation() override {

375 if (numThreads == 0)

377 else

380 }

381 };

382

383 struct StorageSpecifierToLLVMPass

384 : public impl::StorageSpecifierToLLVMBase {

385 StorageSpecifierToLLVMPass() = default;

386

387 void runOnOperation() override {

392

393

394 target.addIllegalDialect();

395 target.addDynamicallyLegalOpfunc::FuncOp([&](func::FuncOp op) {

397 });

398 target.addDynamicallyLegalOpfunc::CallOp([&](func::CallOp op) {

400 });

401 target.addDynamicallyLegalOpfunc::ReturnOp([&](func::ReturnOp op) {

402 return converter.isLegal(op.getOperandTypes());

403 });

404 target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();

405

406 populateFunctionOpInterfaceTypeConversionPatternfunc::FuncOp(patterns,

407 converter);

412 target);

416 signalPassFailure();

417 }

418 };

419

420 }

421

422

423

424

425

427 return std::make_unique();

428 }

429

431 return std::make_unique();

432 }

433

434 std::unique_ptr

436 SparseReinterpretMapOptions options;

438 return std::make_unique(options);

439 }

440

442 return std::make_unique();

443 }

444

446 return std::make_unique();

447 }

448

449 std::unique_ptr

451 return std::make_unique(options);

452 }

453

455 return std::make_unique();

456 }

457

459 return std::make_unique();

460 }

461

462 std::unique_ptr

464 return std::make_unique(enableRT, enableConvert);

465 }

466

468 return std::make_unique();

469 }

470

472 return std::make_unique();

473 }

474

476 return std::make_unique();

477 }

478

480 return std::make_unique();

481 }

482

483 std::unique_ptr

485 bool enableBufferInitialization) {

486 return std::make_unique(createSparseDeallocs,

487 enableBufferInitialization);

488 }

489

491 return std::make_unique();

492 }

493

494 std::unique_ptr

496 return std::make_unique(enableBufferInitialization);

497 }

498

500 return std::make_unique();

501 }

502

503 std::unique_ptr

505 bool enableVLAVectorization,

506 bool enableSIMDIndex32) {

507 return std::make_unique(

508 vectorLength, enableVLAVectorization, enableSIMDIndex32);

509 }

510

512 return std::make_unique();

513 }

514

516 bool enableRT) {

517 return std::make_unique(numThreads, enableRT);

518 }

519

521 return std::make_unique();

522 }

static MLIRContext * getContext(OpFoldResult val)

static llvm::ManagedStatic< PassManagerOptions > options

This class describes a specific conversion target.

Sparse tensor type converter into an actual buffer.

Sparse tensor type converter into an opaque pointer.

bool isLegal(Type type) const

Return true if the given type is legal for this type converter, i.e.

bool isSignatureLegal(FunctionType ty) const

Return true if the inputs and outputs of the given function type are legal.

void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)

Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...

void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Collect a set of vector-to-vector canonicalization patterns.

Include the generated interface declarations.

std::unique_ptr< Pass > createSparseVectorizationPass()

std::unique_ptr< Pass > createSparseAssembler()

void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)

std::unique_ptr< Pass > createLowerSparseOpsToForeachPass()

std::unique_ptr< Pass > createSparseTensorCodegenPass()

void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)

Sets up sparse tensor codegen rules.

std::unique_ptr< Pass > createSparseGPUCodegenPass()

std::unique_ptr< Pass > createSparseReinterpretMapPass()

void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope)

LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)

Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...

void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)

std::unique_ptr< Pass > createSparseTensorConversionPass()

std::unique_ptr< Pass > createSparseBufferRewritePass()

void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)

Sets up sparse tensor conversion rules.

void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)

void populatePreSparsificationRewriting(RewritePatternSet &patterns)

void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)

Populates the given patterns list with vectorization rules.

ReinterpretMapScope

Defines a scope for reinterpret map pass.

void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())

Sets up sparsification rewriting rules with the given options.

const FrozenRewritePatternSet & patterns

void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)

void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)

std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()

std::unique_ptr< Pass > createPreSparsificationRewritePass()

std::unique_ptr< Pass > createLowerForeachToSCFPass()

void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)

void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)

Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...

void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)

Sets up StageSparseOperation rewriting rules.

void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)

Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...

std::unique_ptr< Pass > createLowerSparseIterationToSCFPass()

void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)

void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)

Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...

std::unique_ptr< Pass > createStageSparseOperationsPass()

std::unique_ptr< Pass > createSparsificationPass()

void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)

LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())

Below we define several entry points for operation conversion.

Type converter for iter_space and iterator.

Options for the Sparsification pass.