MLIR: lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13 #include

14 #include

15

17

24 #include "llvm/ADT/STLExtras.h"

25 #include "llvm/ADT/SmallVectorExtras.h"

26

27 using namespace mlir;

28

29

30

31

32

33

34

36 if (!attr)

37 return std::nullopt;

38

39 if (auto boolAttr = llvm::dyn_cast(attr))

40 return boolAttr.getValue();

41 if (auto splatAttr = llvm::dyn_cast(attr))

42 if (splatAttr.getElementType().isInteger(1))

43 return splatAttr.getSplatValue<bool>();

44 return std::nullopt;

45 }

46

47

48

51

52 if (!composite)

53 return {};

54

55 if (indices.empty())

56 return composite;

57

58 if (auto vector = llvm::dyn_cast(composite)) {

59 assert(indices.size() == 1 && "must have exactly one index for a vector");

60 return vector.getValues<Attribute>()[indices[0]];

61 }

62

63 if (auto array = llvm::dyn_cast(composite)) {

64 assert(!indices.empty() && "must have at least one index for an array");

66 indices.drop_front());

67 }

68

69 return {};

70 }

71

73 bool div0 = b.isZero();

74 bool overflow = a.isMinSignedValue() && b.isAllOnes();

75

76 return div0 || overflow;

77 }

78

79

80

81

82

83 namespace {

84 #include "SPIRVCanonicalization.inc"

85 }

86

87

88

89

90

91 namespace {

92

93

94

95 struct CombineChainedAccessChain final

98

99 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,

101 auto parentAccessChainOp =

102 accessChainOp.getBasePtr().getDefiningOpspirv::AccessChainOp();

103

104 if (!parentAccessChainOp) {

105 return failure();

106 }

107

108

110 llvm::append_range(indices, accessChainOp.getIndices());

111

113 accessChainOp, parentAccessChainOp.getBasePtr(), indices);

114

115 return success();

116 }

117 };

118 }

119

120 void spirv::AccessChainOp::getCanonicalizationPatterns(

122 results.add(context);

123 }

124

125

126

127

128

129

130

133

137 Value lhs = op.getOperand1();

138 Value rhs = op.getOperand2();

140

141

143 Value constituents[2] = {rhs, lhs};

144 rewriter.replaceOpWithNewOpspirv::CompositeConstructOp(op, op.getType(),

145 constituents);

146 return success();

147 }

148

149

150

151

152

153

154

155

156

157

158

159

164 return failure();

165

166 auto adds = constFoldBinaryOp(

167 {lhsAttr, rhsAttr},

168 [](const APInt &a, const APInt &b) { return a + b; });

169 if (!adds)

170 return failure();

171

172 auto carrys = constFoldBinaryOp(

173 ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {

175 return a.ult(b) ? (zero + 1) : zero;

176 });

177

178 if (!carrys)

179 return failure();

180

182 rewriter.createspirv::ConstantOp(loc, constituentType, adds);

183

184 Value carrysVal =

185 rewriter.createspirv::ConstantOp(loc, constituentType, carrys);

186

187

188 Value undef = rewriter.createspirv::UndefOp(loc, op.getType());

189

190 Value intermediate =

191 rewriter.createspirv::CompositeInsertOp(loc, addsVal, undef, 0);

192

194 intermediate, 1);

195 return success();

196 }

197 };

198

199 void spirv::IAddCarryOp::getCanonicalizationPatterns(

202 }

203

204

205

206

207

208

209

210 template <typename MulOp, bool IsSigned>

213

217 Value lhs = op.getOperand1();

218 Value rhs = op.getOperand2();

220

221

224 Value constituents[2] = {zero, zero};

225 rewriter.replaceOpWithNewOpspirv::CompositeConstructOp(op, op.getType(),

226 constituents);

227 return success();

228 }

229

230

231

232

233

234

235

236

237

242 return failure();

243

244 auto lowBits = constFoldBinaryOp(

245 {lhsAttr, rhsAttr},

246 [](const APInt &a, const APInt &b) { return a * b; });

247

248 if (!lowBits)

249 return failure();

250

