Skip to content

[TPU][V1] Add support for top-logprobs #17072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

NickLucche
Copy link
Contributor

@NickLucche NickLucche commented Apr 23, 2025

This PR implements top-logprobs support for TPU V1.

The main design decisions I've taken in this first version are:

  • Returning logprobs is optional, so it has a separate graph that is executed only when needed.
    Akin to what is happening on GPU, when a single request in the batch requires logprobs, the prob tensor is gathered for all requests in the batch (but only streamed back to those that need it).
  • To mitigate compilation issue and strike a balance between long compilation times and minimal computational waste at runtime, "logprobs" is a binary flag. Therefore
    a graph is generated for when the flag is off (no change from current) and another one when it's on.
  • The value for which logprobs are gathered is static and fixed at startup with model_config.max_logprobs. Default is 20 as specified by the OpenAI API. Hence (when needed) this impl will gather the top 20 logprobs values, move the batched tensor to host and then slice off the needed ones with the same logic as in GPU.

Benchmark+Compile time highlight:

pre:
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  133.86    
Total input tokens:                      1638796   
Total generated tokens:                  128000    
Request throughput (req/s):              7.47      
Output token throughput (tok/s):         956.22    
Total Token throughput (tok/s):          13198.78  
---------------Time to First Token----------------
Mean TTFT (ms):                          63739.75  
Median TTFT (ms):                        64401.80  
P99 TTFT (ms):                           129335.27 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          79.05     
Median TPOT (ms):                        81.04     
P99 TPOT (ms):                           94.23     
---------------Inter-token Latency----------------
Mean ITL (ms):                           79.05     
Median ITL (ms):                         41.94     
P99 ITL (ms):                            125.57    
==================================================

INFO 04-23 18:02:40 [tpu_model_runner.py:1033] Compiling sampling with different num_reqs.
INFO 04-23 18:02:48 [tpu_model_runner.py:1053]   -- num_seqs: 8
INFO 04-23 18:02:57 [tpu_model_runner.py:1053]   -- num_seqs: 16
INFO 04-23 18:03:07 [tpu_model_runner.py:1053]   -- num_seqs: 32
INFO 04-23 18:03:16 [tpu_model_runner.py:1053]   -- num_seqs: 64
INFO 04-23 18:03:24 [tpu_model_runner.py:1053]   -- num_seqs: 128
INFO 04-23 18:03:24 [tpu_model_runner.py:1056] Compilation finished in in 44.63 [secs].


post (I assume no logprobs on sonnet benchmark):
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  133.84    
Total input tokens:                      1638796   
Total generated tokens:                  128000    
Request throughput (req/s):              7.47      
Output token throughput (tok/s):         956.38    
Total Token throughput (tok/s):          13201.06  
---------------Time to First Token----------------
Mean TTFT (ms):                          63713.64  
Median TTFT (ms):                        64372.12  
P99 TTFT (ms):                           129280.69 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          79.16     
Median TPOT (ms):                        81.05     
P99 TPOT (ms):                           93.32     
---------------Inter-token Latency----------------
Mean ITL (ms):                           79.16     
Median ITL (ms):                         41.98     
P99 ITL (ms):                            125.63    
==================================================

INFO 04-23 17:53:50 [tpu_model_runner.py:1097] Compiling sample_from_logits with different input shapes.
INFO 04-23 17:54:09 [tpu_model_runner.py:1119]   -- num_seqs: 8
INFO 04-23 17:54:29 [tpu_model_runner.py:1119]   -- num_seqs: 16
INFO 04-23 17:54:52 [tpu_model_runner.py:1119]   -- num_seqs: 32
INFO 04-23 17:55:15 [tpu_model_runner.py:1119]   -- num_seqs: 64
INFO 04-23 17:55:38 [tpu_model_runner.py:1119]   -- num_seqs: 128
INFO 04-23 17:55:38 [tpu_model_runner.py:1122] Compilation finished in 107.82 [secs].

Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added v1 tpu Related to Google TPUs labels Apr 23, 2025
@NickLucche NickLucche mentioned this pull request Apr 18, 2025
11 tasks
@yaochengji yaochengji assigned yaochengji and unassigned yaochengji Apr 23, 2025
sampler_out = self.sampler(logits, sampling_metadata)
out_tokens = sampler_out.sampled_token_ids
logprobs_tensors = sampler_out.logprobs_tensors
return out_tokens, logprobs_tensors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update the function's return type.

if sampling_metadata.all_greedy:
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
if sampling_metadata.logprobs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb q: what's the reason of using logprobs given that we already use greedy sampling?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants