File tree 1 file changed +5
-3
lines changed
1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change 7
7
import modules .shared as shared
8
8
9
9
sys .path .insert (0 , str (Path ("repositories/GPTQ-for-LLaMa" )))
10
+ import llama
11
+ import opt
10
12
11
13
12
14
def load_quantized (model_name ):
@@ -21,9 +23,9 @@ def load_quantized(model_name):
21
23
model_type = shared .args .gptq_model_type .lower ()
22
24
23
25
if model_type == 'llama' :
24
- from llama import load_quant
26
+ load_quant = llama . load_quant
25
27
elif model_type == 'opt' :
26
- from opt import load_quant
28
+ load_quant = opt . load_quant
27
29
else :
28
30
print ("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported" )
29
31
exit ()
@@ -50,7 +52,7 @@ def load_quantized(model_name):
50
52
print (f"Could not find { pt_model } , exiting..." )
51
53
exit ()
52
54
53
- model = load_quant (path_to_model , str (pt_path ), shared .args .gptq_bits )
55
+ model = load_quant (str ( path_to_model ) , str (pt_path ), shared .args .gptq_bits )
54
56
55
57
# Multiple GPUs or GPU+CPU
56
58
if shared .args .gpu_memory :
You can’t perform that action at this time.
0 commit comments