MLIR: lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/FileSystem.h"
28 #include "llvm/Support/SourceMgr.h"
29 #include "llvm/Support/raw_ostream.h"
30
31 using namespace mlir;
32
33 #define DEBUG_TYPE "transform-dialect-interpreter-utils"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
35
36
37
38
39
40
44 for (const std::string &path : paths) {
46
47 if (llvm::sys::fs::is_regular_file(path)) {
48 LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
49 fileNames.push_back(path);
50 continue;
51 }
52
53 if (!llvm::sys::fs::is_directory(path)) {
55 << "'" << path << "' is neither a file nor a directory";
56 }
57
58 LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
59
60 std::error_code ec;
61 for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
62 it != itEnd && !ec; it.increment(ec)) {
63 const std::string &fileName = it->path();
64
65 if (it->type() != llvm::sys::fs::file_type::regular_file &&
66 it->type() != llvm::sys::fs::file_type::symlink_file) {
67 LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
68 << "'\n");
69 continue;
70 }
71
72 if (!StringRef(fileName).ends_with(".mlir")) {
73 LLVM_DEBUG(DBGS() << " Skipping '" << fileName
74 << "' because it does not end with '.mlir'\n");
75 continue;
76 }
77
78 LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
79 fileNames.push_back(fileName);
80 }
81
82 if (ec)
83 return emitError(loc) << "error while opening files in '" << path
84 << "': " << ec.message();
85 }
86
87 return success();
88 }
89
91 MLIRContext *context, llvm::StringRef transformFileName,
93 if (transformFileName.empty()) {
94 LLVM_DEBUG(
95 DBGS() << "no transform file name specified, assuming the transform "
96 "module is embedded in the IR next to the top-level\n");
97 return success();
98 }
99
100 std::string errorMessage;
101 auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
102 if (!memoryBuffer) {
105 << "failed to open transform file: " << errorMessage;
106 }
107
108 llvm::SourceMgr sourceMgr;
109 sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
110 transformModule =
112 if (!transformModule) {
113
114
115
116 return failure();
117 }
119 }
120
122 return context->getOrLoadDialecttransform::TransformDialect()
123 ->getLibraryModule();
124 }
125
126 transform::TransformOpInterface
128 StringRef entryPoint) {
130 if (module)
131 l.push_back(module);
133 transform::TransformOpInterface transform = nullptr;
135 [&](transform::NamedSequenceOp namedSequenceOp) {
136 if (namedSequenceOp.getSymName() == entryPoint) {
137 transform = casttransform::TransformOpInterface(
138 namedSequenceOp.getOperation());
140 }
142 });
143 if (transform)
144 return transform;
145 }
147 << "could not find a nested named sequence with name: "
148 << entryPoint;
149 return nullptr;
150 }
151
155
158 libraryFileNames)))
159 return failure();
160
161
163 for (const std::string &libraryFileName : libraryFileNames) {
166 context, libraryFileName, parsedLibrary)))
167 return failure();
168 parsedLibraries.push_back(std::move(parsedLibrary));
169 }
170
171
172 auto loc = FileLineColLoc::get(context, "", 0, 0);
174 ModuleOp::create(loc, "__transform");
175 {
176 mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
178
181 mergedParsedLibraries.get(), std::move(parsedLibrary))))
182 return parsedLibrary->emitError()
183 << "failed to merge symbols into shared library module";
184 }
185 }
186
187 transformModule = std::move(mergedParsedLibraries);
188 return success();
189 }
190
192 Operation *payload, Operation *transformRoot, ModuleOp transformModule,
197 cast(transformRoot),
198 transformModule, options);
199 }
200
204 if (bindings.empty()) {
205 return transformRoot.emitError()
206 << "expected at least one binding for the root";
207 }
208 if (bindings.at(0).size() != 1) {
209 return transformRoot.emitError()
210 << "expected one payload to be bound to the first argument, got "
211 << bindings.at(0).size();
212 }
213 auto *payloadRoot = dyn_cast<Operation *>(bindings.at(0).front());
214 if (!payloadRoot) {
215 return transformRoot->emitError() << "expected the object bound to the "
216 "first argument to be an operation";
217 }
218
220
221
222 if (transformModule && !transformModule->isAncestor(transformRoot)) {
226 std::move(clonedTransformModule)))) {
227 return payloadRoot->emitError() << "failed to merge symbols";
228 }
229 }
230
231 LLVM_DEBUG(DBGS() << "Apply\n" << *transformRoot << "\n");
232 LLVM_DEBUG(DBGS() << "To\n" << *payloadRoot << "\n");
233
235 false);
236 }
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
MLIRContext is the top-level object for a collection of MLIR operations.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy get() const
Allow accessing the internal op.
A 2D array where each row may have different length.
ArrayRef< T > at(size_t pos) const
void removeFront()
Removes the first subarray in-place. Invalidates iterators to all rows.
bool empty() const
Returns true if the are no rows in the 2D array.
void push_back(Range &&elements)
Appends the given range of elements as a new row to the 2D array.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
static WalkResult advance()
static WalkResult interrupt()
Options controlling the application of transform operations by the TransformState.
LogicalResult assembleTransformLibraryFromPaths(MLIRContext *context, ArrayRef< std::string > transformLibraryPaths, OwningOpRef< ModuleOp > &transformModule)
Utility to parse, verify, aggregate and link the content of all mlir files nested under transformLibr...
LogicalResult parseTransformModuleFromFile(MLIRContext *context, llvm::StringRef transformFileName, OwningOpRef< ModuleOp > &transformModule)
Utility to parse and verify the content of a transformFileName MLIR file containing a transform diale...
ModuleOp getPreloadedTransformModule(MLIRContext *context)
Utility to load a transform interpreter module from a module that has already been preloaded in the c...
InFlightDiagnostic mergeSymbolsInto(Operation *target, OwningOpRef< Operation * > other)
Merge all symbols from other into target.
LogicalResult expandPathsToMLIRFiles(ArrayRef< std::string > paths, MLIRContext *context, SmallVectorImpl< std::string > &fileNames)
Expands the given list of paths to a list of .mlir files.
TransformOpInterface findTransformEntryPoint(Operation *root, ModuleOp module, StringRef entryPoint=TransformDialect::kTransformEntryPointSymbolName)
Finds the first TransformOpInterface named kTransformEntryPointSymbolName that is either:
LogicalResult applyTransformNamedSequence(Operation *payload, Operation *transformRoot, ModuleOp transformModule, const TransformOptions &options)
Standalone util to apply the named sequence transformRoot to payload IR.
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true, function_ref< void(TransformState &)> stateInitializer=nullptr, function_ref< LogicalResult(TransformState &)> stateExporter=nullptr)
Entry point to the Transform dialect infrastructure.
Include the generated interface declarations.
std::unique_ptr< llvm::MemoryBuffer > openInputFile(llvm::StringRef inputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for reading.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...