Skip to content

Commit 591f3d2

Browse files
authored
fix: Support TLLM_OVERRIDE_LAYER_NUM for llama4. (#3679)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent a51f755 commit 591f3d2

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

+4
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,10 @@ def _load_model(self, checkpoint_dir: str, load_format: LoadFormat,
771771
num_layers = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0"))
772772
if num_layers > 0:
773773
config.pretrained_config.num_hidden_layers = num_layers
774+
for sub_config in ["text_config", "vision_config"]:
775+
if hasattr(config.pretrained_config, sub_config):
776+
getattr(config.pretrained_config,
777+
sub_config).num_hidden_layers = num_layers
774778

775779
with timing("Model init total"):
776780
try:

0 commit comments

Comments
 (0)