feat: data parallel inference examples by bowang007 · Pull Request #2805 · pytorch/TensorRT (original) (raw)
--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/data_parallel_gpt2.py 2024-05-02 00:29:27.054073+00:00 +++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/data_parallel_gpt2.py 2024-05-02 00:31:18.785078+00:00 @@ -13,12 +13,26 @@
distributed_state = PartialState()
model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(distributed_state.device)
-model.forward = torch.compile(model.forward, backend="torch_tensorrt", options={"truncate_long_and_double": True, "enabled_precisions": {torch.float16}, "debug": True}, dynamic=False,) +model.forward = torch.compile(
- model.forward,
- backend="torch_tensorrt",
- options={
"truncate_long_and_double": True,
"enabled_precisions": {torch.float16},
"debug": True,
- },
- dynamic=False,
+)
with distributed_state.split_between_processes([input_id1, input_id2]) as prompt: cur_input = torch.clone(prompt[0]).to(distributed_state.device)
- gen_tokens = model.generate(cur_input, do_sample=True, temperature=0.9, max_length=100,)
- gen_tokens = model.generate(
cur_input,
do_sample=True,
temperature=0.9,
max_length=100,
- )
gen_text = tokenizer.batch_decode(gen_tokens)[0]