MLIR: lib/Dialect/OpenMP/IR/OpenMPDialect.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

24

25 #include "llvm/ADT/ArrayRef.h"

26 #include "llvm/ADT/BitVector.h"

27 #include "llvm/ADT/STLExtras.h"

28 #include "llvm/ADT/STLForwardCompat.h"

29 #include "llvm/ADT/SmallString.h"

30 #include "llvm/ADT/StringExtras.h"

31 #include "llvm/ADT/StringRef.h"

32 #include "llvm/ADT/TypeSwitch.h"

33 #include "llvm/Frontend/OpenMP/OMPConstants.h"

34 #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"

35 #include

36 #include

37 #include

38 #include

39

40 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"

41 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"

42 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"

43 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"

44

45 using namespace mlir;

47

50 return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);

51 }

52

56 }

57

58 namespace {

59 struct MemRefPointerLikeModel

60 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,

61 MemRefType> {

63 return llvm::cast(pointer).getElementType();

64 }

65 };

66

67 struct LLVMPointerPointerLikeModel

68 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,

69 LLVM::LLVMPointerType> {

71 };

72 }

73

74 void OpenMPDialect::initialize() {

75 addOperations<

76 #define GET_OP_LIST

77 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"

78 >();

79 addAttributes<

80 #define GET_ATTRDEF_LIST

81 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"

82 >();

83 addTypes<

84 #define GET_TYPEDEF_LIST

85 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"

86 >();

87

88 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();

89

90 MemRefType::attachInterface(*getContext());

91 LLVM::LLVMPointerType::attachInterface(

93

94

95

96 mlir::ModuleOp::attachInterfacemlir::omp::OffloadModuleDefaultModel(

98

99

100

101

102 mlir::LLVM::GlobalOp::attachInterface<

105 mlir::LLVM::LLVMFuncOp::attachInterface<

108 mlir::func::FuncOp::attachInterface<

110 }

111

112

113

114

115

116

117

118

119

120

121

128

133 return failure();

134 allocatorVars.push_back(operand);

135 allocatorTypes.push_back(type);

137 return failure();

139 return failure();

140

141 allocateVars.push_back(operand);

142 allocateTypes.push_back(type);

143 return success();

144 });

145 }

146

147

153 for (unsigned i = 0; i < allocateVars.size(); ++i) {

154 std::string separator = i == allocateVars.size() - 1 ? "" : ", ";

155 p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";

156 p << allocateVars[i] << " : " << allocateTypes[i] << separator;

157 }

158 }

159

160

161

162

163

164 template

166 using ClauseT = decltype(std::declval().getValue());

167 StringRef enumStr;

170 return failure();

171 if (std::optional enumValue = symbolizeEnum(enumStr)) {

173 return success();

174 }

175 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";

176 }

177

178 template

180 p << stringifyEnum(attr.getValue());

181 }

182

183

184

185

186

187

188

189

201 return failure();

202

203 linearVars.push_back(var);

204 linearTypes.push_back(type);

205 linearStepVars.push_back(stepVar);

206 return success();

207 });

208 }

209

210

214 size_t linearVarsSize = linearVars.size();

215 for (unsigned i = 0; i < linearVarsSize; ++i) {

216 std::string separator = i == linearVarsSize - 1 ? "" : ", ";

217 p << linearVars[i];

218 if (linearStepVars.size() > i)

219 p << " = " << linearStepVars[i];

220 p << " : " << linearVars[i].getType() << separator;

221 }

222 }

223

224

225

226

227

230

231

233 for (const auto &it : nontemporalVars)

234 if (!nontemporalItems.insert(it).second)

235 return op->emitOpError() << "nontemporal variable used more than once";

236

237 return success();

238 }

239

240

241

242

244 std::optional alignments,

246

247 if (!alignedVars.empty()) {

248 if (!alignments || alignments->size() != alignedVars.size())

250 << "expected as many alignment values as aligned variables";

251 } else {

252 if (alignments)

253 return op->emitOpError() << "unexpected alignment values attribute";

254 return success();

255 }

256

257

259 for (auto it : alignedVars)

260 if (!alignedItems.insert(it).second)

261 return op->emitOpError() << "aligned variable used more than once";

262

263 if (!alignments)

264 return success();

265

266

267 for (unsigned i = 0; i < (*alignments).size(); ++i) {

268 if (auto intAttr = llvm::dyn_cast((*alignments)[i])) {

269 if (intAttr.getValue().sle(0))

270 return op->emitOpError() << "alignment should be greater than 0";

271 } else {

272 return op->emitOpError() << "expected integer alignment";

273 }

274 }

275

276 return success();

277 }

278

279

280

281

282 static ParseResult

286 ArrayAttr &alignmentsAttr) {

289 if (parser.parseOperand(alignedVars.emplace_back()) ||

290 parser.parseColonType(alignedTypes.emplace_back()) ||

291 parser.parseArrow() ||

292 parser.parseAttribute(alignmentVec.emplace_back())) {

293 return failure();

294 }

295 return success();

296 })))

297 return failure();

299 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);

300 return success();

301 }

302

303

306 std::optional alignments) {

307 for (unsigned i = 0; i < alignedVars.size(); ++i) {

308 if (i != 0)

309 p << ", ";

310 p << alignedVars[i] << " : " << alignedVars[i].getType();

311 p << " -> " << (*alignments)[i];

312 }

313 }

314

315

316

317

318

319 static ParseResult

322 if (modifiers.size() > 2)

323 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";

324 for (const auto &mod : modifiers) {

325

326

327 auto symbol = symbolizeScheduleModifier(mod);

328 if (!symbol)

330 << " unknown modifier type: " << mod;

331 }

332

333

334

335 if (modifiers.size() == 1) {

336 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {

337 modifiers.push_back(modifiers[0]);

338 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);

339 }

340 } else if (modifiers.size() == 2) {

341

342

343 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||

344 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)

346 << " incorrect modifier order";

347 }

348 return success();

349 }

350

351

352

353

354

355

356

357

358

359

360 static ParseResult

362 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,

363 std::optionalOpAsmParser::UnresolvedOperand &chunkSize,

364 Type &chunkType) {

365 StringRef keyword;

367 return failure();

368 std::optionalmlir::omp::ClauseScheduleKind schedule =

369 symbolizeClauseScheduleKind(keyword);

370 if (!schedule)

371 return parser.emitError(parser.getNameLoc()) << " expected schedule kind";

372

374 switch (*schedule) {

375 case ClauseScheduleKind::Static:

376 case ClauseScheduleKind::Dynamic:

377 case ClauseScheduleKind::Guided:

381 return failure();

382 } else {

383 chunkSize = std::nullopt;

384 }

385 break;

386 case ClauseScheduleKind::Auto:

388 chunkSize = std::nullopt;

389 }

390

391

394 StringRef mod;

396 return failure();

397 modifiers.push_back(mod);

398 }

399

401 return failure();

402

403 if (!modifiers.empty()) {

405 if (std::optional mod =

406 symbolizeScheduleModifier(modifiers[0])) {

408 } else {

409 return parser.emitError(loc, "invalid schedule modifier");

410 }

411

412 if (modifiers.size() > 1) {

413 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);

415 }

416 }

417

418 return success();

419 }

420

421

423 ClauseScheduleKindAttr scheduleKind,

424 ScheduleModifierAttr scheduleMod,

425 UnitAttr scheduleSimd, Value scheduleChunk,

426 Type scheduleChunkType) {

427 p << stringifyClauseScheduleKind(scheduleKind.getValue());

428 if (scheduleChunk)

429 p << " = " << scheduleChunk << " : " << scheduleChunk.getType();

430 if (scheduleMod)

431 p << ", " << stringifyScheduleModifier(scheduleMod.getValue());

432 if (scheduleSimd)

433 p << ", simd";

434 }

435

436

437

438

439

440

441

443 ClauseOrderKindAttr &order,

444 OrderModifierAttr &orderMod) {

445 StringRef enumStr;

448 return failure();

449 if (std::optional enumValue =

450 symbolizeOrderModifier(enumStr)) {

453 return failure();

456 return failure();

457 }

458 if (std::optional enumValue =

459 symbolizeClauseOrderKind(enumStr)) {

461 return success();

462 }

463 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";

464 }

465

467 ClauseOrderKindAttr order,

468 OrderModifierAttr orderMod) {

469 if (orderMod)

470 p << stringifyOrderModifier(orderMod.getValue()) << ":";

471 if (order)

472 p << stringifyClauseOrderKind(order.getValue());

473 }

474

475 template <typename ClauseTypeAttr, typename ClauseType>

476 static ParseResult

478 std::optionalOpAsmParser::UnresolvedOperand &operand,

479 Type &operandType,

480 std::optional (*symbolizeClause)(StringRef),

481 StringRef clauseName) {

482 StringRef enumStr;

484 if (std::optional enumValue = symbolizeClause(enumStr)) {

487 return failure();

488 } else {

490 << "invalid " << clauseName << " modifier : '" << enumStr << "'";

491 ;

492 }

493 }

494

497 operand = var;

498 } else {

500 << "expected " << clauseName << " operand";

501 }

502

503 if (operand.has_value()) {

505 return failure();

506 }

507

508 return success();

509 }

510

511 template <typename ClauseTypeAttr, typename ClauseType>

512 static void

514 ClauseTypeAttr prescriptiveness, Value operand,

516 StringRef (*stringifyClauseType)(ClauseType)) {

517

518 if (prescriptiveness)

519 p << stringifyClauseType(prescriptiveness.getValue()) << ", ";

520

521 if (operand)

522 p << operand << ": " << operandType;

523 }

524

525

526

527

528

529

530 static ParseResult

532 std::optionalOpAsmParser::UnresolvedOperand &grainsize,

533 Type &grainsizeType) {

534 return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(

535 parser, grainsizeMod, grainsize, grainsizeType,

536 &symbolizeClauseGrainsizeType, "grainsize");

537 }

538

540 ClauseGrainsizeTypeAttr grainsizeMod,

542 printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(

543 p, op, grainsizeMod, grainsize, grainsizeType,

544 &stringifyClauseGrainsizeType);

545 }

546

547

548

549

550

551

552 static ParseResult

554 std::optionalOpAsmParser::UnresolvedOperand &numTasks,

555 Type &numTasksType) {

556 return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(

557 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,

558 "num_tasks");

559 }

560

562 ClauseNumTasksTypeAttr numTasksMod,

564 printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(

565 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);

566 }

567

568

569

570

