MLIR: lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_

10 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_

11

12 #include

13

15

21

22 namespace mlir {

23 namespace sparse_tensor {

24

25

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

56 public:

57

58

61

62

63

66

67

68

69

70

71

72

73

74

75

76

79

81

82

83

84

85

86

87

88 void

90 bool hasOutput = false, bool isSparseOut = false,

93

95 ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false,

96 bool isSparseOut = false, unsigned numLoops = 0,

99

100

101

105

106

107

108

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

127

128

130

131

132

135

136

137

138

139

140

141

142

143

144

145

149 bool isParallel = false, bool needsUniv = false);

150

152 I64BitSet caseBit, unsigned caseIdx,

154

155

156

159

160

162 return llvm::map_range(loopStack, [](const LoopInfo &li) { return li.iv; });

163 }

164

165

166

169 }

170

171

173

174

179 std::advance(it, n);

180 return *it;

181 }

182

183

184

186

187

189

191 }

192

193

195

196

198 assert(hasOutput);

200 }

201

202

205 }

206

207

210 return std::make_pair(tidLvl % nt, tidLvl / nt);

211 }

212

213

214 template

216 using EltTy = decltype(*c.begin());

217 static_assert(std::is_same_v<llvm::remove_cvref_t, TensorLevel>,

218 "Must be unpacking a TensorLevel range");

219 return llvm::map_range(std::forward(c), [this](EltTy tl) {

221 });

222 }

223

224

225

226

228

230 return {spIterVals[tid].back()};

231

232

233 SmallVector batchCrds = iters[tid].back().back()->getBatchCrds();

234 Value lastLvlPos = iters[tid].back().back()->getCurPosition().front();

235 batchCrds.push_back(lastLvlPos);

236 return batchCrds;

237 };

239 return getCurIterator(tid, lvl).getCrd();

240 };

241 const std::vector &getValBuffer() const { return valBuffer; };

242

244 return llvm::StringLiteral("Emitted from");

245 }

246

247 private:

248

249

250

251

252

253

254 struct LoopInfo final {

256 Value iv, StringAttr loopTag)

257 : tidLvls(tidLvls), loop(loop), userCodeBlock(userBlock), iv(iv) {

258

259 if (loopTag)

261 }

262

263

264

265

267 Operation *loop;

268 Block *const userCodeBlock;

269 Value iv;

270 };

271

272 void categorizeIterators(ArrayRef tidLvls,

273 SmallVectorImpl<SparseIterator *> &raIters,

274 SmallVectorImpl<SparseIterator *> &spIters);

275

276

277

278

279 using LoopBodyBuilder = llvm::function_ref<void(OpBuilder &, Location, Value,

280 MutableArrayRef)>;

281

282

283 bool shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters);

284

285

286

287

288

289 Value genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,

291

293

294 bool isOutputTensor(TensorId tid) const {

296 }

297

298 bool isSparseOutput(TensorId tid) const {

299 return isOutputTensor(tid) && isSparseOut;

300 }

301

302 bool isValidLevel(TensorId tid, Level lvl) const {

303 return tid < lvls.size() && lvl < lvls[tid].size();

304 }

305

306

307

308 void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,

310

311

312

313

314

315

316

317 std::pair<Operation *, Value>

318 emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc,

319 SparseIterator &iter, MutableArrayRef reduc,

320 bool isParallel);

321

322

323

324

325

326

327

328 std::pair<Operation *, Value>

329 emitWhileLoopOverTensorsAtLvls(OpBuilder &builder, Location loc,

330 ArrayRef<SparseIterator *> iters,

331 MutableArrayRef reduc, bool needsUniv);

332

333

334

335

336

337

338

339

340

341

342

343

344

345

346

347

348

349

350

351

352

353

354

355

356

357 void exitForLoop(RewriterBase &rewriter, Location loc,

358 MutableArrayRef reduc);

359

360

361 void exitWhileLoop(OpBuilder &builder, Location loc,

362 MutableArrayRef reduc);

363

364

365

366

367

368 void initSubSectIterator(OpBuilder &builder, Location loc);

369

370

371 unsigned redDepOnLevel(TensorId tid, Level lvl) const {

372 return levelReducedDep[tid][lvl];

373 };

374

375 SparseIterator &getCurIterator(TensorId tid, Level lvl) const {

376 if (dependentLvlMap[tid][lvl].empty())

377 return *iters[tid][lvl].back();

378

379 assert(redDepOnLevel(tid, lvl) >= 1);

380 return *iters[tid][lvl][redDepOnLevel(tid, lvl) - 1];

381 }

