LLVM: lib/Target/X86/X86PartialReduction.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

22#include "llvm/IR/IntrinsicsX86.h"

26

27using namespace llvm;

28

29#define DEBUG_TYPE "x86-partial-reduction"

30

31namespace {

32

33class X86PartialReduction : public FunctionPass {

36

37public:

38 static char ID;

39

41

43

46 }

47

49 return "X86 Partial Reduction";

50 }

51

52private:

53 bool tryMAddReplacement(Instruction *Op, bool ReduceInOneBB);

55};

56}

57

59 return new X86PartialReduction();

60}

61

62char X86PartialReduction::ID = 0;

63

65 "X86 Partial Reduction", false, false)

66

67

70 if (!ST->hasVNNI() && !ST->hasAVXVNNI())

71 return false;

72

75

76 if (isa(LHS))

78

80 if (auto *Cast = dyn_cast(Op)) {

82 (Cast->getOpcode() == Instruction::SExt ||

83 Cast->getOpcode() == Instruction::ZExt) &&

84 Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 8)

85 return true;

86 }

87

88 return isa(Op);

89 };

90

91

92

93

97 return true;

98

99 return false;

100}

101

102bool X86PartialReduction::tryMAddReplacement(Instruction *Op,

103 bool ReduceInOneBB) {

104 if (ST->hasSSE2())

105 return false;

106

107

108 if (cast(Op->getType())->getNumElements() < 8)

109 return false;

110

111

112 if (!cast(Op->getType())->getElementType()->isIntegerTy(32))

113 return false;

114

115 auto *Mul = dyn_cast(Op);

117 return false;

118

121

122

123

124

125

126 if (ReduceInOneBB && matchVPDPBUSDPattern(ST, Mul, DL))

127 return false;

128

129

130

131

132

133 if (ST->hasSSE41()) {

136 return false;

137 } else {

139 return false;

141 return false;

142 }

143 }

144

145 auto CanShrinkOp = [&](Value *Op) {

147 if (auto *Cast = dyn_cast(Op)) {

149 (Cast->getOpcode() == Instruction::SExt ||

150 Cast->getOpcode() == Instruction::ZExt) &&

151 Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 16)

152 return true;

153 }

154

155 return isa(Op);

156 };

157

158

159

162 return true;

163

164

165

166 if (auto *BO = dyn_cast(Op)) {

171 return true;

172 }

173

174 return false;

175 };

176

177

178 if (!CanShrinkOp(LHS) && !CanShrinkOp(RHS))

179 return false;

180

182

183 auto *MulTy = cast(Op->getType());

184 unsigned NumElts = MulTy->getNumElements();

185

186

187

188

191 for (int i = 0, e = NumElts / 2; i != e; ++i) {

192 EvenMask[i] = i * 2;

193 OddMask[i] = i * 2 + 1;

194 }

195

196

198 Value *EvenElts = Builder.CreateShuffleVector(NewMul, NewMul, EvenMask);

199 Value *OddElts = Builder.CreateShuffleVector(NewMul, NewMul, OddMask);

200 Value *MAdd = Builder.CreateAdd(EvenElts, OddElts);

201

202

204 std::iota(ConcatMask.begin(), ConcatMask.end(), 0);

206 Value *Concat = Builder.CreateShuffleVector(MAdd, Zero, ConcatMask);

207

210

211 return true;

212}

213