571

572 namespace {

573 struct MapParseArgs {

578 : vars(vars), types(types) {}

579 };

580 struct PrivateParseArgs {

583 ArrayAttr &syms;

584 UnitAttr &needsBarrier;

588 UnitAttr &needsBarrier,

590 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),

591 mapIndices(mapIndices) {}

592 };

593

594 struct ReductionParseArgs {

598 ArrayAttr &syms;

599 ReductionModifierAttr *modifier;

602 ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)

603 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}

604 };

605

606 struct AllRegionParseArgs {

607 std::optional hasDeviceAddrArgs;

608 std::optional hostEvalArgs;

609 std::optional inReductionArgs;

610 std::optional mapArgs;

611 std::optional privateArgs;

612 std::optional reductionArgs;

613 std::optional taskReductionArgs;

614 std::optional useDeviceAddrArgs;

615 std::optional useDevicePtrArgs;

616 };

617 }

618

620 return "private_barrier";

621 }

622

628 ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,

630 ReductionModifierAttr *modifier = nullptr,

631 UnitAttr *needsBarrier = nullptr) {

635 unsigned regionArgOffset = regionPrivateArgs.size();

636

638 return failure();

639

641 StringRef enumStr;

644 return failure();

645 std::optional enumValue =

646 symbolizeReductionModifier(enumStr);

647 if (!enumValue.has_value())

648 return failure();

650 if (!*modifier)

651 return failure();

652 }

653

655 if (byref)

656 isByRefVec.push_back(

657 parser.parseOptionalKeyword("byref").succeeded());

658

659 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))

660 return failure();

661

662 if (parser.parseOperand(operands.emplace_back()) ||

663 parser.parseArrow() ||

664 parser.parseArgument(regionPrivateArgs.emplace_back()))

665 return failure();

666

667 if (mapIndices) {

668 if (parser.parseOptionalLSquare().succeeded()) {

669 if (parser.parseKeyword("map_idx") || parser.parseEqual() ||

670 parser.parseInteger(mapIndicesVec.emplace_back()) ||

671 parser.parseRSquare())

672 return failure();

673 } else {

674 mapIndicesVec.push_back(-1);

675 }

676 }

677

678 return success();

679 }))

680 return failure();

681

683 return failure();

684

686 if (parser.parseType(types.emplace_back()))

687 return failure();

688

689 return success();

690 }))

691 return failure();

692

693 if (operands.size() != types.size())

694 return failure();

695

697 return failure();

698

699 if (needsBarrier) {

701 .succeeded())

703 }

704

705 auto *argsBegin = regionPrivateArgs.begin();

706 MutableArrayRef argsSubrange(argsBegin + regionArgOffset,

707 argsBegin + regionArgOffset + types.size());

708 for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {

709 prv.type = type;

710 }

711

712 if (symbols) {

715 }

716

717 if (!mapIndicesVec.empty())

718 *mapIndices =

720

721 if (byref)

723

724 return success();

725 }

726

730 StringRef keyword, std::optional mapArgs) {

732 if (!mapArgs)

733 return failure();

734

736 entryBlockArgs)))

737 return failure();

738 }

739 return success();

740 }

741

745 StringRef keyword, std::optional privateArgs) {

747 if (!privateArgs)

748 return failure();

749

751 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,

752 &privateArgs->syms, privateArgs->mapIndices, nullptr,

753 nullptr, &privateArgs->needsBarrier)))

754 return failure();

755 }

756 return success();

757 }

758

762 StringRef keyword, std::optional reductionArgs) {

764 if (!reductionArgs)

765 return failure();

767 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,

768 &reductionArgs->syms, nullptr, &reductionArgs->byref,

769 reductionArgs->modifier)))

770 return failure();

771 }

772 return success();

773 }

774

776 AllRegionParseArgs args) {

778

779 if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",

780 args.hasDeviceAddrArgs)))

782 << "invalid `has_device_addr` format";

783

785 args.hostEvalArgs)))

787 << "invalid `host_eval` format";

788

790 args.inReductionArgs)))

792 << "invalid `in_reduction` format";

793

795 args.mapArgs)))

797 << "invalid `map_entries` format";

798

800 args.privateArgs)))

802 << "invalid `private` format";

803

805 args.reductionArgs)))

807 << "invalid `reduction` format";

808

810 args.taskReductionArgs)))

812 << "invalid `task_reduction` format";

813

814 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",

815 args.useDeviceAddrArgs)))

817 << "invalid `use_device_addr` format";

818

820 args.useDevicePtrArgs)))

822 << "invalid `use_device_addr` format";

823

824 return parser.parseRegion(region, entryBlockArgs);

825 }

826

827

828

843 AllRegionParseArgs args;

844 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);

845 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);

846 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,

847 inReductionByref, inReductionSyms);

848 args.mapArgs.emplace(mapVars, mapTypes);

849 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,

850 privateNeedsBarrier, &privateMaps);

852 }

853

861 UnitAttr &privateNeedsBarrier) {

862 AllRegionParseArgs args;

863 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,

864 inReductionByref, inReductionSyms);

865 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,

866 privateNeedsBarrier);

868 }

869

877 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,

880 ArrayAttr &reductionSyms) {

881 AllRegionParseArgs args;

882 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,

883 inReductionByref, inReductionSyms);

884 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,

885 privateNeedsBarrier);

886 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,

887 reductionSyms, &reductionMod);

889 }

890

895 UnitAttr &privateNeedsBarrier) {

896 AllRegionParseArgs args;

897 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,

898 privateNeedsBarrier);

900 }

901

906 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,

909 ArrayAttr &reductionSyms) {

910 AllRegionParseArgs args;

911 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,

912 privateNeedsBarrier);

913 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,

914 reductionSyms, &reductionMod);

916 }

917

922 DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {

923 AllRegionParseArgs args;

924 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,

925 taskReductionByref, taskReductionSyms);

927 }

928

935 AllRegionParseArgs args;

936 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);

937 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);

939 }

940

941

942

943

944

945 namespace {

946 struct MapPrintArgs {

949 MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}

950 };

951 struct PrivatePrintArgs {

954 ArrayAttr syms;

955 UnitAttr needsBarrier;

959 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),

960 mapIndices(mapIndices) {}

961 };

962 struct ReductionPrintArgs {

966 ArrayAttr syms;

967 ReductionModifierAttr modifier;

969 ArrayAttr syms, ReductionModifierAttr mod = nullptr)

970 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}

971 };

972 struct AllRegionPrintArgs {

973 std::optional hasDeviceAddrArgs;

974 std::optional hostEvalArgs;

975 std::optional inReductionArgs;

976 std::optional mapArgs;

977 std::optional privateArgs;

978 std::optional reductionArgs;

979 std::optional taskReductionArgs;

980 std::optional useDeviceAddrArgs;

981 std::optional useDevicePtrArgs;

982 };

983 }

984

988 ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,

990 ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {

991 if (argsSubrange.empty())

992 return;

993

994 p << clauseName << "(";

995

996 if (modifier)

997 p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";

998

999 if (!symbols) {

1002 }

1003

1004 if (!mapIndices) {

1007 }

1008

1009 if (!byref) {

1012 }

1013

1014 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,

1015 mapIndices.asArrayRef(),

1016 byref.asArrayRef()),

1017 p, [&p](auto t) {

1018 auto [op, arg, sym, map, isByRef] = t;

1019 if (isByRef)

1020 p << "byref ";

1021 if (sym)

1022 p << sym << " ";

1023

1024 p << op << " -> " << arg;

1025

1026 if (map != -1)

1027 p << " [map_idx=" << map << "]";

1028 });

1029 p << " : ";

1030 llvm::interleaveComma(types, p);

1031 p << ") ";

1032

1033 if (needsBarrier)

1035 }

1036

1038 StringRef clauseName, ValueRange argsSubrange,

1039 std::optional mapArgs) {

1040 if (mapArgs)

1042 mapArgs->types);

1043 }

1044

1046 StringRef clauseName, ValueRange argsSubrange,

1047 std::optional privateArgs) {

1048 if (privateArgs)

1050 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,

1051 privateArgs->syms, privateArgs->mapIndices, nullptr,

1052 nullptr, privateArgs->needsBarrier);

1053 }

1054

1055 static void

1058 std::optional reductionArgs) {

1059 if (reductionArgs)

1061 reductionArgs->vars, reductionArgs->types,

1062 reductionArgs->syms, nullptr,

1063 reductionArgs->byref, reductionArgs->modifier);

1064 }

1065

1067 const AllRegionPrintArgs &args) {

1068 auto iface = llvm::castmlir::omp::BlockArgOpenMPOpInterface(op);

1070

1072 iface.getHasDeviceAddrBlockArgs(),

1073 args.hasDeviceAddrArgs);

1075 args.hostEvalArgs);

1076 printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),

1077 args.inReductionArgs);

1079 args.mapArgs);

1081 args.privateArgs);

1083 args.reductionArgs);

1085 iface.getTaskReductionBlockArgs(),

1086 args.taskReductionArgs);

1088 iface.getUseDeviceAddrBlockArgs(),

1089 args.useDeviceAddrArgs);

1091 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);

1092

1093 p.printRegion(region, false);

1094 }

1095

1096

1097

1105 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,

1107 AllRegionPrintArgs args;

1108 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);

1109 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);

1110 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,

1111 inReductionByref, inReductionSyms);

1112 args.mapArgs.emplace(mapVars, mapTypes);

1113 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,

1114 privateNeedsBarrier, privateMaps);

1116 }

1117

1121 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,

1122 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {

1123 AllRegionPrintArgs args;

1124 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,

1125 inReductionByref, inReductionSyms);

1126 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,

1127 privateNeedsBarrier,

1128 nullptr);

1130 }

1131

1135 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,

1136 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,

1137 ReductionModifierAttr reductionMod, ValueRange reductionVars,

1139 ArrayAttr reductionSyms) {

1140 AllRegionPrintArgs args;

1141 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,

1142 inReductionByref, inReductionSyms);

1143 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,

1144 privateNeedsBarrier,

1145 nullptr);

1146 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,

1147 reductionSyms, reductionMod);

1149 }

1150

1153 ArrayAttr privateSyms,

1154 UnitAttr privateNeedsBarrier) {

1155 AllRegionPrintArgs args;

1156 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,

1157 privateNeedsBarrier,

1158 nullptr);

1160 }

1161

1164 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,

1165 ReductionModifierAttr reductionMod, ValueRange reductionVars,

