MLIR: lib/Pass/PassCrashRecovery.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/CrashRecoveryContext.h"
22 #include "llvm/Support/ManagedStatic.h"
23 #include "llvm/Support/Mutex.h"
24 #include "llvm/Support/Signals.h"
25 #include "llvm/Support/Threading.h"
26 #include "llvm/Support/ToolOutputFile.h"
27
28 using namespace mlir;
30
31
32
33
34
35 namespace mlir {
36 namespace detail {
37
38
39
43 bool verifyPasses);
45
46
47 void generate(std::string &description);
48
49
50
51 void disable();
52
53
54 void enable();
55
56 private:
57
58 static void crashHandler(void *);
59
60
61 static void registerSignalHandler();
62
63
64 std::string pipelineElements;
65
66
68
69
70
72
73
74 bool disableThreads;
75 bool verifyPasses;
76
77
78
79
80
81 static llvm::ManagedStatic<llvm::sys::SmartMutex> reproducerMutex;
82 static llvm::ManagedStatic<
83 llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
84 reproducerSet;
85 };
86 }
87 }
88
89 llvm::ManagedStatic<llvm::sys::SmartMutex>
90 RecoveryReproducerContext::reproducerMutex;
91 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
92 RecoveryReproducerContext::reproducerSet;
93
95 std::string passPipelineStr, Operation *op,
97 : pipelineElements(std::move(passPipelineStr)),
98 preCrashOperation(op->clone()), streamFactory(streamFactory),
99 disableThreads(!op->getContext()->isMultithreadingEnabled()),
100 verifyPasses(verifyPasses) {
102 }
103
105
106 preCrashOperation->erase();
108 }
109
112 const std::string &pipelineElements,
113 bool disableThreads, bool verifyPasses) {
114 llvm::raw_string_ostream descOS(description);
115
116
117 std::string error;
118 std::unique_ptr stream = factory(error);
119 if (!stream) {
120 descOS << "failed to create output stream: " << error;
121 return;
122 }
123 descOS << "reproducer generated at `" << stream->description() << "`";
124
125 std::string pipeline =
128 state.attachResourcePrinter(
130 builder.buildString("pipeline", pipeline);
131 builder.buildBool("disable_threading", disableThreads);
132 builder.buildBool("verify_each", verifyPasses);
133 });
134
135
136 op->print(stream->os(), state);
137 }
138
140 appendReproducer(description, preCrashOperation, streamFactory,
141 pipelineElements, disableThreads, verifyPasses);
142 }
143
145 llvm::sys::SmartScopedLock lock(*reproducerMutex);
146 reproducerSet->remove(this);
147 if (reproducerSet->empty())
148 llvm::CrashRecoveryContext::Disable();
149 }
150
152 llvm::sys::SmartScopedLock lock(*reproducerMutex);
153 if (reproducerSet->empty())
154 llvm::CrashRecoveryContext::Enable();
155 registerSignalHandler();
156 reproducerSet->insert(this);
157 }
158
159 void RecoveryReproducerContext::crashHandler(void *) {
160
161
162
164 std::string description;
165 context->generate(description);
166
167
168 emitError(context->preCrashOperation->getLoc())
169 << "A signal was caught while processing the MLIR module:"
170 << description << "; marking pass as failed";
171 }
172 }
173
174 void RecoveryReproducerContext::registerSignalHandler() {
175
176 static bool registered =
177 (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
178 (void)registered;
179 }
180
181
182
183
184
188
189
191
192
193
195
196
198
199
200
202
203
205 };
206
209 : impl(std::make_unique<Impl>(streamFactory, localReproducer)) {}
211
214 bool pmFlagVerifyPasses) {
215 assert((->localReproducer ||
217 "expected multi-threading to be disabled when generating a local "
218 "reproducer");
219
220 llvm::CrashRecoveryContext::Enable();
221 impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
222
223
224
225 if (->localReproducer)
227 }
228
229 static void
231 std::pair<Pass *, Operation *> passOpPair) {
232 os << "`" << passOpPair.first->getName() << "` on "
233 << "'" << passOpPair.second->getName() << "' operation";
234 if (SymbolOpInterface symbol = dyn_cast(passOpPair.second))
235 os << ": @" << symbol.getName();
236 }
237
239 LogicalResult executionResult) {
240
241 if (impl->activeContexts.empty())
242 return;
243
244
245 if (succeeded(executionResult))
246 return impl->activeContexts.clear();
247
249 << "Failures have been detected while "
250 "processing an MLIR pass pipeline";
251
252
253
254 if (->localReproducer) {
255 assert(impl->activeContexts.size() == 1 && "expected one active context");
256
257
258 std::string description;
259 impl->activeContexts.front()->generate(description);
260
261
262 Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing [";
263 llvm::interleaveComma(impl->runningPasses, note,
264 [&](const std::pair<Pass *, Operation *> &value) {
265 formatPassOpReproducerMessage(note, value);
266 });
267 note << "]: " << description;
268 impl->runningPasses.clear();
269 impl->activeContexts.clear();
270 return;
271 }
272
273
274
275
276 assert(impl->activeContexts.size() == impl->runningPasses.size() &&
277 "expected running passes to match active contexts");
278
279
281 std::string description;
282 reproducerContext.generate(description);
283
284
285 Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing ";
287 note << ": " << description;
288
289 impl->activeContexts.clear();
290 impl->runningPasses.clear();
291 }
292
295
296
297 impl->runningPasses.insert(std::make_pair(pass, op));
298 if (->localReproducer)
299 return;
300
301
302
303 if (->activeContexts.empty())
304 impl->activeContexts.back()->disable();
305
306
309 scopes.push_back(op->getName());
310 op = parentOp;
311 }
312
313
314
315 std::string passStr;
316 llvm::raw_string_ostream passOS(passStr);
317 for (OperationName scope : llvm::reverse(scopes))
318 passOS << scope << "(";
320 for (unsigned i = 0, e = scopes.size(); i < e; ++i)
321 passOS << ")";
322
323 impl->activeContexts.push_back(std::make_unique(
324 passStr, op, impl->streamFactory, impl->pmFlagVerifyPasses));
325 }
328 std::string passStr;
329 llvm::raw_string_ostream passOS(passStr);
330 llvm::interleaveComma(
332
333 impl->activeContexts.push_back(std::make_unique(
334 passStr, op, impl->streamFactory, impl->pmFlagVerifyPasses));
335 }
336
339
340 impl->runningPasses.remove(std::make_pair(pass, op));
341 if (impl->localReproducer) {
342 impl->activeContexts.pop_back();
343
344
345
346 if (->activeContexts.empty())
347 impl->activeContexts.back()->enable();
348 }
349 }
350
351
352
353
354
355 namespace {
359 ~CrashReproducerInstrumentation() override = default;
360
361 void runBeforePass(Pass *pass, Operation *op) override {
362 if (!isa(pass))
363 generator.prepareReproducerFor(pass, op);
364 }
365
366 void runAfterPass(Pass *pass, Operation *op) override {
367 if (!isa(pass))
368 generator.removeLastReproducerFor(pass, op);
369 }
370
371 void runAfterPassFailed(Pass *pass, Operation *op) override {
372
373 if (alreadyFailed)
374 return;
375
376 alreadyFailed = true;
377 generator.finalize(op, failure());
378 }
379
380 private:
381
383 bool alreadyFailed = false;
384 };
385 }
386
387
388
389
390
391 namespace {
392
393
395 FileReproducerStream(std::unique_ptrllvm::ToolOutputFile outputFile)
396 : outputFile(std::move(outputFile)) {}
397 ~FileReproducerStream() override { outputFile->keep(); }
398
399
400 StringRef description() override { return outputFile->getFilename(); }
401
402
403 raw_ostream &os() override { return outputFile->os(); }
404
405 private:
406
407 std::unique_ptrllvm::ToolOutputFile outputFile = nullptr;
408 };
409 }
410
411
412
413
414
415 LogicalResult PassManager::runWithCrashRecovery(Operation *op,
417 crashReproGenerator->initialize(getPasses(), op, verifyPasses);
418
419
420 LogicalResult passManagerResult = failure();
421 llvm::CrashRecoveryContext recoveryContext;
422 recoveryContext.RunSafelyOnThread(
423 [&] { passManagerResult = runPasses(op, am); });
424 crashReproGenerator->finalize(op, passManagerResult);
425 return passManagerResult;
426 }
427
430
431
432 std::string filename = outputFile.str();
433 return [filename](std::string &error) -> std::unique_ptr {
434 std::unique_ptrllvm::ToolOutputFile outputFile =
436 if (!outputFile) {
437 error = "Failed to create reproducer stream: " + error;
438 return nullptr;
439 }
440 return std::make_unique(std::move(outputFile));
441 };
442 }
443
445 raw_ostream &os, StringRef anchorName,
447 bool pretty = false);
448
450 StringRef anchorName,
452 Operation *op, StringRef outputFile, bool disableThreads,
453 bool verifyPasses) {
454
455 std::string description;
456 std::string pipelineStr;
457 llvm::raw_string_ostream passOS(pipelineStr);
460 pipelineStr, disableThreads, verifyPasses);
461 return description;
462 }
463
465 bool genLocalReproducer) {
467 genLocalReproducer);
468 }
469
472 assert(!crashReproGenerator &&
473 "crash reproducer has already been initialized");
474 if (genLocalReproducer && getContext()->isMultithreadingEnabled())
475 llvm::report_fatal_error(
476 "Local crash reproduction can't be setup on a "
477 "pass-manager without disabling multi-threading first.");
478
479 crashReproGenerator = std::make_unique(
480 factory, genLocalReproducer);
481 addInstrumentation(
482 std::make_unique(*crashReproGenerator));
483 }
484
485
486
487
488
491 if (entry.getKey() == "pipeline") {
492 FailureOrstd::string value = entry.parseAsString();
493 if (succeeded(value))
494 this->pipeline = std::move(*value);
495 return value;
496 }
497 if (entry.getKey() == "disable_threading") {
498 FailureOr value = entry.parseAsBool();
499 if (succeeded(value))
500 this->disableThreading = *value;
501 return value;
502 }
503 if (entry.getKey() == "verify_each") {
504 FailureOr value = entry.parseAsBool();
505 if (succeeded(value))
506 this->verifyEach = *value;
507 return value;
508 }
509 return entry.emitError() << "unknown 'mlir_reproducer' resource key '"
510 << entry.getKey() << "'";
511 };
512 config.attachResourceParser("mlir_reproducer", parseFn);
513 }
514
516 if (pipeline.has_value()) {
518 if (failed(reproPm))
519 return failure();
520 static_cast<OpPassManager &>(pm) = std::move(*reproPm);
521 }
522
523 if (disableThreading.has_value())
525
526 if (verifyEach.has_value())
528
529 return success();
530 }
static MLIRContext * getContext(OpFoldResult val)
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static void appendReproducer(std::string &description, Operation *op, const ReproducerStreamFactory &factory, const std::string &pipelineElements, bool disableThreads, bool verifyPasses)
static void formatPassOpReproducerMessage(Diagnostic &os, std::pair< Pass *, Operation * > passOpPair)
void printAsTextualPipeline(raw_ostream &os, StringRef anchorName, const llvm::iterator_range< OpPassManager::pass_iterator > &passes, bool pretty=false)
static ReproducerStreamFactory makeReproducerStreamFactory(StringRef outputFile)
This class represents an analysis manager for a particular operation instance.
This class represents a single parsed resource entry.
This class is used to build resource entries for use by the printer.
virtual void buildString(StringRef key, StringRef data)=0
Build a resource entry represented by the given human-readable string value.
virtual void buildBool(StringRef key, bool data)=0
Build a resource entry represented by the given bool.
This class provides management for the lifetime of the state used when printing the IR.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a diagnostic that is inflight and set to be reported.
void disableMultithreading(bool disable=true)
Set the flag specifying if multi-threading is disabled by the context.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a pass manager that runs passes on either a specific operation type,...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
void erase()
Remove this operation from its parent block and delete it.
This class represents a configuration for the MLIR assembly parser.
PassInstrumentation provides several entry points into the pass manager infrastructure.
The main pass manager and pipeline builder.
MLIRContext * getContext() const
Return an instance of the context.
void enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer=false)
Enable support for the pass manager to generate a reproducer on the event of a crash or a pass failur...
void enableVerifier(bool enabled=true)
Runs the verifier after each individual pass.
The abstract base pass class.
void printAsTextualPipeline(raw_ostream &os, bool pretty=false)
Prints out the pass in the textual representation of pipelines.
void initialize(iterator_range< PassManager::pass_iterator > passes, Operation *op, bool pmFlagVerifyPasses)
Initialize the generator in preparation for reproducer generation.
void removeLastReproducerFor(Pass *pass, Operation *op)
Remove the last recorded reproducer anchored at the given pass and operation.
void finalize(Operation *rootOp, LogicalResult executionResult)
Finalize the current run of the generator, generating any necessary reproducers if the provided execu...
void prepareReproducerFor(Pass *pass, Operation *op)
Prepare a new reproducer for the given pass, operating on op.
~PassCrashReproducerGenerator()
PassCrashReproducerGenerator(ReproducerStreamFactory &streamFactory, bool localReproducer)
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
std::unique_ptr< llvm::ToolOutputFile > openOutputFile(llvm::StringRef outputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for writing.
std::string makeReproducer(StringRef anchorName, const llvm::iterator_range< OpPassManager::pass_iterator > &passes, Operation *op, StringRef outputFile, bool disableThreads=false, bool verifyPasses=false)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::function< std::unique_ptr< ReproducerStream >(std::string &error)> ReproducerStreamFactory
Method type for constructing ReproducerStream.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
bool pmFlagVerifyPasses
Various pass manager flags that get emitted when generating a reproducer.
ReproducerStreamFactory streamFactory
The factory to use when generating a crash reproducer.
SetVector< std::pair< Pass *, Operation * > > runningPasses
The set of all currently running passes.
bool localReproducer
Flag indicating if reproducer generation should be localized to the failing pass.
Impl(ReproducerStreamFactory &streamFactory, bool localReproducer)
SmallVector< std::unique_ptr< RecoveryReproducerContext > > activeContexts
A record of all of the currently active reproducer contexts.
void attachResourceParser(ParserConfig &config)
Attach an assembly resource parser to 'config' that collects the MLIR reproducer configuration into t...
LogicalResult apply(PassManager &pm) const
Apply the reproducer options to 'pm' and its context.
Streams on which to output crash reproducer.
This class contains all of the context for generating a recovery reproducer.
void disable()
Disable this reproducer context.
~RecoveryReproducerContext()
RecoveryReproducerContext(std::string passPipelineStr, Operation *op, ReproducerStreamFactory &streamFactory, bool verifyPasses)
void generate(std::string &description)
Generate a reproducer with the current context.
void enable()
Enable a previously disabled reproducer context.