214bool X86PartialReduction::trySADReplacement(Instruction *Op) {

215 if (ST->hasSSE2())

216 return false;

217

218

219

220 if (!cast(Op->getType())->getElementType()->isIntegerTy(32))

221 return false;

222

224 if (match(Op, PatternMatch::m_IntrinsicIntrinsic::abs())) {

225 LHS = Op->getOperand(0);

226 } else {

227

228 auto *SI = dyn_cast(Op);

229 if (!SI)

230 return false;

231

233

235 if (SPR.Flavor != SPF_ABS)

236 return false;

237 }

238

239

240 auto *Sub = dyn_cast(LHS);

241 if (!Sub || Sub->getOpcode() != Instruction::Sub)

242 return false;

243

244

245 auto getZeroExtendedVal = [](Value *Op) -> Value * {

246 if (auto *ZExt = dyn_cast(Op))

247 if (cast(ZExt->getOperand(0)->getType())

248 ->getElementType()

249 ->isIntegerTy(8))

250 return ZExt->getOperand(0);

251

252 return nullptr;

253 };

254

255

256 Value *Op0 = getZeroExtendedVal(Sub->getOperand(0));

257 Value *Op1 = getZeroExtendedVal(Sub->getOperand(1));

258 if (!Op0 || !Op1)

259 return false;

260

262

263 auto *OpTy = cast(Op->getType());

264 unsigned NumElts = OpTy->getNumElements();

265

266 unsigned IntrinsicNumElts;

268 if (ST->hasBWI() && NumElts >= 64) {

269 IID = Intrinsic::x86_avx512_psad_bw_512;

270 IntrinsicNumElts = 64;

271 } else if (ST->hasAVX2() && NumElts >= 32) {

272 IID = Intrinsic::x86_avx2_psad_bw;

273 IntrinsicNumElts = 32;

274 } else {

275 IID = Intrinsic::x86_sse2_psad_bw;

276 IntrinsicNumElts = 16;

277 }

278

280

281 if (NumElts < 16) {

282

284 for (unsigned i = 0; i != NumElts; ++i)

285 ConcatMask[i] = i;

286 for (unsigned i = NumElts; i != 16; ++i)

287 ConcatMask[i] = (i % NumElts) + NumElts;

288

290 Op0 = Builder.CreateShuffleVector(Op0, Zero, ConcatMask);

291 Op1 = Builder.CreateShuffleVector(Op1, Zero, ConcatMask);

292 NumElts = 16;

293 }

294

295

296 auto *I32Ty =

298

299 assert(NumElts % IntrinsicNumElts == 0 && "Unexpected number of elements!");

300 unsigned NumSplits = NumElts / IntrinsicNumElts;

301

302

304 for (unsigned i = 0; i != NumSplits; ++i) {

306 std::iota(ExtractMask.begin(), ExtractMask.end(), i * IntrinsicNumElts);

307 Value *ExtractOp0 = Builder.CreateShuffleVector(Op0, Op0, ExtractMask);

308 Value *ExtractOp1 = Builder.CreateShuffleVector(Op1, Op0, ExtractMask);

309 Ops[i] = Builder.CreateCall(PSADBWFn, {ExtractOp0, ExtractOp1});

310 Ops[i] = Builder.CreateBitCast(Ops[i], I32Ty);

311 }

312

314 unsigned Stages = Log2_32(NumSplits);

315 for (unsigned s = Stages; s > 0; --s) {

316 unsigned NumConcatElts =

317 cast(Ops[0]->getType())->getNumElements() * 2;

318 for (unsigned i = 0; i != 1U << (s - 1); ++i) {

320 std::iota(ConcatMask.begin(), ConcatMask.end(), 0);

321 Ops[i] = Builder.CreateShuffleVector(Ops[i*2], Ops[i*2+1], ConcatMask);

322 }

323 }

324

325

326

327 NumElts = cast(OpTy)->getNumElements();

328 if (NumElts == 2) {

329

330 Ops[0] = Builder.CreateShuffleVector(Ops[0], Ops[0], ArrayRef{0, 1});

331 } else if (NumElts >= 8) {

333 unsigned SubElts =

334 cast(Ops[0]->getType())->getNumElements();

335 for (unsigned i = 0; i != SubElts; ++i)

336 ConcatMask[i] = i;

337 for (unsigned i = SubElts; i != NumElts; ++i)

338 ConcatMask[i] = (i % SubElts) + SubElts;

339

341 Ops[0] = Builder.CreateShuffleVector(Ops[0], Zero, ConcatMask);

342 }

343

344 Op->replaceAllUsesWith(Ops[0]);

345 Op->eraseFromParent();

346

347 return true;

348}

