MLIR: lib/Pass/PassCrashRecovery.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

17 #include "llvm/ADT/STLExtras.h"

18 #include "llvm/ADT/ScopeExit.h"

19 #include "llvm/ADT/SetVector.h"

20 #include "llvm/Support/CommandLine.h"

21 #include "llvm/Support/CrashRecoveryContext.h"

22 #include "llvm/Support/ManagedStatic.h"

23 #include "llvm/Support/Mutex.h"

24 #include "llvm/Support/Signals.h"

25 #include "llvm/Support/Threading.h"

26 #include "llvm/Support/ToolOutputFile.h"

27

28 using namespace mlir;

30

31

32

33

34

35 namespace mlir {

36 namespace detail {

37

38

39

43 bool verifyPasses);

45

46

47 void generate(std::string &description);

48

49

50

51 void disable();

52

53

54 void enable();

55

56 private:

57

58 static void crashHandler(void *);

59

60

61 static void registerSignalHandler();

62

63

64 std::string pipelineElements;

65

66

68

69

70

72

73

74 bool disableThreads;

75 bool verifyPasses;

76

77

78

79

80

81 static llvm::ManagedStatic<llvm::sys::SmartMutex> reproducerMutex;

82 static llvm::ManagedStatic<

83 llvm::SmallSetVector<RecoveryReproducerContext *, 1>>

84 reproducerSet;

85 };

86 }

87 }

88

89 llvm::ManagedStatic<llvm::sys::SmartMutex>

90 RecoveryReproducerContext::reproducerMutex;

91 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>

92 RecoveryReproducerContext::reproducerSet;

93

95 std::string passPipelineStr, Operation *op,

97 : pipelineElements(std::move(passPipelineStr)),

98 preCrashOperation(op->clone()), streamFactory(streamFactory),

99 disableThreads(!op->getContext()->isMultithreadingEnabled()),

100 verifyPasses(verifyPasses) {

102 }

103

105

106 preCrashOperation->erase();

108 }

109

112 const std::string &pipelineElements,

113 bool disableThreads, bool verifyPasses) {

114 llvm::raw_string_ostream descOS(description);

115

116

117 std::string error;

118 std::unique_ptr stream = factory(error);

119 if (!stream) {

120 descOS << "failed to create output stream: " << error;

121 return;

122 }

123 descOS << "reproducer generated at `" << stream->description() << "`";

124

125 std::string pipeline =

128 state.attachResourcePrinter(

130 builder.buildString("pipeline", pipeline);

131 builder.buildBool("disable_threading", disableThreads);

132 builder.buildBool("verify_each", verifyPasses);

133 });

134

135

136 op->print(stream->os(), state);

137 }

138

140 appendReproducer(description, preCrashOperation, streamFactory,

141 pipelineElements, disableThreads, verifyPasses);

142 }

143

145 llvm::sys::SmartScopedLock lock(*reproducerMutex);

146 reproducerSet->remove(this);

147 if (reproducerSet->empty())

148 llvm::CrashRecoveryContext::Disable();

149 }

150

152 llvm::sys::SmartScopedLock lock(*reproducerMutex);

153 if (reproducerSet->empty())

154 llvm::CrashRecoveryContext::Enable();

155 registerSignalHandler();

156 reproducerSet->insert(this);

157 }

158

159 void RecoveryReproducerContext::crashHandler(void *) {

160

161

162

164 std::string description;

165 context->generate(description);

166

167

168 emitError(context->preCrashOperation->getLoc())

169 << "A signal was caught while processing the MLIR module:"

170 << description << "; marking pass as failed";

171 }

172 }

173

174 void RecoveryReproducerContext::registerSignalHandler() {

175

176 static bool registered =

177 (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);

178 (void)registered;

179 }

180

181

182

183

184

188

189

191

192

193

195

196

198

199

200

202

203

205 };

206

209 : impl(std::make_unique<Impl>(streamFactory, localReproducer)) {}

211

