MLIR: include/mlir/Bytecode/BytecodeImplementation.h Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14 #ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
15 #define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
16
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/Twine.h"
24
25 namespace mlir {
26
27
28
29
30
31
33 public:
35 };
36
37
38
39
40
41
42
43
44
45
47 public:
49
50
52
53
54 virtual FailureOr<const DialectVersion *>
56 template
59 }
60
61
63
64
66
67
68
69
70
71 template <typename T, typename CallbackFn>
73 uint64_t size;
75 return failure();
76 result.reserve(size);
77
78 for (uint64_t i = 0; i < size; ++i) {
79
80
81 if constexpr (llvm::function_traits<std::decay_t>::num_args) {
82 T element = {};
83 if (failed(callback(element)))
84 return failure();
85 result.emplace_back(std::move(element));
86 } else {
87 FailureOr element = callback();
88 if (failed(element))
89 return failure();
90 result.emplace_back(std::move(*element));
91 }
92 }
93 return success();
94 }
95
96
97
98
99
100
102
103
105
106 template
109 }
110 template
114 return failure();
115 if ((result = dyn_cast(baseResult)))
116 return success();
117 return emitError() << "expected " << llvm::getTypeName()
118 << ", but got: " << baseResult;
119 }
120 template
124 return failure();
125 if (!baseResult)
126 return success();
127 if ((result = dyn_cast(baseResult)))
128 return success();
129 return emitError() << "expected " << llvm::getTypeName()
130 << ", but got: " << baseResult;
131 }
132
133
135 template
137 return readList(types, [this](T &type) { return readType(type); });
138 }
139 template
141 Type baseResult;
142 if (failed(readType(baseResult)))
143 return failure();
144 if ((result = dyn_cast(baseResult)))
145 return success();
146 return emitError() << "expected " << llvm::getTypeName()
147 << ", but got: " << baseResult;
148 }
149
150
151 template
154 if (failed(handle))
155 return failure();
156 if (auto *result = dyn_cast(&*handle))
157 return std::move(*result);
158 return emitError() << "provided resource handle differs from the "
159 "expected resource type";
160 }
161
162
163
164
165
166
167 virtual LogicalResult readVarInt(uint64_t &result) = 0;
168
169
174 }
175
176
177
180 return failure();
181 flag = result & 1;
182 result >>= 1;
183 return success();
184 }
185
186
187
188
189
190
191
192 template
194 static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
195 static_assert(std::is_integral::value, "expects integer");
196 uint64_t nonZeroesCount;
197 bool useSparseEncoding;
199 return failure();
200 if (nonZeroesCount == 0)
201 return success();
202 if (!useSparseEncoding) {
203
204 if (nonZeroesCount > array.size()) {
205 emitError("trying to read an array of ")
206 << nonZeroesCount << " but only " << array.size()
207 << " storage available.";
208 return failure();
209 }
210 for (int64_t index : llvm::seq<int64_t>(0, nonZeroesCount)) {
211 uint64_t value;
213 return failure();
214 array[index] = value;
215 }
216 return success();
217 }
218
219
220 uint64_t indexBitSize;
221 if (failed(readVarInt(indexBitSize)))
222 return failure();
223 constexpr uint64_t maxIndexBitSize = 8;
224 if (indexBitSize > maxIndexBitSize) {
225 emitError("reading sparse array with indexing above 8 bits: ")
226 << indexBitSize;
227 return failure();
228 }
229 for (uint32_t count : llvm::seq<uint32_t>(0, nonZeroesCount)) {
230 (void)count;
231 uint64_t indexValuePair;
232 if (failed(readVarInt(indexValuePair)))
233 return failure();
234 uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize));
235 uint64_t value = indexValuePair >> indexBitSize;
236 if (index >= array.size()) {
237 emitError("reading a sparse array found index ")
238 << index << " but only " << array.size() << " storage available.";
239 return failure();
240 }
241 array[index] = value;
242 }
243 return success();
244 }
245
246
248
249
250
251 virtual FailureOr
253
254
255 virtual LogicalResult readString(StringRef &result) = 0;
256
257
259
260
261 virtual LogicalResult readBool(bool &result) = 0;
262
263 private:
264
265 virtual FailureOr readResourceHandle() = 0;
266 };
267
268
269
270
271
272
273
274
275
276
278 public:
280
281
282
283
284
285
286
287 template <typename RangeT, typename CallbackFn>
288 void writeList(RangeT &&range, CallbackFn &&callback) {
290 for (auto &element : range)
291 callback(element);
292 }
293
294
297 template
300 }
301
302
304 template
307 }
308
309
310 virtual void
312
313
314
315
316
317
318
320
321
322
326 }
327
328
330 writeVarInt((value << 1) | (flag ? 1 : 0));
331 }
332
333
334
335
336
337
338
339 template
341 static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
342 static_assert(std::is_integral::value, "expects integer");
343 uint32_t size = array.size();
344 uint32_t nonZeroesCount = 0, lastIndex = 0;
345 for (uint32_t index : llvm::seq<uint32_t>(0, size)) {
346 if (!array[index])
347 continue;
348 nonZeroesCount++;
349 lastIndex = index;
350 }
351
352
353 if (lastIndex > 256 || nonZeroesCount > size / 2) {
354
356 for (const T &elt : array)
358 return;
359 }
360
361
363 if (nonZeroesCount == 0)
364 return;
365
366 int indexBitSize = llvm::Log2_32_Ceil(lastIndex + 1);
368 for (uint32_t index : llvm::seq<uint32_t>(0, lastIndex + 1)) {
369 T value = array[index];
370 if (!value)
371 continue;
372 uint64_t indexValuePair = (value << indexBitSize) | (index);
374 }
375 }
376
377
378
379
380
381
382
384
385
386
387
389
390
391
392
393
395
396
397
398
400
401
403
404
406
407
408 virtual FailureOr<const DialectVersion *>
410
411 template
414 }
415 };
416
417
418
419
420
423 public:
425
426
427
428
429
430
431
432
434 reader.emitError() << "dialect " << getDialect()->getNamespace()
435 << " does not support reading attributes from bytecode";
437 }
438
439
440
441
443 reader.emitError() << "dialect " << getDialect()->getNamespace()
444 << " does not support reading types from bytecode";
445 return Type();
446 }
447
448
449
450
451
452
453
454
455
458 return failure();
459 }
460
461
462
463
464
467 return failure();
468 }
469
470
472
473
474
475 virtual std::unique_ptr
477 reader.emitError("Dialect does not support versioning");
478 return nullptr;
479 }
480
481
482
483
484
485 virtual LogicalResult
488 return success();
489 }
490 };
491
492
493 template <typename T, typename... Ts>
495 FailureOr &value, Ts &&...params) {
497 if (failed(handle))
498 return failure();
499 if (auto *result = dyn_cast(&*handle)) {
500 value = std::move(*result);
501 return success();
502 }
503 return failure();
504 }
505
506
507
508 template <typename T, typename... Ts>
510
511 if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) {
512 (void)context;
513 return T::get(std::forward(params)...);
516 return T::get(context, std::forward(params)...);
517 } else {
518
519 return T::Base::get(context, std::forward(params)...);
520 }
521 }
522
523 }
524
525 #endif
This class represents an opaque handle to a dialect resource entry.
Attributes are known-constant values of operations.
virtual Type readType(DialectBytecodeReader &reader) const
Read a type belonging to this dialect from the given reader.
virtual LogicalResult upgradeFromVersion(Operation *topLevelOp, const DialectVersion &version) const
Hook invoked after parsing completed, if a version directive was present and included an entry for th...
virtual Attribute readAttribute(DialectBytecodeReader &reader) const
Read an attribute belonging to this dialect from the given reader.
virtual std::unique_ptr< DialectVersion > readVersion(DialectBytecodeReader &reader) const
virtual LogicalResult writeAttribute(Attribute attr, DialectBytecodeWriter &writer) const
Write the given attribute, which belongs to this dialect, to the given writer.
virtual LogicalResult writeType(Type type, DialectBytecodeWriter &writer) const
Write the given type, which belongs to this dialect, to the given writer.
virtual void writeVersion(DialectBytecodeWriter &writer) const
Write the version of this dialect to the given writer.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual ~DialectBytecodeReader()=default
virtual LogicalResult readBlob(ArrayRef< char > &result)=0
Read a blob from the bytecode.
LogicalResult readAttributes(SmallVectorImpl< T > &attrs)
FailureOr< ResourceT > readResourceHandle()
Read a handle to a dialect resource.
virtual MLIRContext * getContext() const =0
Retrieve the context associated to the reader.
virtual FailureOr< APInt > readAPIntWithKnownWidth(unsigned bitWidth)=0
Read an APInt that is known to have been encoded with the given width.
LogicalResult readTypes(SmallVectorImpl< T > &types)
virtual LogicalResult readBool(bool &result)=0
Read a bool from the bytecode.
virtual LogicalResult readVarInt(uint64_t &result)=0
Read a variable width integer.
virtual LogicalResult readType(Type &result)=0
Read a reference to the given type.
virtual uint64_t getBytecodeVersion() const =0
Return the bytecode version being read.
LogicalResult readType(T &result)
LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag)
Parse a variable length encoded integer whose low bit is used to encode an unrelated flag,...
LogicalResult readSignedVarInts(SmallVectorImpl< int64_t > &result)
LogicalResult readOptionalAttribute(T &result)
FailureOr< const DialectVersion * > getDialectVersion() const
virtual LogicalResult readOptionalAttribute(Attribute &attr)=0
Read an optional reference to the given attribute.
LogicalResult readAttribute(T &result)
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
LogicalResult readSparseArray(MutableArrayRef< T > array)
Read a "small" sparse array of integer <= 32 bits elements, where index/value pairs can be compressed...
virtual FailureOr< APFloat > readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics)=0
Read an APFloat that is known to have been encoded with the given semantics.
virtual FailureOr< const DialectVersion * > getDialectVersion(StringRef dialectName) const =0
Retrieve the dialect version by name if available.
virtual LogicalResult readString(StringRef &result)=0
Read a string from the bytecode.
virtual LogicalResult readSignedVarInt(int64_t &result)=0
Read a signed variable width integer.
LogicalResult readList(SmallVectorImpl< T > &result, CallbackFn &&callback)
Read out a list of elements, invoking the provided callback for each element.
virtual LogicalResult readAttribute(Attribute &result)=0
Read a reference to the given attribute.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
virtual void writeOptionalAttribute(Attribute attr)=0
FailureOr< const DialectVersion * > getDialectVersion() const
virtual void writeVarInt(uint64_t value)=0
Write a variable width integer to the output stream.
void writeVarIntWithFlag(uint64_t value, bool flag)
Write a VarInt and a flag packed together.
void writeList(RangeT &&range, CallbackFn &&callback)
Write out a list of elements, invoking the provided callback for each element.
void writeSparseArray(ArrayRef< T > array)
Write out a "small" sparse array of integer <= 32 bits elements, where index/value pairs can be compr...
virtual void writeType(Type type)=0
Write a reference to the given type.
virtual FailureOr< const DialectVersion * > getDialectVersion(StringRef dialectName) const =0
Retrieve the dialect version by name if available.
virtual void writeAPIntWithKnownWidth(const APInt &value)=0
Write an APInt to the bytecode stream whose bitwidth will be known externally at read time.
virtual void writeOwnedBlob(ArrayRef< char > blob)=0
Write a blob to the bytecode, which is owned by the caller and is guaranteed to not die before the en...
virtual void writeAttribute(Attribute attr)=0
Write a reference to the given attribute.
virtual ~DialectBytecodeWriter()=default
void writeAttributes(ArrayRef< T > attrs)
virtual void writeSignedVarInt(int64_t value)=0
Write a signed variable width integer to the output stream.
virtual void writeResourceHandle(const AsmDialectResourceHandle &resource)=0
Write the given handle to a dialect resource.
virtual void writeAPFloatWithKnownSemantics(const APFloat &value)=0
Write an APFloat to the bytecode stream whose semantics will be known externally at read time.
void writeSignedVarInts(ArrayRef< int64_t > value)
virtual void writeOwnedBool(bool value)=0
Write a bool to the output stream.
virtual int64_t getBytecodeVersion() const =0
Return the bytecode version being emitted for.
virtual void writeOwnedString(StringRef str)=0
Write a string to the bytecode, which is owned by the caller and is guaranteed to not die before the ...
void writeTypes(ArrayRef< T > types)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
virtual ~DialectVersion()=default
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
The base class used for all derived interface types.
decltype(T::get(std::declval< Ts >()...)) has_get_method
Include the generated interface declarations.
static LogicalResult readResourceHandle(DialectBytecodeReader &reader, FailureOr< T > &value, Ts &&...params)
Helper for resource handle reading that returns LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...