MLIR: include/mlir/Dialect/Quant/IR/QuantTypes.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9 #ifndef MLIR_DIALECT_QUANT_IR_QUANTTYPES_H

10 #define MLIR_DIALECT_QUANT_IR_QUANTTYPES_H

11

18 #include "llvm/Support/MathExtras.h"

19

20 namespace mlir {

21 namespace quant {

22 namespace detail {

23

30

31 }

32

33

34 namespace QuantizationFlags {

36

37

39 };

40 }

41

42

43

44

45

46

47

48

49

51 public:

54

55

57

58 static LogicalResult

60 Type storageType, Type expressedType, int64_t storageTypeMin,

61 int64_t storageTypeMax);

62

63

65

66

67

69 unsigned integralWidth) {

71 return llvm::minIntN(integralWidth);

72 }

73 return 0;

74 }

75

76

77

79 unsigned integralWidth) {

81 return llvm::maxIntN(integralWidth);

82 }

83 return llvm::maxUIntN(integralWidth);

84 }

85

86

87

88

89

90

91

92

93

95

96

97

99

100

101

102

106 }

107

108

109

111

112

114

115

117

118

119

121

122

123

125

126

127

128

129

130

131

132

134

135

136

137

138

139

140

142

143

144

145

146

147

148

150

151

152

153

155

156

157

158

159

160

161

163

164

165

166

168

169

170

171

172

173

174

176

177 private:

178

179

180

181

188 };

189

190

191

192

193

194

195

196

197

198

200 : public Type::TypeBase<AnyQuantizedType, QuantizedType,

201 detail::AnyQuantizedTypeStorage> {

202 public:

204 using Base::getChecked;

205

206 static constexpr StringLiteral name = "quant.any";

207

208

209

211 Type expressedType, int64_t storageTypeMin,

212 int64_t storageTypeMax);

213

214

215

218 Type storageType, Type expressedType, int64_t storageTypeMin,

219 int64_t storageTypeMax);

220

221

222 static LogicalResult

224 Type storageType, Type expressedType, int64_t storageTypeMin,

225 int64_t storageTypeMax);

226 };

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

261 : public Type::TypeBase<UniformQuantizedType, QuantizedType,

262 detail::UniformQuantizedTypeStorage> {

263 public:

265 using Base::getChecked;

266

267 static constexpr StringLiteral name = "quant.uniform";

268

269

270

272 Type expressedType, double scale,

273 int64_t zeroPoint, int64_t storageTypeMin,

274 int64_t storageTypeMax);

275

276

277

280 Type storageType, Type expressedType, double scale,

281 int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);

282

283

284 static LogicalResult

286 Type storageType, Type expressedType, double scale,

287 int64_t zeroPoint, int64_t storageTypeMin,

288 int64_t storageTypeMax);

289

290

291

293

294

295

297

298

299

300

301

302

304 };

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

321 : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,

322 detail::UniformQuantizedPerAxisTypeStorage> {

323 public:

325 using Base::getChecked;

326

327 static constexpr StringLiteral name = "quant.uniform_per_axis";

328

329

330

332 get(unsigned flags, Type storageType, Type expressedType,

334 int32_t quantizedDimension, int64_t storageTypeMin,

335 int64_t storageTypeMax);

336

337

338

343 int64_t storageTypeMin, int64_t storageTypeMax);

344

345

346 static LogicalResult

348 Type storageType, Type expressedType,

350 int32_t quantizedDimension, int64_t storageTypeMin,

351 int64_t storageTypeMax);

352

353

354

355

356

358

359

360

361

363

364

365

366

367

368

369

370

371

373

374

375

376

377

378

380 if (!isSigned())

381 return false;

383 }

384 };

385

386

387

388

389

390

391

392

393

394

395

396

397

398

399

400

401

402

403

404

406 : public Type::TypeBase<UniformQuantizedSubChannelType, QuantizedType,

407 detail::UniformQuantizedSubChannelTypeStorage> {

408 public:

410 using Base::getChecked;

411

412 static constexpr StringLiteral name = "quant.uniform_sub_channel";

413

414

415

417 get(unsigned flags, Type storageType, Type expressedType,

420 int64_t storageTypeMin, int64_t storageTypeMax);

421

422

423

430 int64_t storageTypeMax);

431

432

433 static LogicalResult

435 Type storageType, Type expressedType,

439 int64_t storageTypeMax);

440

441

442

443

444

445

446

447

448

449

450

451

452

454

455

456

457

458

459

460

461

462

463

464

465

466

467

469

470

471

472

473

474

475

476

477

478

479

480

481

482

483

484

485

486

487

488

489

490

491

492

494

495

496

497

498

499

501

502

503

504

505

506

507

508

509

510

511

512

514 };

515

516

517

518

519

