[MLIR][Python] add ctype python binding support for bf16 by xurui1995 · Pull Request #92489 · llvm/llvm-project (original) (raw)
@llvm/pr-subscribers-mlir
Author: Bimo (xurui1995)
Changes
Since bf16 is supported by mlir, similar to complex128/complex64/float16, we need an implementation of bf16 ctype in Python binding. Furthermore, to resolve the absence of bf16 support in NumPy, a third-party package ml_dtypes is introduced to add bf16 extension, and the same approach was used in torch-mlir
project.
See motivation and discussion in: https://discourse.llvm.org/t/how-to-run-executionengine-with-bf16-dtype-in-mlir-python-bindings/79025
Full diff: https://github.com/llvm/llvm-project/pull/92489.diff
3 Files Affected:
- (modified) mlir/python/mlir/runtime/np_to_memref.py (+10)
- (modified) mlir/python/requirements.txt (+2-1)
- (modified) mlir/test/python/execution_engine.py (+51)
diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index f6b706f9bc8ae..55e6a6cc5ab3e 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -6,6 +6,7 @@
import numpy as np import ctypes +import ml_dtypes
class C128(ctypes.Structure): @@ -25,6 +26,11 @@ class F16(ctypes.Structure):
_fields_ = [("f16", ctypes.c_int16)]
+class BF16(ctypes.Structure):
- """A ctype representation for MLIR's BFloat16."""
- fields = [("bf16", ctypes.c_int16)]
https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
def as_ctype(dtp): @@ -35,6 +41,8 @@ def as_ctype(dtp): return C64 if dtp == np.dtype(np.float16): return F16
- if dtp == ml_dtypes.bfloat16:
return np.ctypeslib.as_ctypes_type(dtp)return BF16
@@ -46,6 +54,8 @@ def to_numpy(array): return array.view("complex64") if array.dtype == F16: return array.view("float16")
- if array.dtype == BF16:
return arrayreturn array.view("bfloat16")
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index acd6dbb25edaf..90acba8d65f09 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,3 +1,4 @@ numpy>=1.19.5, <=1.26 pybind11>=2.9.0, <=2.10.3 -PyYAML>=5.3.1, <=6.0.1 \ No newline at end of file +PyYAML>=5.3.1, <=6.0.1 +ml_dtypes \ No newline at end of file diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index e8b47007a8907..61d145ef24d95 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -5,6 +5,7 @@ from mlir.passmanager import * from mlir.execution_engine import * from mlir.runtime import * +from ml_dtypes import bfloat16
Log everything to stderr and flush so that we have a unified stream to match
@@ -521,6 +522,56 @@ def testComplexUnrankedMemrefAdd(): run(testComplexUnrankedMemrefAdd)
+# Test addition of two bf16 memrefs +# CHECK-LABEL: TEST: testBF16MemrefAdd +def testBF16MemrefAdd():
- with Context():
module = Module.parse(
"""
- module {
func.func @main(%arg0: memref<1xcomplex<bf16>>,
%arg1: memref<1xcomplex<bf16>>,
%arg2: memref<1xcomplex<bf16>>) attributes { llvm.emit_c_interface } {
%0 = arith.constant 0 : index
%1 = memref.load %arg0[%0] : memref<1xcomplex<bf16>>
%2 = memref.load %arg1[%0] : memref<1xcomplex<bf16>>
%3 = complex.add %1, %2 : complex<bf16>
memref.store %3, %arg2[%0] : memref<1xcomplex<bf16>>
return
}
- } """
)
arg1 = np.array([11.0]).astype(bfloat16)
arg2 = np.array([12.0]).astype(bfloat16)
arg3 = np.array([0.0]).astype(bfloat16)
arg1_memref_ptr = ctypes.pointer(
ctypes.pointer(get_ranked_memref_descriptor(arg1))
)
arg2_memref_ptr = ctypes.pointer(
ctypes.pointer(get_ranked_memref_descriptor(arg2))
)
arg3_memref_ptr = ctypes.pointer(
ctypes.pointer(get_ranked_memref_descriptor(arg3))
)
execution_engine = ExecutionEngine(lowerToLLVM(module))
execution_engine.invoke(
"main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
)
# CHECK: [11.] + [22.] = [33.]
log("{0} + {1} = {2}".format(arg1, arg2, arg3))
# test to-numpy utility
# CHECK: [33.]
npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
log(npout)
- +run(testBF16MemrefAdd)
Test addition of two 2d_memref
def testDynamicMemrefAdd2D(): CHECK-LABEL: TEST: testDynamicMemrefAdd2D