forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinterface.py
532 lines (455 loc) · 20.2 KB
/
interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
import copy
import weakref
from collections import namedtuple
from dataclasses import dataclass, field
from enum import Enum, IntEnum
from typing import (Generic, List, Optional, Protocol, Tuple, Type, TypeVar,
Union)
import torch
from typing_extensions import Self
from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils,
RotaryScalingType)
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from ..metadata import KVCacheParams
from ..pyexecutor.resource_manager import KVCacheManager
from ..utils import get_model_extra_attrs
@dataclass
class AttentionRuntimeFeatures:
chunked_prefill: bool = False
cache_reuse: bool = False
has_speculative_draft_tokens: bool = False
# The type of requests in qkv passed to attention
# Please keep sync with AttentionInputType in cpp/tensorrt_llm/thop/attentionOp.cpp
class AttentionInputType(IntEnum):
mixed = 0 # contains both context and generation
context_only = 1
generation_only = 2
@dataclass(kw_only=True)
class AttentionMetadata:
"""
Metadata for the attention module.
"""
# The max number of requests in a single batch.
max_num_requests: int
# The max number of tokens in all requests in a single batch.
max_num_tokens: int
# The KV cache manager.
kv_cache_manager: KVCacheManager
mapping: Optional[Mapping] = None
enable_flash_mla: bool = False
# Whether CUDA graph is enabled.
is_cuda_graph: bool = field(default=False, repr=False)
# The length of each sequence in the batch for query.
# The shape is (batch_size), and located on CPU memory.
# For sub metadata of cross attention, it's automatically
# initialized to seq_lens of parent metadata.
seq_lens: Optional[torch.Tensor] # Implemented using property
# The number of context-phase sequences in the batch.
num_contexts: int # Implemented using property
# The position of each token in each sequence.
# May be None if positional embedding is applied outside the backend.
position_ids: Optional[torch.Tensor] = None
# The number of context-phase sequences in the batch.
_num_contexts: int = field(init=False, default=0, repr=False)
# The parameters for the KV cache.
kv_cache_params: Optional[KVCacheParams] = None
# The length of each sequence in the batch for key and value.
# The shape is (batch_size), and located on CPU memory.
# It defaults to seq_lens if not set.
# Should only set explicitly for cross attention.
seq_lens_kv: Optional[torch.Tensor] # Implemented using property
# Actual storage for seq_lens and seq_lens_kv
_seq_lens: Optional[torch.Tensor] = field(init=False,
repr=False,
default=None)
_seq_lens_kv: Optional[torch.Tensor] = field(init=False,
repr=False,
default=None)
# A copy of seq_lens store on the GPU. Used in the logits
# processor. Using 2 copies avoids a lot of extraneous
# copies in flashinfer's prepare() implementation.
_seq_lens_cuda: Optional[torch.Tensor] = field(init=False,
repr=False,
default=None)
_seq_lens_kv_cuda: Optional[torch.Tensor] = field(init=False,
repr=False,
default=None)
# For self attention, this is the sub metadata for cross attention
# that works together in one model.
# For cross attention, this is automatically inited to self,
# and must not be set explicitly.
cross: Optional["AttentionMetadata"] = None
# The request ID of each sequence in the batch.
# The shape is (batch_size).
request_ids: Optional[List[int]] = None
# The prompt length of each sequence in the batch.
# For context-phase sequence, the value is its token number, which is same with `context_lens`.
# For generation-phase sequence, the value is the token number of its context phase.
# The shape is (batch_size) if provided.
prompt_lens: Optional[List[int]] = None
# These fields indicate whether the runtime can use various features.
# The kernels may or may not have different behaviors when these
# are enabled.
runtime_features: AttentionRuntimeFeatures = field(
default_factory=AttentionRuntimeFeatures)
all_rank_num_tokens: Optional[List[int]] = None
# These fields are set when changing seq_lens and _num_contexts to avoid computation
# during execution. If the calculation happens during execution, torch compile treats it
# as DDS and fails to compile.
_num_generations: int = field(init=False, default=0, repr=False)
_num_ctx_tokens: int = field(init=False, default=0, repr=False)
_num_tokens: int = field(init=False, default=0, repr=False)
is_dummy_attention: bool = False
def __post_init__(self) -> None:
if self.is_cross:
assert self.cross is None or self.cross is self, "Cross attention metadata should not have sub metadata"
self.cross = self
return
assert self.cross is None or type(self) is type(
self.cross
), "Top level and cross attention sub metadata type mismatched"
def on_update(self):
if (self._seq_lens is not None
and self._seq_lens.shape[0] >= self.num_contexts
and self.num_contexts >= 0):
self._num_ctx_tokens = self._seq_lens[:self.num_contexts].sum(
).item()
self._num_generations = self._seq_lens.shape[0] - self.num_contexts
if self._seq_lens_kv is not None:
self._num_tokens = int(self._seq_lens_kv.sum())
elif self._seq_lens is not None:
self._num_tokens = int(self._seq_lens.sum())
@property
def seq_lens(self) -> Optional[torch.Tensor]:
return self._seq_lens
@seq_lens.setter
def seq_lens(self, value: Optional[torch.Tensor]):
# If value not explicitly given, dataclass tries to initialize using class attribute
value = value if value is not AttentionMetadata.seq_lens else None
self._seq_lens = value
self.on_update()
# The model executor sets seq_lens to None initially.
if self._seq_lens is not None:
self._seq_lens = self._seq_lens.pin_memory()
self._seq_lens_cuda = self._seq_lens.cuda(non_blocking=True)
if self.has_cross_sub_metadata:
self.cross._seq_lens = self._seq_lens
self.cross._seq_lens_cuda = self._seq_lens_cuda
@property
def num_contexts(self) -> int:
return self._num_contexts
@num_contexts.setter
def num_contexts(self, value: int):
value = value if value is not AttentionMetadata.num_contexts else 0
self._num_contexts = value
self.on_update()
@property
def num_generations(self) -> int:
return self._num_generations
@num_generations.setter
def num_generations(self, value: int):
value = value if value is not AttentionMetadata.num_generations else 0
self._num_generations = value
self.on_update()
@property
def seq_lens_cuda(self):
return self._seq_lens_cuda
@property
def seq_lens_kv(self) -> Optional[torch.Tensor]:
return self._seq_lens_kv if self._seq_lens_kv is not None else self._seq_lens
@seq_lens_kv.setter
def seq_lens_kv(self, value: Optional[torch.Tensor]):
value = value if value is not AttentionMetadata.seq_lens_kv else None
self._seq_lens_kv = value
self.on_update()
# The model executor sets seqlens to None initially.
if self._seq_lens_kv is not None:
self._seq_lens_kv = self._seq_lens_kv.pin_memory()
self._seq_lens_kv_cuda = self._seq_lens_kv.cuda(non_blocking=True)
@property
def seq_lens_kv_cuda(self):
return self._seq_lens_kv_cuda if self._seq_lens_kv_cuda is not None else self._seq_lens_cuda
@property
def context_lens(self) -> torch.Tensor:
"""
The length of each context-phase query sequence in the batch.
The shape is (num_contexts), where num_contexts is the number of context-phase sequences in the batch.
"""
return self.seq_lens[:self.num_contexts]
@property
def num_seqs(self) -> int:
"""
The number of sequences in the batch.
"""
return self.seq_lens.shape[0]
@property
def is_cross(self) -> bool:
"""
Is this metadata for cross attention.
"""
return self.seq_lens is not self.seq_lens_kv
@property
def has_cross_sub_metadata(self) -> bool:
return self.cross is not None and self.cross is not self
@property
def num_ctx_tokens(self) -> int:
return self._num_ctx_tokens
@property
def num_tokens(self) -> int:
return self._num_tokens
def prepare(self):
"""
Hook to be called before the forward step of the model.
"""
def create_cuda_graph_metadata(self,
max_batch_size: int,
sub_cross_metadata: bool = False,
max_draft_tokens: int = 0) -> Self:
"""
Creates metadata for CUDA graph execution.
CUDA graphs require to use pre-allocated buffers for all tensors in fields.
Please do not re-allocate any tensors stored inside AttentionMetadata
after the initial warmup run when you're using CUDA graphs.
"""
if self.is_cuda_graph:
return self
cuda_graph_metadata = copy.copy(self)
cuda_graph_metadata.is_cuda_graph = True
if self.has_cross_sub_metadata:
cuda_graph_metadata.cross = cuda_graph_metadata.cross.create_cuda_graph_metadata(
max_batch_size, True)
if not sub_cross_metadata:
cuda_graph_metadata.seq_lens = torch.ones(
(max_batch_size, ), dtype=torch.int) * (1 + max_draft_tokens)
if self.is_cross:
cuda_graph_metadata.seq_lens_kv = torch.zeros((max_batch_size, ),
dtype=torch.int)
if self.enable_flash_mla:
if self.kv_cache_manager is not None:
cuda_graph_metadata.block_ids_per_seq = torch.zeros(
[
self.kv_cache_manager.max_batch_size,
self.kv_cache_manager.max_blocks_per_seq
],
dtype=torch.int32,
device='cuda',
)
cuda_graph_metadata.num_contexts = 0
cuda_graph_metadata.max_num_requests = max_batch_size
cuda_graph_metadata.max_num_tokens = max_batch_size * (1 +
max_draft_tokens)
cuda_graph_metadata.__post_init__()
return cuda_graph_metadata
class PositionalEmbedder(Protocol):
"""
A callable that can apply positional embedding
"""
def __call__(self, position_ids: torch.Tensor, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
...
@dataclass(kw_only=True, unsafe_hash=True)
class RopeParams:
dim: int = 0
theta: float = 10000.0
scale_type: RotaryScalingType = RotaryScalingType.none
scale: float = 1.0
low_freq_factor: float = 1.0
high_freq_factor: float = 4.0
short_m_scale: float = 1.0
long_m_scale: float = 1.0
max_positions: int = 1024
original_max_positions: int = 1024
beta_fast: int = 32
beta_slow: int = 1
mscale: float = 1.0
mscale_all_dim: float = 0.0
@staticmethod
def from_config(config) -> "RopeParams":
rope_params = RopeParams()
# get rotary parameters.
hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
head_dim = getattr(config, 'head_dim',
hidden_size // num_attention_heads)
rope_scaling = getattr(config, 'rope_scaling', None)
rope_params.max_positions = config.max_position_embeddings
rope_params.theta = getattr(config, 'rope_theta', 10000.0)
rope_percentage = (getattr(config, 'rotary_pct', None)
or getattr(config, 'partial_rotary_factor', None)
or 1.0)
# rotary embedding dim.
rope_params.dim = (getattr(config, 'rotary_dim', None)
or getattr(config, 'rotary_emb_base', None)
or getattr(config, 'qk_rope_head_dim', None)
or int(head_dim * rope_percentage))
# rotary scaling.
rope_params.scale_type = RotaryScalingType.none
rope_params.scale = 1.0
if rope_scaling is not None:
rotary_scaling_type = rope_scaling.get(
"type", None) or rope_scaling.get("rope_type")
rope_params.scale_type = RotaryScalingType.from_string(
rotary_scaling_type)
rope_params.scale = rope_scaling.get("factor", 1.0)
rope_params.low_freq_factor = rope_scaling.get(
"low_freq_factor", 1.0)
rope_params.high_freq_factor = rope_scaling.get(
"high_freq_factor", 4.0)
rope_params.original_max_positions = rope_scaling.get(
"original_max_position_embeddings", 1024)
rope_params.beta_fast = rope_scaling.get("beta_fast", 32)
rope_params.beta_slow = rope_scaling.get("beta_slow", 1)
rope_params.mscale = rope_scaling.get("mscale", 1.0)
rope_params.mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
# Workaround for DeepSeek V3 Lite since its rope_scaling is null in config.json.
elif config.model_type == "deepseek_v3":
rope_params.scale_type = RotaryScalingType.yarn
return rope_params
def create_rope_const_params(self):
if self.dim == 0:
return None, None
RopeConstParams = namedtuple("RopeConstParams", ["inv_freq", "cos_sin"])
extra_attrs = get_model_extra_attrs()
if extra_attrs is not None:
cache = extra_attrs.setdefault("rope_const_params", {})
rope_const_params = cache.get(self, None)
if rope_const_params is not None and rope_const_params.cos_sin(
) is not None:
return (
rope_const_params.inv_freq()
if rope_const_params.inv_freq is not None else None,
rope_const_params.cos_sin(),
)
if self.scale_type == RotaryScalingType.yarn:
rope_inv_freq = None
rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
self.max_positions,
self.dim,
self.theta,
self.scale,
self.original_max_positions,
self.beta_fast,
self.beta_slow,
self.mscale,
self.mscale_all_dim,
)
elif self.scale_type == RotaryScalingType.longrope:
raise NotImplementedError("Long RoPE is not supported.")
else:
rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
self.max_positions,
self.dim,
self.theta,
self.scale,
self.scale_type,
rope_scaling_config={
"factor": self.scale,
"low_freq_factor": self.low_freq_factor,
"high_freq_factor": self.high_freq_factor,
"original_max_position_embeddings":
self.original_max_positions,
})
if rope_inv_freq is not None:
rope_inv_freq = torch.torch.tensor(
rope_inv_freq,
dtype=torch.float32,
device='cuda',
)
rope_cos_sin = torch.torch.tensor(
rope_cos_sin,
dtype=torch.float32,
device='cuda',
)
if extra_attrs is not None:
cache[self] = RopeConstParams(
weakref.ref(rope_inv_freq)
if rope_inv_freq is not None else None,
weakref.ref(rope_cos_sin),
)
return rope_inv_freq, rope_cos_sin
@dataclass(kw_only=True, frozen=True)
class PositionalEmbeddingParams:
type: PositionEmbeddingType
embedder: Optional[PositionalEmbedder] = None
# RoPE params
rope: Optional[RopeParams] = None
def __post_init__(self) -> None:
if self.type.is_deferred():
assert self.embedder is not None, f"{self.type} requires a not-none external embedder"
else:
assert self.embedder is None, f"Embedder must be None for {self.type}"
if self.type.is_rope():
assert self.rope is not None, f"{self.type} requires a not-none rope"
TMetadata = TypeVar("TMetadata", bound=AttentionMetadata)
class PredefinedAttentionMask(str, Enum):
"""
Predefined attention mask types
Attributes:
CAUSAL: Use causal mask.
FULL: do not use any mask
"""
CAUSAL = "causal"
FULL = "full"
# May extend to custom attention mask type
AttentionMask = Union[PredefinedAttentionMask]
class AttentionBackend(Generic[TMetadata]):
"""
Base class for attention backends.
"""
Metadata: Type[TMetadata] = AttentionMetadata
def __init__(
self,
layer_idx: int,
num_heads: int,
head_dim: int,
num_kv_heads: Optional[int] = None,
quant_config: Optional[QuantConfig] = None,
**kwargs,
):
"""
Initialize the backend.
Args:
layer_idx (int): The index of the attention layer in the model.
num_heads (int): The number of query heads.
head_dim (int): The size of each attention head (hidden_size // num_heads).
num_kv_heads (int): The number of kv heads. Defaults to num_heads if None.
quant_config (QuantConfig): Optional quantization configuration. If None, no quantization is applied.
"""
self.layer_idx = layer_idx
self.num_heads = num_heads
self.head_dim = head_dim
self.num_kv_heads = num_kv_heads or self.num_heads
self.quant_config = quant_config
def forward(self,
q: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
metadata: TMetadata,
*,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
**kwargs) -> torch.Tensor:
"""
Update KV Cache and perform the attention operation.
Args:
q (torch.Tensor): Query tensor with shape (num_q_tokens, num_heads * head_dim),
or QKV tensor with shape (num_tokens, (num_heads + 2 * num_kv_heads) * head_dim).
k (Optional[torch.Tensor]): Key tensor with shape (num_new_kv_tokens, num_kv_heads * head_dim),
or KV tensor with shape (num_new_kv_tokens, (2 * num_kv_heads) * head_dim),
or None: if QKV tensor is provided, or there's no new kv token.
v (Optional[torch.Tensor]): Value tensor with shape (num_new_kv_tokens, num_kv_heads * head_dim),
or None if QKV tensor is provided, or there's no new kv token.
metadata (AttentionMetadata): Metadata for the attention operation.
attention_mask (AttentionMask): Attention mask. See definition of `AttentionMask` for accepted types. Defaults to predefined causal mask.
Returns:
torch.Tensor with shape (num_q_tokens, num_heads * head_dim)
"""
raise NotImplementedError
@dataclass(kw_only=True, unsafe_hash=True)
class MLAParams:
q_lora_rank: int = 0
kv_lora_rank: int = 0
qk_rope_head_dim: int = 0
qk_nope_head_dim: int = 0
v_head_dim: int = 0
predicted_tokens_per_seq: int = 1