214 bool pmFlagVerifyPasses) {

215 assert((impl->localReproducer ||

217 "expected multi-threading to be disabled when generating a local "

218 "reproducer");

219

220 llvm::CrashRecoveryContext::Enable();

221 impl->pmFlagVerifyPasses = pmFlagVerifyPasses;

222

223

224

225 if (impl->localReproducer)

227 }

228

229 static void

231 std::pair<Pass *, Operation *> passOpPair) {

232 os << "`" << passOpPair.first->getName() << "` on "

233 << "'" << passOpPair.second->getName() << "' operation";

234 if (SymbolOpInterface symbol = dyn_cast(passOpPair.second))

235 os << ": @" << symbol.getName();

236 }

237

239 LogicalResult executionResult) {

240

241 if (impl->activeContexts.empty())

242 return;

243

244

245 if (succeeded(executionResult))

246 return impl->activeContexts.clear();

247

249 << "Failures have been detected while "

250 "processing an MLIR pass pipeline";

251

252

253

254 if (impl->localReproducer) {

255 assert(impl->activeContexts.size() == 1 && "expected one active context");

256

257

258 std::string description;

259 impl->activeContexts.front()->generate(description);

260

261

262 Diagnostic &note = diag.attachNote() << "Pipeline failed while executing [";

263 llvm::interleaveComma(impl->runningPasses, note,

264 [&](const std::pair<Pass *, Operation *> &value) {

265 formatPassOpReproducerMessage(note, value);

266 });

267 note << "]: " << description;

268 impl->runningPasses.clear();

269 impl->activeContexts.clear();

270 return;

271 }

272

273

274

275

276 assert(impl->activeContexts.size() == impl->runningPasses.size() &&

277 "expected running passes to match active contexts");

278

279

281 std::string description;

282 reproducerContext.generate(description);

283

284

285 Diagnostic &note = diag.attachNote() << "Pipeline failed while executing ";

287 note << ": " << description;

288

289 impl->activeContexts.clear();

290 impl->runningPasses.clear();

291 }

292

295

296

297 impl->runningPasses.insert(std::make_pair(pass, op));

298 if (impl->localReproducer)

299 return;

300

301

302

303 if (impl->activeContexts.empty())

304 impl->activeContexts.back()->disable();

305

306

309 scopes.push_back(op->getName());

310 op = parentOp;

311 }

312

313

314

315 std::string passStr;

316 llvm::raw_string_ostream passOS(passStr);

317 for (OperationName scope : llvm::reverse(scopes))

318 passOS << scope << "(";

320 for (unsigned i = 0, e = scopes.size(); i < e; ++i)

321 passOS << ")";

322

323 impl->activeContexts.push_back(std::make_unique(

324 passStr, op, impl->streamFactory, impl->pmFlagVerifyPasses));

325 }

328 std::string passStr;

329 llvm::raw_string_ostream passOS(passStr);

330 llvm::interleaveComma(

332

333 impl->activeContexts.push_back(std::make_unique(

334 passStr, op, impl->streamFactory, impl->pmFlagVerifyPasses));

335 }

336

339

340 impl->runningPasses.remove(std::make_pair(pass, op));

341 if (impl->localReproducer) {

342 impl->activeContexts.pop_back();

343

344

345

346 if (impl->activeContexts.empty())

347 impl->activeContexts.back()->enable();

348 }

349 }

350

351

352

353

354

355 namespace {

359 ~CrashReproducerInstrumentation() override = default;

360

361 void runBeforePass(Pass *pass, Operation *op) override {

362 if (!isa(pass))

363 generator.prepareReproducerFor(pass, op);

364 }

365

366 void runAfterPass(Pass *pass, Operation *op) override {

367 if (!isa(pass))

368 generator.removeLastReproducerFor(pass, op);

369 }

370

371 void runAfterPassFailed(Pass *pass, Operation *op) override {

372

373 if (alreadyFailed)

374 return;

375

376 alreadyFailed = true;

377 generator.finalize(op, failure());

378 }

379

380 private:

381

383 bool alreadyFailed = false;

384 };

385 }

386

387

388

389

390

