[Q][GPU][BF16] torch.mul is lowered to HLO as an f32 multiply · Issue #8545 · pytorch/xla (original) (raw)
❓ Questions and Help
torch 2.5.1
torch_xla 2.5.1
cuda 12.4
GPU NVIDIA L4
The following example uses torch.mul
where both operands are bf16, but in the HLO graph, I see an f32 multiply operation.
export XLA_FLAGS="--xla_dump_to=/tmp/dump --xla_dump_hlo_as_text --xla_dump_hlo_pass_re=.*"
import torch
import torch_xla as xla
device = xla.device(0)
def foo(a, b):
y = torch.mul(a, b)
return y
a = torch.ones([5, 9216, 64], dtype=torch.bfloat16, device=device)
b = torch.ones([5, 9216, 64], dtype=torch.bfloat16, device=device)
y = foo(a, b)
print(y)
hlo: module_0000.SyncTensorsGraph.16.before_optimizations.txt
HloModule SyncTensorsGraph.16, entry_computation_layout={()->(bf16[5,9216,64]{2,1,0})}
ENTRY SyncTensorsGraph.16 {
constant.7 = bf16[] constant(1)
reshape.8 = bf16[1,1,1]{2,1,0} reshape(constant.7)
broadcast.9 = bf16[1,1,1]{2,1,0} broadcast(reshape.8), dimensions={0,1,2}
reshape.10 = bf16[] reshape(broadcast.9)
broadcast.11 = bf16[5,9216,64]{2,1,0} broadcast(reshape.10), dimensions={}
convert.12 = f32[5,9216,64]{2,1,0} convert(broadcast.11)
constant.1 = bf16[] constant(1)
reshape.2 = bf16[1,1,1]{2,1,0} reshape(constant.1)
broadcast.3 = bf16[1,1,1]{2,1,0} broadcast(reshape.2), dimensions={0,1,2}
reshape.4 = bf16[] reshape(broadcast.3)
broadcast.5 = bf16[5,9216,64]{2,1,0} broadcast(reshape.4), dimensions={}
convert.6 = f32[5,9216,64]{2,1,0} convert(broadcast.5)
multiply.13 = f32[5,9216,64]{2,1,0} multiply(convert.12, convert.6)
convert.14 = bf16[5,9216,64]{2,1,0} convert(multiply.13)
ROOT tuple.15 = (bf16[5,9216,64]{2,1,0}) tuple(convert.14)
} // SyncTensorsGraph.16
I was able to achieve bf16 multiplication by setting export XLA_USE_BF16=1
, but I received the following warning
XLA_USE_BF16 will be deprecated after the 2.5 release, please convert your model to bf16 directly
I'm not sure how I can enable bf16 multiplication in HLO (High-Level Optimizer) in the correct way, without using the deprecated flag.