MLIR: lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/Support/InterleavedRange.h"
22
23 using namespace mlir;
24
25 #define DEBUG_TYPE "linalg-transforms"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27
28
29
30
31
35
36 if (!isalinalg::LinalgOp(current)) {
37 if (getFailurePropagationMode().value_or(
38 FailurePropagationMode::Propagate) ==
39 FailurePropagationMode::Propagate) {
40 return emitSilenceableError() << "expected a Linalg op";
41 }
42
43 LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
46 }
47
48
49 auto scope = state.make_region_scope(getBodyRegion());
50 if (failed(state.mapBlockArgument(getBody()->getArgument(0),
53 }
54
55 for (Operation &nested : getBody()->without_terminator()) {
57 state.applyTransform(cast(nested));
58 if (diag.isDefiniteFailure())
60 if (diag.succeeded())
61 continue;
62
63
64 assert(diag.isSilenceableFailure());
65 if (getFailurePropagationMode().value_or(
66 FailurePropagationMode::Propagate) ==
67 FailurePropagationMode::Propagate) {
69 }
70
71
72
73
74
75
76
77
78
79 LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
80 << "\n");
81 (void)diag.silence();
83 for (OpOperand &terminatorOperand :
84 getBody()->getTerminator()->getOpOperands()) {
85 Operation *definingOp = terminatorOperand.get().getDefiningOp();
86 if (!definingOp)
87 continue;
88 if (definingOp->getBlock() != getBody())
89 continue;
91 continue;
92
93 undefinedOperands.push_back(&terminatorOperand);
94 }
95
97 auto filtered = llvm::make_filter_range(
98 getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
99 return !llvm::is_contained(undefinedOperands, &opOperand);
100 });
101 SmallVector definedOperands = llvm::to_vector(llvm::map_range(
102 filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
104 for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
105 results.setMappedValues(getResults()[operand.getOperandNumber()],
106 mapping);
107 }
110 }
111
112
115 }
116
117 void transform::MatchStructuredOp::getEffects(
121 producesHandle(getOperation()->getOpResults(), effects);
122 }
123
125 if (getBody()->getNumArguments() != 1)
126 return emitOpError() << "expected one body argument";
127 if (!isa(getBody()->getArgument(0).getType())) {
128 return emitOpError() << "expected body argument to implement "
129 "TransformHandleTypeInterface";
130 }
131 for (Operation &nested : getBody()->without_terminator()) {
132 if (isa(nested))
133 continue;
135 emitOpError()
136 << "expects nested operations to implement MatchOpInterface";
137 diag.attachNote(nested.getLoc()) << "offending operation";
139 }
140 return success();
141 }
142
143
144
145
146
149 if (!isa_and_nonnull(op->getParentOp())) {
150 return op->emitOpError() << "expects parent op to be '"
151 << MatchStructuredOp::getOperationName() << "'";
152 }
153
154
158 return success();
159
162 << "expected predicate to apply to the surrounding structured op";
163 }
164 return success();
165 }
166
167
168
169
170
174 auto linalgOp = castlinalg::LinalgOp(current);
175 if (std::optional<uint64_t> position = getReductionPosition()) {
177 if ((linalgOp.getRegionOutputArgs(), *position,
178 combinerOps)) {
179 return emitSilenceableError() << "could not match reduction";
180 }
181 if (combinerOps.size() != 1) {
182 return emitSilenceableError() << "reduction combiner is not a single op";
183 }
185 }
186 if (getPassthrough()) {
187 Block &body = linalgOp->getRegion(0).front();
189 return emitSilenceableError() << "not a passthrough";
190 }
192 }
193 if (getElementwise()) {
195 return emitSilenceableError() << "not elementwise";
197 }
198 if (std::optional contractionOps = getContraction()) {
199 Block &body = linalgOp->getRegion(0).front();
200 std::string message;
201 llvm::raw_string_ostream os(message);
203 body,
205 return elem->getName().getStringRef() ==
206 cast((*contractionOps)[0]).getValue() &&
208 cast((*contractionOps)[1]).getValue();
209 },
210 os);
211 if (result)
213 return emitSilenceableError() << "contraction: " << message;
214 }
216 }
217
219 int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
220 getElementwise() + getContraction().has_value();
221
222 if (numOptions > 1) {
223 StringAttr attributeNames[] = {
224 getReductionPositionAttrName(), getPassthroughAttrName(),
225 getElementwiseAttrName(), getContractionAttrName()};
226 return emitOpError() << "only one of {" << llvm::interleaved(attributeNames)
227 << "} is allowed";
228 }
229
230 if (std::optional contractionAttr = getContraction()) {
231 if (contractionAttr->size() != 2) {
232 return emitOpError() << "expects " << getContractionAttrName()
233 << " to contain two elements";
234 }
235 }
236 return success();
237 }
238
239
240
241
242
244 transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
247 FailureOrlinalg::ContractionDimensions contractionDims =
249 if (failed(contractionDims))
250 return emitSilenceableError() << "could not infer contraction dimensions";
251
253 Builder builder(context);
255 return llvm::to_vector(
256 llvm::map_range(values, [&](unsigned value) -> Attribute {
257 return builder.getI64IntegerAttr(value);
258 }));
259 };
260 results.setParams(cast(getBatch()),
261 makeI64Attrs(contractionDims->batch));
262 results.setParams(cast(getM()), makeI64Attrs(contractionDims->m));
263 results.setParams(cast(getN()), makeI64Attrs(contractionDims->n));
264 results.setParams(cast(getK()), makeI64Attrs(contractionDims->k));
266 }
267
268
269
270
271
273 transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
276 FailureOrlinalg::ConvolutionDimensions convolutionDims =
278 if (failed(convolutionDims))
279 return emitSilenceableError() << "could not infer convolution dimensions";
280
282 Builder builder(context);
284 return llvm::to_vector(
285 llvm::map_range(values, [&](unsigned value) -> Attribute {
286 return builder.getI64IntegerAttr(value);
287 }));
288 };
289 results.setParams(cast(getBatch()),
290 makeI64Attrs(convolutionDims->batch));
291 results.setParams(cast(getOutputImage()),
292 makeI64Attrs(convolutionDims->outputImage));
293 results.setParams(cast(getOutputChannel()),
294 makeI64Attrs(convolutionDims->outputChannel));
295 results.setParams(cast(getFilterLoop()),
296 makeI64Attrs(convolutionDims->filterLoop));
297 results.setParams(cast(getInputChannel()),
298 makeI64Attrs(convolutionDims->inputChannel));
299 results.setParams(cast(getDepth()),
300 makeI64Attrs(convolutionDims->depth));
301
303 return llvm::to_vector(
304 llvm::map_range(values, [&](int64_t value) -> Attribute {
305 return builder.getI64IntegerAttr(value);
306 }));
307 };
308 results.setParams(cast(getStrides()),
309 makeI64AttrsFromI64(convolutionDims->strides));
310 results.setParams(cast(getDilations()),
311 makeI64AttrsFromI64(convolutionDims->dilations));
313 }
314
315
316
317
318
319
320
321
322
323
327 const char *message) {
328 for (int64_t value : list) {
329 if (llvm::any_of(reference, [&](unsigned ref) {
330 return static_cast<int64_t>(ref) == value;
331 })) {
332 continue;
333 }
335 }
337 }
338
339
340
341
342
346 auto linalgOp = castlinalg::LinalgOp(current);
349 if (.succeeded())
351
352
353 if (getParallel() || getReduction()) {
355 if (getParallel())
356 linalgOp.getParallelDims(reference);
357 else if (getReduction())
358 linalgOp.getReductionDims(reference);
359
361 containsAll(reference, dimensions, getLoc(),
362 getParallel() ? "expects dimension #{0} to be parallel"
363 : "expects dimension #{0} to be reduction");
364 if (.succeeded())
366 }
367
368
369 if (!getResult())
371
373 Builder builder(current);
375 llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
376 return builder.getI64IntegerAttr(ranges[dim]);
377 }));
378 results.setParams(cast(getResult()), captured);
380 }
381
386 getRawDimList(), op.getNumLoops(), dims);
387 if (diag.isSilenceableFailure()) {
388 diag.attachNote(op->getLoc())
389 << "while considering dimensions of this payload operation";
390 }
392 }
393
395 if (getParallel() && getReduction()) {
396 return emitOpError() << "cannot request the same dimension to be both "
397 "parallel and reduction";
398 }
400 getIsInverted(), getIsAll());
401 }
402
403
404
405
406
408 transform::MatchStructuredElementalBitwidthOp::matchValue(
411 auto setupResult = [&](int64_t bitwidth) {
413 results.setParams(cast(getResult()), {attr});
415 };
416
420
421 if (auto shapedType = dyn_cast(type)) {
422 if (shapedType.getElementType().isIntOrFloat())
423 return setupResult(shapedType.getElementTypeBitWidth());
424 }
425 return emitSilenceableError()
426 << "unsupported type for bitwidth extraction: " << type;
427 }
428
429
430
431
432
436 auto linalgOp = castlinalg::LinalgOp(current);
439 if (.succeeded())
441
443 operandMapping.reserve(positions.size());
444 for (int64_t position : positions) {
446 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
447 if (getPermutation() && !indexingMap.isPermutation()) {
448 return emitSilenceableError() << "the indexing map for input #"
449 << position << " is not a permutation";
450 }
452 return emitSilenceableError()
453 << "the indexing map for input #" << position
454 << " is not a projected permutation";
455 }
456
457
458 if (!getResult())
459 continue;
460
461 if (isa(getResult().getType())) {
463 continue;
464 }
465
466 Value operand = linalgOp.getDpsInputOperand(position)->get();
467 if (isa(getResult().getType())) {
468 operandMapping.emplace_back(operand);
469 continue;
470 }
471
473 if (!operandProducer) {
474 return emitSilenceableError()
475 << "input #" << position << " is not produced by an operation";
476 }
477 operandMapping.emplace_back(operandProducer);
478 }
479 if (getResult())
480 results.setMappedValues(cast(getResult()), operandMapping);
482 }
483
487 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
488 op.getNumDpsInputs(), positions);
489 if (diag.isSilenceableFailure()) {
490 diag.attachNote(op->getLoc())
491 << "while considering DPS inputs of this payload operation";
492 }
494 }
495
496
497
498 template
500 if (op.getPermutation() && op.getProjectedPermutation()) {
501 return op.emitOpError()
502 << op.getPermutationAttrName() << " and "
503 << op.getProjectedPermutationAttrName() << " are mutually exclusive";
504 }
505 if (op.getRawPositionList().size() > 1 && op.getResult()) {
506 return op.emitOpError()
507 << "cannot bind multiple inputs/inits to the same value";
508 }
509
510 return success();
511 }
512
515 return failure();
517 getIsInverted(), getIsAll());
518 }
519
520
521
522
523
527 auto linalgOp = castlinalg::LinalgOp(current);
530 if (.succeeded())
532
534 operandMapping.reserve(positions.size());
535 for (int64_t position : positions) {
537 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
538 if (getPermutation() && !indexingMap.isPermutation()) {
539 return emitSilenceableError() << "the indexing map for output(init) #"
540 << position << " is not a permutation";
541 }
543 return emitSilenceableError() << "the indexing map for output(init) #"
544 << position << " is not a permutation";
545 }
546
547
548 if (!getResult())
549 continue;
550
551 if (isa(getResult().getType())) {
553 continue;
554 }
555
556 Value operand = linalgOp.getDpsInitOperand(position)->get();
557 if (isa(getResult().getType())) {
558 operandMapping.emplace_back(operand);
559 continue;
560 }
561
563 if (!operandProducer) {
564 return emitSilenceableError() << "output(init) #" << position
565 << " is not produced by an operation";
566 }
567 operandMapping.emplace_back(operandProducer);
568 }
569 if (getResult())
570 results.setMappedValues(cast(getResult()), operandMapping);
572 }
573
577 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
578 op.getNumDpsInits(), positions);
579 if (diag.isSilenceableFailure()) {
580 diag.attachNote(op->getLoc())
581 << "while considering DPS inits (outputs) of this payload operation";
582 }
584 }
585
588 return failure();
590 getIsInverted(), getIsAll());
591 }
592
593
594
595
596
598 transform::MatchStructuredNumInputsOp::matchOperation(
601 auto linalgOp = castlinalg::LinalgOp(current);
604 results.setParams(cast(getResult()), {attr});
606 }
607
608
609
610
611
613 transform::MatchStructuredNumInitsOp::matchOperation(
616 auto linalgOp = castlinalg::LinalgOp(current);
619 results.setParams(cast(getResult()), {attr});
621 }
622
623
624
625
626
630 auto linalgOp = castlinalg::LinalgOp(current);
631 int64_t numLoops = linalgOp.getNumLoops();
633 results.setParams(cast(getRank()), {attr});
635 }
636
637
638
639
640
644 auto linalgOp = castlinalg::LinalgOp(op);
645 int64_t position;
647 if (.succeeded())
649
650 Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
651 if (isa(getResult().getType())) {
652 results.setValues(cast(getResult()), {result});
654 }
655
656 if (result.getUsers().empty()) {
657 return emitSilenceableError()
658 << "no users of the result #" << getPosition();
659 }
661 if (getAny()) {
662 results.set(cast(getResult()), {firstUser});
664 }
665 if (getSingle()) {
666 if (!llvm::hasSingleElement(result.getUsers())) {
667 return emitSilenceableError()
668 << "more than one result user with single user requested";
669 }
670 results.set(cast(getResult()), {firstUser});
672 }
673
675 }
676
678 transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
679 int64_t &position) {
680 auto rawPosition = static_cast<int64_t>(getPosition());
681 position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
682 if (position >= op.getNumDpsInits() || position < 0) {
683 return emitSilenceableError()
684 << "position " << rawPosition
685 << " overflows the number of results(ints) of the payload operation";
686 }
688 }
689
691 if ((getAny() || getSingle()) ^
692 isa(getResult().getType())) {
693 return emitOpError() << "expects either the any/single keyword or the type "
694 "value handle result type";
695 }
696 if (getAny() && getSingle()) {
697 return emitOpError() << "'any' and 'single' are mutually exclusive";
698 }
699 return success();
700 }
701
702
703
704
705
706 void transform::MatchStructuredYieldOp::getEffects(
710 }
711
712 void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
715 }
716
717 #define GET_OP_CLASSES
718 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
static DiagnosedSilenceableFailure containsAll(ArrayRef< unsigned > reference, ArrayRef< int64_t > list, Location loc, const char *message)
Checks if all values from list are also contained in reference.
LogicalResult verifyStructuredOperandOp(OpTy op)
Verifies a matcher op for structured input or output, specifically the attributes specifying the oper...
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
The state maintained across applications of various ops implementing the TransformOpInterface.
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op, Value structuredOpHandle)
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.