Shape constraint representation (original) (raw)
I agree. The higher the operation semantics, the easier it is to derive knowledge because it can essentially be hard-coded. But let’s play a little. Assuming we had a shape function like the below (no claim for correctness):
slice_shape(%data_input, %start_indices, %limit_indices, %stride_indices) {
%shp_lb = shape.from_extent_tensor(%start_indices)
%shp_ub = shape.from_extent_tensor(%limit_indices)
%shp_st = shape.from_extent_tensor(%stride_indices)
%size = shape.subtract %shp_ub, %shp_lb
%result = shape.floordiv %size, %shp_st
return %result
}
and two calls
%data = ...
%shape = shape.shape_of %data // we need the shape as a vector, missing in std
%rank = rank %data
%zeros = dyn_splat 0, %rank // we need a way to express dynamic splats
%ones = dyn_splat 1, %rank
%twos = dyn_splat 2, %rank
%half = floor_div %shape, %twos
%split_l = slice(%data, %zero, %half, %ones)
%split_r = slice(%data, %half, %shape, %ones)
if we now insert the arguments into the shape functions, we get
%shape = shape.shape_of %data // we need the shape as a vector, missing in std
%rank = rank %data
%zeros = dyn_splat 0, %rank // we need a way to express dynamic splats
%ones = dyn_splat 1, %rank
%twos = dyn_splat 2, %rank
%half = floor_div %shape, %twos
// First shape.
%shp_lb1 = shape.from_extent_tensor(%zeros)
%shp_ub1 = shape.from_extent_tensor(%half)
%shp_st1 = shape.from_extent_tensor(%ones)
%size1 = shape.subtract %shp_ub1, %shp_lb1
%result1 = shape.floordiv %size1, %shp_st1
// Second shape.
%shp_lb2 = shape.from_extent_tensor(%half)
%shp_ub2 = shape.from_extent_tensor(%shape)
%shp_st = shape.from_extent_tensor(%ones)
%size2 = shape.subtract %shp_ub2, %shp_lb2
%result2 = shape.floordiv %size2, %shp_st2
We need the from_extent_tensor
to canonicalize away, e.g., we want the operations to be replaced by shape dialect counterparts. So the ones
, zeros
and twos
should become constants in the shape dialect. With that, the shape.floordiv
would be folded. Giving us
%shape = shape.shape_of %data
%rank = shape.rank %shape
%zeros = shape.splat 2, %rank
%twos = shape.splat 2, %rank
%half = shape.floordiv %shape, %twos
%result1 = shape.subtract %half, %zeros
%result2 = shape.subtract %shape, %half
The first shape, %result1
, trivially becomes half. The second shape is defined as shape
- (shape floor_div 2
). That can only be rewritten if we know that shape
is divisble by 2.
That would be the same property we would need on a split operation. We could have used a different formulation that is independent of that property. For example, we could have written
%data = ...
%shape = shape.shape_of %data // we need the shape as a vector, missing in std
%rank = rank %data
%zeros = dyn_splat 0, %rank // we need a way to express dynamic splats
%ones = dyn_splat 1, %rank
%twos = dyn_splat 2, %rank
%half = floor_div %shape, %twos
%two_halves = mul %half, %twos
%split_l = slice(%data, %zero, %half, %ones)
%split_r = slice(%data, %half, %two_halves, %ones)
Now this always returns two results of the same shape and filling this into the above formulation of the shape function, we would get
%result2 = shape.sub %two_halves, %half
and we would need to simplfiy (mul %half, %twos
) - %half
into %half` which seems reasonable.
So, as this shows, one needs to be careful how the semantics of operations are defined, e.g., for a split
operation that produces equally sized chunks. But this can be lowered to slice
if one is careful and considers shape computations in the process.
Whether this scales to more complex operations is a good question. It is a similar question to how far one can take affine modelling. In the end, I believe the important property of the system is whether it degrades gracefully when knowledge is lacking.