MLIR: lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

24 #include "llvm/ADT/StringRef.h"

25 #include "llvm/Support/Casting.h"

26 #include "llvm/Support/Debug.h"

27 #include "llvm/Support/FileSystem.h"

28 #include "llvm/Support/SourceMgr.h"

29 #include "llvm/Support/raw_ostream.h"

30

31 using namespace mlir;

32

33 #define DEBUG_TYPE "transform-dialect-interpreter-utils"

34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")

35

36

37

38

39

40

44 for (const std::string &path : paths) {

46

47 if (llvm::sys::fs::is_regular_file(path)) {

48 LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");

49 fileNames.push_back(path);

50 continue;

51 }

52

53 if (!llvm::sys::fs::is_directory(path)) {

55 << "'" << path << "' is neither a file nor a directory";

56 }

57

58 LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");

59

60 std::error_code ec;

61 for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;

62 it != itEnd && !ec; it.increment(ec)) {

63 const std::string &fileName = it->path();

64

65 if (it->type() != llvm::sys::fs::file_type::regular_file &&

66 it->type() != llvm::sys::fs::file_type::symlink_file) {

67 LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName

68 << "'\n");

69 continue;

70 }

71

72 if (!StringRef(fileName).ends_with(".mlir")) {

73 LLVM_DEBUG(DBGS() << " Skipping '" << fileName

74 << "' because it does not end with '.mlir'\n");

75 continue;

76 }

77

78 LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");

79 fileNames.push_back(fileName);

80 }

81

82 if (ec)

83 return emitError(loc) << "error while opening files in '" << path

84 << "': " << ec.message();

85 }

86

87 return success();

88 }

89

91 MLIRContext *context, llvm::StringRef transformFileName,

93 if (transformFileName.empty()) {

94 LLVM_DEBUG(

95 DBGS() << "no transform file name specified, assuming the transform "

96 "module is embedded in the IR next to the top-level\n");

97 return success();

98 }

99

100 std::string errorMessage;

101 auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);

102 if (!memoryBuffer) {

105 << "failed to open transform file: " << errorMessage;

106 }

107

108 llvm::SourceMgr sourceMgr;

109 sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());

110 transformModule =

112 if (!transformModule) {

113

114

115

116 return failure();

117 }

119 }

120

122 return context->getOrLoadDialecttransform::TransformDialect()

123 ->getLibraryModule();

124 }

125

126 transform::TransformOpInterface

128 StringRef entryPoint) {

130 if (module)

131 l.push_back(module);

133 transform::TransformOpInterface transform = nullptr;

135 [&](transform::NamedSequenceOp namedSequenceOp) {

136 if (namedSequenceOp.getSymName() == entryPoint) {

137 transform = casttransform::TransformOpInterface(

138 namedSequenceOp.getOperation());

140 }

142 });

143 if (transform)

144 return transform;

145 }

147 << "could not find a nested named sequence with name: "

148 << entryPoint;

149 return nullptr;

150 }

151

155

158 libraryFileNames)))

159 return failure();

160

161

163 for (const std::string &libraryFileName : libraryFileNames) {

166 context, libraryFileName, parsedLibrary)))

167 return failure();

168 parsedLibraries.push_back(std::move(parsedLibrary));

169 }

170

171

172 auto loc = FileLineColLoc::get(context, "", 0, 0);

174 ModuleOp::create(loc, "__transform");

175 {

176 mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",

178

181 mergedParsedLibraries.get(), std::move(parsedLibrary))))

182 return parsedLibrary->emitError()

183 << "failed to merge symbols into shared library module";

184 }

185 }

186

187 transformModule = std::move(mergedParsedLibraries);

188 return success();

189 }

190

192 Operation *payload, Operation *transformRoot, ModuleOp transformModule,

197 cast(transformRoot),

198 transformModule, options);

199 }

200

204 if (bindings.empty()) {

205 return transformRoot.emitError()

206 << "expected at least one binding for the root";

207 }

208 if (bindings.at(0).size() != 1) {

209 return transformRoot.emitError()

210 << "expected one payload to be bound to the first argument, got "

211 << bindings.at(0).size();

212 }

213 auto *payloadRoot = dyn_cast<Operation *>(bindings.at(0).front());

214 if (!payloadRoot) {

215 return transformRoot->emitError() << "expected the object bound to the "

216 "first argument to be an operation";

217 }

218

220

221

222 if (transformModule && !transformModule->isAncestor(transformRoot)) {

226 std::move(clonedTransformModule)))) {

227 return payloadRoot->emitError() << "failed to merge symbols";

228 }

229 }

230

231 LLVM_DEBUG(DBGS() << "Apply\n" << *transformRoot << "\n");

232 LLVM_DEBUG(DBGS() << "To\n" << *payloadRoot << "\n");

233

235 false);

236 }

static std::string diag(const llvm::Value &value)

static llvm::ManagedStatic< PassManagerOptions > options

static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)

MLIRContext is the top-level object for a collection of MLIR operations.

T * getOrLoadDialect()

Get (or create) a dialect for the given derived dialect type.

Operation is the basic unit of execution within MLIR.

std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)

Walk the operation by calling the callback for each nested operation (including this one),...

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

OpTy get() const

Allow accessing the internal op.

A 2D array where each row may have different length.

ArrayRef< T > at(size_t pos) const

void removeFront()

Removes the first subarray in-place. Invalidates iterators to all rows.

bool empty() const

Returns true if the are no rows in the 2D array.

void push_back(Range &&elements)

Appends the given range of elements as a new row to the 2D array.

static Operation * getNearestSymbolTable(Operation *from)

Returns the nearest symbol table from a given operation from.

static WalkResult advance()

static WalkResult interrupt()

Options controlling the application of transform operations by the TransformState.

LogicalResult assembleTransformLibraryFromPaths(MLIRContext *context, ArrayRef< std::string > transformLibraryPaths, OwningOpRef< ModuleOp > &transformModule)

Utility to parse, verify, aggregate and link the content of all mlir files nested under transformLibr...

LogicalResult parseTransformModuleFromFile(MLIRContext *context, llvm::StringRef transformFileName, OwningOpRef< ModuleOp > &transformModule)

Utility to parse and verify the content of a transformFileName MLIR file containing a transform diale...

ModuleOp getPreloadedTransformModule(MLIRContext *context)

Utility to load a transform interpreter module from a module that has already been preloaded in the c...

InFlightDiagnostic mergeSymbolsInto(Operation *target, OwningOpRef< Operation * > other)

Merge all symbols from other into target.

LogicalResult expandPathsToMLIRFiles(ArrayRef< std::string > paths, MLIRContext *context, SmallVectorImpl< std::string > &fileNames)

Expands the given list of paths to a list of .mlir files.

TransformOpInterface findTransformEntryPoint(Operation *root, ModuleOp module, StringRef entryPoint=TransformDialect::kTransformEntryPointSymbolName)

Finds the first TransformOpInterface named kTransformEntryPointSymbolName that is either:

LogicalResult applyTransformNamedSequence(Operation *payload, Operation *transformRoot, ModuleOp transformModule, const TransformOptions &options)

Standalone util to apply the named sequence transformRoot to payload IR.

LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true, function_ref< void(TransformState &)> stateInitializer=nullptr, function_ref< LogicalResult(TransformState &)> stateExporter=nullptr)

Entry point to the Transform dialect infrastructure.

Include the generated interface declarations.

std::unique_ptr< llvm::MemoryBuffer > openInputFile(llvm::StringRef inputFilename, std::string *errorMessage=nullptr)

Open the file specified by its name for reading.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...