251 auto highBits = constFoldBinaryOp(

252 {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {

253 if (IsSigned) {

254 return llvm::APIntOps::mulhs(a, b);

255 } else {

256 return llvm::APIntOps::mulhu(a, b);

257 }

258 });

259

260 if (!highBits)

261 return failure();

262

263 Value lowBitsVal =

264 rewriter.createspirv::ConstantOp(loc, constituentType, lowBits);

265

266 Value highBitsVal =

267 rewriter.createspirv::ConstantOp(loc, constituentType, highBits);

268

269

270 Value undef = rewriter.createspirv::UndefOp(loc, op.getType());

271

272 Value intermediate =

273 rewriter.createspirv::CompositeInsertOp(loc, lowBitsVal, undef, 0);

274

275 rewriter.replaceOpWithNewOpspirv::CompositeInsertOp(op, highBitsVal,

276 intermediate, 1);

277 return success();

278 }

279 };

280

282 void spirv::SMulExtendedOp::getCanonicalizationPatterns(

285 }

286

289

293 Value lhs = op.getOperand1();

294 Value rhs = op.getOperand2();

296

297

300 Value constituents[2] = {lhs, zero};

301 rewriter.replaceOpWithNewOpspirv::CompositeConstructOp(op, op.getType(),

302 constituents);

303 return success();

304 }

305

306 return failure();

307 }

308 };

309

311 void spirv::UMulExtendedOp::getCanonicalizationPatterns(

314 }

315

316

317

318

319

320

321

322

323

324

325

326

327

328

331

334 auto prevUMod = umodOp.getOperand(0).getDefiningOpspirv::UModOp();

335 if (!prevUMod)

336 return failure();

337

338 TypedAttr prevValue;

339 TypedAttr currValue;

342 return failure();

343

344

345

346 bool isApplicable = false;

347 if (auto prevInt = dyn_cast(prevValue)) {

348 auto currInt = cast(currValue);

349 isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;

350 } else if (auto prevVec = dyn_cast(prevValue)) {

351 auto currVec = cast(currValue);

352 isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues(),

353 currVec.getValues()),

354 [](const auto &pair) {

355 auto &[prev, curr] = pair;

356 return prev.urem(curr) == 0;

357 });

358 }

359

360 if (!isApplicable)

361 return failure();

362

363

364

366 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));

367

368 return success();

369 }

370 };

371

375 }

376

377

378

379

380

381 OpFoldResult spirv::BitcastOp::fold(FoldAdaptor ) {

382 Value curInput = getOperand();

384 return curInput;

385

386

387 if (auto prevCast = curInput.getDefiningOpspirv::BitcastOp()) {

388 Value prevInput = prevCast.getOperand();

390 return prevInput;

391

392 getOperandMutable().assign(prevInput);

393 return getResult();

394 }

395

396

397 return {};

398 }

399

400

401

402

403

404 OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {

405 Value compositeOp = getComposite();

406

407 while (auto insertOp =

408 compositeOp.getDefiningOpspirv::CompositeInsertOp()) {

409 if (getIndices() == insertOp.getIndices())

410 return insertOp.getObject();

411 compositeOp = insertOp.getComposite();

412 }

413

414 if (auto constructOp =

415 compositeOp.getDefiningOpspirv::CompositeConstructOp()) {

416 auto type = llvm::castspirv::CompositeType(constructOp.getType());

418 constructOp.getConstituents().size() == type.getNumElements()) {

419 auto i = llvm::cast(*getIndices().begin());

420 if (i.getValue().getSExtValue() <

421 static_cast<int64_t>(constructOp.getConstituents().size()))

422 return constructOp.getConstituents()[i.getValue().getSExtValue()];

423 }

424 }

425

426 auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {

427 return static_cast<unsigned>(llvm::cast(attr).getInt());

428 });

430 }

431

432

433

434

435

436 OpFoldResult spirv::ConstantOp::fold(FoldAdaptor ) {

437 return getValue();

438 }

439

440

441

442

443

444 OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {

445

447 return getOperand1();

448

449

450

451

452

453

454 return constFoldBinaryOp(

455 adaptor.getOperands(),

456 [](APInt a, const APInt &b) { return std::move(a) + b; });

457 }

458

459

460

461

462

463 OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {

464

466 return getOperand2();

467

469 return getOperand1();

470

471

472

473

474

475

476 return constFoldBinaryOp(

477 adaptor.getOperands(),

478 [](const APInt &a, const APInt &b) { return a * b; });

479 }

480

481

482

483

484

485 OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {

486

487 if (getOperand1() == getOperand2())

489

490

491

492

493

494

495 return constFoldBinaryOp(

496 adaptor.getOperands(),

497 [](APInt a, const APInt &b) { return std::move(a) - b; });

498 }