391 namespace {

392

393

395 FileReproducerStream(std::unique_ptrllvm::ToolOutputFile outputFile)

396 : outputFile(std::move(outputFile)) {}

397 ~FileReproducerStream() override { outputFile->keep(); }

398

399

400 StringRef description() override { return outputFile->getFilename(); }

401

402

403 raw_ostream &os() override { return outputFile->os(); }

404

405 private:

406

407 std::unique_ptrllvm::ToolOutputFile outputFile = nullptr;

408 };

409 }

410

411

412

413

414

415 LogicalResult PassManager::runWithCrashRecovery(Operation *op,

417 crashReproGenerator->initialize(getPasses(), op, verifyPasses);

418

419

420 LogicalResult passManagerResult = failure();

421 llvm::CrashRecoveryContext recoveryContext;

422 recoveryContext.RunSafelyOnThread(

423 [&] { passManagerResult = runPasses(op, am); });

424 crashReproGenerator->finalize(op, passManagerResult);

425 return passManagerResult;

426 }

427

430

431

432 std::string filename = outputFile.str();

433 return [filename](std::string &error) -> std::unique_ptr {

434 std::unique_ptrllvm::ToolOutputFile outputFile =

436 if (!outputFile) {

437 error = "Failed to create reproducer stream: " + error;

438 return nullptr;

439 }

440 return std::make_unique(std::move(outputFile));

441 };

442 }

443

445 raw_ostream &os, StringRef anchorName,

447 bool pretty = false);

448

450 StringRef anchorName,

452 Operation *op, StringRef outputFile, bool disableThreads,

453 bool verifyPasses) {

454

455 std::string description;

456 std::string pipelineStr;

457 llvm::raw_string_ostream passOS(pipelineStr);

460 pipelineStr, disableThreads, verifyPasses);

461 return description;

462 }

463

465 bool genLocalReproducer) {

467 genLocalReproducer);

468 }

469

472 assert(!crashReproGenerator &&

473 "crash reproducer has already been initialized");

474 if (genLocalReproducer && getContext()->isMultithreadingEnabled())

475 llvm::report_fatal_error(

476 "Local crash reproduction can't be setup on a "

477 "pass-manager without disabling multi-threading first.");

478

479 crashReproGenerator = std::make_unique(

480 factory, genLocalReproducer);

481 addInstrumentation(

482 std::make_unique(*crashReproGenerator));

483 }

484

485

486

487

488

491 if (entry.getKey() == "pipeline") {

492 FailureOrstd::string value = entry.parseAsString();

493 if (succeeded(value))

494 this->pipeline = std::move(*value);

495 return value;

496 }

497 if (entry.getKey() == "disable_threading") {

498 FailureOr value = entry.parseAsBool();

499 if (succeeded(value))

500 this->disableThreading = *value;

501 return value;

502 }

503 if (entry.getKey() == "verify_each") {

504 FailureOr value = entry.parseAsBool();

505 if (succeeded(value))

506 this->verifyEach = *value;

507 return value;

508 }

509 return entry.emitError() << "unknown 'mlir_reproducer' resource key '"

510 << entry.getKey() << "'";

511 };

512 config.attachResourceParser("mlir_reproducer", parseFn);

513 }

514

516 if (pipeline.has_value()) {

518 if (failed(reproPm))

519 return failure();

520 static_cast<OpPassManager &>(pm) = std::move(*reproPm);

521 }

522

523 if (disableThreading.has_value())

525

526 if (verifyEach.has_value())

528

529 return success();

530 }

static MLIRContext * getContext(OpFoldResult val)

static const mlir::GenInfo * generator

static std::string diag(const llvm::Value &value)

static void appendReproducer(std::string &description, Operation *op, const ReproducerStreamFactory &factory, const std::string &pipelineElements, bool disableThreads, bool verifyPasses)

static void formatPassOpReproducerMessage(Diagnostic &os, std::pair< Pass *, Operation * > passOpPair)

void printAsTextualPipeline(raw_ostream &os, StringRef anchorName, const llvm::iterator_range< OpPassManager::pass_iterator > &passes, bool pretty=false)

static ReproducerStreamFactory makeReproducerStreamFactory(StringRef outputFile)

This class represents an analysis manager for a particular operation instance.

This class represents a single parsed resource entry.

