MLIR: lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
23
24 namespace mlir {
25 #define GEN_PASS_DEF_SPARSEASSEMBLER
26 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
28 #define GEN_PASS_DEF_SPARSIFICATIONPASS
29 #define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
30 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
31 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
32 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
33 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
34 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
35 #define GEN_PASS_DEF_SPARSEVECTORIZATION
36 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
37 #define GEN_PASS_DEF_STAGESPARSEOPERATIONS
38 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
39 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
40 }
41
42 using namespace mlir;
44
45 namespace {
46
47
48
49
50
51 struct SparseAssembler : public impl::SparseAssemblerBase {
52 SparseAssembler() = default;
53 SparseAssembler(const SparseAssembler &pass) = default;
54 SparseAssembler(bool dO) { directOut = dO; }
55
56 void runOnOperation() override {
61 }
62 };
63
64 struct SparseReinterpretMap
65 : public impl::SparseReinterpretMapBase {
66 SparseReinterpretMap() = default;
67 SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
68 SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
70 }
71
72 void runOnOperation() override {
77 }
78 };
79
80 struct PreSparsificationRewritePass
81 : public impl::PreSparsificationRewriteBase {
82 PreSparsificationRewritePass() = default;
83 PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
84 default;
85
86 void runOnOperation() override {
91 }
92 };
93
94 struct SparsificationPass
95 : public impl::SparsificationPassBase {
96 SparsificationPass() = default;
97 SparsificationPass(const SparsificationPass &pass) = default;
99 parallelization = options.parallelizationStrategy;
100 sparseEmitStrategy = options.sparseEmitStrategy;
101 enableRuntimeLibrary = options.enableRuntimeLibrary;
102 }
103
104 void runOnOperation() override {
106
108 enableRuntimeLibrary);
109
112 scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
114 }
115 };
116
117 struct StageSparseOperationsPass
118 : public impl::StageSparseOperationsBase {
119 StageSparseOperationsPass() = default;
120 StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
121 void runOnOperation() override {
126 }
127 };
128
129 struct LowerSparseOpsToForeachPass
130 : public impl::LowerSparseOpsToForeachBase {
131 LowerSparseOpsToForeachPass() = default;
132 LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
133 default;
134 LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
135 enableRuntimeLibrary = enableRT;
136 enableConvert = convert;
137 }
138
139 void runOnOperation() override {
143 enableConvert);
145 }
146 };
147
148 struct LowerForeachToSCFPass
149 : public impl::LowerForeachToSCFBase {
150 LowerForeachToSCFPass() = default;
151 LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
152
153 void runOnOperation() override {
158 }
159 };
160
161 struct LowerSparseIterationToSCFPass
162 : public impl::LowerSparseIterationToSCFBase<
163 LowerSparseIterationToSCFPass> {
164 LowerSparseIterationToSCFPass() = default;
165 LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
166 default;
167
168 void runOnOperation() override {
173
174
175 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
176 memref::MemRefDialect, scf::SCFDialect,
177 sparse_tensor::SparseTensorDialect>();
178 target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
179 IterateOp>();
180 target.addLegalOp();
182
185 signalPassFailure();
186 }
187 };
188
189 struct SparseTensorConversionPass
190 : public impl::SparseTensorConversionPassBase {
191 SparseTensorConversionPass() = default;
192 SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
193
194 void runOnOperation() override {
199
200 target.addIllegalDialect();
201
202
203
204 target.addDynamicallyLegalOpfunc::FuncOp([&](func::FuncOp op) {
206 });
207 target.addDynamicallyLegalOpfunc::CallOp([&](func::CallOp op) {
209 });
210 target.addDynamicallyLegalOpfunc::ReturnOp([&](func::ReturnOp op) {
211 return converter.isLegal(op.getOperandTypes());
212 });
213 target.addDynamicallyLegalOptensor::DimOp([&](tensor::DimOp op) {
214 return converter.isLegal(op.getOperandTypes());
215 });
216 target.addDynamicallyLegalOptensor::CastOp([&](tensor::CastOp op) {
217 return converter.isLegal(op.getSource().getType()) &&
218 converter.isLegal(op.getDest().getType());
219 });
220 target.addDynamicallyLegalOptensor::ExpandShapeOp(
221 [&](tensor::ExpandShapeOp op) {
222 return converter.isLegal(op.getSrc().getType()) &&
223 converter.isLegal(op.getResult().getType());
224 });
225 target.addDynamicallyLegalOptensor::CollapseShapeOp(
226 [&](tensor::CollapseShapeOp op) {
227 return converter.isLegal(op.getSrc().getType()) &&
228 converter.isLegal(op.getResult().getType());
229 });
230 target.addDynamicallyLegalOpbufferization::AllocTensorOp(
231 [&](bufferization::AllocTensorOp op) {
232 return converter.isLegal(op.getType());
233 });
234 target.addDynamicallyLegalOpbufferization::DeallocTensorOp(
235 [&](bufferization::DeallocTensorOp op) {
236 return converter.isLegal(op.getTensor().getType());
237 });
238
239
240 target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
241 linalg::YieldOp, tensor::ExtractOp,
242 tensor::FromElementsOp>();
243 target.addLegalDialect<
244 arith::ArithDialect, bufferization::BufferizationDialect,
245 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
246
247
248 populateFunctionOpInterfaceTypeConversionPatternfunc::FuncOp(patterns,
249 converter);
252 target);
256 signalPassFailure();
257 }
258 };
259
260 struct SparseTensorCodegenPass
261 : public impl::SparseTensorCodegenBase {
262 SparseTensorCodegenPass() = default;
263 SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
264 SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
265 createSparseDeallocs = createDeallocs;
266 enableBufferInitialization = enableInit;
267 }
268
269 void runOnOperation() override {
274
275 target.addIllegalDialect();
276 target.addLegalOp();
277 target.addLegalOp();
278
279 target.addLegalOp();
280 target.addLegalOp();
281 target.addLegalOp();
282
283 target.addLegalOptensor::FromElementsOp();
284
285
286
287
288 target.addDynamicallyLegalOpfunc::FuncOp([&](func::FuncOp op) {
290 });
291 target.addDynamicallyLegalOpfunc::CallOp([&](func::CallOp op) {
293 });
294 target.addDynamicallyLegalOpfunc::ReturnOp([&](func::ReturnOp op) {
295 return converter.isLegal(op.getOperandTypes());
296 });
297 target.addDynamicallyLegalOpbufferization::AllocTensorOp(
298 [&](bufferization::AllocTensorOp op) {
299 return converter.isLegal(op.getType());
300 });
301 target.addDynamicallyLegalOpbufferization::DeallocTensorOp(
302 [&](bufferization::DeallocTensorOp op) {
303 return converter.isLegal(op.getTensor().getType());
304 });
305
306
307 target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
308 target.addLegalDialect<
309 arith::ArithDialect, bufferization::BufferizationDialect,
310 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
311 target.addLegalOp();
312
313 populateFunctionOpInterfaceTypeConversionPatternfunc::FuncOp(patterns,
314 converter);
316 target);
318 converter, patterns, createSparseDeallocs, enableBufferInitialization);
321 signalPassFailure();
322 }
323 };
324
325 struct SparseBufferRewritePass
326 : public impl::SparseBufferRewriteBase {
327 SparseBufferRewritePass() = default;
328 SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
329 SparseBufferRewritePass(bool enableInit) {
330 enableBufferInitialization = enableInit;
331 }
332
333 void runOnOperation() override {
338 }
339 };
340
341 struct SparseVectorizationPass
342 : public impl::SparseVectorizationBase {
343 SparseVectorizationPass() = default;
344 SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
345 SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
346 vectorLength = vl;
347 enableVLAVectorization = vla;
348 enableSIMDIndex32 = sidx32;
349 }
350
351 void runOnOperation() override {
352 if (vectorLength == 0)
353 return signalPassFailure();
357 patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
360 }
361 };
362
363 struct SparseGPUCodegenPass
364 : public impl::SparseGPUCodegenBase {
365 SparseGPUCodegenPass() = default;
366 SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
367 SparseGPUCodegenPass(unsigned nT, bool enableRT) {
368 numThreads = nT;
369 enableRuntimeLibrary = enableRT;
370 }
371
372 void runOnOperation() override {
375 if (numThreads == 0)
377 else
380 }
381 };
382
383 struct StorageSpecifierToLLVMPass
384 : public impl::StorageSpecifierToLLVMBase {
385 StorageSpecifierToLLVMPass() = default;
386
387 void runOnOperation() override {
392
393
394 target.addIllegalDialect();
395 target.addDynamicallyLegalOpfunc::FuncOp([&](func::FuncOp op) {
397 });
398 target.addDynamicallyLegalOpfunc::CallOp([&](func::CallOp op) {
400 });
401 target.addDynamicallyLegalOpfunc::ReturnOp([&](func::ReturnOp op) {
402 return converter.isLegal(op.getOperandTypes());
403 });
404 target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
405
406 populateFunctionOpInterfaceTypeConversionPatternfunc::FuncOp(patterns,
407 converter);
412 target);
416 signalPassFailure();
417 }
418 };
419
420 }
421
422
423
424
425
427 return std::make_unique();
428 }
429
431 return std::make_unique();
432 }
433
434 std::unique_ptr
436 SparseReinterpretMapOptions options;
438 return std::make_unique(options);
439 }
440
442 return std::make_unique();
443 }
444
446 return std::make_unique();
447 }
448
449 std::unique_ptr
451 return std::make_unique(options);
452 }
453
455 return std::make_unique();
456 }
457
459 return std::make_unique();
460 }
461
462 std::unique_ptr
464 return std::make_unique(enableRT, enableConvert);
465 }
466
468 return std::make_unique();
469 }
470
472 return std::make_unique();
473 }
474
476 return std::make_unique();
477 }
478
480 return std::make_unique();
481 }
482
483 std::unique_ptr
485 bool enableBufferInitialization) {
486 return std::make_unique(createSparseDeallocs,
487 enableBufferInitialization);
488 }
489
491 return std::make_unique();
492 }
493
494 std::unique_ptr
496 return std::make_unique(enableBufferInitialization);
497 }
498
500 return std::make_unique();
501 }
502
503 std::unique_ptr
505 bool enableVLAVectorization,
506 bool enableSIMDIndex32) {
507 return std::make_unique(
508 vectorLength, enableVLAVectorization, enableSIMDIndex32);
509 }
510
512 return std::make_unique();
513 }
514
516 bool enableRT) {
517 return std::make_unique(numThreads, enableRT);
518 }
519
521 return std::make_unique();
522 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class describes a specific conversion target.
Sparse tensor type converter into an actual buffer.
Sparse tensor type converter into an opaque pointer.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseVectorizationPass()
std::unique_ptr< Pass > createSparseAssembler()
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createLowerSparseOpsToForeachPass()
std::unique_ptr< Pass > createSparseTensorCodegenPass()
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::unique_ptr< Pass > createSparseGPUCodegenPass()
std::unique_ptr< Pass > createSparseReinterpretMapPass()
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope)
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
std::unique_ptr< Pass > createSparseTensorConversionPass()
std::unique_ptr< Pass > createSparseBufferRewritePass()
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
ReinterpretMapScope
Defines a scope for reinterpret map pass.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
const FrozenRewritePatternSet & patterns
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
std::unique_ptr< Pass > createStorageSpecifierToLLVMPass()
std::unique_ptr< Pass > createPreSparsificationRewritePass()
std::unique_ptr< Pass > createLowerForeachToSCFPass()
void populateSparseAssembler(RewritePatternSet &patterns, bool directOut)
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
std::unique_ptr< Pass > createLowerSparseIterationToSCFPass()
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
std::unique_ptr< Pass > createStageSparseOperationsPass()
std::unique_ptr< Pass > createSparsificationPass()
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Type converter for iter_space and iterator.
Options for the Sparsification pass.