349

350

351

353 bool &ReduceInOneBB) {

354 ReduceInOneBB = true;

355

356 auto *Index = dyn_cast(EE.getIndexOperand());

357 if (!Index || !Index->isNullValue())

358 return nullptr;

359

360 const auto *BO = dyn_cast(EE.getVectorOperand());

361 if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse())

362 return nullptr;

363 if (EE.getParent() != BO->getParent())

364 ReduceInOneBB = false;

365

366 unsigned NumElems = cast(BO->getType())->getNumElements();

367

369 return nullptr;

370

372 unsigned Stages = Log2_32(NumElems);

373 for (unsigned i = 0; i != Stages; ++i) {

374 const auto *BO = dyn_cast(Op);

375 if (!BO || BO->getOpcode() != Instruction::Add)

376 return nullptr;

377 if (EE.getParent() != BO->getParent())

378 ReduceInOneBB = false;

379

380

381

382 if (i != 0 && !BO->hasNUses(2))

383 return nullptr;

384

385 Value *LHS = BO->getOperand(0);

386 Value *RHS = BO->getOperand(1);

387

388 auto *Shuffle = dyn_cast(LHS);

389 if (Shuffle) {

391 } else {

392 Shuffle = dyn_cast(RHS);

394 }

395

396

397

398 if (!Shuffle || Shuffle->getOperand(0) != Op)

399 return nullptr;

400

401

402 unsigned MaskEnd = 1 << i;

403 for (unsigned Index = 0; Index < MaskEnd; ++Index)

404 if (Shuffle->getMaskValue(Index) != (int)(MaskEnd + Index))

405 return nullptr;

406 }

407

408 return const_cast<Value *>(Op);

409}

410

411

412

413

415

416 if (!Phi->hasOneUse())

417 return false;

418

419 Instruction *U = cast(*Phi->user_begin());

420 if (U == BO)

421 return true;

422

423 while (U->hasOneUse() && U->getOpcode() == BO->getOpcode())

424 U = cast(*U->user_begin());

425

426 return U == BO;

427}

428

429

430

431

432

437

438 while (!Worklist.empty()) {

440 if (!Visited.insert(V).second)

441 continue;

442

443 if (auto *PN = dyn_cast(V)) {

444

445

446 if (!PN->hasNUses(PN == Root ? 2 : 1))

447 break;

448

449

450 append_range(Worklist, PN->incoming_values());

451

452 continue;

453 }

454

455 if (auto *BO = dyn_cast(V)) {

456 if (BO->getOpcode() == Instruction::Add) {

457

458 if (BO->hasNUses(BO == Root ? 2 : 1)) {

460 continue;

461 }

462

463

464

465 if (BO->hasNUses(BO == Root ? 3 : 2)) {

467 for (auto *U : BO->users())

468 if (auto *P = dyn_cast(U))

469 if (!Visited.count(P))

470 PN = P;

471

472

473

475 continue;

476

477

479 continue;

480

481

483 }

484 }

485 }

486

487

488 if (auto *I = dyn_cast(V)) {

489 if (!V->hasNUses(I == Root ? 2 : 1))

490 continue;

491

492

494 }

495 }

496}

497

498bool X86PartialReduction::runOnFunction(Function &F) {

499 if (skipFunction(F))

500 return false;

501

502 auto *TPC = getAnalysisIfAvailable();

503 if (!TPC)

504 return false;

505

507 ST = TM.getSubtargetImpl(F);

508

509 DL = &F.getDataLayout();

510

511 bool MadeChange = false;

512 for (auto &BB : F) {

513 for (auto &I : BB) {

514 auto *EE = dyn_cast(&I);

515 if (!EE)

516 continue;

517

518 bool ReduceInOneBB;

519

520

522 if (!Root)

523 continue;

524

527

529 if (tryMAddReplacement(I, ReduceInOneBB)) {

530 MadeChange = true;

531 continue;

532 }

533

534

535

536 if (I != Root && trySADReplacement(I))

537 MadeChange = true;

538 }

539 }

540 }

