MLIR: lib/Dialect/GPU/Utils/DistributionUtils.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

17

18 #include

19

20 using namespace mlir;

22

23 WarpExecuteOnLane0Op

25 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,

27

30 auto newWarpOp = rewriter.create(

31 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),

32 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());

33

34 Region &opBody = warpOp.getBodyRegion();

35 Region &newOpBody = newWarpOp.getBodyRegion();

36 Block &newOpFirstBlock = newOpBody.front();

38 rewriter.eraseBlock(&newOpFirstBlock);

39 assert(newWarpOp.getWarpRegion().hasOneBlock() &&

40 "expected WarpOp with single block");

41

42 auto yield =

43 castgpu::YieldOp(newOpBody.getBlocks().begin()->getTerminator());

44

46 yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); });

47 return newWarpOp;

48 }

49

50 WarpExecuteOnLane0Op

52 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,

56 warpOp.getResultTypes().end());

57 auto yield = castgpu::YieldOp(

58 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());

59 llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),

60 yield.getOperands().end());

61 for (auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) {

62 if (yieldValues.insert(value)) {

63 types.push_back(type);

64 indices.push_back(yieldValues.size() - 1);

65 } else {

66

67 for (auto [idx, yieldOperand] :

69 if (yieldOperand == value) {

70 indices.push_back(idx);

71 break;

72 }

73 }

74 }

75 }

76 yieldValues.insert_range(newYieldedValues);

78 rewriter, warpOp, yieldValues.getArrayRef(), types);

80 newWarpOp.getResults().take_front(warpOp.getNumResults()));

81 return newWarpOp;

82 }

83

85 WarpExecuteOnLane0Op warpOp,

87 auto yield = castgpu::YieldOp(

88 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());

89 for (OpOperand &yieldOperand : yield->getOpOperands()) {

90 Value yieldValues = yieldOperand.get();

92 if (definedOp && fn(definedOp)) {

93 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())

94 return &yieldOperand;

95 }

96 }

97 return nullptr;

98 }

99

104

105

106

107 if (originalShape == distributedShape) {

108 delinearizedIds.clear();

109 return true;

110 }

111

113 for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {

114 if (large % small != 0)

115 return false;

116 sizes.push_back(large / small);

117 }

118 if (std::accumulate(sizes.begin(), sizes.end(), 1,

119 std::multiplies<int64_t>()) != warpSize)

120 return false;

121

124

125 int64_t usedThreads = 1;

126

128 delinearizedIds.assign(sizes.size(), zero);

129

130 for (int i = sizes.size() - 1; i >= 0; --i) {

131 usedThreads *= sizes[i];

132 if (usedThreads == warpSize) {

133

134

135 delinearizedIds[i] = laneId;

136 break;

137 }

138 delinearizedIds[i] =

141 builder, loc, s0.floorDiv(usedThreads), {laneId});

142 }

143 return true;

144 }

Base type for affine expression.

AffineExpr floorDiv(uint64_t v) const

Block represents an ordered list of Operations.

MLIRContext * getContext() const

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents an operand of an operation.

Operation is the basic unit of execution within MLIR.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

BlockListType & getBlocks()

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

virtual void eraseBlock(Block *block)

This method erases all operations in a block.

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)

Move the blocks that belong to "region" before the given position in another region "parent".

This class provides an abstraction over the various different ranges of value types.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

Specialization of arith.constant op that returns an integer of index type.

AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

Include the generated interface declarations.

void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to SymbolExpr at positions: [0 .

WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const

Helper to create a new WarpExecuteOnLane0Op region with extra outputs.

bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const

Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...

WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const

Helper to create a new WarpExecuteOnLane0Op with different signature.

OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const

Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.