[RFC] Add a generic way to imitate/emulate unsupported data types in a target environment (original) (raw)

Hi All,

We want to propose a generic way to emulate/imitate unsupported data types via a gpu dialect transform pass.

Motivation:

In MLIR we often deal with situation where we utilize non-standard (IEEE standard) floating points (e.g., bf16, f8, bf8, f4) operation to get the best out of different hardware. However, this poses a problem for the compiler community. We are often dealt with a situation where upper-level software stack and hardware supports a specific data type, but the middle layer does not. One such example would be SPIR-V which only supports IEEE standard floating point types. While the upper-level stack/IR (e.g., non-target dialects in MLIR) supports different floating point types (e.e.g, bf16, f8), SPIR-V and subsequently SPIR-V dialect does not. Therefore generating SPIR-V code for a hardware that has mixed precision floating point support requires handling of the lack of data type support in SPIR-V like dialects.

Proposal:

Although, LLVM dialect handles them during conversion to them, we propose a pass approach to solve this issue for generic use-case.

This pass imitates (bitcast/reinterpret_cast) unsupported types with supported types of same bitwidth. The imitation is done by bitcasting the unspported types to the supported types of same bitwidth. Therefore, the source type and destination type must have the same bitwidth. The imitation is done by using the following operations: arith.bitcast.

The imitation is often needed when the GPU target (dialect/IR) does not support a certain type but the underlying architecture does. Take SPIR-V for example, it does not support bf16, but an underlying architecture (e.g., intel pvc gpu) that uses SPIR-V for code-generation does. Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a kernel parameter or inside the kernel), bf16 have to be bitcasted (similar to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The SPIR-V kernel can then use the imitated type (i16) in the computation. However, i16 is not the same as bf16 (integer vs float), so the computation can not readily use the imitated type (i16).

Therefore, this transformation pass is intended to be used in conjuction with other transformation passes such as EmulateUnsupportedFloats and ExtendUnsupportedTypes that extend the bitwidth of bf16 to f32 and vice-versa.

Finally, usually, there are instructions available in the target (dialect/IR) that can take advantage of these generated patterns (bf16->i16->f32, f32->bf16->i16), and convert them to the supported types.
For example, Intel provides SPIR-V extension ops that can take imitated bf16 (i16) and convert them to f32 and vice-versa.

Why Pass Approach:

As we know LLVM dialect implements the mechanism during lowering. However, we propose a pass approach for the following reasons:

We’d like community feedback to see what community thinks of this approach.

PR

We already have a PR open: [mlir][gpu] Add pass for emulating unsupported types. by mshahneo · Pull Request #138087 · llvm/llvm-project · GitHub

Example:

Let’s use an example to show the full flow as to how this pass would interact with other passes (EmulateUnsupportedFloats and ExtendUnsupportedTypes) to solve the problem:

The following example code (both host and device) does an elementwise bf16 addition:

func.func @host(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>) -> memref<10x20xbf16> {
    ...
    %gpu_arg0 = gpu.alloc  host_shared () : memref<10x20xbf16>
    memref.copy %arg0, %gpu_arg0 : memref<10x20xbf16> to memref<10x20xbf16>
    ...

    gpu.launch_func  @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1)  args(%gpu_arg0 : memref<10x20xbf16>,...)
    ...
    ...
    return %alloc : memref<10x20xbf16>
  }

gpu.module @test_kernel {
    gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>)... {
      %block_id_x = gpu.block_id  x
      %block_id_y = gpu.block_id  y
      %0 = memref.load %arg0[%block_id_x, %block_id_y] : memref<10x20xbf16>
      %1 = memref.load %arg1[%block_id_x, %block_id_y] : memref<10x20xbf16>
      %2 = arith.addf %0, %1 : bf16
      memref.store %3, %arg2[%block_id_x, %block_id_y] : memref<10x20xbf16>
      gpu.return
    }
  }

Op-level emulation:

Now let’s say we want to generate code for a target that does not support native bf16 addition, but supports f32 addition. We would use EmulateUnsupportedFloatsPass to extend the data types and do the addition, in other words do an op-level emulation. And the result would look like this:


