Skip to content

feat: Cache sin cos in model instead of global LRU cache. #3378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/common/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ class AttentionOp
return mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPTJ
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPT_NEOX
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kLONG_ROPE
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kYARN
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_M;
}

Expand Down
9 changes: 4 additions & 5 deletions cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,10 @@ class Runner : public RunnerBase

if (op.isRoPE())
{
rotary_inv_freq_ptr = rotary_inv_freq.value().data_ptr<float>();
}

if (op.isRoPE() || op.isMLAEnabled())
{
if (rotary_inv_freq.has_value())
{
rotary_inv_freq_ptr = rotary_inv_freq.value().data_ptr<float>();
}
rotary_cos_sin_ptr = static_cast<float2 const*>(rotary_cos_sin.value().data_ptr());
}

Expand Down
26 changes: 13 additions & 13 deletions tensorrt_llm/_torch/attention_backend/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import threading
import weakref
from dataclasses import dataclass, field
from typing import Dict, Literal, Optional

Expand All @@ -10,6 +10,7 @@
from tensorrt_llm.functional import AttentionMaskType
from tensorrt_llm.models.modeling_utils import QuantConfig

from ..utils import get_global_attrs, get_model_extra_attrs
from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
PredefinedAttentionMask)
from .vanilla import VanillaAttention
Expand Down Expand Up @@ -168,7 +169,11 @@ def page_size(self) -> int:
return self.kv_cache_manager.tokens_per_block

