[MLIR][Python] add ctype python binding support for bf16 by xurui1995 · Pull Request #92489 · llvm/llvm-project (original) (raw)

@llvm/pr-subscribers-mlir

Author: Bimo (xurui1995)

Changes

Since bf16 is supported by mlir, similar to complex128/complex64/float16, we need an implementation of bf16 ctype in Python binding. Furthermore, to resolve the absence of bf16 support in NumPy, a third-party package ml_dtypes is introduced to add bf16 extension, and the same approach was used in torch-mlir project.

See motivation and discussion in: https://discourse.llvm.org/t/how-to-run-executionengine-with-bf16-dtype-in-mlir-python-bindings/79025


Full diff: https://github.com/llvm/llvm-project/pull/92489.diff

3 Files Affected:

diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index f6b706f9bc8ae..55e6a6cc5ab3e 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -6,6 +6,7 @@

import numpy as np import ctypes +import ml_dtypes

class C128(ctypes.Structure): @@ -25,6 +26,11 @@ class F16(ctypes.Structure):

 _fields_ = [("f16", ctypes.c_int16)]

+class BF16(ctypes.Structure):

https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype

def as_ctype(dtp): @@ -35,6 +41,8 @@ def as_ctype(dtp): return C64 if dtp == np.dtype(np.float16): return F16

@@ -46,6 +54,8 @@ def to_numpy(array): return array.view("complex64") if array.dtype == F16: return array.view("float16")

diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index acd6dbb25edaf..90acba8d65f09 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,3 +1,4 @@ numpy>=1.19.5, <=1.26 pybind11>=2.9.0, <=2.10.3 -PyYAML>=5.3.1, <=6.0.1 \ No newline at end of file +PyYAML>=5.3.1, <=6.0.1 +ml_dtypes \ No newline at end of file diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index e8b47007a8907..61d145ef24d95 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -5,6 +5,7 @@ from mlir.passmanager import * from mlir.execution_engine import * from mlir.runtime import * +from ml_dtypes import bfloat16

Log everything to stderr and flush so that we have a unified stream to match

@@ -521,6 +522,56 @@ def testComplexUnrankedMemrefAdd(): run(testComplexUnrankedMemrefAdd)

+# Test addition of two bf16 memrefs +# CHECK-LABEL: TEST: testBF16MemrefAdd +def testBF16MemrefAdd():