gpu.module @test_kernel {
    gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>)... {
      %block_id_x = gpu.block_id  x
      %block_id_y = gpu.block_id  y

      %0 = memref.load %arg0[%block_id_x, %block_id_y] : memref<10x20xbf16>
      %1 = memref.load %arg1[%block_id_x, %block_id_y] : memref<10x20xbf16>

      // Emulate the operands of the arith op using f32
      %f32_0 = arith.extf %0  fastmath<contract> : bf16 to f32
      %f32_1 = arith.extf %1  fastmath<contract> : bf16 to f32

      // Do the operation in f32
      %f32_2 = arith.addf %f32_0, %f32_1 : f32

      // Revert the result back to bf16,
      // since the original version returned a bf16 result.
      %2 = arith.truncf %f32_2  fastmath<contract> : f32 to bf16

      memref.store %2, %arg2[%block_id_x, %block_id_y] : memref<10x20xbf16>
      gpu.return
    }
  }

Data type emulation/Imitation:

The above emulation is enough if the target IR supports/recognizes bf16 as a data type. However, the target IR may not support/recognize bf16 as a valid data type, which is the case for SPIR-V. Then, we still have no way to lower the above code to SPIR-V. One possible way to handle this would be to imitate/emulate the bf16 data type as a some same bitwidth data type like i16. And change the signature of the kernel as well as arguments passed to the kernel, and do necessary bitcasts. In other words, we would have to do data type emulation/imitation. So, if we run the current pass, the output would look like following:

func.func @host(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>) -> memref<10x20xbf16> {
    ...
    %gpu_arg0 = gpu.alloc  host_shared () : memref<10x20xbf16>
    memref.copy %arg0, %gpu_arg0 : memref<10x20xbf16> to memref<10x20xbf16>

    ...

    // gpu.launch_func is changed with modifed args, i16 is passed insted of bf16
    // Bitcast the gpu.launch_func bf16 arguments to i16 type
    %gpu_arg0_i16 = arith.bitcast % memref<10x20xb16> to memref<10x20xi16>
    gpu.launch_func  @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1)  args(%gpu_arg0_i16 : memref<10x20xi16>,...)
    ...
    ...
    return %alloc : memref<10x20xbf16>
}


gpu.module @test_kernel {
    // Kernel signature is modified to use i16 type args instead of bf16 args
    gpu.func @test_kernel(%arg0: memref<10x20xi16>, %arg1: memref<10x20xi16>, %arg2: memref<10x20xi16>)... {
      %block_id_x = gpu.block_id  x
      %block_id_y = gpu.block_id  y

      // Propagate the usage of i16 type
      %0 = memref.load %arg0[%block_id_x, %block_id_y] : memref<10x20xi16>
      %1 = memref.load %arg1[%block_id_x, %block_id_y] : memref<10x20xi16>

      // Add bit cast operation from i16 to bf16
      %bf16_0 = arith.bitcast %0 : i16 to bf16
      %bf16_1 = arith.bitcast %1 : i16 to bf16
      // Emulate the operands of the arith op using f32
      %f32_0 = arith.extf %bf16_0  fastmath<contract> : bf16 to f32
      %f32_1 = arith.extf %bf16_1  fastmath<contract> : bf16 to f32

      // Do the operation in f32
      %f32_2 = arith.addf %f32_0, %f32_1 : f32

      // Revert the result back to bf16,
      // since the original version returned a bf16 result.
      %bf16_2 = arith.truncf %f32_2  fastmath<contract> : f32 to bf16

      // Cast bf16 type back to i16
      %2 = arith.bitcast %bf16_2 : bf16 to i16

      // Propagate the usage of i16 type
      memref.store %2, %arg2[%block_id_x, %block_id_y] : memref<10x20xi16>
      gpu.return
    }
  }

Conversion Pattern/Pass to Target:

Now during the final conversion to the target IR (SPIR-V), the conversion can identify these bitcast and extf/truncf combination and generate supported operations.

  1. bf16 to f32 pattern: Look for the following pattern in device code:
%bf16_0 = arith.bitcast %0 : i16 to bf16
%f32_0 = arith.extf %bf16_0 fastmath<contract> : bf16 to f32

Replace them with :

%f32_0 = spirv.ConvertBF16ToF %0 : i16 to f32
  1. f32 to bf16 pattern: Look for the following pattern in device code:
%bf16_2 = arith.truncf %f32_2 fastmath<contract> : f32 to bf16
%2 = arith.bitcast %bf16_2 : bf16 to i16

Replace them with :

%2 = spirv.ConvertFToBF16 %f32_2 : f32 to i16

Once, this step is done all the bf16 data type is removed from the device code, and the code now can be converted to SPIR-V.