MLIR: lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10 #include "llvm/ADT/StringExtras.h"

11

12 using namespace mlir;

14

16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};

17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};

18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};

19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};

20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};

21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};

22 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};

23 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};

24 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};

25 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};

26 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};

27

28

29

31

32 }

33

34 template <>

36 return profileComplianceMap;

37 }

38

39 template <>

42 return extensionComplianceMap;

43 }

44

45

46 LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,

48 for (auto operand : operands)

49 addValue(operand);

50 addValue(output);

51 return success();

52 }

53

54 template <>

55 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {

56 addValue(op.getInput1().front());

57 addValue(op.getOutput());

58 return success();

59 }

60

61 template <>

62 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {

63 addValue(op.getInput());

64 addValue(op.getInputZp());

65 addValue(op.getOutputZp());

66 addType(op.getAccType());

67 addValue(op.getOutput());

68 return success();

69 }

70

71 template

72 LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {

73 addValue(op.getInput());

74 addValue(op.getWeight());

75 addValue(op.getBias());

76 addValue(op.getInputZp());

77 addValue(op.getWeightZp());

78 addType(op.getAccType());

79 addValue(op.getOutput());

80 return success();

81 }

82

83 template <>

84 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {

85 return populateProfileInfoConv(op);

86 }

87

88 template <>

89 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {

90 return populateProfileInfoConv(op);

91 }

92

93 template <>

94 LogicalResult

95 ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {

96 return populateProfileInfoConv(op);

97 }

98

99 template <>

100 LogicalResult

101 ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {

102 return populateProfileInfoConv(op);

103 }

104

105 template <>

106 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {

107 addValue(op.getInput1());

108 addValue(op.getPadConst());

109 addValue(op.getOutput());

110 return success();

111 }

112

113 template

114 LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {

115 addValue(op.getInput1());

116 addValue(op.getOutput());

117 return success();

118 }

119

120 template <>

121 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {

122 return populateProfileInfoDataLayout(op);

123 }

124

125 template <>

126 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {

127 return populateProfileInfoDataLayout(op);

128 }

129

130 template <>

131 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {

132 return populateProfileInfoDataLayout(op);

133 }

134

135 template <>

136 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {

137 return populateProfileInfoDataLayout(op);

138 }

139

140 template <>

141 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {

142 addValue(op.getValues());

143 addValue(op.getIndices());

144 addValue(op.getOutput());

145 return success();

146 }

147

148 template <>

149 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {

150 addValue(op.getValuesIn());

151 addValue(op.getIndices());

152 addValue(op.getInput());

153 addValue(op.getValuesOut());

154 return success();

155 }

156

157 template <>

158 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {

159 addValue(op.getInput1());

160 addValue(op.getInput2());

161 addValue(op.getOutput());

162 return success();

163 }

164

165 template <>

166 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {

167 addValue(op.getInput());

168 addValue(op.getOutput());

169 return success();

170 }

171

172 template <>

173 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {

174 addValue(op.getInputReal());

175 addValue(op.getInputImag());

176 addValue(op.getOutputReal());

177 addValue(op.getOutputImag());

178 return success();

179 }

180

181 template <>

182 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {

183 addValue(op.getInputReal());

184 addValue(op.getOutputReal());

185 addValue(op.getOutputImag());

186 return success();

187 }

188

189 template <>

190 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {

191 addValue(op.getInput2());

192 addValue(op.getInput3());

193 addValue(op.getOutput());

194 return success();

195 }

196

197 template <>

198 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {

199 addValue(op.getInput());

200 addValue(op.getInputZp());

201 addValue(op.getOutputZp());

202 addValue(op.getOutput());

203 return success();

204 }

205

206 template <>

207 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {

208 addValue(op.getA());

209 addValue(op.getB());

210 addValue(op.getAZp());

211 addValue(op.getBZp());

212 addValue(op.getOutput());

213 return success();

214 }

215

216 template <>

217 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {

218 addType(op.getType());

219 return success();

220 }

221

222 template <>

223 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {

224 addValue(op.getInput1());

225 return success();

226 }

227

228 template <>

229 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {

230 addValue(op.getCondition());

231 return success();

232 }

233

234 template <>

235 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::WhileOp op) {

236 Block *block = &op.getCondGraph().front();

238 addValue(terminator->getOperands().front());

239 return success();

240 }

241

