Skip to content

feat: [AutoDeploy] generalizing cudagraph to multiple dynamic inputs #3589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/auto_deploy/build_and_run_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ def main(config: Optional[SimpleConfig] = None):
print_outputs(outs)

# run a benchmark for the model with batch_size == config.benchmark_bs
if config.benchmark:
if config.benchmark and config.runtime != "demollm":
ad_logger.warning(
f"Benchmarking with {config.runtime=} not supported. Please use `demollm` instead for "
"quick benchmarking and `trtllm-bench` for full benchmarking."
)
elif config.benchmark:
token_ids = torch.randint(0, 100, (config.benchmark_bs, config.benchmark_isl)).tolist()
sampling_params = SamplingParams(max_tokens=config.benchmark_osl, top_k=None)
keys = ["compile_backend", "attn_backend", "mla_backend"]
Expand Down
98 changes: 67 additions & 31 deletions tensorrt_llm/_torch/auto_deploy/compile/backends/torch_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,22 @@

class CompiledGraph(nn.Module):
def __init__(
self, model: GraphModule, max_batch_size: int, cuda_graph_batch_sizes: List[int] = None
self,
model: GraphModule,
max_batch_size: int,
cuda_graph_batch_sizes: List[int] = None,
num_batched_inputs: Optional[int] = 1, # number of batched, dynamic inputs...
):
super().__init__()
self._in_spec: TreeSpec = model._in_spec
self._out_spec: TreeSpec = model._out_spec
self.gm_compiled = torch.compile(model, dynamic=True)
self.max_batch_size = max_batch_size
self.num_batched_inputs = num_batched_inputs if num_batched_inputs is not None else 1
self.graphs: Dict[Tuple[int, ...], CUDAGraph] = {}
self._input_buffer: torch.Tensor = torch.empty(0, 1)
self._input_buffers: List[torch.Tensor] = [
torch.empty(0, 1) for _ in range(self.num_batched_inputs)
]
self._out_buffer_flat: List[torch.Tensor] = None
self._args_hash: Optional[Tuple[int, ...]] = None
self.cuda_graph_batch_sizes = (
Expand All @@ -42,6 +49,10 @@ def round_up_to_closest(batch_sizes: Iterable[int], bs: int) -> Optional[int]:
return None
return min(batch_sizes, key=lambda x: (x < bs, abs(x - bs)), default=None)

def round_to_cuda_batch_size(self, bs: int) -> int:
"""Round batch size to the nearest cuda batch size."""
return self.round_up_to_closest(self.cuda_graph_batch_sizes, bs)

def _capture_one_graph(self, *args, **kwargs) -> torch.cuda.CUDAGraph:
"""Capture and return one cuda graph."""
# warm-up
Expand Down Expand Up @@ -78,17 +89,31 @@ def _get_graph_batch_sizes(
# return as sorted list
return sorted(batch_sizes)

def _capture_cudagraph(self, input_t: torch.Tensor, flat_args: List[Any]):
"""Capture graph for variable batch size."""
# set the args hash --> this is used to compare the inputs during graph replay
self._args_hash = self._get_hash(flat_args)
def capture_graph(self, *args, **kwargs):
"""Capture and pre-fetch the graph for variable batch size."""
# flatten args, kwargs
all_args_flat = _flatten_args(self._in_spec, *args, **kwargs)

# extract the batched input tensors
args_batched = all_args_flat[: self.num_batched_inputs]
args_static = all_args_flat[self.num_batched_inputs :]

# set the input buffer to the max needed batch size with rest of shape as is
assert self.max_batch_size >= input_t.shape[0], "Max batch size too small."
self._input_buffer = input_t[:1].repeat_interleave(self.max_batch_size, dim=0)
# set the args hash --> this is used to compare the static inputs during graph replay
self._args_hash = self._get_hash(args_static)

# unflatten args, kwargs
args, kwargs = self._in_spec.unflatten([self._input_buffer] + flat_args)
# sanity checks on the batched inputs
msg_bs = "Max batch size too small."
msg_ndim = "Expecting at least a 2D for batched input tensors."
assert all(self.max_batch_size >= input.shape[0] for input in args_batched), msg_bs
assert all(input.ndim > 1 for input in args_batched), msg_ndim

# repeat the batched input tensors to the max batch size
self._input_buffers = [
input[:1].repeat_interleave(self.max_batch_size, dim=0) for input in args_batched
]

# create new args, kwargs with the input buffers and static args
args, kwargs = self._in_spec.unflatten(self._input_buffers + args_static)

# capture output once with max batch size to capture output buffers
with CudaGraphWarmUpPhase():
Expand All @@ -101,35 +126,46 @@ def _capture_cudagraph(self, input_t: torch.Tensor, flat_args: List[Any]):
ad_logger.info(f"Capturing graph for batch size: {bs}")

# setup args, kwargs
input_truncated = self._input_buffer[:bs]
args, kwargs = self._in_spec.unflatten([input_truncated, *flat_args])
inputs_truncated = [in_buffer[:bs] for in_buffer in self._input_buffers]
args, kwargs = self._in_spec.unflatten(inputs_truncated + args_static)

# capture graph
self.graphs[input_truncated.shape] = self._capture_one_graph(*args, **kwargs)

def capture_graph(self, *args, **kwargs):
"""Capture and pre-fetch the graph."""
input_t, flat_args = _flatten_args(self._in_spec, *args, **kwargs)
self._capture_cudagraph(input_t, flat_args)
# capture graph for truncated inputs
combined_shape = sum((input.shape for input in inputs_truncated), start=())
self.graphs[combined_shape] = self._capture_one_graph(*args, **kwargs)

def forward(self, *args, **kwargs) -> Any:
"""Run the compiled graph."""
input_t, flat_args = _flatten_args(self._in_spec, *args, **kwargs)
bs, *other_dims = input_t.shape
# flatten args, kwargs
all_args_flat = _flatten_args(self._in_spec, *args, **kwargs)

# round up batch size and construct rounded up shape
bs_graph = self.round_up_to_closest([shapes[0] for shapes in self.graphs.keys()], bs)
shape_rounded_up = (bs_graph, *other_dims)
# extract the batched input tensors
args_batched = all_args_flat[: self.num_batched_inputs]
args_static = all_args_flat[self.num_batched_inputs :]

# regular forward for non-matching shapes or non-matching flat_args
if shape_rounded_up not in self.graphs or self._args_hash != self._get_hash(flat_args):
# check if args_static match the stored hash
if self._args_hash != self._get_hash(args_static):
return self.gm_compiled(*args, **kwargs)

# Calculate rounded-up shapes for each input
rounded_shapes = [
(self.round_to_cuda_batch_size(input.shape[0]),) + input.shape[1:]
for input in args_batched
]
combined_shape = sum(rounded_shapes, start=())

# regular forward for non-matching shapes
if combined_shape not in self.graphs:
return self.gm_compiled(*args, **kwargs)

# copy inputs to input buffers
for i, input_tensor in enumerate(args_batched):
self._input_buffers[i][: input_tensor.shape[0]] = input_tensor

# run forward pass via graph
self._input_buffer[:bs] = input_t
self.graphs[shape_rounded_up].replay()
self.graphs[combined_shape].replay()

# retrieve output from buffer, cut to batch size, and unflatten
bs = args_batched[0].shape[0]
out_flat = [o_b[:bs].detach().clone() for o_b in self._out_buffer_flat]
return self._out_spec.unflatten(out_flat)

Expand All @@ -138,11 +174,11 @@ def forward(self, *args, **kwargs) -> Any:
class TorchOptCompiler(BackendCompiler):
@torch.inference_mode()
def compile(self) -> CompiledGraph:
cuda_graph_batch_sizes = self.compiler_kwargs.get("cuda_graph_batch_sizes", None)
compiled_gm = CompiledGraph(
self.gm,
max_batch_size=self.max_batch_size,
cuda_graph_batch_sizes=cuda_graph_batch_sizes,
cuda_graph_batch_sizes=self.compiler_kwargs.get("cuda_graph_batch_sizes"),
num_batched_inputs=self.compiler_kwargs.get("num_batched_inputs"),
)

# try capturing cudagraph
Expand Down
12 changes: 5 additions & 7 deletions tensorrt_llm/_torch/auto_deploy/compile/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import torch.nn as nn
from torch.fx import GraphModule
from torch.fx._pytree import tree_flatten_spec
Expand All @@ -16,12 +15,10 @@
from ..utils.logger import ad_logger


def _flatten_args(in_spec, *args, **kwargs) -> Tuple[torch.Tensor, List[Any]]:
def _flatten_args(in_spec, *args, **kwargs) -> List[Any]:
"""Flatten inputs from in_spec where we assume the first input is the main input tensor."""
all_args: PyTree = (args, kwargs)
input_t, *flat_args = tree_flatten_spec(all_args, in_spec)
assert input_t.ndim > 1, "Expecting at least a 2D input tensor."
return input_t, flat_args
return tree_flatten_spec(all_args, in_spec)


class BackendRegistry:
Expand Down Expand Up @@ -66,8 +63,9 @@ def __init__(
if self.dynamic_shapes is not None and 0 in self.dynamic_shapes[0]:
self.max_batch_size = self.dynamic_shapes[0][0].max
else:
idxs, *_ = _flatten_args(self.gm._in_spec, *self.args, **self.kwargs)
self.max_batch_size = idxs.shape[0]
# NOTE: we assume the first input is the main input tensor with batch dimension
batched_input, *_ = _flatten_args(self.gm._in_spec, *self.args, **self.kwargs)
self.max_batch_size = batched_input.shape[0]

@abstractmethod
def compile(self) -> nn.Module:
Expand Down
5 changes: 4 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/transformations/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ def __call__(self, cm: CachedSequenceInterface) -> GraphModule:
############################################################################################

cm.info._set_generate_only_batch()
compiler_kwargs = {"cuda_graph_batch_sizes": self.ad_config.cuda_graph_batch_sizes}
compiler_kwargs = {
"cuda_graph_batch_sizes": self.ad_config.cuda_graph_batch_sizes,
"num_batched_inputs": 1, # TODO (lucaslie): improve once we have a config system...
}
egm_compiled = compile_and_capture(
egm,
self.compile_backend,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"model_kwargs": {"num_hidden_layers": 2},
},
),
# small llama3.1-8B model with world_size 2 + trtllm runtime
# small llama3.1-8B model with world_size 2 + trtllm runtime + torch-opt
(
2,
{
Expand All @@ -36,7 +36,7 @@
),
"runtime": "trtllm",
"attn_backend": "TritonWithFlattenedInputs",
"compile_backend": "torch-simple",
"compile_backend": "torch-opt",
"model_kwargs": {"num_hidden_layers": 2},
},
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@
from tensorrt_llm._torch.auto_deploy.transformations.export import torch_export_to_gm


class ModelWithMultipleInputs(torch.nn.Module):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model

def forward(self, x0, x1=None, x2=None):
out = self.base_model(x0)
if x1 is not None:
out = out + self.base_model(x1)
if x2 is not None:
out = out + self.base_model(x2)
return out


# Using pytest.mark.parametrize to test multiple cases
@pytest.mark.parametrize(
"lst, value, expected",
Expand All @@ -31,60 +45,98 @@ def test_round_up_to_closest(lst, value, expected):
assert CompiledGraph.round_up_to_closest(lst, value) == expected


@pytest.mark.parametrize("num_inputs", [1, 2, 3])
@pytest.mark.parametrize(
"model_type, model_cls, input_shape, captured_shape_fn, atol",
"model_type, model_cls, input_shape, atol",
[
("llm", TransformerLikeModel, (32, 10), lambda b, s: (b, s), 1e-5),
("vit", VisionTransformerLikeModel, (32, 4096, 16), lambda b, s, c: (b, s, c), 1e-3),
("llm", TransformerLikeModel, (32, 10), 1e-5),
("vit", VisionTransformerLikeModel, (32, 4096, 16), 1e-3),
],
)
def test_cudagraph_capture_replay(model_type, model_cls, input_shape, captured_shape_fn, atol):
def test_cudagraph_capture_replay(model_type, model_cls, input_shape, atol, num_inputs):
batch_size, *seq_shape = input_shape

if model_type == "llm":
vocab_size = 100 # Vocabulary size
embed_dim = 32 # Embedding dimension
hidden_dim = 64 # Hidden layer dimension
model = model_cls(vocab_size, embed_dim, hidden_dim).to("cuda")
input_data = torch.randint(0, vocab_size, input_shape).to("cuda")
captured_shape = captured_shape_fn(batch_size, seq_shape[0])
base_model = model_cls(vocab_size, embed_dim, hidden_dim).to("cuda")
model = ModelWithMultipleInputs(base_model).to("cuda")

# Create inputs for the model
input_data = [
torch.randint(0, vocab_size, input_shape).to("cuda") for _ in range(num_inputs)
]

elif model_type == "vit":
channels = 16 # Number of channels
hidden_dim = 64 # Hidden layer dimension
model = model_cls(channels, hidden_dim).to("cuda")
input_data = torch.randn(*input_shape).to("cuda")
captured_shape = captured_shape_fn(batch_size, seq_shape[0], channels)
base_model = model_cls(channels, hidden_dim).to("cuda")
model = ModelWithMultipleInputs(base_model).to("cuda")

# Create inputs for the model
input_data = [torch.randn(*input_shape).to("cuda") for _ in range(num_inputs)]

combined_shape = input_shape * num_inputs

model.eval()
dynamic_shapes = generate_dynamic_shapes(batch_size, seq_shape[0])
graph_module = torch_export_to_gm(model, args=(input_data,), dynamic_shapes=dynamic_shapes)
compiled_model = CompiledGraph(graph_module, max_batch_size=batch_size)
dynamic_shapes = generate_dynamic_shapes(batch_size, seq_shape[0]) * num_inputs

# Prepare args - include only the number of inputs needed
args = tuple(input_data[:num_inputs])
print(args)
print(dynamic_shapes)

graph_module = torch_export_to_gm(model, args=args, dynamic_shapes=dynamic_shapes)
compiled_model = CompiledGraph(
graph_module, max_batch_size=batch_size, num_batched_inputs=num_inputs
)

with torch.inference_mode():
full_args = (input_data,)
compiled_model.capture_graph(*full_args)
# Capture graph with all inputs
compiled_model.capture_graph(*args)

# Ensure the graph is stored for the batch size
assert captured_shape in compiled_model.graphs, "Graph for batch size was not captured."
# Ensure the graph is stored for the combined shape of all inputs
assert combined_shape in compiled_model.graphs, (
f"Graph for combined shape {combined_shape} was not captured."
)

# Create smaller inputs for replay
if model_type == "llm":
replay_input_data = [x[:, :1] for x in input_data[:num_inputs]]
else: # vit
replay_input_data = [x[:, :1, :] for x in input_data[:num_inputs]]

# Prepare replay args - include only the number of inputs needed
replay_args = tuple(replay_input_data)

# Get flat inputs for manual replay
all_args_flat = _flatten_args(compiled_model._in_spec, *replay_args)
input_args_flat = all_args_flat[:num_inputs] # Extract just the batched inputs

input_data_replay = input_data[:, :1] if model_type == "llm" else input_data[:, :1, :]
# Update input buffers for replay
for i, input_tensor in enumerate(input_args_flat):
compiled_model._input_buffers[i][: input_tensor.shape[0]] = input_tensor

graph = compiled_model.graphs[captured_shape]
input_data_flatten, _ = _flatten_args(compiled_model._in_spec, input_data_replay)
compiled_model._input_buffer[:] = input_data_flatten # Update input buffer
# Get the appropriate graph and replay
graph = compiled_model.graphs[combined_shape]
graph.replay()

# Get output from manual replay
replay_output = compiled_model._out_spec.unflatten(
[buf[:batch_size].detach().clone() for buf in compiled_model._out_buffer_flat]
)
replay_output2 = compiled_model.forward(input_data_replay)

# Get output from forward method
replay_output2 = compiled_model.forward(*replay_args)

# Compare outputs
assert torch.allclose(replay_output, replay_output2, atol=atol), (
"CUDAGraph replay output mismatch"
)

original_output = compiled_model.gm_compiled(input_data_replay)

# Compare with original model output
original_output = compiled_model.gm_compiled(*replay_args)
assert torch.allclose(original_output, replay_output, atol=atol), (
"CUDAGraph replay output mismatch"
)
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def _generate_ds_attention_mask(b, s):
)


# TODO (svelury): update unit test to run fast
@pytest.mark.skip(reason="TODO: too slow for a unit test")
@pytest.mark.parametrize(
"model_name, module_name, patch, yarn, inputs",
[
Expand Down
Loading