MLIR: include/mlir/Bytecode/BytecodeImplementation.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14 #ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H

15 #define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H

16

22 #include "llvm/ADT/STLExtras.h"

23 #include "llvm/ADT/Twine.h"

24

25 namespace mlir {

26

27

28

29

30

31

33 public:

35 };

36

37

38

39

40

41

42

43

44

45

47 public:

49

50

52

53

54 virtual FailureOr<const DialectVersion *>

56 template

59 }

60

61

63

64

66

67

68

69

70

71 template <typename T, typename CallbackFn>

73 uint64_t size;

75 return failure();

76 result.reserve(size);

77

78 for (uint64_t i = 0; i < size; ++i) {

79

80

81 if constexpr (llvm::function_traits<std::decay_t>::num_args) {

82 T element = {};

83 if (failed(callback(element)))

84 return failure();

85 result.emplace_back(std::move(element));

86 } else {

87 FailureOr element = callback();

88 if (failed(element))

89 return failure();

90 result.emplace_back(std::move(*element));

91 }

92 }

93 return success();

94 }

95

96

97

98

99

100

102

103

105

106 template

109 }

110 template

114 return failure();

115 if ((result = dyn_cast(baseResult)))

116 return success();

117 return emitError() << "expected " << llvm::getTypeName()

118 << ", but got: " << baseResult;

119 }

120 template

124 return failure();

125 if (!baseResult)

126 return success();

127 if ((result = dyn_cast(baseResult)))

128 return success();

129 return emitError() << "expected " << llvm::getTypeName()

130 << ", but got: " << baseResult;

131 }

132

133

135 template

137 return readList(types, [this](T &type) { return readType(type); });

138 }

139 template

141 Type baseResult;

142 if (failed(readType(baseResult)))

143 return failure();

144 if ((result = dyn_cast(baseResult)))

145 return success();

146 return emitError() << "expected " << llvm::getTypeName()

147 << ", but got: " << baseResult;

148 }

149

150

151 template

154 if (failed(handle))

155 return failure();

156 if (auto *result = dyn_cast(&*handle))

157 return std::move(*result);

158 return emitError() << "provided resource handle differs from the "

159 "expected resource type";

160 }

161

162

163

164

165

166

167 virtual LogicalResult readVarInt(uint64_t &result) = 0;

168

169

174 }

175

176

177

180 return failure();

181 flag = result & 1;

182 result >>= 1;

183 return success();

184 }

185

186

187

188

189

190

191

192 template

194 static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");

195 static_assert(std::is_integral::value, "expects integer");

196 uint64_t nonZeroesCount;

197 bool useSparseEncoding;

199 return failure();

200 if (nonZeroesCount == 0)

201 return success();

202 if (!useSparseEncoding) {

203

204 if (nonZeroesCount > array.size()) {

205 emitError("trying to read an array of ")

206 << nonZeroesCount << " but only " << array.size()

207 << " storage available.";

208 return failure();

209 }

210 for (int64_t index : llvm::seq<int64_t>(0, nonZeroesCount)) {

211 uint64_t value;

213 return failure();

214 array[index] = value;

215 }

216 return success();

217 }

218

219

220 uint64_t indexBitSize;

221 if (failed(readVarInt(indexBitSize)))

222 return failure();

223 constexpr uint64_t maxIndexBitSize = 8;

224 if (indexBitSize > maxIndexBitSize) {

225 emitError("reading sparse array with indexing above 8 bits: ")

226 << indexBitSize;

227 return failure();

228 }

229 for (uint32_t count : llvm::seq<uint32_t>(0, nonZeroesCount)) {

230 (void)count;

231 uint64_t indexValuePair;

232 if (failed(readVarInt(indexValuePair)))

233 return failure();

234 uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize));

235 uint64_t value = indexValuePair >> indexBitSize;

236 if (index >= array.size()) {

237 emitError("reading a sparse array found index ")

238 << index << " but only " << array.size() << " storage available.";

239 return failure();

240 }

241 array[index] = value;

242 }

243 return success();

244 }

245

246

248

249

250

251 virtual FailureOr

253

254

255 virtual LogicalResult readString(StringRef &result) = 0;

256

257

259

260

261 virtual LogicalResult readBool(bool &result) = 0;

262

263 private:

264

265 virtual FailureOr readResourceHandle() = 0;

266 };

267

268

269

270

271

272

273

274

275

276

278 public:

280

281

282

283

284

285

286

287 template <typename RangeT, typename CallbackFn>