def prepare(self) -> None:
_thread_local.metadata = self
extra_attrs = get_model_extra_attrs()
if extra_attrs is not None:
extra_attrs["attention_metadata"] = weakref.ref(self)
else:
get_global_attrs().attention_metadata = weakref.ref(self)
# start and end indices of each sequence in the ragged query
torch.cumsum(self.seq_lens_cuda,
dim=0,
Expand Down Expand Up @@ -391,16 +396,6 @@ def decode_plan():
return plan_params


_thread_local = threading.local()


def get_metadata() -> FlashInferAttentionMetadata:
try:
return _thread_local.metadata
except AttributeError:
return None


class FlashInferAttention(AttentionBackend[FlashInferAttentionMetadata]):

Metadata = FlashInferAttentionMetadata
Expand Down Expand Up @@ -442,7 +437,12 @@ def forward_pattern(
otherwise it will graph break when calling `metadata.num_contexts` since it convert tensor's sum directly to int.
'''
# torch.compile does not support custom object as arguments, so we have to use global function to get the metadata.
metadata = get_metadata()
extra_attrs = get_model_extra_attrs()
if extra_attrs is not None:
metadata_ref = extra_attrs.get("attention_metadata", None)
metadata = metadata_ref() if metadata_ref is not None else None
else:
metadata = get_global_attrs().attention_metadata()

q = q.view(-1, num_heads, head_dim)
if k is not None:
Expand Down
101 changes: 61 additions & 40 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import copy
import weakref
from collections import namedtuple
from dataclasses import dataclass, field
from enum import Enum, IntEnum
from functools import lru_cache
from typing import (Generic, List, Optional, Protocol, Tuple, Type, TypeVar,
Union)

Expand All @@ -15,6 +16,7 @@

from ..metadata import KVCacheParams
from ..pyexecutor.resource_manager import KVCacheManager
from ..utils import get_model_extra_attrs


@dataclass
Expand Down Expand Up @@ -334,6 +336,7 @@ def from_config(config) -> "RopeParams":
# rotary embedding dim.
rope_params.dim = (getattr(config, 'rotary_dim', None)
or getattr(config, 'rotary_emb_base', None)
or getattr(config, 'qk_rope_head_dim', None)
or int(head_dim * rope_percentage))
# rotary scaling.
rope_params.scale_type = RotaryScalingType.none
Expand All @@ -354,57 +357,75 @@ def from_config(config) -> "RopeParams":
rope_params.beta_slow = rope_scaling.get("beta_slow", 1)
rope_params.mscale = rope_scaling.get("mscale", 1.0)
rope_params.mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
# Workaround for DeepSeek V3 Lite since its rope_scaling is null in config.json.
elif config.model_type == "deepseek_v3":
rope_params.scale_type = RotaryScalingType.yarn

return rope_params

@lru_cache(maxsize=1)
def create_rope_const_params(self):
if self.dim == 0:
return None, None
assert self.scale_type != RotaryScalingType.longrope, "Long RoPE is not yet supported."
rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
self.max_positions,
self.dim,
self.theta,
self.scale,
self.scale_type,
rope_scaling_config={
"factor": self.scale,
"low_freq_factor": self.low_freq_factor,
"high_freq_factor": self.high_freq_factor,
"original_max_position_embeddings": self.original_max_positions,
})
rope_inv_freq = torch.torch.tensor(
rope_inv_freq,
dtype=torch.float32,
device='cuda',
)
rope_cos_sin = torch.torch.tensor(
rope_cos_sin,
dtype=torch.float32,
device='cuda',
)
return rope_inv_freq, rope_cos_sin

@lru_cache(maxsize=1)
def create_deepseek_rope_const_params(self, qk_rope_head_dim: int):
rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_deepseek_attention_plugin(
self.max_positions,
qk_rope_head_dim,
self.theta,
self.scale,
self.original_max_positions,
self.beta_fast,
self.beta_slow,
self.mscale,
self.mscale_all_dim,
)
RopeConstParams = namedtuple("RopeConstParams", ["inv_freq", "cos_sin"])
extra_attrs = get_model_extra_attrs()
if extra_attrs is not None:
cache = extra_attrs.setdefault("rope_const_params", {})
rope_const_params = cache.get(self, None)
if rope_const_params is not None and rope_const_params.cos_sin(
) is not None:
return (
rope_const_params.inv_freq()
if rope_const_params.inv_freq is not None else None,
rope_const_params.cos_sin(),
)

if self.scale_type == RotaryScalingType.yarn:
rope_inv_freq = None
rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
self.max_positions,
self.dim,
self.theta,
self.scale,
self.original_max_positions,
self.beta_fast,
self.beta_slow,
self.mscale,
self.mscale_all_dim,
)
elif self.scale_type == RotaryScalingType.longrope:
raise NotImplementedError("Long RoPE is not supported.")
else:
rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
self.max_positions,
self.dim,
self.theta,
self.scale,
self.scale_type,
rope_scaling_config={
"factor": self.scale,
"low_freq_factor": self.low_freq_factor,
"high_freq_factor": self.high_freq_factor,
"original_max_position_embeddings":
self.original_max_positions,
})
if rope_inv_freq is not None:
rope_inv_freq = torch.torch.tensor(
rope_inv_freq,
dtype=torch.float32,
device='cuda',
)
rope_cos_sin = torch.torch.tensor(
rope_cos_sin,
dtype=torch.float32,
device='cuda',
)
rope_inv_freq = None
if extra_attrs is not None:
cache[self] = RopeConstParams(
weakref.ref(rope_inv_freq)
if rope_inv_freq is not None else None,
weakref.ref(rope_cos_sin),
)
return rope_inv_freq, rope_cos_sin


Expand Down
50 changes: 13 additions & 37 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

import torch

from tensorrt_llm.functional import (AttentionMaskType, RopeEmbeddingUtils,
RotaryScalingType)
from tensorrt_llm.functional import AttentionMaskType
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import QuantConfig

Expand Down Expand Up @@ -85,16 +84,14 @@ def __init__(
pos_embd_params (PositionalEmbeddingParams): Optional parameters defining how positional embedding should be applied.
quant_config (QuantConfig): Optional quantization configuration. If None, no quantization is applied.
"""
rope_params = None
if pos_embd_params is not None:
rope_params = pos_embd_params.rope
else:
self.rotary_inv_freq = None
self.rotary_cos_sin = None
rope_params = RopeParams()
rope_params = rope_params or RopeParams()
self.rope_params = rope_params

self.is_mla_enable = mla_params is not None
self.q_scaling = q_scaling or 1.0
self.mla_rope_params = None
self.predicted_tokens_per_seq = 1

if self.is_mla_enable:
Expand All @@ -104,25 +101,15 @@ def __init__(
self.qk_rope_head_dim = mla_params.qk_rope_head_dim
self.v_head_dim = mla_params.v_head_dim
self.predicted_tokens_per_seq = mla_params.predicted_tokens_per_seq

self.rotary_embedding_dim = 0
self.rotary_inv_freq, self.rotary_cos_sin = rope_params.create_deepseek_rope_const_params(
self.qk_rope_head_dim)
self.rotary_embedding_scale_type = RotaryScalingType.none
self.rotary_embedding_scale = 1.0
self.mla_rope_params = rope_params
else:
self.q_lora_rank = None
self.kv_lora_rank = None
self.qk_nope_head_dim = None
self.qk_rope_head_dim = None
self.v_head_dim = None

self.rotary_inv_freq, self.rotary_cos_sin = rope_params.create_rope_const_params(
)
self.rotary_embedding_dim = rope_params.dim
self.rotary_embedding_scale_type = int(rope_params.scale_type)
self.rotary_embedding_scale = rope_params.scale
self.rotary_inv_freq, self.rotary_cos_sin = rope_params.create_rope_const_params(
)

self.layer_idx = layer_idx
self.num_heads = num_heads
Expand All @@ -132,7 +119,10 @@ def __init__(
self.quant_mode = int(quant_config.layer_quant_mode)
self.position_embedding_type = int(
pos_embd_params.type) if pos_embd_params is not None else 0
self.rotary_embedding_dim = rope_params.dim
self.rotary_embedding_base = rope_params.theta
self.rotary_embedding_scale_type = int(rope_params.scale_type)
self.rotary_embedding_scale = rope_params.scale
self.rotary_embedding_short_m_scale = rope_params.short_m_scale
self.rotary_embedding_long_m_scale = rope_params.long_m_scale
self.rotary_embedding_max_positions = rope_params.max_positions
Expand Down Expand Up @@ -230,24 +220,10 @@ def plan(
self.kwargs.update(kwargs)
self.block_ids_per_seq = block_ids_per_seq

if self.is_mla_enable:
# max_context_length will increment 1 when overlap scheduler enabled
if self.max_context_length > (self.rotary_cos_sin.shape[1] /
(2 * self.qk_rope_head_dim) + 1):
rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_deepseek_attention_plugin(
self.max_context_length,
self.qk_rope_head_dim,
self.mla_rope_params.theta,
self.mla_rope_params.scale,
self.mla_rope_params.original_max_positions,
self.mla_rope_params.beta_fast,
self.mla_rope_params.beta_slow,
self.mla_rope_params.mscale,
self.mla_rope_params.mscale_all_dim,
)
self.rotary_cos_sin = torch.tensor(rope_cos_sin,
dtype=torch.float32,
device="cuda")
if max_sequence_length > self.rope_params.max_positions:
self.rope_params.max_positions = max_sequence_length
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
)

def run(
self,
Expand Down
7 changes: 6 additions & 1 deletion tensorrt_llm/_torch/models/modeling_auto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Generic

from ..model_config import ModelConfig
from ..utils import model_extra_attrs
from .modeling_utils import (MODEL_CLASS_MAPPING, DecoderModelForCausalLM,
TConfig, TModel)

Expand All @@ -25,4 +26,8 @@ def from_config(
)
if issubclass(cls, DecoderModelForCausalLM):
config.skip_create_weights = True
return cls(config)
extra_attrs = {}
with model_extra_attrs(extra_attrs):
model = cls(config)
model.extra_attrs = extra_attrs
return model
13 changes: 9 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ..models.modeling_utils import MetaInitMode, timing
from ..pipeline_interface import PipelineInterface
from ..speculative import SpecConfig, SpecMetadata, get_spec_metadata
from ..utils import set_torch_compiling
from ..utils import set_torch_compiling, with_model_extra_attrs
from .config import LoadFormat, PyTorchConfig
from .cuda_graph_runner import DecodingCUDAGraphRunner
from .distributed import MPIDist
Expand Down Expand Up @@ -263,6 +263,9 @@ def __init__(
max_num_tokens=max_num_tokens,
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,
)
# In case that some tests use stub models and override `_load_model`.
if not hasattr(self.model, 'extra_attrs'):
self.model.extra_attrs = {}
if self.pytorch_backend_config.enable_layerwise_nvtx_marker:
layerwise_nvtx_marker = LayerwiseNvtxMarker()
module_prefix = 'Model'
Expand Down Expand Up @@ -1572,6 +1575,7 @@ def _prepare_inputs(
new_tensors_device)

@torch.inference_mode()
@with_model_extra_attrs(lambda self: self.model.extra_attrs)
def forward(self,
scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager,
Expand Down Expand Up @@ -1694,9 +1698,10 @@ def _forward_step(self, inputs: Dict[str, Any],

# For simplicity, just return all the the logits if we have special gather_ids
# from speculative decoding.
logits = self.model_forward(**inputs,
return_context_logits=gather_ids
is not None)
logits = self.model_forward(
**inputs,
return_context_logits=gather_ids is not None,
)
if gather_ids is not None:
return {'logits': logits[gather_ids]}
else:
Expand Down
Loading