1167 ArrayAttr reductionSyms) {

1168 AllRegionPrintArgs args;

1169 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,

1170 privateNeedsBarrier,

1171 nullptr);

1172 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,

1173 reductionSyms, reductionMod);

1175 }

1176

1182 ArrayAttr taskReductionSyms) {

1183 AllRegionPrintArgs args;

1184 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,

1185 taskReductionByref, taskReductionSyms);

1187 }

1188

1195 AllRegionPrintArgs args;

1196 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);

1197 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);

1199 }

1200

1201

1202 static LogicalResult

1206 if (!reductionVars.empty()) {

1207 if (!reductionSyms || reductionSyms->size() != reductionVars.size())

1209 << "expected as many reduction symbol references "

1210 "as reduction variables";

1211 if (reductionByref && reductionByref->size() != reductionVars.size())

1212 return op->emitError() << "expected as many reduction variable by "

1213 "reference attributes as reduction variables";

1214 } else {

1215 if (reductionSyms)

1216 return op->emitOpError() << "unexpected reduction symbol references";

1217 return success();

1218 }

1219

1220

1221

1223 for (auto args : llvm::zip(reductionVars, *reductionSyms)) {

1224 Value accum = std::get<0>(args);

1225

1226 if (!accumulators.insert(accum).second)

1227 return op->emitOpError() << "accumulator variable used more than once";

1228

1230 auto symbolRef = llvm::cast(std::get<1>(args));

1231 auto decl =

1232 SymbolTable::lookupNearestSymbolFrom(op, symbolRef);

1233 if (!decl)

1234 return op->emitOpError() << "expected symbol reference " << symbolRef

1235 << " to point to a reduction declaration";

1236

1237 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)

1239 << "expected accumulator (" << varType

1240 << ") to be the same type as reduction declaration ("

1241 << decl.getAccumulatorType() << ")";

1242 }

1243

1244 return success();

1245 }

1246

1247

1248

1249

1250

1251

1252

1253

1260 if (parser.parseOperand(copyprivateVars.emplace_back()) ||

1261 parser.parseArrow() ||

1262 parser.parseAttribute(symsVec.emplace_back()) ||

1263 parser.parseColonType(copyprivateTypes.emplace_back()))

1264 return failure();

1265 return success();

1266 })))

1267 return failure();

1270 return success();

1271 }

1272

1273

1277 std::optional copyprivateSyms) {

1278 if (!copyprivateSyms.has_value())

1279 return;

1280 llvm::interleaveComma(

1281 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,

1282 [&](const auto &args) {

1283 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "

1284 << std::get<2>(args);

1285 });

1286 }

1287

1288

1289 static LogicalResult

1291 std::optional copyprivateSyms) {

1292 size_t copyprivateSymsSize =

1293 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;

1294 if (copyprivateSymsSize != copyprivateVars.size())

1295 return op->emitOpError() << "inconsistent number of copyprivate vars (= "

1296 << copyprivateVars.size()

1297 << ") and functions (= " << copyprivateSymsSize

1298 << "), both must be equal";

1299 if (!copyprivateSyms.has_value())

1300 return success();

1301

1302 for (auto copyprivateVarAndSym :

1303 llvm::zip(copyprivateVars, *copyprivateSyms)) {

1304 auto symbolRef =

1305 llvm::cast(std::get<1>(copyprivateVarAndSym));

1306 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>

1307 funcOp;

1308 if (mlir::func::FuncOp mlirFuncOp =

1309 SymbolTable::lookupNearestSymbolFrommlir::func::FuncOp(op,

1310 symbolRef))

1311 funcOp = mlirFuncOp;

1312 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =

1313 SymbolTable::lookupNearestSymbolFrommlir::LLVM::LLVMFuncOp(

1314 op, symbolRef))

1315 funcOp = llvmFuncOp;

1316

1317 auto getNumArguments = [&] {

1318 return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);

1319 };

1320

1321 auto getArgumentType = [&](unsigned i) {

1322 return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },

1323 *funcOp);

1324 };

1325

1326 if (!funcOp)

1327 return op->emitOpError() << "expected symbol reference " << symbolRef

1328 << " to point to a copy function";

1329

1330 if (getNumArguments() != 2)

1332 << "expected copy function " << symbolRef << " to have 2 operands";

1333

1334 Type argTy = getArgumentType(0);

1335 if (argTy != getArgumentType(1))

1336 return op->emitOpError() << "expected copy function " << symbolRef

1337 << " arguments to have the same type";

1338

1339 Type varType = std::get<0>(copyprivateVarAndSym).getType();

1340 if (argTy != varType)

1342 << "expected copy function arguments' type (" << argTy

1343 << ") to be the same as copyprivate variable's type (" << varType

1344 << ")";

1345 }

1346

1347 return success();

1348 }

1349

1350

1351

1352

1353

1354

1355

1356

1357 static ParseResult

1363 StringRef keyword;

1364 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||

1365 parser.parseOperand(dependVars.emplace_back()) ||

1366 parser.parseColonType(dependTypes.emplace_back()))

1367 return failure();

1368 if (std::optional keywordDepend =

1369 (symbolizeClauseTaskDepend(keyword)))

1370 kindsVec.emplace_back(

1371 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));

1372 else

1373 return failure();

1374 return success();

1375 })))

1376 return failure();

1379 return success();

1380 }

1381

1382

1385 std::optional dependKinds) {

1386

1387 for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {

1388 if (i != 0)

1389 p << ", ";

1390 p << stringifyClauseTaskDepend(

1391 llvm::castmlir::omp::ClauseTaskDependAttr((*dependKinds)[i])

1392 .getValue())

1393 << " -> " << dependVars[i] << " : " << dependTypes[i];

1394 }

1395 }

1396

1397

1399 std::optional dependKinds,

1401 if (!dependVars.empty()) {

1402 if (!dependKinds || dependKinds->size() != dependVars.size())

1403 return op->emitOpError() << "expected as many depend values"

1404 " as depend variables";

1405 } else {

1406 if (dependKinds && !dependKinds->empty())

1407 return op->emitOpError() << "unexpected depend values";

1408 return success();

1409 }

1410

1411 return success();

1412 }

1413

1414

1415

1416

1417

1418

1419

1420

1421

1423 IntegerAttr &hintAttr) {

1424 StringRef hintKeyword;

1425 int64_t hint = 0;

1428 return success();

1429 }

1430 auto parseKeyword = [&]() -> ParseResult {

1431 if (failed(parser.parseKeyword(&hintKeyword)))

1432 return failure();

1433 if (hintKeyword == "uncontended")

1434 hint |= 1;

1435 else if (hintKeyword == "contended")

1436 hint |= 2;

1437 else if (hintKeyword == "nonspeculative")

1438 hint |= 4;

1439 else if (hintKeyword == "speculative")

1440 hint |= 8;

1441 else

1443 << hintKeyword << " is not a valid hint";

1444 return success();

1445 };

1447 return failure();

1449 return success();

1450 }

1451

1452

1454 IntegerAttr hintAttr) {

1455 int64_t hint = hintAttr.getInt();

1456

1457 if (hint == 0) {

1458 p << "none";

1459 return;

1460 }

1461

1462

1463 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };

1464

1465 bool uncontended = bitn(hint, 0);

1466 bool contended = bitn(hint, 1);

1467 bool nonspeculative = bitn(hint, 2);

1468 bool speculative = bitn(hint, 3);

1469

1471 if (uncontended)

1472 hints.push_back("uncontended");

1473 if (contended)

1474 hints.push_back("contended");

1475 if (nonspeculative)

1476 hints.push_back("nonspeculative");

1477 if (speculative)

1478 hints.push_back("speculative");

1479

1480 llvm::interleaveComma(hints, p);

1481 }

1482

1483

1485

1486

1487 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };

1488

1489 bool uncontended = bitn(hint, 0);

1490 bool contended = bitn(hint, 1);

1491 bool nonspeculative = bitn(hint, 2);

1492 bool speculative = bitn(hint, 3);

1493

1494 if (uncontended && contended)

1495 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "

1496 "omp_sync_hint_contended cannot be combined";

1497 if (nonspeculative && speculative)

1498 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "

1499 "omp_sync_hint_speculative cannot be combined.";

1500 return success();

1501 }

1502

1503

1504

1505

1506

1507

1509 llvm::omp::OpenMPOffloadMappingFlags flag) {

1510 return value & llvm::to_underlying(flag);

1511 }

1512

1513

1514

1515

1516

1517

1519 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =

1520 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;

1521

1522

1523

1524 auto parseTypeAndMod = [&]() -> ParseResult {

1525 StringRef mapTypeMod;

1527 return failure();

1528

1529 if (mapTypeMod == "always")

1530 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;

1531

1532 if (mapTypeMod == "implicit")

1533 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;

1534

1535 if (mapTypeMod == "ompx_hold")

1536 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;

1537

1538 if (mapTypeMod == "close")

1539 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;

1540

1541 if (mapTypeMod == "present")

1542 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;

1543

1544 if (mapTypeMod == "to")

1545 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;

1546

1547 if (mapTypeMod == "from")

1548 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;

1549

1550 if (mapTypeMod == "tofrom")

1551 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |

1552 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;

1553

1554 if (mapTypeMod == "delete")

1555 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;

1556

1557 if (mapTypeMod == "return_param")

1558 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;

1559

1560 return success();

1561 };

1562

1564 return failure();

1565

1568 llvm::to_underlying(mapTypeBits));

1569

1570 return success();

1571 }

1572

1573

1574

1576 IntegerAttr mapType) {

1577 uint64_t mapTypeBits = mapType.getUInt();

1578

1579 bool emitAllocRelease = true;

1581

1582

1583

1585 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))

1586 mapTypeStrs.push_back("always");

1588 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))

1589 mapTypeStrs.push_back("implicit");

1591 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))

1592 mapTypeStrs.push_back("ompx_hold");

1594 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))

1595 mapTypeStrs.push_back("close");

1597 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))

1598 mapTypeStrs.push_back("present");

1599

1600

1601

1602

1604 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);

1606 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);

1607 if (to && from) {

1608 emitAllocRelease = false;

1609 mapTypeStrs.push_back("tofrom");

1610 } else if (from) {

1611 emitAllocRelease = false;

1612 mapTypeStrs.push_back("from");

1613 } else if (to) {

1614 emitAllocRelease = false;

1615 mapTypeStrs.push_back("to");

1616 }

1618 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {

1619 emitAllocRelease = false;

1620 mapTypeStrs.push_back("delete");

1621 }

