torch.mps.compile_shader — PyTorch 2.7 documentation (original) (raw)
torch.mps.compile_shader(source)[source][source]¶
Compiles compute shader from source and allows one to invoke kernels defined there from the comfort of Python runtime Example:
lib = torch.mps.compile_shader( ... "kernel void full(device float* out, constant float& val, uint idx [[thread_position_in_grid]]) { out[idx] = val; }" ... ) x = torch.zeros(16, device="mps") lib.full(x, 3.14)