INT8 convolution using cuDNN Python Frontend (original) (raw)
Hi,
We are working on bringing a simple INT8 conv2d operator into PyTorch using the python cuDNN Frontend (version 1.14, backend 90501).
However, when adapting the sample FP16 convolution notebook (00_introduction.ipynb) to INT8, we get wrong results compared to PyTorch’s conv2d:
pytorch: tensor([[[[ 10581, -49822, 9887],
[ -5654, 11015, -20480],
[ -5404, 9559, -1994]]]], device='cuda:0', dtype=torch.int32)
cudnn: tensor([[[[-2139127681, 2139127935, 128],
[ 0, 0, 0],
[ 0, 0, 0]]]], device='cuda:0',
dtype=torch.int32)
Here is the converted sample code we used :
import cudnn
import torch
print(cudnn.backend_version())
handle = cudnn.create_handle()
graph = cudnn.pygraph(
handle=handle,
name="cudnn_graph_0",
io_data_type=cudnn.data_type.INT8,
compute_data_type=cudnn.data_type.INT32,
)
X = graph.tensor(
name="X",
dim=[1, 1, 5, 5],
stride=[5 * 5 * 1, 1, 5 * 1, 1],
data_type=cudnn.data_type.INT8,
)
W = graph.tensor(name="W", dim=[1, 1, 3, 3], stride=[3 * 3 * 1, 1, 3 * 1, 1])
Y = graph.conv_fprop(
X,
W,
padding=[0, 0],
stride=[1, 1],
dilation=[1, 1],
compute_data_type=cudnn.data_type.INT32,
)
Y.set_output(True)
graph.build([cudnn.heur_mode.A])
X_gpu = torch.randint(-128, 127, (
1, 1, 5, 5), requires_grad=False, device="cuda", dtype=torch.int8
).to(memory_format=torch.channels_last)
W_gpu = torch.randint(-128, 127, (
1, 1, 3, 3), requires_grad=False, device="cuda", dtype=torch.int8
).to(memory_format=torch.channels_last)
Y_gpu = torch.zeros(
1, 1, 3, 3, requires_grad=False, device="cuda", dtype=torch.int32
).to(memory_format=torch.channels_last)
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.int32)
graph.execute({X: X_gpu, W: W_gpu, Y: Y_gpu}, workspace, handle=handle)
truth = torch.nn.functional.conv2d(X_gpu.to(torch.float32), W_gpu.to(torch.float32)).to(torch.int32)
print("pytorch:", truth)
print("cudnn:", Y_gpu)
This is surprising as PyTorch is also using the NCHW representation and the stride computation lookd correct. The values are correctly aligned on 16 bits.
The GPU used is a RTX 4080, if that’s relevant.
Are we missing a step in or is there a known limitation for INT8 foward prop ?
Thanks a lot for the help.