Skip to content

feat: Support CUDA graphs for EAGLE3 #3176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 16, 2025
Merged

Conversation

mikeiovine
Copy link
Collaborator

@mikeiovine mikeiovine commented Mar 31, 2025

Support CUDA graphs for the EAGLE-3 spec decode. Also contains fixes for loading eagle 3 weights for llama3 70B models (previously only tested for 8b).

The graphs significantly improve the performance. However, we still have a lot of work to do to eliminate the host overheads.

@juney-nvidia juney-nvidia changed the title Support CUDA graphs for EAGLE3 feat: Support CUDA graphs for EAGLE3 Apr 1, 2025
@mikeiovine mikeiovine force-pushed the eagle3-graphs branch 8 times, most recently from 1b6ec07 to 0a61d1c Compare April 11, 2025 20:27
@mikeiovine mikeiovine requested review from hlu1, lfr-0531 and QiJune April 11, 2025 21:13
@mikeiovine mikeiovine marked this pull request as ready for review April 11, 2025 21:15
@mikeiovine
Copy link
Collaborator Author

Putting this out to get early feedback and ideas. I'm not really happy with the design right now. Specifically, passing the states between the target and draft models gets pretty complicated when CUDA graphs enter the picture.

Copy link
Collaborator

@hlu1 hlu1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly nits.

@mikeiovine
Copy link
Collaborator Author

mikeiovine commented Apr 12, 2025

Had some discussion offline with @hlu1 about how to make it cleaner:

  1. extra_model_inputs is hard to extend. A single class ModelInput that all models consume would allow us to easily add new features to other models.
  2. At the same time, spec_decode_extra_input_info can be generalized. We can rename it to get_input_shapes, and it can return a dict (str -> (shape, dtype)). It can have some sensible defaults to avoid the burden of having to implement this for more standard models. Decided to just rename it get_warmup_extra_inputs.

I think (2) should definitely be done now. (1) is a pretty big refactor and would probably have to be done in followups.

@mikeiovine
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #2196 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #2196 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #1586 completed with status: 'FAILURE'

@mikeiovine
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #2212 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #2212 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #1601 completed with status: 'FAILURE'

@mikeiovine
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #2348 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #2348 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #1688 completed with status: 'FAILURE'

@mikeiovine
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #2357 [ run ] triggered by Bot

@mikeiovine
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #2363 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #2357 [ run ] completed with state ABORTED

Copy link
Collaborator

@hlu1 hlu1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve to unblock.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #2363 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #1699 completed with status: 'SUCCESS'

@mikeiovine
Copy link
Collaborator Author

/bot run --disable-fail-fast

@mikeiovine
Copy link
Collaborator Author

Running CI one more time before merging as it has been a while since my last rebase.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #2497 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #2497 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #1791 completed with status: 'SUCCESS'

@mikeiovine
Copy link
Collaborator Author

/bot skip --comment "Pipeline passed before last rebase"

@tensorrt-cicd
Copy link
Collaborator

PR_Github #2510 [ skip ] triggered by Bot

@mikeiovine mikeiovine enabled auto-merge (squash) April 16, 2025 20:51
@tensorrt-cicd
Copy link
Collaborator

PR_Github #2510 [ skip ] completed with state SUCCESS
Skipping testing for commit efba97d

@mikeiovine mikeiovine merged commit 41a6c98 into NVIDIA:main Apr 16, 2025
3 checks passed
@mikeiovine mikeiovine deleted the eagle3-graphs branch April 16, 2025 20:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants