@@ -87,14 +87,13 @@ def get_ckpt_type(model_path):
87
87
VSWA_MODELS = VSWA_ATTENTION .keys ()
88
88
89
89
GEMMA2_MODELS = {GEMMA_2_9B_IT , GEMMA_2_27B_IT }
90
- "For plain, non VSWA testing"
91
90
92
91
93
92
@skip_pre_hopper
94
93
@pytest .mark .parametrize ("batch_size" , [8 ])
95
94
@pytest .mark .parametrize ("data_type" , ['bfloat16' ])
96
95
@pytest .mark .parametrize ("qformat" , ['fp8' ])
97
- @pytest .mark .parametrize ("gemma_model_root" , VSWA_MODELS , indirect = True )
96
+ @pytest .mark .parametrize ("gemma_model_root" , GEMMA2_MODELS , indirect = True )
98
97
def test_llm_hf_gemma_quantization_1gpu_vswa (batch_size , data_type ,
99
98
gemma_model_root , llm_venv ,
100
99
cmodel_dir , engine_dir ,
@@ -175,7 +174,8 @@ def hf_gemma_quantization_1gpu(batch_size,
175
174
threshold_score = 18
176
175
177
176
window = [
178
- f"--max_attention_window={ max_attention_window } " ,
177
+ "--max_attention_window_size" ,
178
+ ',' .join ((str (w ) for w in max_attention_window )),
179
179
] if max_attention_window is not None else []
180
180
181
181
summary_cmd = [
@@ -306,7 +306,7 @@ def gemma_1gpu_summary(batch_size,
306
306
check_call (" " .join (build_cmd ), shell = True , env = llm_venv ._new_env )
307
307
308
308
window = {
309
- 'max_attention_window ' : max_attention_window
309
+ 'max_attention_window_size ' : max_attention_window
310
310
} if max_attention_window is not None else {}
311
311
312
312
print ("Run summarize..." )
0 commit comments