499

500

501

502

503

504 OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {

505

507 return getOperand1();

508

509

510

511

512

513

514

515

516

517 bool div0OrOverflow = false;

518 auto res = constFoldBinaryOp(

519 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {

520 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {

521 div0OrOverflow = true;

522 return a;

523 }

524 return a.sdiv(b);

525 });

526 return div0OrOverflow ? Attribute() : res;

527 }

528

529

530

531

532

533 OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {

534

537

538

539

540

541

542

543

544

545

546

547

548 bool div0OrOverflow = false;

549 auto res = constFoldBinaryOp(

550 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {

551 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {

552 div0OrOverflow = true;

553 return a;

554 }

555 APInt c = a.abs().urem(b.abs());

556 if (c.isZero())

557 return c;

558 if (b.isNegative()) {

559 APInt zero = APInt::getZero(c.getBitWidth());

560 return a.isNegative() ? (zero - c) : (b + c);

561 }

562 return a.isNegative() ? (b - c) : c;

563 });

564 return div0OrOverflow ? Attribute() : res;

565 }

566

567

568

569

570

571 OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {

572

575

576

577

578

579

580

581

582

583

584

585

586 bool div0OrOverflow = false;

587 auto res = constFoldBinaryOp(

588 adaptor.getOperands(), [&](APInt a, const APInt &b) {

589 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {

590 div0OrOverflow = true;

591 return a;

592 }

593 return a.srem(b);

594 });

595 return div0OrOverflow ? Attribute() : res;

596 }

597

598

599

600

601

602 OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {

603

605 return getOperand1();

606

607

608

609

610

611

612

613 bool div0 = false;

614 auto res = constFoldBinaryOp(

615 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {

616 if (div0 || b.isZero()) {

617 div0 = true;

618 return a;

619 }

620 return a.udiv(b);

621 });

623 }

624

625

626

627

628

629 OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {

630

633

634

635

636

637

638

639

640 bool div0 = false;

641 auto res = constFoldBinaryOp(

642 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {

643 if (div0 || b.isZero()) {

644 div0 = true;

645 return a;

646 }

647 return a.urem(b);

648 });

650 }

651

652

653

654

655

656 OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {

657

658 auto op = getOperand();

659 if (auto negateOp = op.getDefiningOpspirv::SNegateOp())

660 return negateOp->getOperand(0);

661

662

663

664

665 return constFoldUnaryOp(

666 adaptor.getOperands(), [](const APInt &a) {

667 APInt zero = APInt::getZero(a.getBitWidth());

668 return zero - a;

669 });

670 }

671

672

673

674

675

676 OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {

677

678 auto op = getOperand();

679 if (auto notOp = op.getDefiningOpspirv::NotOp())

680 return notOp->getOperand(0);

681

682

683

684

685 return constFoldUnaryOp(adaptor.getOperands(), [&](APInt a) {

686 a.flipAllBits();

687 return a;

688 });

689 }

690

691

692

693

694

695 OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {

696 if (std::optional rhs =

698

699 if (*rhs)

700 return getOperand1();

701

702

703 if (!*rhs)

704 return adaptor.getOperand2();

705 }

706

708 }

709

710

711

712

713

715 spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {

716

717 if (getOperand1() == getOperand2()) {

719 if (isa(getType()))

720 return trueAttr;

721 if (auto vecTy = dyn_cast(getType()))

723 }

724

725 return constFoldBinaryOp(

726 adaptor.getOperands(), [](const APInt &a, const APInt &b) {

727 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);

728 });

729 }

730

731

732

733

734

735 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {

736 if (std::optional rhs =

738

739 if (!rhs.value())

740 return getOperand1();

741 }

742

743

744 if (getOperand1() == getOperand2()) {

746 if (isa(getType()))

747 return falseAttr;

748 if (auto vecTy = dyn_cast(getType()))

750 }

751

752 return constFoldBinaryOp(

753 adaptor.getOperands(), [](const APInt &a, const APInt &b) {

754 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);

755 });

756 }

757

758

759

760

761

762 OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {

763

764 auto op = getOperand();

765 if (auto notOp = op.getDefiningOpspirv::LogicalNotOp())

766 return notOp->getOperand(0);

767

768

769

770

771 return constFoldUnaryOp(adaptor.getOperands(),

772 [](const APInt &a) {

773 APInt zero = APInt::getZero(1);

774 return a == 1 ? zero : (zero + 1);

775 });

776 }