1623 mapTypeBits,

1624 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {

1625 emitAllocRelease = false;

1626 mapTypeStrs.push_back("return_param");

1627 }

1628 if (emitAllocRelease)

1629 mapTypeStrs.push_back("exit_release_or_enter_alloc");

1630

1631 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {

1632 p << mapTypeStrs[i];

1633 if (i + 1 < mapTypeStrs.size()) {

1634 p << ", ";

1635 }

1636 }

1637 }

1638

1640 ArrayAttr &membersIdx) {

1642

1643 auto parseIndices = [&]() -> ParseResult {

1644 int64_t value;

1646 return failure();

1648 APInt(64, value, false)));

1649 return success();

1650 };

1651

1652 do {

1654 return failure();

1655

1657 return failure();

1658

1660 return failure();

1661

1663 values.clear();

1665

1666 if (!memberIdxs.empty())

1668

1669 return success();

1670 }

1671

1673 ArrayAttr membersIdx) {

1674 if (!membersIdx)

1675 return;

1676

1677 llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {

1678 p << "[";

1679 auto memberIdx = cast(v);

1680 llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {

1681 p << cast(v2).getInt();

1682 });

1683 p << "]";

1684 });

1685 }

1686

1688 VariableCaptureKindAttr mapCaptureType) {

1689 std::string typeCapStr;

1690 llvm::raw_string_ostream typeCap(typeCapStr);

1691 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)

1692 typeCap << "ByRef";

1693 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)

1694 typeCap << "ByCopy";

1695 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)

1696 typeCap << "VLAType";

1697 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)

1698 typeCap << "This";

1699 p << typeCapStr;

1700 }

1701

1703 VariableCaptureKindAttr &mapCaptureType) {

1704 StringRef mapCaptureKey;

1706 return failure();

1707

1708 if (mapCaptureKey == "This")

1710 parser.getContext(), mlir::omp::VariableCaptureKind::This);

1711 if (mapCaptureKey == "ByRef")

1713 parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);

1714 if (mapCaptureKey == "ByCopy")

1716 parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);

1717 if (mapCaptureKey == "VLAType")

1719 parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);

1720

1721 return success();

1722 }

1723

1727

1728 for (auto mapOp : mapVars) {

1729 if (!mapOp.getDefiningOp())

1730 return emitError(op->getLoc(), "missing map operation");

1731

1732 if (auto mapInfoOp =

1733 mlir::dyn_castmlir::omp::MapInfoOp(mapOp.getDefiningOp())) {

1734 uint64_t mapTypeBits = mapInfoOp.getMapType();

1735

1737 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);

1739 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);

1741 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);

1742

1744 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);

1746 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);

1748 mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);

1749

1750 if ((isa(op) || isa(op)) && del)

1752 "to, from, tofrom and alloc map types are permitted");

1753

1754 if (isa(op) && (from || del))

1755 return emitError(op->getLoc(), "to and alloc map types are permitted");

1756

1757 if (isa(op) && to)

1759 "from, release and delete map types are permitted");

1760

1761 if (isa(op)) {

1762 if (del) {

1764 "at least one of to or from map types must be "

1765 "specified, other map types are not permitted");

1766 }

1767

1768 if (!to && !from) {

1770 "at least one of to or from map types must be "

1771 "specified, other map types are not permitted");

1772 }

1773

1774 auto updateVar = mapInfoOp.getVarPtr();

1775

1776 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||

1777 (from && updateToVars.contains(updateVar))) {

1780 "either to or from map types can be specified, not both");

1781 }

1782

1783 if (always || close || implicit) {

1786 "present, mapper and iterator map type modifiers are permitted");

1787 }

1788

1789 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);

1790 }

1791 } else if (!isa(op)) {

1793 "map argument is not a map entry operation");

1794 }

1795 }

1796

1797 return success();

1798 }

1799

1801 std::optional privateMapIndices =

1802 targetOp.getPrivateMapsAttr();

1803

1804

1805 if (!privateMapIndices.has_value() || !privateMapIndices.value())

1806 return success();

1807

1808 OperandRange privateVars = targetOp.getPrivateVars();

1809

1810 if (privateMapIndices.value().size() !=

1811 static_cast<int64_t>(privateVars.size()))

1812 return emitError(targetOp.getLoc(), "sizes of `private` operand range and "

1813 "`private_maps` attribute mismatch");

1814

1815 return success();

1816 }

1817

1818

1819

1820

1821

1823 StringRef clauseName,

1825 for (Value var : vars)

1826 if (!llvm::isa_and_present(var.getDefiningOp()))

1828 << "'" << clauseName

1829 << "' arguments must be defined by 'omp.map.info' ops";

1830 return success();

1831 }

1832

1834 if (getMapperId() &&

1835 !SymbolTable::lookupNearestSymbolFromomp::DeclareMapperOp(

1836 *this, getMapperIdAttr())) {

1837 return emitError("invalid mapper id");

1838 }

1839

1841 return failure();

1842

1843 return success();

1844 }

1845

1846

1847

1848

1849

1851 const TargetDataOperands &clauses) {

1852 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,

1853 clauses.mapVars, clauses.useDeviceAddrVars,

1854 clauses.useDevicePtrVars);

1855 }

1856

1858 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&

1859 getUseDeviceAddrVars().empty()) {

1861 "At least one of map, use_device_ptr_vars, or "

1862 "use_device_addr_vars operand must be present");

1863 }

1864

1866 getUseDevicePtrVars())))

1867 return failure();

1868

1870 getUseDeviceAddrVars())))

1871 return failure();

1872

1874 }

1875

1876

1877

1878

1879

1880 void TargetEnterDataOp::build(

1884 TargetEnterDataOp::build(builder, state,

1886 clauses.dependVars, clauses.device, clauses.ifExpr,

1887 clauses.mapVars, clauses.nowait);

1888 }

1889

1891 LogicalResult verifyDependVars =

1893 return failed(verifyDependVars) ? verifyDependVars

1895 }

1896

1897

1898

1899

1900

1904 TargetExitDataOp::build(builder, state,

1906 clauses.dependVars, clauses.device, clauses.ifExpr,

1907 clauses.mapVars, clauses.nowait);

1908 }

1909

1911 LogicalResult verifyDependVars =

1913 return failed(verifyDependVars) ? verifyDependVars

1915 }

1916

1917

1918

1919

1920

1924 TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),

1925 clauses.dependVars, clauses.device, clauses.ifExpr,

1926 clauses.mapVars, clauses.nowait);

1927 }

1928

1930 LogicalResult verifyDependVars =

1932 return failed(verifyDependVars) ? verifyDependVars

1934 }

1935

1936

1937

1938

1939

1941 const TargetOperands &clauses) {

1943

1944

1945 TargetOp::build(builder, state, {}, {},

1946 clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),

1947 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,

1948 clauses.hostEvalVars, clauses.ifExpr,

1949 {}, nullptr,

1950 nullptr, clauses.isDevicePtrVars,

1951 clauses.mapVars, clauses.nowait, clauses.privateVars,

1953 clauses.privateNeedsBarrier, clauses.threadLimit,

1954 nullptr);

1955 }

1956

1958 if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))

1959 return failure();

1960

1962 getHasDeviceAddrVars())))

1963 return failure();

1964

1966 return failure();

1967

1969 }

1970

1971 LogicalResult TargetOp::verifyRegions() {

1972 auto teamsOps = getOps();

1973 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)

1974 return emitError("target containing multiple 'omp.teams' nested ops");

1975

1976

1977 Operation *capturedOp = getInnermostCapturedOmpOp();

1978 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);

1979 for (Value hostEvalArg :

1980 cast(getOperation()).getHostEvalBlockArgs()) {

1982 if (auto teamsOp = dyn_cast(user)) {

1983 if (llvm::is_contained({teamsOp.getNumTeamsLower(),

1984 teamsOp.getNumTeamsUpper(),

1985 teamsOp.getThreadLimit()},

1986 hostEvalArg))

1987 continue;

1988

1989 return emitOpError() << "host_eval argument only legal as 'num_teams' "

1990 "and 'thread_limit' in 'omp.teams'";

1991 }

1992 if (auto parallelOp = dyn_cast(user)) {

1993 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&

1994 parallelOp->isAncestor(capturedOp) &&

1995 hostEvalArg == parallelOp.getNumThreads())

1996 continue;

1997

1998 return emitOpError()

1999 << "host_eval argument only legal as 'num_threads' in "

2000 "'omp.parallel' when representing target SPMD";

2001 }

2002 if (auto loopNestOp = dyn_cast(user)) {

2003 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&

2004 loopNestOp.getOperation() == capturedOp &&

2005 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||

2006 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||

2007 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))

2008 continue;

2009

2010 return emitOpError() << "host_eval argument only legal as loop bounds "

2011 "and steps in 'omp.loop_nest' when trip count "

2012 "must be evaluated in the host";

2013 }

2014

2015 return emitOpError() << "host_eval argument illegal use in '"

2016 << user->getName() << "' operation";

2017 }

2018 }

2019 return success();

2020 }

2021

2025 assert(rootOp && "expected valid operation");

2026

2028 Operation *capturedOp = nullptr;

2030

2031

2032

2033

2034

2035 rootOp->walkWalkOrder::PreOrder([&](Operation *op) {

2036 if (op == rootOp)

2037 return WalkResult::advance();

2038

2039

2040

2041

2042 bool isOmpDialect = op->getDialect() == ompDialect;

2044 if (!isOmpDialect || !hasRegions)

2045 return WalkResult::skip();

2046

2047

2048

2049

2050

2051 if (checkSingleMandatoryExec) {

2054

2056 if (successor->isReachable(parentBlock))

2057 return WalkResult::interrupt();

2058

2059 for (Block &block : *parentRegion)

2061 !domInfo.dominates(parentBlock, &block))

2062 return WalkResult::interrupt();

2063 }

2064

2065

2066

2068 if (&sibling != op && !siblingAllowedFn(&sibling))

2069 return WalkResult::interrupt();

2070

2071

2072

2073 capturedOp = op;

2074 return llvm::isa(op) ? WalkResult::interrupt()

2075 : WalkResult::advance();

2076 });

2077

2078 return capturedOp;

2079 }

2080

