Skip to content

Commit db24b3b

Browse files
committed
feat: data parallel inference sample
1 parent 9ee99fc commit db24b3b

File tree

5 files changed

+148
-0
lines changed

5 files changed

+148
-0
lines changed

docsrc/index.rst

+6
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,13 @@ Tutorials
111111
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
112112
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
113113
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
114+
<<<<<<< HEAD
114115
tutorials/_rendered_examples/dynamo/custom_kernel_plugins
116+
=======
117+
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
118+
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
119+
120+
>>>>>>> dfbf6ea84 (feat: data parallel inference sample)
115121

116122
Python API Documenation
117123
------------------------
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Torch-TensorRT parallelism for distributed inference
2+
3+
Examples in this folder demonstrates doing distributed inference on multiple devices with Torch-TensorRT backend.
4+
5+
1. Data parallel distributed inference based on [Acclerate](https://huggingface.co/docs/accelerate/usage_guides/distributed_inference)
6+
7+
Using Accelerate users can achieve data parallel distributed inference with Torch-TensorRt backend. In this case, the entire model
8+
will be loaded onto each GPU and different chunks of batch input is processed on each device.
9+
10+
See the examples started with `data_parallel` for more details.
11+
12+
2. Tensor parallel distributed inference
13+
14+
In development.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
.. _data_parallel_gpt2:
3+
4+
Torch-TensorRT Distributed Inference
5+
======================================================
6+
7+
This interactive script is intended as a sample of distributed inference using data
8+
parallelism using Accelerate
9+
library with the Torch-TensorRT workflow on GPT2 model.
10+
11+
"""
12+
13+
# %%
14+
# Imports and Model Definition
15+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
16+
17+
import torch
18+
from accelerate import PartialState
19+
from transformers import AutoTokenizer, GPT2LMHeadModel
20+
21+
import torch_tensorrt
22+
23+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
24+
25+
# Set input prompts for different devices
26+
prompt1 = "GPT2 is a model developed by."
27+
prompt2 = "Llama is a model developed by "
28+
29+
input_id1 = tokenizer(prompt1, return_tensors="pt").input_ids
30+
input_id2 = tokenizer(prompt2, return_tensors="pt").input_ids
31+
32+
distributed_state = PartialState()
33+
34+
# Import GPT2 model and load to distributed devices
35+
model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(distributed_state.device)
36+
37+
38+
# Instantiate model with Torch-TensorRT backend
39+
model.forward = torch.compile(
40+
model.forward,
41+
backend="torch_tensorrt",
42+
options={
43+
"truncate_long_and_double": True,
44+
"enabled_precisions": {torch.float16},
45+
"debug": True,
46+
},
47+
dynamic=False,
48+
)
49+
50+
# %%
51+
# Inference
52+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
53+
54+
# Assume there are 2 processes (2 devices)
55+
with distributed_state.split_between_processes([input_id1, input_id2]) as prompt:
56+
cur_input = torch.clone(prompt[0]).to(distributed_state.device)
57+
58+
gen_tokens = model.generate(
59+
cur_input,
60+
do_sample=True,
61+
temperature=0.9,
62+
max_length=100,
63+
)
64+
gen_text = tokenizer.batch_decode(gen_tokens)[0]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
.. _data_parallel_stable_diffusion:
3+
4+
Torch-TensorRT Distributed Inference
5+
======================================================
6+
7+
This interactive script is intended as a sample of distributed inference using data
8+
parallelism using Accelerate
9+
library with the Torch-TensorRT workflow on Stable Diffusion model.
10+
11+
"""
12+
13+
# %%
14+
# Imports and Model Definition
15+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
16+
import torch
17+
from accelerate import PartialState
18+
from diffusers import DiffusionPipeline
19+
20+
import torch_tensorrt
21+
22+
model_id = "CompVis/stable-diffusion-v1-4"
23+
24+
# Instantiate Stable Diffusion Pipeline with FP16 weights
25+
pipe = DiffusionPipeline.from_pretrained(
26+
model_id, revision="fp16", torch_dtype=torch.float16
27+
)
28+
29+
distributed_state = PartialState()
30+
pipe = pipe.to(distributed_state.device)
31+
32+
backend = "torch_tensorrt"
33+
34+
# Optimize the UNet portion with Torch-TensorRT
35+
pipe.unet = torch.compile( # %%
36+
# Inference
37+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
38+
# Assume there are 2 processes (2 devices)
39+
pipe.unet,
40+
backend=backend,
41+
options={
42+
"truncate_long_and_double": True,
43+
"precision": torch.float16,
44+
"debug": True,
45+
"use_python_runtime": True,
46+
},
47+
dynamic=False,
48+
)
49+
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
50+
51+
52+
# %%
53+
# Inference
54+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
55+
56+
# Assume there are 2 processes (2 devices)
57+
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
58+
print("before \n")
59+
result = pipe(prompt).images[0]
60+
print("after ")
61+
result.save(f"result_{distributed_state.process_index}.png")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
accelerate
2+
transformers
3+
diffusers

0 commit comments

Comments
 (0)