777

778 void spirv::LogicalNotOp::getCanonicalizationPatterns(

780 results

781 .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,

782 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(

783 context);

784 }

785

786

787

788

789

790 OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {

792 if (*rhs) {

793

794 return adaptor.getOperand2();

795 }

796

797 if (!*rhs) {

798

799 return getOperand1();

800 }

801 }

802

804 }

805

806

807

808

809

810 OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {

811

812 Value trueVals = getTrueValue();

813 Value falseVals = getFalseValue();

814 if (trueVals == falseVals)

815 return trueVals;

816

818

819

820

822 return *boolAttr ? trueVals : falseVals;

823

824

825 if (!operands[0] || !operands[1] || !operands[2])

827

828

829

830

831 auto condAttrs = dyn_cast(operands[0]);

832 auto trueAttrs = dyn_cast(operands[1]);

833 auto falseAttrs = dyn_cast(operands[2]);

834 if (!condAttrs || !trueAttrs || !falseAttrs)

836

837 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());

838 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),

839 falseAttrs.getValues<Attribute>());

840 for (auto [result, cond, falseRes] : iters) {

841 if (!cond.getValue())

842 result = falseRes;

843 }

844

845 auto resultType = trueAttrs.getType();

847 }

848

849

850

851

852

853 OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {

854

855 if (getOperand1() == getOperand2()) {

857 if (isa(getType()))

858 return trueAttr;

859 if (auto vecTy = dyn_cast(getType()))

861 }

862

863 return constFoldBinaryOp(

864 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {

865 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);

866 });

867 }

868

869

870

871

872

873 OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {

874

875 if (getOperand1() == getOperand2()) {

877 if (isa(getType()))

878 return falseAttr;

879 if (auto vecTy = dyn_cast(getType()))

881 }

882

883 return constFoldBinaryOp(

884 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {

885 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);

886 });

887 }

888

889

890

891

892

894 spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {

895

896 if (getOperand1() == getOperand2()) {

898 if (isa(getType()))

899 return falseAttr;

900 if (auto vecTy = dyn_cast(getType()))

902 }

903

904 return constFoldBinaryOp(

905 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {

906 return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);

907 });

908 }

909

910

911

912

913

914 OpFoldResult spirv::SGreaterThanEqualOp::fold(

915 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {

916

917 if (getOperand1() == getOperand2()) {

919 if (isa(getType()))

920 return trueAttr;

921 if (auto vecTy = dyn_cast(getType()))

923 }

924

925 return constFoldBinaryOp(

926 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {

927 return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);

928 });

929 }

930

931

932

933

934

936 spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {

937

938 if (getOperand1() == getOperand2()) {

940 if (isa(getType()))

941 return falseAttr;

942 if (auto vecTy = dyn_cast(getType()))

944 }

945

946 return constFoldBinaryOp(

947 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {

948 return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);

949 });

950 }

951

952

953

954

955

956 OpFoldResult spirv::UGreaterThanEqualOp::fold(

957 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {

958

959 if (getOperand1() == getOperand2()) {

961 if (isa(getType()))

962 return trueAttr;

963 if (auto vecTy = dyn_cast(getType()))

965 }

966

967 return constFoldBinaryOp(

968 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {

969 return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);

970 });

971 }

972

973

974

975

976

977 OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {

978

979 if (getOperand1() == getOperand2()) {

981 if (isa(getType()))

982 return falseAttr;

983 if (auto vecTy = dyn_cast(getType()))

985 }

986

987 return constFoldBinaryOp(

988 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {

989 return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);

990 });

991 }

992

993

994

995

996

998 spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {

999

1000 if (getOperand1() == getOperand2()) {

1002 if (isa(getType()))

1003 return trueAttr;

1004 if (auto vecTy = dyn_cast(getType()))

1006 }

1007

1008 return constFoldBinaryOp(

1009 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {

1010 return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);

1011 });

1012 }

1013

1014

1015

1016

1017

1018 OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {

1019

1020 if (getOperand1() == getOperand2()) {

1022 if (isa(getType()))

1023 return falseAttr;

1024 if (auto vecTy = dyn_cast(getType()))

1026 }

1027

1028 return constFoldBinaryOp(

1029 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {

1030 return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);

1031 });