This class is used to build resource entries for use by the printer.

virtual void buildString(StringRef key, StringRef data)=0

Build a resource entry represented by the given human-readable string value.

virtual void buildBool(StringRef key, bool data)=0

Build a resource entry represented by the given bool.

This class provides management for the lifetime of the state used when printing the IR.

This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.

This class represents a diagnostic that is inflight and set to be reported.

void disableMultithreading(bool disable=true)

Set the flag specifying if multi-threading is disabled by the context.

bool isMultithreadingEnabled()

Return true if multi-threading is enabled by the context.

This class represents a pass manager that runs passes on either a specific operation type,...

StringRef getStringRef() const

Return the name of this operation. This always succeeds.

Operation is the basic unit of execution within MLIR.

void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)

MLIRContext * getContext()

Return the context this operation is associated with.

Location getLoc()

The source location the operation was defined or derived from.

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

OperationName getName()

The name of an operation is the key identifier for it.

void erase()

Remove this operation from its parent block and delete it.

This class represents a configuration for the MLIR assembly parser.

PassInstrumentation provides several entry points into the pass manager infrastructure.

The main pass manager and pipeline builder.

MLIRContext * getContext() const

Return an instance of the context.

void enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer=false)

Enable support for the pass manager to generate a reproducer on the event of a crash or a pass failur...

void enableVerifier(bool enabled=true)

Runs the verifier after each individual pass.

The abstract base pass class.

void printAsTextualPipeline(raw_ostream &os, bool pretty=false)

Prints out the pass in the textual representation of pipelines.

void initialize(iterator_range< PassManager::pass_iterator > passes, Operation *op, bool pmFlagVerifyPasses)

Initialize the generator in preparation for reproducer generation.

void removeLastReproducerFor(Pass *pass, Operation *op)

Remove the last recorded reproducer anchored at the given pass and operation.

void finalize(Operation *rootOp, LogicalResult executionResult)

Finalize the current run of the generator, generating any necessary reproducers if the provided execu...

void prepareReproducerFor(Pass *pass, Operation *op)

Prepare a new reproducer for the given pass, operating on op.

~PassCrashReproducerGenerator()

PassCrashReproducerGenerator(ReproducerStreamFactory &streamFactory, bool localReproducer)

Include the generated interface declarations.

const FrozenRewritePatternSet GreedyRewriteConfig config

std::unique_ptr< llvm::ToolOutputFile > openOutputFile(llvm::StringRef outputFilename, std::string *errorMessage=nullptr)

Open the file specified by its name for writing.

std::string makeReproducer(StringRef anchorName, const llvm::iterator_range< OpPassManager::pass_iterator > &passes, Operation *op, StringRef outputFile, bool disableThreads=false, bool verifyPasses=false)

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

std::function< std::unique_ptr< ReproducerStream >(std::string &error)> ReproducerStreamFactory

Method type for constructing ReproducerStream.

Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)

LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())

Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.

bool pmFlagVerifyPasses

Various pass manager flags that get emitted when generating a reproducer.

ReproducerStreamFactory streamFactory

The factory to use when generating a crash reproducer.

SetVector< std::pair< Pass *, Operation * > > runningPasses

The set of all currently running passes.

bool localReproducer

Flag indicating if reproducer generation should be localized to the failing pass.

Impl(ReproducerStreamFactory &streamFactory, bool localReproducer)

SmallVector< std::unique_ptr< RecoveryReproducerContext > > activeContexts

A record of all of the currently active reproducer contexts.

void attachResourceParser(ParserConfig &config)

Attach an assembly resource parser to 'config' that collects the MLIR reproducer configuration into t...

LogicalResult apply(PassManager &pm) const

Apply the reproducer options to 'pm' and its context.

Streams on which to output crash reproducer.

This class contains all of the context for generating a recovery reproducer.

void disable()

Disable this reproducer context.

~RecoveryReproducerContext()

RecoveryReproducerContext(std::string passPipelineStr, Operation *op, ReproducerStreamFactory &streamFactory, bool verifyPasses)

void generate(std::string &description)

Generate a reproducer with the current context.

void enable()

Enable a previously disabled reproducer context.