2081 Operation *TargetOp::getInnermostCapturedOmpOp() {

2083

2084

2085

2087 *this, true, [&](Operation *sibling) {

2088 if (!sibling)

2089 return false;

2090

2091 if (ompDialect == sibling->getDialect())

2093

2094 if (auto memOp = dyn_cast(sibling)) {

2096 effects;

2097 memOp.getEffects(effects);

2098 return !llvm::any_of(

2100 return isaMemoryEffects::Write(effect.getEffect()) &&

2101 isaSideEffects::AutomaticAllocationScopeResource(

2103 });

2104 }

2105 return true;

2106 });

2107 }

2108

2109 TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {

2110

2111

2112 TargetOp targetOp =

2113 capturedOp ? capturedOp->getParentOfType() : nullptr;

2114 assert((!capturedOp ||

2115 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&

2116 "unexpected captured op");

2117

2118

2119 if (!isa_and_present(capturedOp))

2120 return TargetRegionFlags::generic;

2121

2122

2124 cast(capturedOp).gatherWrappers(loopWrappers);

2125 assert(!loopWrappers.empty());

2126

2127 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();

2128 if (isa(innermostWrapper))

2129 innermostWrapper = std::next(innermostWrapper);

2130

2131 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());

2132 if (numWrappers != 1 && numWrappers != 2)

2133 return TargetRegionFlags::generic;

2134

2135

2136 if (numWrappers == 2) {

2137 if (!isa(innermostWrapper))

2138 return TargetRegionFlags::generic;

2139

2140 innermostWrapper = std::next(innermostWrapper);

2141 if (!isa(innermostWrapper))

2142 return TargetRegionFlags::generic;

2143

2145 if (!isa_and_present(parallelOp))

2146 return TargetRegionFlags::generic;

2147

2149 if (!isa_and_present(teamsOp))

2150 return TargetRegionFlags::generic;

2151

2152 if (teamsOp->getParentOp() == targetOp.getOperation())

2153 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;

2154 }

2155

2156 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {

2158 if (!isa_and_present(teamsOp))

2159 return TargetRegionFlags::generic;

2160

2161 if (teamsOp->getParentOp() != targetOp.getOperation())

2162 return TargetRegionFlags::generic;

2163

2164 if (isa(innermostWrapper))

2165 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;

2166

2167

2168

2169

2170

2171

2172

2173

2174

2175 Dialect *ompDialect = targetOp->getDialect();

2177 capturedOp, false,

2179 return sibling && (ompDialect != sibling->getDialect() ||

2181 });

2182

2183 TargetRegionFlags result =

2184 TargetRegionFlags::generic | TargetRegionFlags::trip_count;

2185

2186 if (!nestedCapture)

2187 return result;

2188

2189 while (nestedCapture->getParentOp() != capturedOp)

2190 nestedCapture = nestedCapture->getParentOp();

2191

2192 return isa(nestedCapture) ? result | TargetRegionFlags::spmd

2193 : result;

2194 }

2195

2196 else if (isa(innermostWrapper)) {

2198 if (!isa_and_present(parallelOp))

2199 return TargetRegionFlags::generic;

2200

2201 if (parallelOp->getParentOp() == targetOp.getOperation())

2202 return TargetRegionFlags::spmd;

2203 }

2204

2205 return TargetRegionFlags::generic;

2206 }

2207

2208

2209

2210

2211

2214 ParallelOp::build(builder, state, ValueRange(),

2215 ValueRange(), nullptr,

2216 nullptr, ValueRange(),

2217 nullptr, nullptr,

2218 nullptr,

2219 nullptr, ValueRange(),

2220 nullptr, nullptr);

2221 state.addAttributes(attributes);

2222 }

2223

2225 const ParallelOperands &clauses) {

2227 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,

2228 clauses.ifExpr, clauses.numThreads, clauses.privateVars,

2230 clauses.privateNeedsBarrier, clauses.procBindKind,

2231 clauses.reductionMod, clauses.reductionVars,

2234 }

2235

2236 template

2238 auto privateVars = op.getPrivateVars();

2239 auto privateSyms = op.getPrivateSymsAttr();

2240

2241 if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))

2242 return success();

2243

2244 auto numPrivateVars = privateVars.size();

2245 auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();

2246

2247 if (numPrivateVars != numPrivateSyms)

2248 return op.emitError() << "inconsistent number of private variables and "

2249 "privatizer op symbols, private vars: "

2250 << numPrivateVars

2251 << " vs. privatizer op symbols: " << numPrivateSyms;

2252

2253 for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {

2254 Type varType = std::get<0>(privateVarInfo).getType();

2255 SymbolRefAttr privateSym = cast(std::get<1>(privateVarInfo));

2256 PrivateClauseOp privatizerOp =

2257 SymbolTable::lookupNearestSymbolFrom(op, privateSym);

2258

2259 if (privatizerOp == nullptr)

2260 return op.emitError() << "failed to lookup privatizer op with symbol: '"

2261 << privateSym << "'";

2262

2263 Type privatizerType = privatizerOp.getArgType();

2264

2265 if (privatizerType && (varType != privatizerType))

2266 return op.emitError()

2267 << "type mismatch between a "

2268 << (privatizerOp.getDataSharingType() ==

2269 DataSharingClauseType::Private

2270 ? "private"

2271 : "firstprivate")

2272 << " variable and its privatizer op, var type: " << varType

2273 << " vs. privatizer op type: " << privatizerType;

2274 }

2275

2276 return success();

2277 }

2278

2280 if (getAllocateVars().size() != getAllocatorVars().size())

2282 "expected equal sizes for allocate and allocator variables");

2283

2285 return failure();

2286

2288 getReductionByref());

2289 }

2290

2291 LogicalResult ParallelOp::verifyRegions() {

2292 auto distChildOps = getOps();

2293 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());

2294 if (numDistChildOps > 1)

2296 << "multiple 'omp.distribute' nested inside of 'omp.parallel'";

2297

2298 if (numDistChildOps == 1) {

2299 if (!isComposite())

2301 << "'omp.composite' attribute missing from composite operation";

2302

2304 Operation &distributeOp = **distChildOps.begin();

2305 for (Operation &childOp : getOps()) {

2306 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())

2307 continue;

2308

2310 return emitError() << "unexpected OpenMP operation inside of composite "

2311 "'omp.parallel': "

2312 << childOp.getName();

2313 }

2314 } else if (isComposite()) {

2316 << "'omp.composite' attribute present in non-composite operation";

2317 }

2318 return success();

2319 }

2320

2321

2322

2323

2324

2327 if (isa(op->getDialect()))

2328 return false;

2329 return true;

2330 }

2331

2333 const TeamsOperands &clauses) {

2335

2336 TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,

2337 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,

2338 {}, nullptr,

2339 nullptr, clauses.reductionMod,

2340 clauses.reductionVars,

2343 clauses.threadLimit);

2344 }

2345

2347

2348

2349

2350

2351

2353 if (!isa(op->getParentOp()) &&

2355 return emitError("expected to be nested inside of omp.target or not nested "

2356 "in any OpenMP dialect operations");

2357

2358

2359 if (auto numTeamsLowerBound = getNumTeamsLower()) {

2360 auto numTeamsUpperBound = getNumTeamsUpper();

2361 if (!numTeamsUpperBound)

2362 return emitError("expected num_teams upper bound to be defined if the "

2363 "lower bound is defined");

2364 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())

2366 "expected num_teams upper bound and lower bound to be the same type");

2367 }

2368

2369

2370 if (getAllocateVars().size() != getAllocatorVars().size())

2372 "expected equal sizes for allocate and allocator variables");

2373

2375 getReductionByref());

2376 }

2377

2378

2379

2380

2381

2383 return getParentOp().getPrivateVars();

2384 }

2385

2386 OperandRange SectionOp::getReductionVars() {

2387 return getParentOp().getReductionVars();

2388 }

2389

2390

2391

2392

2393

2395 const SectionsOperands &clauses) {

2397

2398 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,

2399 clauses.nowait, {},

2400 nullptr, nullptr,

2401 clauses.reductionMod, clauses.reductionVars,

2404 }

2405

2407 if (getAllocateVars().size() != getAllocatorVars().size())

2409 "expected equal sizes for allocate and allocator variables");

2410

2412 getReductionByref());

2413 }

2414

2415 LogicalResult SectionsOp::verifyRegions() {

2416 for (auto &inst : *getRegion().begin()) {

2417 if (!(isa(inst) || isa(inst))) {

2418 return emitOpError()

2419 << "expected omp.section op or terminator op inside region";

2420 }

2421 }

2422

2423 return success();

2424 }

2425

2426

2427

2428

2429

2431 const SingleOperands &clauses) {

2433

2434 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,

2435 clauses.copyprivateVars,

2436 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,

2437 {}, nullptr,

2438 nullptr);

2439 }

2440

2442

2443 if (getAllocateVars().size() != getAllocatorVars().size())

2445 "expected equal sizes for allocate and allocator variables");

2446

2448 getCopyprivateSyms());

2449 }

2450

2451

2452

2453

2454

2456 const WorkshareOperands &clauses) {

2457 WorkshareOp::build(builder, state, clauses.nowait);

2458 }

2459

2460

2461

2462

2463

2465 if (!(*this)->getParentOfType())

2466 return emitOpError() << "must be nested in an omp.workshare";

2467 return success();

2468 }

2469

2470 LogicalResult WorkshareLoopWrapperOp::verifyRegions() {

2471 if (isa_and_nonnull((*this)->getParentOp()) ||

2472 getNestedWrapper())

2473 return emitOpError() << "expected to be a standalone loop wrapper";

2474

2475 return success();

2476 }

2477

2478

2479

2480

2481

2482 LogicalResult LoopWrapperInterface::verifyImpl() {

2483 Operation *op = this->getOperation();

2486 return emitOpError() << "loop wrapper must also have the `NoTerminator` "

2487 "and `SingleBlock` traits";

2488

2490 return emitOpError() << "loop wrapper does not contain exactly one region";

2491

2493 if (range_size(region.getOps()) != 1)

2494 return emitOpError()

2495 << "loop wrapper does not contain exactly one nested op";

2496

2498 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))

2499 return emitOpError() << "nested in loop wrapper is not another loop "

2500 "wrapper or `omp.loop_nest`";

2501

2502 return success();

2503 }

2504

2505

2506

2507

2508

2510 const LoopOperands &clauses) {

2512

2513 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,

2515 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,

2516 clauses.reductionMod, clauses.reductionVars,

2519 }

2520

2523 getReductionByref());

2524 }

2525