1032 }

1033

1034

1035

1036

1037

1039 spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {

1040

1041 if (getOperand1() == getOperand2()) {

1043 if (isa(getType()))

1044 return trueAttr;

1045 if (auto vecTy = dyn_cast(getType()))

1047 }

1048

1049 return constFoldBinaryOp(

1050 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {

1051 return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);

1052 });

1053 }

1054

1055

1056

1057

1058

1059 OpFoldResult spirv::ShiftLeftLogicalOp::fold(

1060 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {

1061

1063 return getOperand1();

1064 }

1065

1066

1067

1068

1069

1070

1071

1072

1073

1074 bool shiftToLarge = false;

1075 auto res = constFoldBinaryOp(

1076 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {

1077 if (shiftToLarge || b.uge(a.getBitWidth())) {

1078 shiftToLarge = true;

1079 return a;

1080 }

1081 return a << b;

1082 });

1083 return shiftToLarge ? Attribute() : res;

1084 }

1085

1086

1087

1088

1089

1090 OpFoldResult spirv::ShiftRightArithmeticOp::fold(

1091 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {

1092

1094 return getOperand1();

1095 }

1096

1097

1098

1099

1100

1101

1102

1103

1104

1105 bool shiftToLarge = false;

1106 auto res = constFoldBinaryOp(

1107 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {

1108 if (shiftToLarge || b.uge(a.getBitWidth())) {

1109 shiftToLarge = true;

1110 return a;

1111 }

1112 return a.ashr(b);

1113 });

1114 return shiftToLarge ? Attribute() : res;

1115 }

1116

1117

1118

1119

1120

1121 OpFoldResult spirv::ShiftRightLogicalOp::fold(

1122 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {

1123

1125 return getOperand1();

1126 }

1127

1128

1129

1130

1131

1132

1133

1134

1135

1136 bool shiftToLarge = false;

1137 auto res = constFoldBinaryOp(

1138 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {

1139 if (shiftToLarge || b.uge(a.getBitWidth())) {

1140 shiftToLarge = true;

1141 return a;

1142 }

1143 return a.lshr(b);

1144 });

1145 return shiftToLarge ? Attribute() : res;

1146 }

1147

1148

1149

1150

1151

1153 spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {

1154

1155 if (getOperand1() == getOperand2()) {

1156 return getOperand1();

1157 }

1158

1159 APInt rhsMask;

1161

1162 if (rhsMask.isZero())

1163 return getOperand2();

1164

1165

1166 if (rhsMask.isAllOnes())

1167 return getOperand1();

1168

1169

1170 if (auto zext = getOperand1().getDefiningOpspirv::UConvertOp()) {

1171 int valueBits =

1173 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())

1174 return getOperand1();

1175 }

1176 }

1177

1178

1179

1180

1181

1182

1183 return constFoldBinaryOp(

1184 adaptor.getOperands(),

1185 [](const APInt &a, const APInt &b) { return a & b; });

1186 }

1187

1188

1189

1190

1191

1192 OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {

1193

1194 if (getOperand1() == getOperand2()) {

1195 return getOperand1();

1196 }

1197

1198 APInt rhsMask;

1200

1201 if (rhsMask.isZero())

1202 return getOperand1();

1203

1204

1205 if (rhsMask.isAllOnes())

1206 return getOperand2();

1207 }

1208

1209

1210

1211

1212

1213

1214 return constFoldBinaryOp(

1215 adaptor.getOperands(),

1216 [](const APInt &a, const APInt &b) { return a | b; });

1217 }

1218

1219

1220

1221

1222

1224 spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {

1225

1227 return getOperand1();

1228 }

1229

1230

1231 if (getOperand1() == getOperand2())

1233

1234

1235

1236

1237

1238

1239 return constFoldBinaryOp(

1240 adaptor.getOperands(),

1241 [](const APInt &a, const APInt &b) { return a ^ b; });

1242 }

1243

1244

1245

1246

1247

