Numerical Precision Difference (f32) between CPU (MKL) and GPU (cuDNN) (original) (raw)

Hello NVIDIA Developer Community,

I am conducting an experiment to understand the numerical precision differences between CPU and GPU operations in PyTorch, specifically for a simple CNN.

I have created a fully deterministic, minimal, reproducible example that eliminates all randomness from weight initialization and input data. My goal is to isolate the algorithmic differences between the CPU (presumably Intel MKL) and the GPU (cuDNN).

When I use “perfectly representable” inputs like 1.0 or 0.5, the float32 and float64 outputs from the CPU and GPU are bit-for-bit identical (zero difference).

However, when I use an input value that is not perfectly representable in binary (like 1/7), I see a divergence.

This experiment was run with torch.use_deterministic_algorithms(True) and torch.set_float32_matmul_precision('highest') to disable TF32 and ensure a fair float32 vs. float32 comparison.

That brings me to the question, is a float32 divergence of this magnitude (1.52e-05) expected, even with all deterministic settings enabled?

I assume this is due to cuDNN using different algorithms (e.g., Winograd, FFT, or just different summation paths) for the convolution compared to the CPU’s library, and the lower precision of float32 causes the initial approximation error of 1/7 to accumulate differently.

Could you confirm if this is expected behavior? And if so, could you provide any insight into which specific algorithms in cuDNN (for a 3x3 convolution) are the likely source of this different accumulation path compared to the CPU implementation?

Thank you for your time and expertise.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# This input value (1/7) is not perfectly representable in binary
# floating point, which helps expose precision differences.
input_val = 1/7

def run_deterministic_test(device, dtype):
    """
    Runs a forward pass with a fixed network, weights, and input
    on a specific device and dtype.
    
    This isolates algorithmic differences (e.g., Intel MKL vs. Nvidia cuDNN)
    by removing all initialization randomness.
    """
    print("\n" + "="*70)
    print(f"--- RUNNING TEST: Device={str(device).upper()} | Type={str(dtype)} ---")
    print("="*70)

    # 1. Network Definition
    class MiniNet(nn.Module):
        def __init__(self):
            super().__init__()
            # 3x3 kernel, 1 in channel, 4 out channels, no padding
            self.conv = nn.Conv2d(1, 4, kernel_size=3, padding=0)
            # Input (8x8) -> Conv (3x3) -> Output (6x6)
            # 4 channels * 6 * 6 = 144
            self.linear = nn.Linear(4 * 6 * 6, 2)
        
        def forward(self, x):
            x = F.relu(self.conv(x))
            x = x.view(x.size(0), -1) # Flatten
            return self.linear(x)

    # 2. Model Creation and Device Transfer
    model = MiniNet().to(device=device, dtype=dtype)
    
    # IMPORTANT: Set to evaluation mode
    model.eval()

    # 3. Manual Weight & Bias Initialization
    # Set all weights to 1.0 and biases to 0.0 for a predictable calculation.
    with torch.no_grad():
        nn.init.ones_(model.conv.weight)
        nn.init.zeros_(model.conv.bias)
        nn.init.ones_(model.linear.weight)
        nn.init.zeros_(model.linear.bias)

    # 4. Manual Input Creation
    inp = torch.full((1, 1, 8, 8), input_val, device=device, dtype=dtype)

    # 5. Forward Pass Execution
    print(f"Input value ({input_val:.17f}) as {str(dtype)}: {inp[0, 0, 0, 0].item():.17f}")
    
    with torch.no_grad():
        final_output = model(inp)
    
    print(f"\n--- FINAL OUTPUT ({str(device).upper()}) ---")
    print(final_output.cpu().numpy())
    
    # Return output to CPU for comparison
    return final_output.cpu()

# --- Main Execution Block ---
if __name__ == '__main__':
    # Force PyTorch to use deterministic algorithms
    torch.use_deterministic_algorithms(True)
    
    if torch.cuda.is_available():
        print("CUDA (GPU) available.")
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        
        # IMPORTANT: For float32 precision debugging
        # This setting forces full float32 precision.
        print("Setting matmul precision to 'highest' (disabling TF32)...")
        torch.set_float32_matmul_precision('highest')
    else:
        print("WARNING: CUDA (GPU) not found. Test will be limited.")

    # Set Numpy print precision
    np.set_printoptions(precision=17, suppress=True)

    # --- Float64 (Double Precision) Test ---
    print("\n\n" + "#"*80)
    print("### STARTING FLOAT64 (Double Precision) COMPARISON ###")
    print("#"*80)
    
    out_cpu64 = run_deterministic_test('cpu', torch.float64)
    
    if torch.cuda.is_available():
        out_gpu64 = run_deterministic_test('cuda', torch.float64)
        print("\n--- Comparison Result (Float64) ---")
        diff64 = torch.abs(out_cpu64 - out_gpu64)
        print(f"Max Absolute Difference (f64): {torch.max(diff64).item():.17e}")
    
    # --- Float32 (Single Precision) Test ---
    print("\n\n" + "#"*80)
    print("### STARTING FLOAT32 (Single Precision) COMPARISON ###")
    print("#"*80)
    
    out_cpu32 = run_deterministic_test('cpu', torch.float32)

    if torch.cuda.is_available():
        out_gpu32 = run_deterministic_test('cuda', torch.float32)
        print("\n--- Comparison Result (Float32) ---")
        diff32 = torch.abs(out_cpu32 - out_gpu32)
        print(f"Max Absolute Difference (f32): {torch.max(diff32).item():.17e}")
    
    print("\n" + "="*80)
    print("TEST COMPLETE.")

Environment:

Output:

################################################################################
### STARTING FLOAT64 (Double Precision) COMPARISON ###
################################################################################

======================================================================
--- RUNNING TEST: Device=CPU | Type=torch.float64 ---
======================================================================
Input value (0.14285714285714285) as torch.float64: 0.14285714285714285

--- FINAL OUTPUT (CPU) ---
[[185.14285714285708 185.14285714285708]]

======================================================================
--- RUNNING TEST: Device=CUDA | Type=torch.float64 ---
======================================================================
Input value (0.14285714285714285) as torch.float64: 0.14285714285714285

--- FINAL OUTPUT (CUDA) ---
[[185.14285714285703 185.14285714285703]]

--- Comparison Result (Float64) ---
Max Absolute Difference (f64): 5.68434188608080149e-14


################################################################################
### STARTING FLOAT32 (Single Precision) COMPARISON ###
################################################################################

======================================================================
--- RUNNING TEST: Device=CPU | Type=torch.float32 ---
======================================================================
Input value (0.14285714285714285) as torch.float32: 0.14285714924335480

--- FINAL OUTPUT (CPU) ---
[[185.14287 185.14287]]

======================================================================
--- RUNNING TEST: Device=CUDA | Type=torch.float32 ---
======================================================================
Input value (0.14285714285714285) as torch.float32: 0.14285714924335480

--- FINAL OUTPUT (CUDA) ---
[[185.14285 185.14285]]

--- Comparison Result (Float32) ---
Max Absolute Difference (f32): 1.52587890625000000e-05

================================================================================
TEST COMPLETE.