2526 LogicalResult LoopOp::verifyRegions() {

2527 if (llvm::isa_and_nonnull((*this)->getParentOp()) ||

2528 getNestedWrapper())

2529 return emitOpError() << "expected to be a standalone loop wrapper";

2530

2531 return success();

2532 }

2533

2534

2535

2536

2537

2540 build(builder, state, {}, {},

2542 false, nullptr, nullptr,

2543 nullptr, {}, nullptr,

2544 false,

2545 nullptr, ValueRange(),

2546 nullptr,

2547 nullptr, nullptr,

2548 nullptr, nullptr,

2549 false);

2550 state.addAttributes(attributes);

2551 }

2552

2554 const WsloopOperands &clauses) {

2556

2557 WsloopOp::build(

2558 builder, state,

2559 {}, {}, clauses.linearVars,

2560 clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,

2561 clauses.ordered, clauses.privateVars,

2562 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,

2563 clauses.reductionMod, clauses.reductionVars,

2565 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,

2566 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);

2567 }

2568

2571 getReductionByref());

2572 }

2573

2574 LogicalResult WsloopOp::verifyRegions() {

2575 bool isCompositeChildLeaf =

2576 llvm::dyn_cast_if_present((*this)->getParentOp());

2577

2578 if (LoopWrapperInterface nested = getNestedWrapper()) {

2579 if (!isComposite())

2581 << "'omp.composite' attribute missing from composite wrapper";

2582

2583

2584

2585 if (!isa(nested))

2586 return emitError() << "only supported nested wrapper is 'omp.simd'";

2587

2588 } else if (isComposite() && !isCompositeChildLeaf) {

2590 << "'omp.composite' attribute present in non-composite wrapper";

2591 } else if (!isComposite() && isCompositeChildLeaf) {

2593 << "'omp.composite' attribute missing from composite wrapper";

2594 }

2595

2596 return success();

2597 }

2598

2599

2600

2601

2602

2604 const SimdOperands &clauses) {

2606

2607 SimdOp::build(builder, state, clauses.alignedVars,

2608 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,

2609 {}, {},

2610 clauses.nontemporalVars, clauses.order, clauses.orderMod,

2611 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),

2612 clauses.privateNeedsBarrier, clauses.reductionMod,

2613 clauses.reductionVars,

2615 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,

2616 clauses.simdlen);

2617 }

2618

2620 if (getSimdlen().has_value() && getSafelen().has_value() &&

2621 getSimdlen().value() > getSafelen().value())

2622 return emitOpError()

2623 << "simdlen clause and safelen clause are both present, but the "

2624 "simdlen value is not less than or equal to safelen value";

2625

2626 if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())

2627 return failure();

2628

2630 return failure();

2631

2632 bool isCompositeChildLeaf =

2633 llvm::dyn_cast_if_present((*this)->getParentOp());

2634

2635 if (!isComposite() && isCompositeChildLeaf)

2637 << "'omp.composite' attribute missing from composite wrapper";

2638

2639 if (isComposite() && !isCompositeChildLeaf)

2641 << "'omp.composite' attribute present in non-composite wrapper";

2642

2643 return success();

2644 }

2645

2646 LogicalResult SimdOp::verifyRegions() {

2647 if (getNestedWrapper())

2648 return emitOpError() << "must wrap an 'omp.loop_nest' directly";

2649

2650 return success();

2651 }

2652

2653

2654

2655

2656

2658 const DistributeOperands &clauses) {

2659 DistributeOp::build(builder, state, clauses.allocateVars,

2660 clauses.allocatorVars, clauses.distScheduleStatic,

2661 clauses.distScheduleChunkSize, clauses.order,

2662 clauses.orderMod, clauses.privateVars,

2664 clauses.privateNeedsBarrier);

2665 }

2666

2668 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())

2669 return emitOpError() << "chunk size set without "

2670 "dist_schedule_static being present";

2671

2672 if (getAllocateVars().size() != getAllocatorVars().size())

2674 "expected equal sizes for allocate and allocator variables");

2675

2676 return success();

2677 }

2678

2679 LogicalResult DistributeOp::verifyRegions() {

2680 if (LoopWrapperInterface nested = getNestedWrapper()) {

2681 if (!isComposite())

2683 << "'omp.composite' attribute missing from composite wrapper";

2684

2685

2686 if (isa(nested)) {

2688 if (!llvm::dyn_cast_if_present(parentOp) ||

2689 !cast(parentOp).isComposite()) {

2690 return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "

2691 "when a composite 'omp.parallel' is the direct "

2692 "parent";

2693 }

2694 } else if (!isa(nested))

2695 return emitError() << "only supported nested wrappers are 'omp.simd' and "

2696 "'omp.wsloop'";

2697 } else if (isComposite()) {

2699 << "'omp.composite' attribute present in non-composite wrapper";

2700 }

2701

2702 return success();

2703 }

2704

2705

2706

2707

2708

2711 }

2712

2713 LogicalResult DeclareMapperOp::verifyRegions() {

2714 if (!llvm::isa_and_present(

2715 getRegion().getBlocks().front().getTerminator()))

2716 return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";

2717

2718 return success();

2719 }

2720

2721

2722

2723

2724

2725 LogicalResult DeclareReductionOp::verifyRegions() {

2726 if (!getAllocRegion().empty()) {

2727 for (YieldOp yieldOp : getAllocRegion().getOps()) {

2728 if (yieldOp.getResults().size() != 1 ||

2729 yieldOp.getResults().getTypes()[0] != getType())

2730 return emitOpError() << "expects alloc region to yield a value "

2731 "of the reduction type";

2732 }

2733 }

2734

2735 if (getInitializerRegion().empty())

2736 return emitOpError() << "expects non-empty initializer region";

2737 Block &initializerEntryBlock = getInitializerRegion().front();

2738

2740 if (!getAllocRegion().empty())

2741 return emitOpError() << "expects two arguments to the initializer region "

2742 "when an allocation region is used";

2743 } else if (initializerEntryBlock.getNumArguments() == 2) {

2744 if (getAllocRegion().empty())

2745 return emitOpError() << "expects one argument to the initializer region "

2746 "when no allocation region is used";

2747 } else {

2748 return emitOpError()

2749 << "expects one or two arguments to the initializer region";

2750 }

2751

2753 if (arg.getType() != getType())

2754 return emitOpError() << "expects initializer region argument to match "

2755 "the reduction type";

2756

2757 for (YieldOp yieldOp : getInitializerRegion().getOps()) {

2758 if (yieldOp.getResults().size() != 1 ||

2759 yieldOp.getResults().getTypes()[0] != getType())

2760 return emitOpError() << "expects initializer region to yield a value "

2761 "of the reduction type";

2762 }

2763

2764 if (getReductionRegion().empty())

2765 return emitOpError() << "expects non-empty reduction region";

2766 Block &reductionEntryBlock = getReductionRegion().front();

2771 return emitOpError() << "expects reduction region with two arguments of "

2772 "the reduction type";

2773 for (YieldOp yieldOp : getReductionRegion().getOps()) {

2774 if (yieldOp.getResults().size() != 1 ||

2775 yieldOp.getResults().getTypes()[0] != getType())

2776 return emitOpError() << "expects reduction region to yield a value "

2777 "of the reduction type";

2778 }

2779

2780 if (!getAtomicReductionRegion().empty()) {

2781 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();

2785 return emitOpError() << "expects atomic reduction region with two "

2786 "arguments of the same type";

2787 auto ptrType = llvm::dyn_cast(

2789 if (!ptrType ||

2790 (ptrType.getElementType() && ptrType.getElementType() != getType()))

2791 return emitOpError() << "expects atomic reduction region arguments to "

2792 "be accumulators containing the reduction type";

2793 }

2794

2795 if (getCleanupRegion().empty())

2796 return success();

2797 Block &cleanupEntryBlock = getCleanupRegion().front();

2800 return emitOpError() << "expects cleanup region with one argument "

2801 "of the reduction type";

2802

2803 return success();

2804 }

2805

2806

2807

2808

2809

2811 const TaskOperands &clauses) {

2813 TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,

2814 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,

2815 clauses.final, clauses.ifExpr, clauses.inReductionVars,

2817 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,

2818 clauses.priority, clauses.privateVars,

2819 makeArrayAttr(ctx, clauses.privateSyms),

2820 clauses.privateNeedsBarrier, clauses.untied,

2821 clauses.eventHandle);

2822 }

2823

2825 LogicalResult verifyDependVars =

2827 return failed(verifyDependVars)

2828 ? verifyDependVars

2830 getInReductionVars(),

2831 getInReductionByref());

2832 }

2833

2834

2835

2836

2837

2839 const TaskgroupOperands &clauses) {

2841 TaskgroupOp::build(builder, state, clauses.allocateVars,

2842 clauses.allocatorVars, clauses.taskReductionVars,

2845 }

2846

2849 getTaskReductionVars(),

2850 getTaskReductionByref());

2851 }

2852

2853

2854

2855

2856

2858 const TaskloopOperands &clauses) {

2860 TaskloopOp::build(

2861 builder, state, clauses.allocateVars, clauses.allocatorVars,

2862 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,

2863 clauses.inReductionVars,

2865 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,

2866 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,

2867 clauses.privateVars,

2868 makeArrayAttr(ctx, clauses.privateSyms),

2869 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,

2871 makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);

2872 }

2873

2875 if (getAllocateVars().size() != getAllocatorVars().size())

2877 "expected equal sizes for allocate and allocator variables");

2879 getReductionVars(), getReductionByref())) ||

2881 getInReductionVars(),

2882 getInReductionByref())))

2883 return failure();

2884

2885 if (!getReductionVars().empty() && getNogroup())

2886 return emitError("if a reduction clause is present on the taskloop "

2887 "directive, the nogroup clause must not be specified");

2888 for (auto var : getReductionVars()) {

2889 if (llvm::is_contained(getInReductionVars(), var))

2890 return emitError("the same list item cannot appear in both a reduction "

2891 "and an in_reduction clause");

2892 }

2893

2894 if (getGrainsize() && getNumTasks()) {

2896 "the grainsize clause and num_tasks clause are mutually exclusive and "

2897 "may not appear on the same taskloop directive");

2898 }

2899

2900 return success();

2901 }

2902

2903 LogicalResult TaskloopOp::verifyRegions() {

2904 if (LoopWrapperInterface nested = getNestedWrapper()) {

2905 if (!isComposite())

2907 << "'omp.composite' attribute missing from composite wrapper";

2908

2909

2910

2911 if (!isa(nested))

2912 return emitError() << "only supported nested wrapper is 'omp.simd'";

2913 } else if (isComposite()) {

2915 << "'omp.composite' attribute present in non-composite wrapper";

2916 }

2917

2918 return success();

2919 }