541

542 return MadeChange;

543}

MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL

This file contains the declarations for the subclasses of Constant, which represent the different fla...

#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)

assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())

static SymbolRef::Type getType(const Symbol *Sym)

Target-Independent Code Generator Pass Configuration Options pass.

static constexpr int Concat[]

static bool isReachableFromPHI(PHINode *Phi, BinaryOperator *BO)

BinaryOperator const DataLayout * DL

if(isa< SExtInst >(LHS)) std auto IsFreeTruncation

static Value * matchAddReduction(const ExtractElementInst &EE, bool &ReduceInOneBB)

static void collectLeaves(Value *Root, SmallVectorImpl< Instruction * > &Leaves)

Represent the analysis usage information of a pass.

void setPreservesCFG()

This function should be called by the pass, iff they do not:

ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...

BinaryOps getOpcode() const

static Constant * getNullValue(Type *Ty)

Constructor to create a '0' constant of arbitrary type.

This class represents an Operation in the Expression.

A parsed version of the target data layout string in and methods for querying it.

static FixedVectorType * get(Type *ElementType, unsigned NumElts)

FunctionPass class - This class is used to implement most global optimizations.

virtual bool runOnFunction(Function &F)=0

runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.

This provides a uniform API for creating instructions and inserting them into a basic block: either a...

InstListType::iterator eraseFromParent()

This method unlinks 'this' from the containing basic block and deletes it.

unsigned getNumIncomingValues() const

Return the number of incoming edges.

virtual void getAnalysisUsage(AnalysisUsage &) const

getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...

virtual StringRef getPassName() const

getPassName - Return a nice clean name for a pass.

size_type count(ConstPtrType Ptr) const

count - Return 1 if the specified pointer is in the set, 0 otherwise.

std::pair< iterator, bool > insert(PtrType Ptr)

Inserts Ptr if and only if there is no element in the container equal to Ptr.

SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.

This class consists of common code factored out of the SmallVector class to reduce code duplication b...

void push_back(const T &Elt)

This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.

StringRef - Represent a constant reference to a string, i.e.

Value * getOperand(unsigned i) const

LLVM Value Representation.

Type * getType() const

All values are typed, get the type of this value.

bool hasOneUse() const

Return true if there is exactly one use of this value.

void replaceAllUsesWith(Value *V)

Change all uses of this to point to a new Value.

bool hasNUses(unsigned N) const

Return true if this Value has exactly N uses.

const ParentTy * getParent() const

unsigned ID

LLVM IR allows to use arbitrary numbers as calling convention identifiers.

Function * getOrInsertDeclaration(Module *M, ID id, ArrayRef< Type * > Tys={})

Look up the Function declaration of the intrinsic id in the Module M.

bool match(Val *V, const Pattern &P)

This is an optimization pass for GlobalISel generic memory operations.

void append_range(Container &C, Range &&R)

Wrapper function to append range R to container C.

unsigned Log2_32(uint32_t Value)

Return the floor log base 2 of the specified value, -1 if the value is zero.

@ SPF_ABS

Floating point maxnum.

constexpr bool isPowerOf2_32(uint32_t Value)

Return true if the argument is a power of two > 0.

SelectPatternResult matchSelectPattern(Value *V, Value *&LHS, Value *&RHS, Instruction::CastOps *CastOp=nullptr, unsigned Depth=0)

Pattern match integer [SU]MIN, [SU]MAX and ABS idioms, returning the kind and providing the out param...

FunctionPass * createX86PartialReductionPass()

This pass optimizes arithmetic based on knowledge that is only used by a reduction sequence and is th...

void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)

Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...

DWARFExpression::Operation Op

unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)

Return the number of times the sign bit of the register is replicated into the other bits.

unsigned ComputeMaxSignificantBits(const Value *Op, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr)

Get the upper bound on bit size for this Value Op as a signed integer.

void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)

Implement std::swap in terms of BitVector swap.