288 void writeList(RangeT &&range, CallbackFn &&callback) {

290 for (auto &element : range)

291 callback(element);

292 }

293

294

297 template

300 }

301

302

304 template

307 }

308

309

310 virtual void

312

313

314

315

316

317

318

320

321

322

326 }

327

328

330 writeVarInt((value << 1) | (flag ? 1 : 0));

331 }

332

333

334

335

336

337

338

339 template

341 static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");

342 static_assert(std::is_integral::value, "expects integer");

343 uint32_t size = array.size();

344 uint32_t nonZeroesCount = 0, lastIndex = 0;

345 for (uint32_t index : llvm::seq<uint32_t>(0, size)) {

346 if (!array[index])

347 continue;

348 nonZeroesCount++;

349 lastIndex = index;

350 }

351

352

353 if (lastIndex > 256 || nonZeroesCount > size / 2) {

354

356 for (const T &elt : array)

358 return;

359 }

360

361

363 if (nonZeroesCount == 0)

364 return;

365

366 int indexBitSize = llvm::Log2_32_Ceil(lastIndex + 1);

368 for (uint32_t index : llvm::seq<uint32_t>(0, lastIndex + 1)) {

369 T value = array[index];

370 if (!value)

371 continue;

372 uint64_t indexValuePair = (value << indexBitSize) | (index);

374 }

375 }

376

377

378

379

380

381

382

384

385

386

387

389

390

391

392

393

395

396

397

398

400

401

403

404

406

407

408 virtual FailureOr<const DialectVersion *>

410

411 template

414 }

415 };

416

417

418

419

420

423 public:

425

426

427

428

429

430

431

432

434 reader.emitError() << "dialect " << getDialect()->getNamespace()

435 << " does not support reading attributes from bytecode";

437 }

438

439

440

441

443 reader.emitError() << "dialect " << getDialect()->getNamespace()

444 << " does not support reading types from bytecode";

445 return Type();

446 }

447

448

449

450

451

452

453

454

455

458 return failure();

459 }

460

461

462

463

464

467 return failure();

468 }

469

470

472

473

474

475 virtual std::unique_ptr

477 reader.emitError("Dialect does not support versioning");

478 return nullptr;

479 }

480

481

482

483

484

485 virtual LogicalResult

488 return success();

489 }

490 };

491

492

493 template <typename T, typename... Ts>

495 FailureOr &value, Ts &&...params) {

497 if (failed(handle))

498 return failure();

499 if (auto *result = dyn_cast(&*handle)) {

500 value = std::move(*result);

501 return success();

502 }

503 return failure();

504 }

505

506

507

508 template <typename T, typename... Ts>

510

511 if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) {

512 (void)context;

513 return T::get(std::forward(params)...);

516 return T::get(context, std::forward(params)...);

517 } else {

518

519 return T::Base::get(context, std::forward(params)...);

520 }

521 }

522

523 }

524

525 #endif

This class represents an opaque handle to a dialect resource entry.

Attributes are known-constant values of operations.

virtual Type readType(DialectBytecodeReader &reader) const

Read a type belonging to this dialect from the given reader.

virtual LogicalResult upgradeFromVersion(Operation *topLevelOp, const DialectVersion &version) const

Hook invoked after parsing completed, if a version directive was present and included an entry for th...

virtual Attribute readAttribute(DialectBytecodeReader &reader) const

Read an attribute belonging to this dialect from the given reader.

virtual std::unique_ptr< DialectVersion > readVersion(DialectBytecodeReader &reader) const

virtual LogicalResult writeAttribute(Attribute attr, DialectBytecodeWriter &writer) const

Write the given attribute, which belongs to this dialect, to the given writer.

virtual LogicalResult writeType(Type type, DialectBytecodeWriter &writer) const

Write the given type, which belongs to this dialect, to the given writer.

virtual void writeVersion(DialectBytecodeWriter &writer) const

Write the version of this dialect to the given writer.

This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...

virtual ~DialectBytecodeReader()=default

virtual LogicalResult readBlob(ArrayRef< char > &result)=0

Read a blob from the bytecode.

LogicalResult readAttributes(SmallVectorImpl< T > &attrs)

FailureOr< ResourceT > readResourceHandle()

Read a handle to a dialect resource.

virtual MLIRContext * getContext() const =0

Retrieve the context associated to the reader.

virtual FailureOr< APInt > readAPIntWithKnownWidth(unsigned bitWidth)=0

