Skip to content

GGUF compatible quantization (2, 3, 4 bit / any bit) #285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,36 @@ def generate(self, *args, **kwargs):
@torch.no_grad()
def quantize(self, tokenizer=None, quant_config={},
calib_data: Union[str, List[str]]="pileval",
split="train", text_column="text", duo_scaling=True, modules_to_not_convert=None):
split="train", text_column="text", duo_scaling=True,
modules_to_not_convert=None, export_compatible=False):
self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)

quantizer = AwqQuantizer(
self.quantizer = AwqQuantizer(
self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size,
self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert
self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert,
export_compatible=export_compatible
)
quantizer.quantize()
self.quantizer.quantize()

self.is_quantized = True

@torch.no_grad()
def pack(self):
"""
A utility function for the following scenario. Note that save_quantized will
overwrite existing weights if you use the same quant_path.

model.quantize(
tokenizer,
quant_config=quant_config,
export_compatible=True
)
model.save_quantized(...) # produces GGUF/other compat weights
model.pack(...) # makes the model CUDA compat
model.save_quantized(...) # produces CUDA compat weights
"""
self.quantizer.pack()

@staticmethod
def fuse_layers(model):
pass
Expand Down
13 changes: 12 additions & 1 deletion awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

class AwqQuantizer:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
calib_data, split, text_column, duo_scaling, modules_to_not_convert=None) -> None:
calib_data, split, text_column, duo_scaling, modules_to_not_convert=None,
export_compatible=False) -> None:
self.awq_model = awq_model
self.model = model
self.tokenizer = tokenizer
Expand All @@ -32,6 +33,7 @@ def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
self.split = split
self.text_column = text_column
self.duo_scaling = duo_scaling
self.export_compatible = export_compatible
self.modules_to_not_convert = modules_to_not_convert if modules_to_not_convert is not None else []
self.modules, self.module_kwargs, self.inps = self.init_quant()

Expand Down Expand Up @@ -115,6 +117,15 @@ def quantize(self):
clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")

# [STEP 4]: Quantize weights
if not self.export_compatible:
self._apply_quant(self.modules[i], named_linears)

clear_memory()

def pack(self):
for i in tqdm(range(len(self.modules)), desc="Packing"):
named_linears = get_named_linears(self.modules[i])
named_linears = exclude_layers_to_not_quantize(named_linears, self.modules_to_not_convert)
self._apply_quant(self.modules[i], named_linears)
clear_memory()

Expand Down
48 changes: 48 additions & 0 deletions examples/awq_to_gguf_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import subprocess
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = 'mistralai/Mistral-7B-v0.1'
quant_path = 'mistral-awq'
llama_cpp_path = '/workspace/llama.cpp'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 6, "version": "GEMM" }

# Load model
# NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
# NOTE: We avoid packing weights, so you cannot use this model in AutoAWQ
# after quantizing. The saved model is FP16 but has the AWQ scales applied.
model.quantize(
tokenizer,
quant_config=quant_config,
export_compatible=True
)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')

# GGUF conversion
print('Converting model to GGUF...')
llama_cpp_method = "q4_K_M"
convert_cmd_path = os.path.join(llama_cpp_path, "convert.py")
quantize_cmd_path = os.path.join(llama_cpp_path, "quantize")

if not os.path.exists(llama_cpp_path):
cmd = f"git clone https://github.com/ggerganov/llama.cpp.git {llama_cpp_path} && cd {llama_cpp_path} && make LLAMA_CUBLAS=1 LLAMA_CUDA_F16=1"
subprocess.run([cmd], shell=True, check=True)

subprocess.run([
f"python {convert_cmd_path} {quant_path} --outfile {quant_path}/model.gguf"
], shell=True, check=True)

subprocess.run([
f"{quantize_cmd_path} {quant_path}/model.gguf {quant_path}/model_{llama_cpp_method}.gguf {llama_cpp_method}"
], shell=True, check=True)