[RFC][mlir][vector] Modify/remove bit-width specific flattening/linearization (original) (raw)
Hi all,
This RFC proposes changing the way bit-width dependent flattening/linearization works, or removing it entirely.
Question I’d most like answered
Are there still users of the bit-width logic? Either for linearization, or for the flattening of transfer_read and transfer_write? Because if not, maybe we can jump straight to proposal 1 (remove).
History of bit-width dependent flattening
This is just what I have uncovered! The logic for bit-width dependent flattening was introduced in this PR. Shortly thereafter the logic was extended to the closely related linearization pass in this PR.
There were comments left by reviewers/authors in those PRs about future work to have partial unrolling/flattening, such as this one. These seem closely related to this RFC, specifically proposal 3 below.
I refactored the logic for linearization here and separated the bit-width specific logic from the core linearization. That separation was possible because linearization uses type conversion (unlike most of Vector).
What is bit-width dependent flattening
The idea is that ops with operands/results with inner dimensions that are larger than some threshold are left unchanged. For example if the bit-width threshold is greater than 8 in:
%0 = vector.extract %arg0[1]: vector<32x8xi1> from vector<2x32x8xi1>
then the op is left unchanged, otherwise it is linearized to a rank-1 shuffle. The logic for this is here.
For transfer_read and transfer_write in the original PR, it is essentially the same: if the vector that is read or written has a large enough inner dimension, it is not flattened at all. Logic here.
Unintuitive behavior
The current behavior is ok, but could be better in my opinion. To me, it is strange that we can have 2 ops that would be linearized to exactly the same thing, but only one is linearized. Example:
%0 = vector.extract %arg0[1]: vector<32x8xi1> from vector<2x32x8xi1>
might be unchanged (if the threshold is less than 8 bits), while
%0 = vector.extract %arg0[1]: vector<32x8x1xi1> from vector<2x32x8x1xi1>
is linearized/flattened to
%1 = vector.shuffle [...] [indices] : vector<512xi1>, vector<512xi1>
with surrounding shape_cast
ops changing the rank. %0
would be linearized to exactly the same shuffle, if the threshold was larger than 8.
Proposal 1
Remove bit-width specific logic entirely. First need to check who relies on it. PR: [mlir][vector] Remove bit-width logic tests by newling · Pull Request #143007 · llvm/llvm-project · GitHub
Proposal 2
Make bit-width specific logic depend on the total number of elements in the vector, not just the inner-dimension.
Proposal 3
Legalize ops to their ‘nearest’ legal form. In the case above
%0 = vector.extract %arg0[1]: vector<32x8x1xi1> from vector<2x32x8x1xi1>
would be converted to
%0 = vector.extract %arg0[1]: vector<32x8xi1> from vector<2x32x8xi1>
In other words, this proposal is to reduce the rank of operands/results by incrementally collapsing the inner 2 dimensions, until the bit-width threshold is met (or the rank is 1).
A similar logic could be applied to transfer_read/transfer_write flattening. I think this is similar to what is suggested by @hanchung here.
Proposal 4
Leave the logic as it is. This is ok, although blocks me from implementing some improvements to linearization (rough description here)
Thank you for reading!
CC @dcaballe who implemented the initial logic.
Thanks for writing this up and for providing all the context - that’s very helpful!
First, an apology: in your PR, I commented in support of removing the bit-width-related logic. However, with the additional context (and my memory refreshed), I realise this was all part of a broader design we never fully implemented. So I’d like to retract that support for now.
This seems to stem from the presence of the trailing unit dimension. In other words, we’re hitting the flattening logic before certain shape normalization or pre-processing steps have occurred. We might indeed be missing a few canonicalization patterns here, but that’s the first direction I’d explore.
In fact, this is essentially what you describe in Proposal 3:
That said, I don’t fully follow the phrase “inner 2 dimensions” - in your example, it looks like only a single dimension is being collapsed. But that’s just a nit.
This last point makes me wonder: shouldn’t the bit-width logic be a no-op when targetVectorBitWidth is not provided? If not, maybe that’s something we should fix directly?
My overall thinking is that there is a lot of nuance and unfinished work here. Once we have a clear mental model, we should be able to converge towards something that unblocks you.
Let me know what you think - and thanks again for putting this together! I think it’s a good time to revisit this area.
-Andrzej
newling June 16, 2025, 6:01pm 3
So I’d like to retract that support for now .
That’s totally fine. I agree that the decision should be based on this broader scope. Writing this RFC was a good exercise, it forced me to reveal more of the context.
That said, I don’t fully follow the phrase “inner 2 dimensions” - in your example, it looks like only a single dimension is being collapsed. But that’s just a nit.
I think we mean the same thing, by collapse the inner N dimensions (of a rank M+N thing) I mean flatten N dimensions into a single dimension. So go from rank M+N to rank M+1. I’ll be more explicit next time to avoid this ambiguity!
This seems to stem from the presence of the trailing unit dimension .
I’d rather say that the unit dimension reveals the extreme case. To me the unintuitive behavior persists when the trailing dim isn’t 1. Consider 3 types T1) 1600xi1
T2) 100x16xi1
T3) 100x4x4xi1
.
With a bit-threshold of 8, transfer_read
operations are converted as follows: T1 -> T1
and T2 -> T2
and T3 -> T1
. Which is better, T2
or T1
? If T1
is better, then why not T2 -> T1
? If T2
is better, then why not T3 -> T2
?
The current algorithm is:
If the rank-1 form is better, convert to the rank-1 form.
The gradual lowering (proposal 3) algorithm:
While the rank N-1 form is better, convert from rank N to rank N-1 form.
The latter seems better to me. I can’t provide a strong argument for why.
This last point makes me wonder: shouldn’t the bit-width logic be a no-op when targetVectorBitWidth is not provided? If not, maybe that’s something we should fix directly?
It is, yes. The issue I have is that if I change the patterns to flatten more gradually, the bit-width related linearization tests fail - linearization gets stuck at T2
and doesn’t go to T1
. Just to clarify, I am not blocked from making progress if we don’t change the logic (proposal 4). It would just mean less compact code, because I need to retain the ‘direct path’ patterns.
Another thought I’ve had is that flattening of transfer_read/transfer_write (code) is a kind of linearization, and maybe Vector could be simplified by moving it to linearization (with options to choose subsets of patterns).
Hey, thanks for bringing this up! Sorry, I had missed the PR!
Could you elaborate a bit more on the blocking aspect of this? The bitwidth threshold is an optional knob. I’m not sure I follow after reading the PR description.
Unit dimensions have been problematic for many transformations all over the compiler and, unfortunately, dealing with them gracefully in each and every pass is complex and expensive. If this is a major source of problems for this transformation, documenting that removing redundant unit dimensions is a pre-requisite to this transformation should help set expectations.
Yes, proposal 3 seems like the right way to go. Vector unrolling and linearization are used in different ways by different projects but one of the common goals is to make sure we optimize the use of physical vector registers, which is what my comment was about at the time.