CUDA OOM leads to unhandled thrust::system exception · Issue #357 · NVIDIA/MinkowskiEngine (original) (raw)
Describe the bug
ME raises a C++ thrust::system::system_error
exception which cannot be handled from Python and crashes the program. This issue is raised non-deterministically during training (especially in long running trainings after a few days) and cannot be caught from Python leading to a failing training pipeline.
As parallel_for
is not used directly in the repo, most likely one of the functions in MinkowskiConvolution
use a thrust
builtin function which utilizes it. This function call should be wrapped with THRUST_CHECK
like CUDA_CHECK
to create an exception which can be interpreted in Python.
To Reproduce
The problem is GPU dependent, the below code is deterministically producing the error on a 16 GB Tesla V100 GPU. To reproduce on other GPUs (mostly dependent on VRAM size), one needs to find the optimal point_count
in the below code.
import MinkowskiEngine as ME
import torch
import torch.nn as nn
from MinkowskiEngine import SparseTensor
class TestNet(ME.MinkowskiNetwork):
def __init__(self, in_feat, out_feat, D, layers=80):
super(TestNet, self).__init__(D)
convs = [out_feat for _ in range(layers)]
self.convs = []
prev = in_feat
for outchannels in convs:
layer = nn.Sequential(
ME.MinkowskiConvolution(
in_channels=prev,
out_channels=outchannels,
kernel_size=3,
stride=2,
dilation=1,
bias=True,
dimension=D,
),
ME.MinkowskiReLU(),
)
self.convs.append(layer)
prev = outchannels
self.relu = ME.MinkowskiReLU()
def forward(self, x):
temp = x
for convlayer in self.convs:
temp = convlayer(temp)
return temp
def cuda(self):
super(TestNet, self).cuda()
self.convs = [c.cuda() for c in self.convs]
return self
point_count = 6000000
in_channels, out_channels, D = 2, 3, 3
coords, feats = (
torch.randint(low=-1000, high=1000, size=(point_count, D + 1)).int().cuda(),
torch.rand(size=(point_count, in_channels)).cuda(),
)
coords[:, 0] = 0
testnetwork = TestNet(in_channels, 32, 3).cuda()
for i in range(5):
print(f"starting {i}")
xt = SparseTensor(feats, coordinates=coords, device="cuda")
torch.cuda.synchronize()
print("run forward")
res = testnetwork(xt)
loss = res.F.sum()
torch.cuda.synchronize()
print("run backward")
loss.backward()
Expected behavior
A thrust::system::system_error
exception should be converted to a Python RuntimeError
or MemoryError
so that it can be caught with a try .. except
block in Python.
Server (running inside Nvidia Docker):
==========System==========
Linux-5.4.0-1047-aws-x86_64-with-glibc2.10
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.5 LTS"
3.8.5 (default, Sep 4 2020, 07:30:14)
[GCC 7.3.0]
==========Pytorch==========
1.7.1
torch.cuda.is_available(): True
==========NVIDIA-SMI==========
/usr/bin/nvidia-smi
Driver Version 460.73.01
CUDA Version 11.2
VBIOS Version 88.00.4F.00.09
Image Version G503.0201.00.03
==========NVCC==========
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89
==========CC==========
/usr/bin/c++
c++ (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
==========MinkowskiEngine==========
0.5.4 (master of 05/26/2021)
MinkowskiEngine compiled with CUDA Support: True
NVCC version MinkowskiEngine is compiled: 10020
CUDART version MinkowskiEngine is compiled: 10020