MLIR: lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
23 #include
24
25 namespace mlir {
26 namespace linalg {
28 return llvm::all_of(
29 attr, [](const APInt &element) { return element.getSExtValue() == 1; });
30 }
31
33 if (isa(x.getType()))
34 return builder.createarith::AddIOp(loc, x, y);
35 if (isa(x.getType()))
36 return builder.createcomplex::AddOp(loc, x, y);
37 return builder.createarith::AddFOp(loc, x, y);
38 }
39
42
47 if (isa(accType))
48 return builder.createcomplex::MulOp(loc, xConvert, yConvert);
49 if (isa(accType))
50 return builder.createarith::MulIOp(loc, xConvert, yConvert);
51 return builder.createarith::MulFOp(loc, xConvert, yConvert);
52 }
53
54
57 assert(!factors.empty() && "empty factor list");
59 for (int64_t f : factors)
61 FailureOr<SmallVector> multiIndex =
63 assert(!failed(multiIndex) && "Failed to linearize img2col index");
64 return *multiIndex;
65 }
66
67
68
69
71 Value fIndex, int64_t stride) {
76 }
77
78 FailureOr<std::pair<Operation *, Operation *>>
80 auto inputType = cast(convOp.getInputs()[0].getType());
81 auto filterType = cast(convOp.getInputs()[1].getType());
82 auto outputType = cast(convOp.getOutputs()[0].getType());
83
84 if (!filterType.hasStaticShape())
86 convOp, "expected a static shape for the filter");
87
88 if (!inputType.hasStaticShape())
90 "expected a static shape for the input");
91
92
95 "expected all ones for dilations");
96
98 Value input = convOp.getInputs()[0];
99 Value filter = convOp.getInputs()[1];
100 Value output = convOp.getOutputs()[0];
101
104
105 int64_t n = outputShape[0];
106 int64_t oh = outputShape[1];
107 int64_t ow = outputShape[2];
108 int64_t oc = outputShape[3];
109 int64_t fh = filterShape[0];
110 int64_t fw = filterShape[1];
111 int64_t ic = filterShape[2];
112
113 Location loc = convOp.getLoc();
114
115
117 auto reshapedFilterType =
119 Value reshapedFilter = rewriter.createtensor::CollapseShapeOp(
120 loc, reshapedFilterType, filter, filterReassocIndices);
121
123 RankedTensorType reshapedOutputType =
125 Value reshapedOutput = rewriter.createtensor::CollapseShapeOp(
126 loc, reshapedOutputType, output, outputReassocIndices);
127
129 Value colTensor = rewriter.createtensor::EmptyOp(
130 loc, colTensorShape, inputType.getElementType());
131
132
133 auto nloops = colTensorShape.size();
134
135 auto parallel = utils::IteratorType::parallel;
136 auto reduction = utils::IteratorType::reduction;
138
141
142 auto img2ColTensor = rewriter.createlinalg::GenericOp(
143 loc, colTensor.getType(),
144 ValueRange{}, colTensor, img2colIndexingMaps,
145 img2colIterators,
147
148 Value bIndex = nestedBuilder.createlinalg::IndexOp(loc, 0);
149 Value mIndex = nestedBuilder.createlinalg::IndexOp(loc, 1);
150 Value kIndex = nestedBuilder.createlinalg::IndexOp(loc, 2);
151
152
155 auto ohIndex = mIndices[0];
156 auto owIndex = mIndices[1];
157
159 nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
160 auto fhIndex = kIndices[0];
161 auto fwIndex = kIndices[1];
162 auto icIndex = kIndices[2];
163
164
167 convOp.getStrides().getValues<int64_t>()[0]);
170 convOp.getStrides().getValues<int64_t>()[1]);
171
172
173 SmallVector extractionIndices{bIndex, hIndex, wIndex, icIndex};
174 Value inputVal = nestedBuilder.createtensor::ExtractOp(
175 loc, input, extractionIndices);
176 nestedBuilder.createlinalg::YieldOp(nestedLoc, inputVal);
177 });
178
179
180
181
182
184 bindDims(context, bDim, mDim, nDim, kDim);
185 auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
186 auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
187 auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
189 parallel, reduction};
190
191 auto genericOp = rewriter.createlinalg::GenericOp(
192 loc, reshapedOutputType,
193 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
194 ValueRange{reshapedOutput},
198 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
199 Value add = createAdd(loc, mul, args[2], nestedBuilder);
200 nestedBuilder.createlinalg::YieldOp(nestedLoc, add);
201 });
202 Value result = genericOp.getResults().front();
203
204 auto reshapedResult = rewriter.createtensor::ExpandShapeOp(
205 loc, outputType, result, outputReassocIndices);
206
208
209 return std::make_pair(img2ColTensor.getOperation(),
210 reshapedResult.getOperation());
211 }
212
213 FailureOr<std::pair<Operation *, Operation *>>
215 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
216 auto inputType = cast(convOp.getInputs()[0].getType());
217 auto filterType = cast(convOp.getInputs()[1].getType());
218 auto outputType = cast(convOp.getOutputs()[0].getType());
219
220 if (!filterType.hasStaticShape())
222 convOp, "expected a static shape for the filter");
223
224 if (!inputType.hasStaticShape())
226 "expected a static shape for the input");
227
228
231 "expected all ones for dilations");
232
233 Location loc = convOp.getLoc();
234
236 auto operandTensorType = cast(operand.getType());
237 auto nloops = indices.size();
239
241 llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
243 }));
244
246 indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
247
248 Value outputTensor = rewriter.createtensor::EmptyOp(
249 loc, targetShape, operandTensorType.getElementType());
250
252 nloops, utils::IteratorType::parallel);
253
258
259 auto transposedOp = rewriter.createlinalg::GenericOp(
260 loc, outputTensor.getType(),
261 operand, outputTensor, indexingMaps,
262 loopAttributeTypes,
264 nestedBuilder.createlinalg::YieldOp(nestedLoc, args[0]);
265 });
266
267 return transposedOp.getResult(0);
268 };
269
270 Value input = convOp.getInputs()[0];
271 Value filter = convOp.getInputs()[1];
272 Value output = convOp.getOutputs()[0];
273
274
275 Value inputT = transposeOperand(input, {0, 3, 1, 2});
276 Value filterT = transposeOperand(filter, {2, 0, 1});
278 cast(filterT.getType()).getShape();
280
281 int n = outputShape[0];
282 int oh = outputShape[1];
283 int ow = outputShape[2];
284 int c = outputShape[3];
285 int fh = filterTShape[1];
286 int fw = filterTShape[2];
287
289 Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
290
291 AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
292 bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
293
295 convOp.getStrides().getValues<int64_t>()[0]);
297 convOp.getStrides().getValues<int64_t>()[1]);
298
300 owDim * swSym + kwDim};
301
302 auto nloops = colTensorShape.size();
303
305 nloops, utils::IteratorType::parallel);
306
310
311 Value colTensor = rewriter.createtensor::EmptyOp(
312 loc, colTensorShape, inputType.getElementType());
313
314 auto img2ColTensor = rewriter.createlinalg::GenericOp(
315 loc, colTensor.getType(),
316 inputT, colTensor, indexingMaps,
317 loopAttributeTypes,
319 nestedBuilder.createlinalg::YieldOp(nestedLoc, args[0]);
320 });
321
323 {0, 1}, {2, 3}, {4, 5}};
326 {2, 3}};
327
329 {n * c, oh * ow, fh * fw}, inputType.getElementType());
330 auto reshapedFilterTensorType =
332 auto reshapedOutputTensorType =
334
335 Value reshapedImg2ColTensor = rewriter.createtensor::CollapseShapeOp(
336 loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
337 img2ColTensorReassocIndices);
338 Value reshapedFilterTensor = rewriter.createtensor::CollapseShapeOp(
339 loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
340 Value reshapedoutputTensor = rewriter.createtensor::CollapseShapeOp(
341 loc, reshapedOutputTensorType, transposedOutputTensor,
342 outputReassociationIndice);
343
344 auto batchMatVecResult = rewriter.createlinalg::BatchMatvecOp(
345 loc, TypeRange{reshapedoutputTensor.getType()},
346 ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
348
350 {2, 3}};
351
352 auto batchMatVecResultReshaped = rewriter.createtensor::ExpandShapeOp(
353 loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
354 batchMatVecReassociationIndice);
355
356 Value transposedResult =
357 transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
358
360 return std::make_pair(img2ColTensor.getOperation(),
361 transposedResult.getDefiningOp());
362 }
363
364 FailureOr<std::pair<Operation *, Operation *>>
366 auto inputType = cast(convOp.getInputs()[0].getType());
367 auto filterType = cast(convOp.getInputs()[1].getType());
368 auto outputType = cast(convOp.getOutputs()[0].getType());
369
370 if (!filterType.hasStaticShape())
372 convOp, "expected a static shape for the filter");
373
374 if (!inputType.hasStaticShape())
376 "expected a static shape for the input");
377
378
381 "expected all ones for dilations");
382
383 Value input = convOp.getInputs()[0];
384 Value filter = convOp.getInputs()[1];
385 Value output = convOp.getOutputs()[0];
386
387 auto filterShape = filterType.getShape();
388 auto outputShape = outputType.getShape();
389
390 int64_t n = outputShape[0];
391 int64_t oc = outputShape[1];
392 int64_t oh = outputShape[2];
393 int64_t ow = outputShape[3];
394 int64_t ic = filterShape[1];
395 int64_t fh = filterShape[2];
396 int64_t fw = filterShape[3];
397
398 auto loc = convOp.getLoc();
400
402 auto reshapedFilterType =
404 Value reshapedFilter = rewriter.createtensor::CollapseShapeOp(
405 loc, reshapedFilterType, filter, filterReassocIndices);
406
408 auto reshapedOutputType =
410 Value reshapedOutput = rewriter.createtensor::CollapseShapeOp(
411 loc, reshapedOutputType, output, outputReassocIndices);
412
413
415 Value colTensor = rewriter.createtensor::EmptyOp(
416 loc, colTensorShape, inputType.getElementType());
417
418 auto nloops = colTensorShape.size();
419
420 auto parallel = utils::IteratorType::parallel;
421 auto reduction = utils::IteratorType::reduction;
423
426
427 auto img2ColTensor = rewriter.createlinalg::GenericOp(
428 loc, colTensor.getType(),
429 ValueRange{}, colTensor, img2colIndexingMaps,
430 img2colIterators,
432
433 Value bIndex = nestedBuilder.createlinalg::IndexOp(loc, 0);
434 Value kIndex = nestedBuilder.createlinalg::IndexOp(loc, 1);
435 Value nIndex = nestedBuilder.createlinalg::IndexOp(loc, 2);
436
437
439 nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
440 auto icIndex = kIndices[0];
441 auto fhIndex = kIndices[1];
442 auto fwIndex = kIndices[2];
443
446 auto ohIndex = nIndices[0];
447 auto owIndex = nIndices[1];
448
449
452 convOp.getStrides().getValues<int64_t>()[0]);
455 convOp.getStrides().getValues<int64_t>()[1]);
456
457
458 SmallVector extractionIndices{bIndex, icIndex, hIndex, wIndex};
459 Value inputVal = nestedBuilder.createtensor::ExtractOp(
460 loc, input, extractionIndices);
461 nestedBuilder.createlinalg::YieldOp(nestedLoc, inputVal);
462 });
463
464
465
466
467
469 bindDims(context, bDim, mDim, nDim, kDim);
470 auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
471 auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
472 auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
474 parallel, reduction};
475 auto genericOp = rewriter.createlinalg::GenericOp(
476 loc, reshapedOutputType,
477 ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
478 ValueRange{reshapedOutput},
482 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
483 Value add = createAdd(loc, mul, args[2], nestedBuilder);
484 nestedBuilder.createlinalg::YieldOp(nestedLoc, add);
485 });
486 Value result = genericOp.getResults().front();
487
488 auto reshapedResult = rewriter.createtensor::ExpandShapeOp(
489 loc, outputType, result, outputReassocIndices);
490
492
493 return std::make_pair(img2ColTensor.getOperation(),
494 reshapedResult.getOperation());
495 }
496
497 FailureOr<std::pair<Operation *, Operation *>>
499 auto inputType = cast(convOp.getInputs()[0].getType());
500 auto filterType = cast(convOp.getInputs()[1].getType());
501 auto outputType = cast(convOp.getOutputs()[0].getType());
502
503 if (!filterType.hasStaticShape())
505 convOp, "expected a static shape for the filter");
506
507 if (!inputType.hasStaticShape())
509 "expected a static shape for the input");
510
511
514 "expected all ones for dilations");
515
517 Value input = convOp.getInputs()[0];
518 Value filter = convOp.getInputs()[1];
519 Value output = convOp.getOutputs()[0];
520
523
524 int64_t n = outputShape[0];
525 int64_t oh = outputShape[1];
526 int64_t ow = outputShape[2];
527 int64_t oc = outputShape[3];
528 int64_t fh = filterShape[1];
529 int64_t fw = filterShape[2];
530 int64_t ic = filterShape[3];
531
532 Location loc = convOp.getLoc();
533
534
535
537 auto reshapedFilterType =
539 Value reshapedFilter = rewriter.createtensor::CollapseShapeOp(
540 loc, reshapedFilterType, filter, filterReassocIndices);
541
543 RankedTensorType reshapedOutputType =
545 Value reshapedOutput = rewriter.createtensor::CollapseShapeOp(
546 loc, reshapedOutputType, output, outputReassocIndices);
547
549 Value colTensor = rewriter.createtensor::EmptyOp(
550 loc, colTensorShape, inputType.getElementType());
551
552
553 auto nloops = colTensorShape.size();
554
555 auto parallel = utils::IteratorType::parallel;
556 auto reduction = utils::IteratorType::reduction;
558
561
562 auto img2ColTensor = rewriter.createlinalg::GenericOp(
563 loc, colTensor.getType(),
564 ValueRange{}, colTensor, img2colIndexingMaps,
565 img2colIterators,
567
568 Value bIndex = nestedBuilder.createlinalg::IndexOp(loc, 0);
569 Value mIndex = nestedBuilder.createlinalg::IndexOp(loc, 1);
570 Value kIndex = nestedBuilder.createlinalg::IndexOp(loc, 2);
571
572
575 auto ohIndex = mIndices[0];
576 auto owIndex = mIndices[1];
577
579 nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
580 auto fhIndex = kIndices[0];
581 auto fwIndex = kIndices[1];
582 auto icIndex = kIndices[2];
583
584
587 convOp.getStrides().getValues<int64_t>()[0]);
590 convOp.getStrides().getValues<int64_t>()[1]);
591
592
593 SmallVector extractionIndices{bIndex, hIndex, wIndex, icIndex};
594 Value inputVal = nestedBuilder.createtensor::ExtractOp(
595 loc, input, extractionIndices);
596 nestedBuilder.createlinalg::YieldOp(nestedLoc, inputVal);
597 });
598
599
600
601
603 bindDims(context, bDim, mDim, nDim, kDim);
604 auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
605 auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
606 auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
608 parallel, reduction};
609
610 auto genericOp = rewriter.createlinalg::GenericOp(
611 loc, reshapedOutputType,
612 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
613 ValueRange{reshapedOutput},
617 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
618 Value add = createAdd(loc, mul, args[2], nestedBuilder);
619 nestedBuilder.createlinalg::YieldOp(nestedLoc, add);
620 });
621 Value result = genericOp.getResults().front();
622
623 auto reshapedResult = rewriter.createtensor::ExpandShapeOp(
624 loc, outputType, result, outputReassocIndices);
625
627
628 return std::make_pair(img2ColTensor.getOperation(),
629 reshapedResult.getOperation());
630 }
631
632 namespace {
633
634 class ConvertConv2DNhwcHwcf final
636 public:
638
639 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
642 return failure();
643 return success();
644 }
645 };
646
647 class ConvertDepthwiseConv2DNhwcHwc final
648 : public OpRewritePatternlinalg::DepthwiseConv2DNhwcHwcOp {
649 public:
651
652 LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
653 PatternRewriter &rewriter) const override {
655 return failure();
656 return success();
657 }
658 };
659
660 class ConvertConv2DNchwFchw final
661 : public OpRewritePatternlinalg::Conv2DNchwFchwOp {
662 public:
664
665 LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
666 PatternRewriter &rewriter) const override {
668 return failure();
669 return success();
670 }
671 };
672
673 class ConvertConv2DNhwcFhwc final
674 : public OpRewritePatternlinalg::Conv2DNhwcFhwcOp {
675 public:
677
678 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
679 PatternRewriter &rewriter) const override {
681 return failure();
682 return success();
683 }
684 };
685 }
686
689 patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
690 ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
691 }
692 }
693 }
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
static SmallVector< Value > unrollIndex(OpBuilder &b, Location loc, Value index, ArrayRef< int64_t > factors)
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static bool hasAllOneValues(DenseIntElementsAttr attr)
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, Value fIndex, int64_t stride)
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...