2920

2921

2922

2923

2924

2926

2929 Type loopVarType;

2930 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||

2932

2934 parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||

2936 parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))

2937 return failure();

2938

2939 for (auto &iv : ivs)

2940 iv.type = loopVarType;

2941

2942

2946

2947

2950 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))

2951 return failure();

2952

2953

2956 return failure();

2957

2958

2962 return failure();

2963

2964

2966 }

2967

2969 Region &region = getRegion();

2971 p << " (" << args << ") : " << args[0].getType() << " = ("

2972 << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";

2973 if (getLoopInclusive())

2974 p << "inclusive ";

2975 p << "step (" << getLoopSteps() << ") ";

2976 p.printRegion(region, false);

2977 }

2978

2980 const LoopNestOperands &clauses) {

2981 LoopNestOp::build(builder, state, clauses.loopLowerBounds,

2982 clauses.loopUpperBounds, clauses.loopSteps,

2983 clauses.loopInclusive);

2984 }

2985

2987 if (getLoopLowerBounds().empty())

2988 return emitOpError() << "must represent at least one loop";

2989

2990 if (getLoopLowerBounds().size() != getIVs().size())

2991 return emitOpError() << "number of range arguments and IVs do not match";

2992

2993 for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {

2994 if (lb.getType() != iv.getType())

2995 return emitOpError()

2996 << "range argument type does not match corresponding IV type";

2997 }

2998

2999 if (!llvm::dyn_cast_if_present((*this)->getParentOp()))

3000 return emitOpError() << "expects parent op to be a loop wrapper";

3001

3002 return success();

3003 }

3004

3005 void LoopNestOp::gatherWrappers(

3008 while (auto wrapper =

3009 llvm::dyn_cast_if_present(parent)) {

3010 wrappers.push_back(wrapper);

3012 }

3013 }

3014

3015

3016

3017

3018

3020 const CriticalDeclareOperands &clauses) {

3021 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);

3022 }

3023

3026 }

3027

3028 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

3029 if (getNameAttr()) {

3030 SymbolRefAttr symbolRef = getNameAttr();

3032 *this, symbolRef);

3033 if (!decl) {

3034 return emitOpError() << "expected symbol reference " << symbolRef

3035 << " to point to a critical declaration";

3036 }

3037 }

3038

3039 return success();

3040 }

3041

3042

3043

3044

3045

3049 if (!loopOp) {

3050 if (hasRegion)

3051 return success();

3052

3053

3054

3055 return op.emitOpError() << "must be nested inside of a loop";

3056 }

3057

3059 if (auto wsloopOp = dyn_cast(wrapper)) {

3060 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();

3061 if (!orderedAttr)

3062 return op.emitOpError() << "the enclosing worksharing-loop region must "

3063 "have an ordered clause";

3064

3065 if (hasRegion && orderedAttr.getInt() != 0)

3066 return op.emitOpError() << "the enclosing loop's ordered clause must not "

3067 "have a parameter present";

3068

3069 if (!hasRegion && orderedAttr.getInt() == 0)

3070 return op.emitOpError() << "the enclosing loop's ordered clause must "

3071 "have a parameter present";

3072 } else if (!isa(wrapper)) {

3073 return op.emitOpError() << "must be nested inside of a worksharing, simd "

3074 "or worksharing simd loop";

3075 }

3076 return success();

3077 }

3078

3080 const OrderedOperands &clauses) {

3081 OrderedOp::build(builder, state, clauses.doacrossDependType,

3082 clauses.doacrossNumLoops, clauses.doacrossDependVars);

3083 }

3084

3087 return failure();

3088

3089 auto wrapper = (*this)->getParentOfType();

3090 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())

3091 return emitOpError() << "number of variables in depend clause does not "

3092 << "match number of iteration variables in the "

3093 << "doacross loop";

3094

3095 return success();

3096 }

3097

3099 const OrderedRegionOperands &clauses) {

3100 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);

3101 }

3102

3104

3105

3106

3107

3108

3110 const TaskwaitOperands &clauses) {

3111

3112 TaskwaitOp::build(builder, state, nullptr,

3113 {}, nullptr);

3114 }

3115

3116

3117

3118

3119

3121 if (verifyCommon().failed())

3122 return mlir::failure();

3123

3124 if (auto mo = getMemoryOrder()) {

3125 if (*mo == ClauseMemoryOrderKind::Acq_rel ||

3126 *mo == ClauseMemoryOrderKind::Release) {

3128 "memory-order must not be acq_rel or release for atomic reads");

3129 }

3130 }

3132 }

3133

3134

3135

3136

3137

3139 if (verifyCommon().failed())

3140 return mlir::failure();

3141

3142 if (auto mo = getMemoryOrder()) {

3143 if (*mo == ClauseMemoryOrderKind::Acq_rel ||

3144 *mo == ClauseMemoryOrderKind::Acquire) {

3146 "memory-order must not be acq_rel or acquire for atomic writes");

3147 }

3148 }

3150 }

3151

3152

3153

3154

3155

3156 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,

3158 if (op.isNoOp()) {

3160 return success();

3161 }

3162 if (Value writeVal = op.getWriteOpVal()) {

3164 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());

3165 return success();

3166 }

3167 return failure();

3168 }

3169

3171 if (verifyCommon().failed())

3172 return mlir::failure();

3173

3174 if (auto mo = getMemoryOrder()) {

3175 if (*mo == ClauseMemoryOrderKind::Acq_rel ||

3176 *mo == ClauseMemoryOrderKind::Acquire) {

3178 "memory-order must not be acq_rel or acquire for atomic updates");

3179 }

3180 }

3181

3183 }

3184

3185 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }

3186

3187

3188

3189

3190

3191 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {

3192 if (auto op = dyn_cast(getFirstOp()))

3193 return op;

3194 return dyn_cast(getSecondOp());

3195 }

3196

3197 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {

3198 if (auto op = dyn_cast(getFirstOp()))

3199 return op;

3200 return dyn_cast(getSecondOp());

3201 }

3202

3203 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {

3204 if (auto op = dyn_cast(getFirstOp()))

3205 return op;

3206 return dyn_cast(getSecondOp());

3207 }

3208

3211 }

3212

3213 LogicalResult AtomicCaptureOp::verifyRegions() {

3214 if (verifyRegionsCommon().failed())

3215 return mlir::failure();

3216

3217 if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))

3218 return emitOpError(

3219 "operations inside capture region must not have hint clause");

3220

3221 if (getFirstOp()->getAttr("memory_order") ||

3222 getSecondOp()->getAttr("memory_order"))

3223 return emitOpError(

3224 "operations inside capture region must not have memory_order clause");

3225 return success();

3226 }

3227

3228

3229

3230

3231

3233 const CancelOperands &clauses) {

3234 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);

3235 }

3236

3239 while (parent) {

3241 return parent;

3243 }

3244 return nullptr;

3245 }

3246

3248 ClauseCancellationConstructType cct = getCancelDirective();

3249

3251 if (!structuralParent)

3252 return emitOpError() << "Orphaned cancel construct";

3253

3254 if ((cct == ClauseCancellationConstructType::Parallel) &&

3255 !mlir::isa(structuralParent)) {

3256 return emitOpError() << "cancel parallel must appear "

3257 << "inside a parallel region";

3258 }

3259 if (cct == ClauseCancellationConstructType::Loop) {

3260

3261

3262 auto wsloopOp = mlir::dyn_cast(structuralParent->getParentOp());

3263

3264 if (!wsloopOp) {

3265 return emitOpError()

3266 << "cancel loop must appear inside a worksharing-loop region";

3267 }

3268 if (wsloopOp.getNowaitAttr()) {

3269 return emitError() << "A worksharing construct that is canceled "

3270 << "must not have a nowait clause";

3271 }

3272 if (wsloopOp.getOrderedAttr()) {

3273 return emitError() << "A worksharing construct that is canceled "

3274 << "must not have an ordered clause";

3275 }

3276

3277 } else if (cct == ClauseCancellationConstructType::Sections) {

3278

3279

3280 auto sectionsOp =

3281 mlir::dyn_cast(structuralParent->getParentOp());

3282 if (!sectionsOp) {

3283 return emitOpError() << "cancel sections must appear "

3284 << "inside a sections region";

3285 }

3286 if (sectionsOp.getNowait()) {

3287 return emitError() << "A sections construct that is canceled "

3288 << "must not have a nowait clause";

3289 }

3290 }

3291 if ((cct == ClauseCancellationConstructType::Taskgroup) &&

3292 (!mlir::isaomp::TaskOp(structuralParent) &&

3293 !mlir::isaomp::TaskloopOp(structuralParent->getParentOp()))) {

3294 return emitOpError() << "cancel taskgroup must appear "

3295 << "inside a task region";

3296 }

3297 return success();

3298 }

3299

3300

3301

3302

3303

3305 const CancellationPointOperands &clauses) {

3306 CancellationPointOp::build(builder, state, clauses.cancelDirective);

3307 }

3308

3310 ClauseCancellationConstructType cct = getCancelDirective();

3311

3313 if (!structuralParent)

3314 return emitOpError() << "Orphaned cancellation point";

3315

3316 if ((cct == ClauseCancellationConstructType::Parallel) &&

3317 !mlir::isa(structuralParent)) {

3318 return emitOpError() << "cancellation point parallel must appear "

3319 << "inside a parallel region";

3320 }

3321

3322

3323 if ((cct == ClauseCancellationConstructType::Loop) &&

3324 !mlir::isa(structuralParent->getParentOp())) {

3325 return emitOpError() << "cancellation point loop must appear "

3326 << "inside a worksharing-loop region";

3327 }

3328 if ((cct == ClauseCancellationConstructType::Sections) &&

3329 !mlir::isaomp::SectionOp(structuralParent)) {

3330 return emitOpError() << "cancellation point sections must appear "

3331 << "inside a sections region";

3332 }

3333 if ((cct == ClauseCancellationConstructType::Taskgroup) &&

3334 !mlir::isaomp::TaskOp(structuralParent)) {

3335 return emitOpError() << "cancellation point taskgroup must appear "

3336 << "inside a task region";

3337 }

3338 return success();

3339 }

3340

3341

3342

3343

3344

3346 auto extent = getExtent();

3348 if (!extent && !upperbound)