242 LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {

243

244 #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \

245 if (isatosa::tosaOp##Op(op)) { \

246 return populateProfileInfo(casttosa::tosaOp##Op(op)); \

247 }

248

249 #define POPULATE_PROFILE_INFO_SKIP(tosaOp) \

250 if (isatosa::tosaOp##Op(op)) \

251 return success();

252

253

254 #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \

255 if (isatosa::tosaOp##Op(op)) { \

256 return populateProfileInfo(op->getOperands(), op->getResult(0)); \

257 }

258

259

260

285

286

287

337

338

339

340

343

344 return failure();

345 }

346

347

348

349

350

351 template

352 FailureOr<SmallVector>

353 TosaProfileCompliance::getOperatorDefinition(Operation *op,

356 const auto complianceMap = getProfileComplianceMap();

357 const auto it = complianceMap.find(opName);

358 if (it == complianceMap.end())

359 return {};

360

361 return findMatchedProfile(op, it->second, condition);

362 }

363

364 template

367 const SmallVector<ArrayRef> &specRequiredModeSet) {

368

369

370 if (specRequiredModeSet.size() == 0)

371 return success();

372

374 const auto maybeOpRequiredMode = getOperatorDefinition(op, condition);

375 if (failed(maybeOpRequiredMode)) {

376

377

378

379

380 int mode_count = 0;

381 for (const auto &cands : specRequiredModeSet) {

383 return success();

384 mode_count += cands.size();

385 }

386

388 << (mode_count > 1 ? " any of " : " ") << "["

389 << llvm::join(stringifyProfile(specRequiredModeSet),

390 ", ")

391 << "] but not enabled in target\n";

392

393 return failure();

394 }

395

396

397

398 const auto opRequiredMode = maybeOpRequiredMode.value();

399 if (opRequiredMode.size() == 0) {

400

401 return success();

402 }

403

405 !targetEnv.allowsAllOf(opRequiredMode)) {

407 << (opRequiredMode.size() > 1 ? " all of " : " ") << "["

408 << llvm::join(stringifyProfile(opRequiredMode), ", ")

409 << "] but not enabled in target\n";

410 return failure();

411 }

412

414 !targetEnv.allowsAnyOf(opRequiredMode)) {

416 << (opRequiredMode.size() > 1 ? " any of " : " ") << "["

417 << llvm::join(stringifyProfile(opRequiredMode), ", ")

418 << "] but not enabled in target\n";

419 return failure();

420 }

421

422

423

424 if constexpr (std::is_same_v<T, Extension>) {

425 for (const auto &mode : opRequiredMode) {

426 SmallVector coProfs = getCooperativeProfiles(mode);

428 op->emitOpError() << "illegal: requires ["

429 << llvm::join(stringifyProfile(coProfs),

430 ", ")

431 << "] to work with but not enabled in target\n";

432 return failure();

433 }

434 }

435 }

436

437

438

439 for (const auto &cands : specRequiredModeSet) {

440 for (const auto &mode : opRequiredMode) {

441 if (!llvm::is_contained(cands, mode)) {

442 op->emitOpError() << "illegal: requires ["

443 << llvm::join(stringifyProfile(opRequiredMode),

444 ", ")

445 << "] but not included in the profile compliance ["

446 << llvm::join(

447 stringifyProfile(specRequiredModeSet), ", ")

448 << "]\n";

449 return failure();

450 }

451 }

452 }

453

454 return success();

455 }

456

457 LogicalResult

460 if (auto interface = dyn_casttosa::QueryProfileInterface(op))

461 return checkProfileOrExtension(op, targetEnv,

462 interface.getProfiles());

463

464 return success();

465 }

466

467 LogicalResult

470 if (auto interface = dyn_casttosa::QueryExtensionInterface(op))

471 return checkProfileOrExtension(op, targetEnv,

472 interface.getExtensions());

473

474 return success();

475 }

476

479 const auto maybeProfDef = getOperatorDefinition(op, condition);

480 const auto maybeExtDef = getOperatorDefinition(op, condition);

481