1248 namespace {

1249

1250

1251

1252

1253

1254

1255

1256

1257

1258

1259

1260

1261

1262

1263

1264

1265

1266

1267

1268

1269

1270

1271

1272

1273

1274 struct ConvertSelectionOpToSelect final : OpRewritePatternspirv::SelectionOp {

1276

1277 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,

1279 Operation *op = selectionOp.getOperation();

1281

1282 if (body.empty()) {

1283 return failure();

1284 }

1285

1286

1287

1288 if (llvm::range_size(body) != 4) {

1289 return failure();

1290 }

1291

1292 Block *headerBlock = selectionOp.getHeaderBlock();

1293 if (!onlyContainsBranchConditionalOp(headerBlock)) {

1294 return failure();

1295 }

1296

1297 auto brConditionalOp =

1298 castspirv::BranchConditionalOp(headerBlock->front());

1299

1302 Block *mergeBlock = selectionOp.getMergeBlock();

1303

1304 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))

1305 return failure();

1306

1307 Value trueValue = getSrcValue(trueBlock);

1308 Value falseValue = getSrcValue(falseBlock);

1309 Value ptrValue = getDstPtr(trueBlock);

1310 auto storeOpAttributes =

1311 castspirv::StoreOp(trueBlock->front())->getAttrs();

1312

1313 auto selectOp = rewriter.createspirv::SelectOp(

1314 selectionOp.getLoc(), trueValue.getType(),

1315 brConditionalOp.getCondition(), trueValue, falseValue);

1316 rewriter.createspirv::StoreOp(selectOp.getLoc(), ptrValue,

1317 selectOp.getResult(), storeOpAttributes);

1318

1319

1321 return success();

1322 }

1323

1324 private:

1325

1326

1327

1328

1329

1330

1331 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,

1332 Block *mergeBlock) const;

1333

1334 bool onlyContainsBranchConditionalOp(Block *block) const {

1335 return llvm::hasSingleElement(*block) &&

1336 isaspirv::BranchConditionalOp(block->front());

1337 }

1338

1339 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {

1340 return lhs->getDiscardableAttrDictionary() ==

1341 rhs->getDiscardableAttrDictionary() &&

1342 lhs.getProperties() == rhs.getProperties();

1343 }

1344

1345

1346 Value getSrcValue(Block *block) const {

1347 auto storeOp = castspirv::StoreOp(block->front());

1348 return storeOp.getValue();

1349 }

1350

1351

1352 Value getDstPtr(Block *block) const {

1353 auto storeOp = castspirv::StoreOp(block->front());

1354 return storeOp.getPtr();

1355 }

1356 };

1357

1358 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(

1359 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {

1360

1361 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {

1362 return failure();

1363 }

1364

1365 auto trueBrStoreOp = dyn_castspirv::StoreOp(trueBlock->front());

1366 auto trueBrBranchOp =

1367 dyn_castspirv::BranchOp(*std::next(trueBlock->begin()));

1368 auto falseBrStoreOp = dyn_castspirv::StoreOp(falseBlock->front());

1369 auto falseBrBranchOp =

1370 dyn_castspirv::BranchOp(*std::next(falseBlock->begin()));

1371

1372 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||

1373 !falseBrBranchOp) {

1374 return failure();

1375 }

1376

1377

1378

1379

1380

1381

1382 bool isScalarOrVector =

1383 llvm::castspirv::SPIRVType(trueBrStoreOp.getValue().getType())

1384 .isScalarOrVector();

1385

1386

1387

1388 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||

1389 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {

1390 return failure();

1391 }

1392

1393 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||

1394 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {

1395 return failure();

1396 }

1397

1398 return success();

1399 }

1400 }

1401

1402 void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,

1404 results.add(context);

1405 }

static Value getZero(OpBuilder &b, Location loc, Type elementType)

Get zero value for an element type.

static uint64_t zext(uint32_t arg)

static MLIRContext * getContext(OpFoldResult val)

static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)

MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold

static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)

Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...

static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

Block * getSuccessor(unsigned i)

Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.

This class is a general helper class for creating context-global objects like types,...

TypedAttr getZeroAttr(Type type)

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents a single result from folding an operation.

Operation is the basic unit of execution within MLIR.

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class contains a list of basic blocks and a link to the parent operation it is attached to.

RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)

Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

Operation::operand_range getIndices(Operation *op)

Get the indices that the given load/store operation is operating on.

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

detail::constant_int_predicate_matcher m_Zero()

Matches a constant scalar / vector splat / tensor splat integer zero.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

detail::constant_int_predicate_matcher m_One()

Matches a constant scalar / vector splat / tensor splat integer one.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

LogicalResult matchAndRewrite(spirv::IAddCarryOp op, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...

LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final

Wrapper around the RewritePattern method that passes the derived op type.