Read an APInt that is known to have been encoded with the given width.

LogicalResult readTypes(SmallVectorImpl< T > &types)

virtual LogicalResult readBool(bool &result)=0

Read a bool from the bytecode.

virtual LogicalResult readVarInt(uint64_t &result)=0

Read a variable width integer.

virtual LogicalResult readType(Type &result)=0

Read a reference to the given type.

virtual uint64_t getBytecodeVersion() const =0

Return the bytecode version being read.

LogicalResult readType(T &result)

LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag)

Parse a variable length encoded integer whose low bit is used to encode an unrelated flag,...

LogicalResult readSignedVarInts(SmallVectorImpl< int64_t > &result)

LogicalResult readOptionalAttribute(T &result)

FailureOr< const DialectVersion * > getDialectVersion() const

virtual LogicalResult readOptionalAttribute(Attribute &attr)=0

Read an optional reference to the given attribute.

LogicalResult readAttribute(T &result)

virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0

Emit an error to the reader.

LogicalResult readSparseArray(MutableArrayRef< T > array)

Read a "small" sparse array of integer <= 32 bits elements, where index/value pairs can be compressed...

virtual FailureOr< APFloat > readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics)=0

Read an APFloat that is known to have been encoded with the given semantics.

virtual FailureOr< const DialectVersion * > getDialectVersion(StringRef dialectName) const =0

Retrieve the dialect version by name if available.

virtual LogicalResult readString(StringRef &result)=0

Read a string from the bytecode.

virtual LogicalResult readSignedVarInt(int64_t &result)=0

Read a signed variable width integer.

LogicalResult readList(SmallVectorImpl< T > &result, CallbackFn &&callback)

Read out a list of elements, invoking the provided callback for each element.

virtual LogicalResult readAttribute(Attribute &result)=0

Read a reference to the given attribute.

This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...

virtual void writeOptionalAttribute(Attribute attr)=0

FailureOr< const DialectVersion * > getDialectVersion() const

virtual void writeVarInt(uint64_t value)=0

Write a variable width integer to the output stream.

void writeVarIntWithFlag(uint64_t value, bool flag)

Write a VarInt and a flag packed together.

void writeList(RangeT &&range, CallbackFn &&callback)

Write out a list of elements, invoking the provided callback for each element.

void writeSparseArray(ArrayRef< T > array)

Write out a "small" sparse array of integer <= 32 bits elements, where index/value pairs can be compr...

virtual void writeType(Type type)=0

Write a reference to the given type.

virtual FailureOr< const DialectVersion * > getDialectVersion(StringRef dialectName) const =0

Retrieve the dialect version by name if available.

virtual void writeAPIntWithKnownWidth(const APInt &value)=0

Write an APInt to the bytecode stream whose bitwidth will be known externally at read time.

virtual void writeOwnedBlob(ArrayRef< char > blob)=0

Write a blob to the bytecode, which is owned by the caller and is guaranteed to not die before the en...

virtual void writeAttribute(Attribute attr)=0

Write a reference to the given attribute.

virtual ~DialectBytecodeWriter()=default

void writeAttributes(ArrayRef< T > attrs)

virtual void writeSignedVarInt(int64_t value)=0

Write a signed variable width integer to the output stream.

virtual void writeResourceHandle(const AsmDialectResourceHandle &resource)=0

Write the given handle to a dialect resource.

virtual void writeAPFloatWithKnownSemantics(const APFloat &value)=0

Write an APFloat to the bytecode stream whose semantics will be known externally at read time.

void writeSignedVarInts(ArrayRef< int64_t > value)

virtual void writeOwnedBool(bool value)=0

Write a bool to the output stream.

virtual int64_t getBytecodeVersion() const =0

Return the bytecode version being emitted for.

virtual void writeOwnedString(StringRef str)=0

Write a string to the bytecode, which is owned by the caller and is guaranteed to not die before the ...

void writeTypes(ArrayRef< T > types)

This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.

virtual ~DialectVersion()=default

This class represents a diagnostic that is inflight and set to be reported.

MLIRContext is the top-level object for a collection of MLIR operations.

Operation is the basic unit of execution within MLIR.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

The base class used for all derived interface types.

decltype(T::get(std::declval< Ts >()...)) has_get_method

Include the generated interface declarations.

static LogicalResult readResourceHandle(DialectBytecodeReader &reader, FailureOr< T > &value, Ts &&...params)

Helper for resource handle reading that returns LogicalResult.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...