forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodeling_auto.py
33 lines (28 loc) · 1.23 KB
/
modeling_auto.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from typing import Generic
from ..model_config import ModelConfig
from ..utils import model_extra_attrs
from .modeling_utils import (MODEL_CLASS_MAPPING, DecoderModelForCausalLM,
TConfig, TModel)
class AutoModelForCausalLM(Generic[TModel, TConfig]):
@staticmethod
def from_config(
config: ModelConfig[TConfig],
) -> DecoderModelForCausalLM[TModel, TConfig]:
model_arch = config.pretrained_config.architectures[0]
# Hack to detect eagle3 checkpoints. TODO: should we provide
# our own checkpoints with the correct arch? It would let us
# avoid nasty stuff like this.
if hasattr(config.pretrained_config, "draft_vocab_size"):
model_arch = "EAGLE3" + model_arch
cls = MODEL_CLASS_MAPPING.get(model_arch)
if cls is None:
raise ValueError(
f"Unknown architecture for AutoModelForCausalLM: {config.pretrained_config.architectures[0]}"
)
if issubclass(cls, DecoderModelForCausalLM):
config.skip_create_weights = True
extra_attrs = {}
with model_extra_attrs(extra_attrs):
model = cls(config)
model.extra_attrs = extra_attrs
return model