MLIR: lib/Dialect/Linalg/Transforms/Specialize.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

23 #include "llvm/Support/Debug.h"

24

25 namespace mlir {

26 #define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS

27 #include "mlir/Dialect/Linalg/Passes.h.inc"

28 }

29

30 #define DEBUG_TYPE "linalg-specialization"

31

32 #define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \

33 (rewriter.replaceOpWithNewOp( \

34 genericOp, \

35 ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \

36 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \

37 ValueRange{genericOp.getDpsInits()[0]}))

38

39 #define REPLACE_UNARY_OP(NEWOP) \

40 (rewriter.replaceOpWithNewOp(genericOp, \

41 ValueRange{genericOp.getDpsInputs()[0]}, \

42 ValueRange{genericOp.getDpsInits()[0]}))

43

44 using namespace mlir;

46

47

48

49

50

51

52

53

54

55

56

57

59 Block *body = genericOp.getBody();

61 bool swapped = false;

63 swapped = true;

66 "binary op uses just one block arg");

67 }

68 return swapped;

69 }

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96 namespace {

97 enum class IndexMatchResult {

98 Match = 0,

99 Transposed,

100 Mismatch

101 };

102

103

104

105

106

107

108

109

110

111 static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,

112 unsigned expectedPosOfRowDim,

113 unsigned expectedPosOfColDim) {

114

115 auto exprOfRowDim = map.getResults()[rowDimIdx];

116 auto exprOfColDim = map.getResults()[rowDimIdx + 1];

117

118

121 return IndexMatchResult::Mismatch;

122

123 auto posRowDim = cast(exprOfRowDim).getPosition();

124 auto posColDim = cast(exprOfColDim).getPosition();

125

126 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)

127 return IndexMatchResult::Match;

128

129 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)

130 return IndexMatchResult::Transposed;

131

132 return IndexMatchResult::Mismatch;

133 }

134

135

136

137

138

139 template

140 static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {

142 op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},

144 return namedOp;

145 }

146

147

148 static FailureOr specializeLinalgContractions(RewriterBase &rewriter,

149 GenericOp genericOp) {

150 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)

151 return failure();

152

153

154 auto mapRange = genericOp.getIndexingMapsArray();

155 if (llvm::any_of(mapRange,

156 [](AffineMap m) { return !m.isProjectedPermutation(); }))

157 return failure();

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

179 if (!succeeded(res))

180 return failure();

181 auto dims = *res;

182 if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)

183 return failure();

184

187 if ((isaarith::MulFOp(first) && isaarith::AddFOp(second)) ||

188 (isaarith::MulIOp(first) && isaarith::AddIOp(second)) ||

189 (isacomplex::MulOp(first) && isacomplex::AddOp(second)))

190 return true;

191 return false;

192 }))

193 return failure();

194

195

196 auto indexingMaps = genericOp.getIndexingMapsArray();

197 if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {

198 return m.getResults().size() !=

199 dims.batch.size() + 2 ;

200 }))

201 return failure();

202

203 auto numOfBatchDims = dims.batch.size();

204 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)

205 return failure();

206

207 if (numOfBatchDims) {

208

209

210

211 if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {

212 for (unsigned i = 0; i < numOfBatchDims; ++i) {

213 auto expr = m.getResults()[i];

215 cast(expr).getPosition() != i)

216 return true;

217 }

218 return false;

219 }))

220 return failure();

221 }

222

223 auto a =

224 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);

225 auto b =

226 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);

227 auto c =

228 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);

229

230 if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))

231 return failure();

232

233 if (c != IndexMatchResult::Match ||

234 (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))

235 return failure();

236

237

238 if (numOfBatchDims) {

239 if (a == IndexMatchResult::Transposed)

240 return replaceWithMatmulVariant(rewriter,

241 genericOp);

242 if (b == IndexMatchResult::Transposed)

243 return replaceWithMatmulVariant(rewriter,

244 genericOp);

245 return replaceWithMatmulVariant(rewriter, genericOp);

246 }

247

248 if (a == IndexMatchResult::Transposed)

249 return replaceWithMatmulVariant(rewriter, genericOp);

250 if (b == IndexMatchResult::Transposed)

251 return replaceWithMatmulVariant(rewriter, genericOp);

252 return replaceWithMatmulVariant(rewriter, genericOp);

253 }

254

255 }

256

257

258

259

261 GenericOp genericOp) {

262

265 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);

266 return namedOp;

267 }

268

269

272 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);

273 return namedOp;

274 }

275

276

277 std::optional<SmallVector<int64_t>> equivalentToBroadcast =

279 if (equivalentToBroadcast) {

280 auto dims = *equivalentToBroadcast;

282 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],

283 dims);

284 return namedOp;

285 }

286

287

288 std::optional<SmallVector<int64_t>> equivalentToTranspose =

290 if (equivalentToTranspose) {

291 auto permutation = *equivalentToTranspose;

293 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],

294 permutation);

295 return namedOp;

296 }

297

298

300 Operation *op = &genericOp.getBody()->front();

301 if (isamath::ExpOp(op)) {

303 return namedOp;

304 }

305 }

306

307

310 Operation *op = &genericOp.getBody()->front();

311 if (isaarith::AddFOp(op)) {

313 return namedOp;

314 }

315 if (isaarith::SubFOp(op)) {

317 return namedOp;

318 }

319 if (isaarith::MulFOp(op)) {

321 return namedOp;

322 }

323 if (isaarith::DivFOp(op)) {

325 return namedOp;

326 }

327 }

328

329

331 return specializeLinalgContractions(rewriter, genericOp);

332 }

333 return failure();

334 }

335

336 namespace {

337 struct LinalgSpecializeGenericOpsPass

338 : public impl::LinalgSpecializeGenericOpsPassBase<

339 LinalgSpecializeGenericOpsPass> {

340

341 using impl::LinalgSpecializeGenericOpsPassBase<

342 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;

343 void runOnOperation() override;

344 };

345 }

346

347 void LinalgSpecializeGenericOpsPass::runOnOperation() {

351

353 signalPassFailure();

354 }

355

359 }

static MLIRContext * getContext(OpFoldResult val)

#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)

static bool areBinOpsSwapped(GenericOp genericOp)

#define REPLACE_UNARY_OP(NEWOP)

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

ArrayRef< AffineExpr > getResults() const

Block represents an ordered list of Operations.

BlockArgument getArgument(unsigned i)

IRValueT get() const

Return the current value being used by this operand.

Operation is the basic unit of execution within MLIR.

OpOperand & getOpOperand(unsigned idx)

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class provides an abstraction over the different types of ranges over Values.

bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())

Returns true if the block contains a contraction of the following form:

std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)

Checks whether genericOp is semantically equivalent to a linalg.transpose.

bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)

Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.

bool isaCopyOpInterface(LinalgOp linalgOp)

Checks whether linalgOp is semantically equivalent to a linalg.copyOp.

void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)

Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.

FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)

Create a namedOp from the given GenericOp and replace the GenericOp.

std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)

Checks whether genericOp is semantically equivalent to a linalg.broadcast.

FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)

Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...

bool isaContractionOpInterface(LinalgOp linalgOp)

Checks whether linalgOp conforms to ContractionOpInterface.

void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns)

Populates patterns with patterns to convert linalg.generic ops to named ops where possible.

std::optional< Value > isaFillOpInterface(GenericOp genericOp)

Checks whether genericOp is semantically equivalent to a linalg.fill.

bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)

Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....

Include the generated interface declarations.

LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)

Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...

@ DimId

Dimensional identifier.

const FrozenRewritePatternSet & patterns