[Feature Request][XLA] Support fallback for the dynamo-xla bridge (original) (raw)
🐛 Describe the bug
With both #87741 and pytorch/xla#4119 landed, we can run dynamo inference with pytorch/xla like
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch._dynamo as dynamo
@dynamo.optimize("torchxla_trace_once")
def fn_simple(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b
x = torch.tensor(100.0)
y = torch.tensor(200.0)
res = fn_simple(x, y)
However if within the function we are trying to trace, there is a op that xla does not support, the program will crashed. For example xla does not support addmm
with beta != 1.0
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch._dynamo as dynamo
@dynamo.optimize("torchxla_trace_once")
def fn_fallback(M, mat1, mat2, beta):
# xla currently only support alpha and beta == 1
return torch.addmm(M, mat1, mat2, beta=beta)
M = torch.randn(2, 3)
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
res = fn_fallback(M, mat1, mat2, 0.5)
Error message is attached below.
IMO there are two ways to handle this issue
- PyTorch/XLA(or any backend) finds a way to tell the Dynamo about its supported op list(including supported through decomposition) and dynamo will do the graph break at unsupported op.
- PyTorch/XLA(or any backend) takes the whole model and generate single hash despite there will be more than one graph being created and silently handle fallback within the backend.
First approaches seems more general and might be cleaner.
In terms of types of fallback, common ones are
- certain ops are not lowered to xla, this can be find out by looking at the yaml file
- op itself is supported, but xla can't handle certain input shapes or certain input values. For example XLA can't handle addmm with
beta != 0
. XLA also can't handle overlapping windows for _adaptive_avg_pool3d.
FYI @wconstab @shunting314 @Krovatkin @alanwaketan @desertfire @wonjoolee95
Error logs
[2022-11-02 00:10:51,561] torch._dynamo.optimizations.backends: [ERROR] torchxla_trace_once error
Traceback (most recent call last):
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/optimizations/backends.py", line 53, in inner
return fn(model, **kwargs)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/optimizations/backends.py", line 823, in torchxla_trace_once
return integration.extract_compiled_graph(model, example_inputs)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/optimizations/torchxla_integration.py", line 94, in extract_compiled_graph
f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}"
RuntimeError: Fail to extact the compiled graph because of fallback: aten::addmm=1
Traceback (most recent call last):
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/output_graph.py", line 436, in call_user_compiler
assert callable(compiled_fn), "compiler_fn did not return callable"
AssertionError: compiler_fn did not return callable
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/test/test_dynamo.py", line 26, in <module>
res = fn_fallback(M, mat1, mat2, 0.5)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/eval_frame.py", line 157, in _fn
return fn(*args, **kwargs)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/eval_frame.py", line 236, in catch_errors
return callback(frame, cache_size)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/convert_frame.py", line 466, in _convert_frame
result = inner_convert(frame, cache_size)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/convert_frame.py", line 118, in _fn
return fn(*args, **kwargs)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/utils.py", line 92, in time_wrapper
r = func(*args, **kwargs)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/convert_frame.py", line 348, in _convert_frame_assert
frame,
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/convert_frame.py", line 394, in _compile
out_code = transform_code_object(code, transform)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/convert_frame.py", line 382, in transform
tracer.run()
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/symbolic_convert.py", line 1452, in run
super().run()
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/symbolic_convert.py", line 352, in run
and self.step()
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/symbolic_convert.py", line 322, in step
getattr(self, inst.opname)(inst)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/symbolic_convert.py", line 1514, in RETURN_VALUE
self.output.compile_subgraph(self)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/output_graph.py", line 332, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/output_graph.py", line 402, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_dynamo/output_graph.py", line 439, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: torchxla_trace_once raised AssertionError: compiler_fn did not return callable
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torchdynamo.config.suppress_errors = True
==========
real error is torch._dynamo.exc.BackendCompilerFailed: torchxla_trace_once raised AssertionError: compiler_fn did not return callable
Minified repro
No response