3349 return emitError("expected extent or upperbound.");

3350 return success();

3351 }

3352

3354 TypeRange , StringAttr symName,

3355 TypeAttr type) {

3356 PrivateClauseOp::build(

3357 odsBuilder, odsState, symName, type,

3359 DataSharingClauseType::Private));

3360 }

3361

3362 LogicalResult PrivateClauseOp::verifyRegions() {

3363 Type argType = getArgType();

3364 auto verifyTerminator = [&](Operation *terminator,

3365 bool yieldsValue) -> LogicalResult {

3367 return success();

3368

3369 if (!llvm::isa(terminator))

3371 << "expected exit block terminator to be an `omp.yield` op.";

3372

3373 YieldOp yieldOp = llvm::cast(terminator);

3374 TypeRange yieldedTypes = yieldOp.getResults().getTypes();

3375

3376 if (!yieldsValue) {

3377 if (yieldedTypes.empty())

3378 return success();

3379

3381 << "Did not expect any values to be yielded.";

3382 }

3383

3384 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)

3385 return success();

3386

3388 << "Invalid yielded value. Expected type: " << argType

3389 << ", got: ";

3390

3391 if (yieldedTypes.empty())

3392 error << "None";

3393 else

3394 error << yieldedTypes;

3395

3396 return error;

3397 };

3398

3400 StringRef regionName,

3401 bool yieldsValue) -> LogicalResult {

3402 assert(!region.empty());

3403

3406 << "`" << regionName << "`: "

3407 << "expected " << expectedNumArgs

3408 << " region arguments, got: " << region.getNumArguments();

3409

3410 for (Block &block : region) {

3411

3412 if (!block.mightHaveTerminator())

3413 continue;

3414

3415 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))

3416 return failure();

3417 }

3418

3419 return success();

3420 };

3421

3422

3423 for (Region *region : getRegions())

3424 for (Type ty : region->getArgumentTypes())

3425 if (ty != argType)

3426 return emitError() << "Region argument type mismatch: got " << ty

3427 << " expected " << argType << ".";

3428

3430 if (!initRegion.empty() &&

3431 failed(verifyRegion(getInitRegion(), 2, "init",

3432 true)))

3433 return failure();

3434

3435 DataSharingClauseType dsType = getDataSharingType();

3436

3437 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())

3438 return emitError("`private` clauses do not require a `copy` region.");

3439

3440 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())

3442 "`firstprivate` clauses require at least a `copy` region.");

3443

3444 if (dsType == DataSharingClauseType::FirstPrivate &&

3445 failed(verifyRegion(getCopyRegion(), 2, "copy",

3446 true)))

3447 return failure();

3448

3449 if (!getDeallocRegion().empty() &&

3450 failed(verifyRegion(getDeallocRegion(), 1, "dealloc",

3451 false)))

3452 return failure();

3453

3454 return success();

3455 }

3456

3457

3458

3459

3460

3462 const MaskedOperands &clauses) {

3463 MaskedOp::build(builder, state, clauses.filteredThreadId);

3464 }

3465

3466

3467

3468

3469

3471 const ScanOperands &clauses) {

3472 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);

3473 }

3474

3476 if (hasExclusiveVars() == hasInclusiveVars())

3478 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");

3479 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType()) {

3480 if (parentWsLoopOp.getReductionModAttr() &&

3481 parentWsLoopOp.getReductionModAttr().getValue() ==

3482 ReductionModifier::inscan)

3483 return success();

3484 }

3485 if (SimdOp parentSimdOp = (*this)->getParentOfType()) {

3486 if (parentSimdOp.getReductionModAttr() &&

3487 parentSimdOp.getReductionModAttr().getValue() ==

3488 ReductionModifier::inscan)

3489 return success();

3490 }

3491 return emitError("SCAN directive needs to be enclosed within a parent "

3492 "worksharing loop construct or SIMD construct with INSCAN "

3493 "reduction modifier");

3494 }

3495

3496 #define GET_ATTRDEF_CLASSES

3497 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"

3498

3499 #define GET_OP_CLASSES

3500 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"

3501

3502 #define GET_TYPEDEF_CLASSES

3503 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"

static std::optional< int64_t > getUpperBound(Value iv)

Gets the constant upper bound on an affine.for iv.

static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)

static void visit(Operation *op, DenseSet< Operation * > &visited)

Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.

static MLIRContext * getContext(OpFoldResult val)

void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)

static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)

static constexpr StringRef getPrivateNeedsBarrierSpelling()

static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)

static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)

static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)

static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)

static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)

Print allocate clause.

static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)

static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)

static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)

static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)

static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)

static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)

static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)

Parses a Synchronization Hint clause.

uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)

static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)

linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...

static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)

Print schedule clause.

static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)

Print Copyprivate clause.

static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)

static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)

Print Aligned Clause.

static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)

Verifies a synchronization hint clause.

static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)

static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)

Print Linear Clause.

static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)

static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)

static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)

Prints a Synchronization Hint clause.

static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))

static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)

Print Depend clause.

static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)

static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)

Verifies CopyPrivate Clause.

static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)

static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)

static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)

static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)

static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)

static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)

static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)

static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)

Parses a map_entries map type from a string format back into its numeric value.

static LogicalResult verifyOrderedParent(Operation &op)

static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)

static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)

static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)

static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)

static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)

static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)

schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...

static Operation * getParentInSameDialect(Operation *thisOp)

static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)

Parse an allocate clause with allocators and a list of operands with types.

static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)

static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)

static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)

Verifies Reduction Clause.

static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)

static bool opInGlobalImplicitParallelRegion(Operation *op)

static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)

static LogicalResult verifyPrivateVarList(OpType &op)

static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)

Prints a map_entries map type from its numeric value out into its string format.

static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)

static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)

static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)

aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...

static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)

static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)

static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)

static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)

static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)

static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)

copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...

static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)

Verifies Depend clause.

static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)

depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...

static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)

static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)

Walks the given type hierarchy with the given indices, potentially down to component granularity,...

This base class exposes generic asm parser hooks, usable across the various derived parsers.

virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0

Parse a list of comma-separated items with an optional delimiter.

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0

Parse a named dictionary into 'result' if it is present.

virtual ParseResult parseOptionalEqual()=0

Parse a = token if present.

virtual ParseResult parseOptionalKeyword(StringRef keyword)=0

Parse the given keyword if present.

MLIRContext * getContext() const

virtual ParseResult parseRParen()=0

Parse a ) token.

virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0

Emit a diagnostic at the specified location and return failure.

virtual ParseResult parseOptionalColon()=0

Parse a : token if present.

virtual ParseResult parseLSquare()=0

Parse a [ token.

virtual ParseResult parseRSquare()=0

Parse a ] token.

ParseResult parseInteger(IntT &result)

Parse an integer value from the stream.

virtual ParseResult parseEqual()=0

Parse a = token.

virtual ParseResult parseColonType(Type &result)=0

Parse a colon followed by a type.

virtual SMLoc getCurrentLocation()=0

Get the location of the next token and store it into the argument.

virtual ParseResult parseOptionalComma()=0

Parse a , token if present.

virtual ParseResult parseColon()=0

Parse a : token.

virtual SMLoc getNameLoc() const =0

Return the location of the original name token.

virtual ParseResult parseArrow()=0

Parse a '->' token.

virtual ParseResult parseLParen()=0

Parse a ( token.

virtual ParseResult parseComma()=0

Parse a , token.

ParseResult parseKeyword(StringRef keyword)

Parse a given keyword.

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

ValueTypeRange< BlockArgListType > getArgumentTypes()

Return a range containing the types of the arguments for this block.

BlockArgument getArgument(unsigned i)

unsigned getNumArguments()

SuccessorRange getSuccessors()

BlockArgListType getArguments()

IntegerAttr getIntegerAttr(Type type, int64_t value)

IntegerType getIntegerType(unsigned width)

MLIRContext * getContext() const

Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...

A class for computing basic dominance information.

bool dominates(Operation *a, Operation *b) const

Return true if operation A dominates operation B, i.e.

MLIRContext is the top-level object for a collection of MLIR operations.

Dialect * getLoadedDialect(StringRef name)

Get a registered IR dialect with the given namespace.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0

Parses a region.

virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0

Parse zero or more arguments with a specified surrounding delimiter.

ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)

Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...

virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single SSA value operand name along with a result number if allowResultNumber is true.

virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0

Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0

Prints a region.

This class helps build Operations.

This class provides the API for ops that are known to be terminators.

This class indicates that the regions associated with this op don't have terminators.

This class implements the operand iterators for the Operation class.

Operation is the basic unit of execution within MLIR.

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

Dialect * getDialect()

Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...

std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)

Walk the operation by calling the callback for each nested operation (including this one),...

MLIRContext * getContext()

Return the context this operation is associated with.

unsigned getNumRegions()

Returns the number of regions held by this operation.

Location getLoc()

The source location the operation was defined or derived from.

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

Block * getBlock()

Returns the operation block that contains this operation.

OpTy getParentOfType()

Return the closest surrounding parent operation that is of type 'OpTy'.

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

user_range getUsers()

Returns a range of all users.

Region * getParentRegion()

Returns the region to which the instruction belongs.

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class contains a list of basic blocks and a link to the parent operation it is attached to.

iterator_range< OpIterator > getOps()

BlockArgListType getArguments()

OpIterator op_begin()

Return iterators that walk the operations nested directly within this region.

unsigned getNumArguments()

Location getLoc()

Return a location for this region.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class represents a specific instance of an effect.

Resource * getResource() const

Return the resource that the effect applies to.

EffectT * getEffect() const

Return the effect being applied.

This class represents a collection of SymbolTables.

virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)

Returns the operation registered with the given symbol name within the closest parent operation of,...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

type_range getType() const

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...

static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)

Builder from ArrayRef.

bool isReachableFromEntry(Block *a) const

Return true if the specified block is reachable from the entry block of its region.

Runtime

Potential runtimes for AMD GPU kernels.

TargetEnterDataOperands TargetEnterExitUpdateDataOperands

omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...

QueryRef parse(llvm::StringRef line, const QuerySession &qs)

Include the generated interface declarations.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

This is the representation of an operand reference.

This class provides APIs and verifiers for ops with regions having a single block.

This represents an operation in an abstracted form, suitable for use with the builder APIs.

SmallVector< Value, 4 > operands

void addAttribute(StringRef name, Attribute attr)

Add an attribute with the specified name.

Region * addRegion()

Create a region that should be attached to the operation.