MLIR: lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10 #include "llvm/ADT/StringExtras.h"
11
12 using namespace mlir;
14
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
23 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
24 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
25 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
26 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
27
28
29
31
32 }
33
34 template <>
36 return profileComplianceMap;
37 }
38
39 template <>
42 return extensionComplianceMap;
43 }
44
45
46 LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
48 for (auto operand : operands)
49 addValue(operand);
50 addValue(output);
51 return success();
52 }
53
54 template <>
55 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
56 addValue(op.getInput1().front());
57 addValue(op.getOutput());
58 return success();
59 }
60
61 template <>
62 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
63 addValue(op.getInput());
64 addValue(op.getInputZp());
65 addValue(op.getOutputZp());
66 addType(op.getAccType());
67 addValue(op.getOutput());
68 return success();
69 }
70
71 template
72 LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
73 addValue(op.getInput());
74 addValue(op.getWeight());
75 addValue(op.getBias());
76 addValue(op.getInputZp());
77 addValue(op.getWeightZp());
78 addType(op.getAccType());
79 addValue(op.getOutput());
80 return success();
81 }
82
83 template <>
84 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
85 return populateProfileInfoConv(op);
86 }
87
88 template <>
89 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
90 return populateProfileInfoConv(op);
91 }
92
93 template <>
94 LogicalResult
95 ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
96 return populateProfileInfoConv(op);
97 }
98
99 template <>
100 LogicalResult
101 ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
102 return populateProfileInfoConv(op);
103 }
104
105 template <>
106 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
107 addValue(op.getInput1());
108 addValue(op.getPadConst());
109 addValue(op.getOutput());
110 return success();
111 }
112
113 template
114 LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
115 addValue(op.getInput1());
116 addValue(op.getOutput());
117 return success();
118 }
119
120 template <>
121 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
122 return populateProfileInfoDataLayout(op);
123 }
124
125 template <>
126 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
127 return populateProfileInfoDataLayout(op);
128 }
129
130 template <>
131 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
132 return populateProfileInfoDataLayout(op);
133 }
134
135 template <>
136 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
137 return populateProfileInfoDataLayout(op);
138 }
139
140 template <>
141 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
142 addValue(op.getValues());
143 addValue(op.getIndices());
144 addValue(op.getOutput());
145 return success();
146 }
147
148 template <>
149 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
150 addValue(op.getValuesIn());
151 addValue(op.getIndices());
152 addValue(op.getInput());
153 addValue(op.getValuesOut());
154 return success();
155 }
156
157 template <>
158 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
159 addValue(op.getInput1());
160 addValue(op.getInput2());
161 addValue(op.getOutput());
162 return success();
163 }
164
165 template <>
166 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
167 addValue(op.getInput());
168 addValue(op.getOutput());
169 return success();
170 }
171
172 template <>
173 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
174 addValue(op.getInputReal());
175 addValue(op.getInputImag());
176 addValue(op.getOutputReal());
177 addValue(op.getOutputImag());
178 return success();
179 }
180
181 template <>
182 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
183 addValue(op.getInputReal());
184 addValue(op.getOutputReal());
185 addValue(op.getOutputImag());
186 return success();
187 }
188
189 template <>
190 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
191 addValue(op.getInput2());
192 addValue(op.getInput3());
193 addValue(op.getOutput());
194 return success();
195 }
196
197 template <>
198 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
199 addValue(op.getInput());
200 addValue(op.getInputZp());
201 addValue(op.getOutputZp());
202 addValue(op.getOutput());
203 return success();
204 }
205
206 template <>
207 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
208 addValue(op.getA());
209 addValue(op.getB());
210 addValue(op.getAZp());
211 addValue(op.getBZp());
212 addValue(op.getOutput());
213 return success();
214 }
215
216 template <>
217 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
218 addType(op.getType());
219 return success();
220 }
221
222 template <>
223 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
224 addValue(op.getInput1());
225 return success();
226 }
227
228 template <>
229 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
230 addValue(op.getCondition());
231 return success();
232 }
233
234 template <>
235 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::WhileOp op) {
236 Block *block = &op.getCondGraph().front();
238 addValue(terminator->getOperands().front());
239 return success();
240 }
241
242 LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
243
244 #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
245 if (isatosa::tosaOp##Op(op)) { \
246 return populateProfileInfo(casttosa::tosaOp##Op(op)); \
247 }
248
249 #define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
250 if (isatosa::tosaOp##Op(op)) \
251 return success();
252
253
254 #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
255 if (isatosa::tosaOp##Op(op)) { \
256 return populateProfileInfo(op->getOperands(), op->getResult(0)); \
257 }
258
259
260
285
286
287
337
338
339
340
343
344 return failure();
345 }
346
347
348
349
350
351 template
352 FailureOr<SmallVector>
353 TosaProfileCompliance::getOperatorDefinition(Operation *op,
356 const auto complianceMap = getProfileComplianceMap();
357 const auto it = complianceMap.find(opName);
358 if (it == complianceMap.end())
359 return {};
360
361 return findMatchedProfile(op, it->second, condition);
362 }
363
364 template
367 const SmallVector<ArrayRef> &specRequiredModeSet) {
368
369
370 if (specRequiredModeSet.size() == 0)
371 return success();
372
374 const auto maybeOpRequiredMode = getOperatorDefinition(op, condition);
375 if (failed(maybeOpRequiredMode)) {
376
377
378
379
380 int mode_count = 0;
381 for (const auto &cands : specRequiredModeSet) {
383 return success();
384 mode_count += cands.size();
385 }
386
388 << (mode_count > 1 ? " any of " : " ") << "["
389 << llvm::join(stringifyProfile(specRequiredModeSet),
390 ", ")
391 << "] but not enabled in target\n";
392
393 return failure();
394 }
395
396
397
398 const auto opRequiredMode = maybeOpRequiredMode.value();
399 if (opRequiredMode.size() == 0) {
400
401 return success();
402 }
403
405 !targetEnv.allowsAllOf(opRequiredMode)) {
407 << (opRequiredMode.size() > 1 ? " all of " : " ") << "["
408 << llvm::join(stringifyProfile(opRequiredMode), ", ")
409 << "] but not enabled in target\n";
410 return failure();
411 }
412
414 !targetEnv.allowsAnyOf(opRequiredMode)) {
416 << (opRequiredMode.size() > 1 ? " any of " : " ") << "["
417 << llvm::join(stringifyProfile(opRequiredMode), ", ")
418 << "] but not enabled in target\n";
419 return failure();
420 }
421
422
423
424 if constexpr (std::is_same_v<T, Extension>) {
425 for (const auto &mode : opRequiredMode) {
426 SmallVector coProfs = getCooperativeProfiles(mode);
428 op->emitOpError() << "illegal: requires ["
429 << llvm::join(stringifyProfile(coProfs),
430 ", ")
431 << "] to work with but not enabled in target\n";
432 return failure();
433 }
434 }
435 }
436
437
438
439 for (const auto &cands : specRequiredModeSet) {
440 for (const auto &mode : opRequiredMode) {
441 if (!llvm::is_contained(cands, mode)) {
442 op->emitOpError() << "illegal: requires ["
443 << llvm::join(stringifyProfile(opRequiredMode),
444 ", ")
445 << "] but not included in the profile compliance ["
446 << llvm::join(
447 stringifyProfile(specRequiredModeSet), ", ")
448 << "]\n";
449 return failure();
450 }
451 }
452 }
453
454 return success();
455 }
456
457 LogicalResult
460 if (auto interface = dyn_casttosa::QueryProfileInterface(op))
461 return checkProfileOrExtension(op, targetEnv,
462 interface.getProfiles());
463
464 return success();
465 }
466
467 LogicalResult
470 if (auto interface = dyn_casttosa::QueryExtensionInterface(op))
471 return checkProfileOrExtension(op, targetEnv,
472 interface.getExtensions());
473
474 return success();
475 }
476
479 const auto maybeProfDef = getOperatorDefinition(op, condition);
480 const auto maybeExtDef = getOperatorDefinition(op, condition);
481
482 if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
483 !maybeProfDef.value().size() && !maybeExtDef.value().size()) {
484 std::string message;
485 llvm::raw_string_ostream os(message);
486 os << "illegal: operation operand/result data types did not align with any "
487 "profile or extension, got (";
488
490 SmallVector current = depot.getInfo();
491 for (const auto &typeInfo : llvm::drop_end(current))
492 os << stringifyTypeInfo(typeInfo) << ",";
493 os << stringifyTypeInfo(current.back()) << ")";
494
495
496
498 int maxMatches = -1;
499 SmallVector bestTypeInfo;
500 const auto searchBestMatch = [&](auto map) {
501 for (const auto &complianceInfos : map[opName]) {
502 for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
503 const int matches = llvm::count_if(
504 llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
505 return isSameTypeInfo(std::get<0>(zipType),
506 std::get<1>(zipType));
507 });
508 if (matches > maxMatches) {
509 maxMatches = matches;
510 bestTypeInfo = typeInfos;
511 }
512 }
513 }
514 };
515 searchBestMatch(getProfileComplianceMap());
516 searchBestMatch(getProfileComplianceMap());
517
518 os << ", did you mean (";
519 for (const auto &typeInfo : llvm::drop_end(bestTypeInfo))
520 os << stringifyTypeInfo(typeInfo) << ",";
521 os << stringifyTypeInfo(bestTypeInfo.back()) << ")? ";
522 os << "Otherwise, please refer to the 'supported data types' for '"
523 << opName << "' in the specification.";
525 return failure();
526 }
527
528 return success();
529 }
530
531
532
533 template
537 assert(compInfo.size() != 0 &&
538 "profile-based compliance information is empty");
539
540
542 SmallVector present = depot.getInfo();
543 if (present.size() == 0)
544 return {};
545
546 for (size_t i = 0; i < compInfo.size(); i++) {
547 SmallVector<SmallVector> sets = compInfo[i].operandTypeInfoSet;
548 for (SmallVector expected : sets) {
549 assert(present.size() == expected.size() &&
550 "the entries for profile-based compliance do not match between "
551 "the generated metadata and the type definition retrieved from "
552 " the operation");
553
554 bool is_found = true;
555
556
557 for (size_t j = 0; j < expected.size(); j++) {
558 if (!isSameTypeInfo(present[j], expected[j])) {
559
560 is_found = false;
561 break;
562 }
563 }
564
565 if (is_found == true) {
566 condition = compInfo[i].condition;
567 return compInfo[i].mode;
568 }
569 }
570 }
571
572 return {};
573 }
574
575
576 template
579 SmallVector debugStrings;
580 for (const auto &profile : profiles) {
581 if constexpr (std::is_same_v<T, Profile>)
582 debugStrings.push_back(tosa::stringifyProfile(profile));
583 else
584 debugStrings.push_back(tosa::stringifyExtension(profile));
585 }
586 return debugStrings;
587 }
588
589 template
591 const SmallVector<ArrayRef> &profileSet) {
592 SmallVector debugStrings;
593
594 for (const auto &profiles : profileSet) {
595 auto tempStrings = stringifyProfile(profiles);
596 llvm::append_range(debugStrings, tempStrings);
597 }
598
599 return debugStrings;
600 }
601
604 if (typeInfo.typeID == mlir::IntegerType::getTypeID()) {
605 return {"i" + llvm::utostr(typeInfo.bitWidth)};
606 } else if (typeInfo.typeID == mlir::Float16Type::getTypeID()) {
607 return {"f16"};
608 } else if (typeInfo.typeID == mlir::Float32Type::getTypeID()) {
609 return {"f32"};
610 } else if (typeInfo.typeID == mlir::BFloat16Type::getTypeID()) {
611 return {"bf16"};
612 } else if (typeInfo.typeID == mlir::Float8E4M3FNType::getTypeID()) {
613 return {"fp8e4m3"};
614 } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
615 return {"fp8e5m2"};
616 }
617 llvm_unreachable("unknown type");
618 }
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > >> OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > >> OperationExtensionComplianceMap
SmallVector< TypeInfo > getInfo()
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
SmallVector< T > findMatchedProfile(Operation *op, SmallVector< OpComplianceInfo< T >> compInfo, CheckCondition &condition)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T >> &specDefinedProfileSet)
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This class represents the capability enabled in the target implementation such as profile,...
bool allowsAllOf(ArrayRef< Profile > profs) const
bool allowsAnyOf(ArrayRef< Profile > profs) const
NestedPattern If(const NestedPattern &child)
Include the generated interface declarations.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.