feat: data parallel inference examples by bowang007 · Pull Request #2805 · pytorch/TensorRT (original) (raw)

--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/data_parallel_gpt2.py 2024-05-02 00:29:27.054073+00:00 +++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/data_parallel_gpt2.py 2024-05-02 00:31:18.785078+00:00 @@ -13,12 +13,26 @@

distributed_state = PartialState()

model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(distributed_state.device)

-model.forward = torch.compile(model.forward, backend="torch_tensorrt", options={"truncate_long_and_double": True, "enabled_precisions": {torch.float16}, "debug": True}, dynamic=False,) +model.forward = torch.compile(

+)

with distributed_state.split_between_processes([input_id1, input_id2]) as prompt: cur_input = torch.clone(prompt[0]).to(distributed_state.device)

gen_text = tokenizer.batch_decode(gen_tokens)[0]