382

383 std::unique_ptr

384 makeLevelIterator(OpBuilder &builder, Location loc, TensorId tid, Level l);

385

386

387

388

389 StringAttr loopTag;

390

391

392 bool hasOutput;

393 bool isSparseOut;

395

396

397

398

399

400

401 std::vector tensors;

402 std::vector loopHighs;

403 std::vector<std::vector<std::unique_ptr>> lvls;

404 std::vector<std::vector<std::vector<std::unique_ptr>>> iters;

405 std::vector valBuffer;

406

407

408

409 std::vector<std::vector<std::vector<std::pair<LoopId, unsigned>>>>

410 dependentLvlMap;

411

412

413

414 std::vector<std::vector<std::vector<std::pair<Value, unsigned>>>> sliceMeta;

415

416

417 std::vector<std::vector> levelReducedDep;

418

419

420

421

422

423

424

425 std::vector loopStack;

426

427

428

429 std::vector<std::pair<Value, std::vector>> loopSeqStack;

430

431

432

433

434

435

436 std::vector<std::vector> spIterVals;

437 };

438

439

440

441

442

443

444 std::pair<Operation *, Value> genCoIteration(OpBuilder &builder, Location loc,

445 ArrayRef<SparseIterator *> iters,

446 MutableArrayRef reduc,

447 Value uniIdx,

448 bool userReducFirst = false);

449

450 }

451 }

452

453 #endif

Base type for affine expression.

Block represents an ordered list of Operations.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class helps build Operations.

Operation is the basic unit of execution within MLIR.

void setAttr(StringAttr name, Attribute value)

If the an attribute exists with the specified name, change it to the new value.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....

void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef< Value > reduc={})

Generates code to exit the current loop (e.g., generates yields, forwards loop induction variables,...

constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()

void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)

Emits the address for a dense level based on the value evaluated by the provided affine expression.

const std::vector< Value > & getValBuffer() const

void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)

Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...

Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)

Generates code to compute an affine expression whose variables are LoopIds (i.e., cast...

Region * enterCurrentCoIterationCase(OpBuilder &builder, Location loc, I64BitSet caseBit, unsigned caseIdx, MutableArrayRef< Value > reduc)

Operation * enterCoIterationOverTensorsAtLvls(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls, unsigned numCases, MutableArrayRef< Value > reduc={}, bool isParallel=false, bool needsUniv=false)

Emits a co-iteration loop over a set of tensors.

TensorId getOutTensorId() const

Gets the TensorId for output tensor.

TensorLevel makeTensorLevel(TensorId t, Level l) const

Compresses a TensorId and Level into a TensorLevel.

unsigned getNumManifestTensors() const

Gets the total number of manifest tensors (excluding the synthetic tensor).

void initialize(ValueRange tensors, StringAttr loopTag=nullptr, bool hasOutput=false, bool isSparseOut=false, unsigned numLoops=0, DependentLvlGetter getter=nullptr, SparseEmitStrategy emitStrategy=SparseEmitStrategy::kFunctional)

Takes an array of input tensors, which the generated loops will iterate over.

Value getLoopIV(LoopId n) const

Gets loop induction variable for the given loop.

std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tidLvl) const

De-compresses a TensorLevel back to a pair of TensorId and Level.

auto unpackTensorLevelRange(ContainerTy &&c) const

Converts a range of TensorLevel to a range of std::pair<TensorId, Level>

SmallVector< Value > getValPosits(TensorId tid) const

Getters.

unsigned getNumTensors() const

Gets the total number of tensors that loopEmitter is operating on.

SmallVector< Value > getLoopIVs() const

Fills the out-parameter with the loop induction variables for all loops in the current loop-stack.

auto getLoopIVsRange() const

Get the range of values for all induction variables.

void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)

Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.

LoopId getCurrentDepth() const

Gets the current depth of the loop-stack.

void exitCurrentLoopSeq(OpBuilder &builder, Location loc)

Exits the current loop sequence, this will reset universal index to 0.

TensorId getSynTensorId() const

Gets the TensorId for synthetic tensor.

Value getCoord(TensorId tid, Level lvl) const

uint64_t Level

The type of level identifiers and level-ranks.

unsigned LoopId

Loop identifiers.

std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)

unsigned TensorId

Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...

Include the generated interface declarations.

SparseEmitStrategy

Defines a scope for reinterpret map pass.