482 if (!failed(maybeProfDef) && !failed(maybeExtDef) &&

483 !maybeProfDef.value().size() && !maybeExtDef.value().size()) {

484 std::string message;

485 llvm::raw_string_ostream os(message);

486 os << "illegal: operation operand/result data types did not align with any "

487 "profile or extension, got (";

488

490 SmallVector current = depot.getInfo();

491 for (const auto &typeInfo : llvm::drop_end(current))

492 os << stringifyTypeInfo(typeInfo) << ",";

493 os << stringifyTypeInfo(current.back()) << ")";

494

495

496

498 int maxMatches = -1;

499 SmallVector bestTypeInfo;

500 const auto searchBestMatch = [&](auto map) {

501 for (const auto &complianceInfos : map[opName]) {

502 for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {

503 const int matches = llvm::count_if(

504 llvm::zip_equal(current, typeInfos), [&](const auto zipType) {

505 return isSameTypeInfo(std::get<0>(zipType),

506 std::get<1>(zipType));

507 });

508 if (matches > maxMatches) {

509 maxMatches = matches;

510 bestTypeInfo = typeInfos;

511 }

512 }

513 }

514 };

515 searchBestMatch(getProfileComplianceMap());

516 searchBestMatch(getProfileComplianceMap());

517

518 os << ", did you mean (";

519 for (const auto &typeInfo : llvm::drop_end(bestTypeInfo))

520 os << stringifyTypeInfo(typeInfo) << ",";

521 os << stringifyTypeInfo(bestTypeInfo.back()) << ")? ";

522 os << "Otherwise, please refer to the 'supported data types' for '"

523 << opName << "' in the specification.";

525 return failure();

526 }

527

528 return success();

529 }

530

531

532

533 template

537 assert(compInfo.size() != 0 &&

538 "profile-based compliance information is empty");

539

540

542 SmallVector present = depot.getInfo();

543 if (present.size() == 0)

544 return {};

545

546 for (size_t i = 0; i < compInfo.size(); i++) {

547 SmallVector<SmallVector> sets = compInfo[i].operandTypeInfoSet;

548 for (SmallVector expected : sets) {

549 assert(present.size() == expected.size() &&

550 "the entries for profile-based compliance do not match between "

551 "the generated metadata and the type definition retrieved from "

552 " the operation");

553

554 bool is_found = true;

555

556

557 for (size_t j = 0; j < expected.size(); j++) {

558 if (!isSameTypeInfo(present[j], expected[j])) {

559

560 is_found = false;

561 break;

562 }

563 }

564

565 if (is_found == true) {

566 condition = compInfo[i].condition;

567 return compInfo[i].mode;

568 }

569 }

570 }

571

572 return {};

573 }

574

575

576 template

579 SmallVector debugStrings;

580 for (const auto &profile : profiles) {

581 if constexpr (std::is_same_v<T, Profile>)

582 debugStrings.push_back(tosa::stringifyProfile(profile));

583 else

584 debugStrings.push_back(tosa::stringifyExtension(profile));

585 }

586 return debugStrings;

587 }

588

589 template

591 const SmallVector<ArrayRef> &profileSet) {

592 SmallVector debugStrings;

593

594 for (const auto &profiles : profileSet) {

595 auto tempStrings = stringifyProfile(profiles);

596 llvm::append_range(debugStrings, tempStrings);

597 }

598

599 return debugStrings;

600 }

601

604 if (typeInfo.typeID == mlir::IntegerType::getTypeID()) {

605 return {"i" + llvm::utostr(typeInfo.bitWidth)};

606 } else if (typeInfo.typeID == mlir::Float16Type::getTypeID()) {

607 return {"f16"};

608 } else if (typeInfo.typeID == mlir::Float32Type::getTypeID()) {

609 return {"f32"};

610 } else if (typeInfo.typeID == mlir::BFloat16Type::getTypeID()) {

611 return {"bf16"};

612 } else if (typeInfo.typeID == mlir::Float8E4M3FNType::getTypeID()) {

613 return {"fp8e4m3"};

614 } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {

615 return {"fp8e5m2"};

616 }

617 llvm_unreachable("unknown type");

618 }

#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)

#define POPULATE_PROFILE_INFO_COMMON(tosaOp)

#define POPULATE_PROFILE_INFO_SKIP(tosaOp)

std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > >> OperationProfileComplianceMap

std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > >> OperationExtensionComplianceMap

SmallVector< TypeInfo > getInfo()

std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()

LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)

SmallVector< T > findMatchedProfile(Operation *op, SmallVector< OpComplianceInfo< T >> compInfo, CheckCondition &condition)

LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)

LogicalResult checkInvalid(Operation *op)

SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)

static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)

LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T >> &specDefinedProfileSet)

Block represents an ordered list of Operations.

Operation * getTerminator()

Get the terminator operation of this block.

StringRef getStringRef() const

Return the name of this operation. This always succeeds.

Operation is the basic unit of execution within MLIR.

OperationName getName()

The name of an operation is the key identifier for it.

operand_range getOperands()

Returns an iterator on the underlying Value's.

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

This class represents the capability enabled in the target implementation such as profile,...

bool allowsAllOf(ArrayRef< Profile > profs) const

bool allowsAnyOf(ArrayRef< Profile > profs) const

NestedPattern If(const NestedPattern &child)

Include the generated interface declarations.

Eliminates variable at the specified position using Fourier-Motzkin variable elimination.