521 : public Type::TypeBase<CalibratedQuantizedType, QuantizedType,

522 detail::CalibratedQuantizedTypeStorage> {

523 public:

525 using Base::getChecked;

526

527 static constexpr StringLiteral name = "quant.calibrated";

528

529

530

532 double max);

533

534

535

538 double min, double max);

539

540

541 static LogicalResult

543 Type expressedType, double min, double max);

544 double getMin() const;

545 double getMax() const;

546 };

547

548 }

549 }

550

551 #endif

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)

An attribute that represents a reference to a dense vector or tensor object.

This class represents a diagnostic that is inflight and set to be reported.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isInteger() const

Return true if this is an integer type (with the specified width).

Utility class for implementing users of storage classes uniqued by a StorageUniquer.

A quantized type that maps storage to/from expressed types in an unspecified way.

static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)

Gets an instance of the type with all parameters specified but not checked.

static constexpr StringLiteral name

static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)

Verifies construction invariants and issues errors/warnings.

static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)

Gets an instance of the type with all specified parameters checked.

A quantized type that infers its range from given min/max values.

static constexpr StringLiteral name

static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)

Verifies construction invariants and issues errors/warnings.

static CalibratedQuantizedType get(Type expressedType, double min, double max)

Gets an instance of the type with all parameters specified but not checked.

static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)

Gets an instance of the type with all specified parameters checked.

Base class for all quantized types known to this dialect.

Type getExpressedType() const

Gets the original expressed type that this quantized type approximates.

static constexpr unsigned MaxStorageBits

The maximum number of bits supported for storage types.

bool hasStorageTypeBounds() const

Return whether the storage type has explicit min or max boundaries different from the minimum and max...

static Type castToStorageType(Type quantizedType)

Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...

Type castExpressedToStorageType(Type candidateType)

Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...

static Type castToExpressedType(Type quantizedType)

Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...

bool isSigned() const

Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...

static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)

Returns the element type as a QuantizedType or nullptr if it is not a quantized type.

unsigned getFlags() const

Gets the flags associated with this type.

int64_t getStorageTypeMax() const

The maximum value that storageType can take.

static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)

Gets the maximum possible stored by a storageType.

unsigned getStorageTypeIntegralWidth() const

Gets the integral bit width that the underlying storage type can exactly represent.

static bool classof(Type type)

Support method to enable LLVM-style type casting.

Type castFromStorageType(Type candidateType)

Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...

int64_t getStorageTypeMin() const

The minimum value that storageType can take.

static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)

Gets the minimum possible stored by a storageType.

Type getStorageType() const

Gets the underlying type used for to store values.

Type castFromExpressedType(Type candidateType)

Casts from a type based on the expressedType to a corresponding type based on this type (returns null...

bool isCompatibleExpressedType(Type candidateExpressedType)

Returns whether the candidateExpressedType is a match for this QuantizedType.

static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)

Represents per-axis (also known as per-channel quantization).

static constexpr StringLiteral name

static UniformQuantizedPerAxisType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)

Gets an instance of the type with all specified parameters checked.

bool isFixedPoint() const

Fixed point values are real numbers divided by a scale.

static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)

Gets an instance of the type with all parameters specified but not checked.

int32_t getQuantizedDimension() const

Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.

ArrayRef< int64_t > getZeroPoints() const

Gets the storage values corresponding to the real value 0 in the affine equation.

ArrayRef< double > getScales() const

Gets the quantization scales.

static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)

Verifies construction invariants and issues errors/warnings.

Represents sub-channel (also known as blockwise quantization).

static constexpr StringLiteral name

ArrayRef< int32_t > getQuantizedDimensions() const

Gets the quantized dimensions.

DenseElementsAttr getZeroPoints() const

Gets the quantization zero-points.

ArrayRef< int64_t > getBlockSizes() const

Gets the block sizes for the quantized dimensions.

static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)

Verifies construction invariants and issues errors/warnings.

const SmallVector< std::pair< int32_t, int64_t > > getBlockSizeInfo() const

Gets the block size information.

static UniformQuantizedSubChannelType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)

Gets an instance of the type with all specified parameters checked.

static UniformQuantizedSubChannelType get(unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)

Gets an instance of the type with all parameters specified but not checked.

DenseElementsAttr getScales() const

Gets the quantization scales.

Represents a family of uniform, quantized types.

double getScale() const

Gets the scale term.

bool isFixedPoint() const

int64_t getZeroPoint() const

Gets the storage value corresponding to the real value 0 in the affine equation.

static constexpr StringLiteral name

static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)

Verifies construction invariants and issues errors/warnings.

static UniformQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)

Gets an instance of the type with all specified parameters checked.

static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)

Gets an instance of the type with all parameters specified but not checked.

Include the generated interface declarations.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.