server quantize: store quantizer config in standard format (#2299)

- Create `quantization_config` option in the model config.
- Don't store the quantizer config in tensors anymore.
This commit is contained in:
Daniël de Kok 2024-07-30 15:16:20 +02:00 committed by GitHub
parent 0b95693fb8
commit 53aec27328
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 5 deletions

View File

@ -17,7 +17,7 @@ from loguru import logger
from typing import Optional
from text_generation_server.layers.gptq.utils import torch_snr_error
from text_generation_server.utils.weights import DefaultWeightsLoader
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
DEV = torch.device("cuda:0")
@ -897,7 +897,7 @@ def quantize(
dtype=torch.float16,
process_group=process_group,
aliases={"embed_tokens.weight": ["lm_head.weight"]},
weights_loader=DefaultWeightsLoader(),
weights_loader=DefaultWeightsLoader(UnquantizedWeight),
)
hooks = []
for name, module in model.named_modules():
@ -960,9 +960,6 @@ def quantize(
state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
state_dict["gptq_bits"] = torch.LongTensor([bits])
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
state_dict["gptq_sym"] = torch.BoolTensor([sym])
max_shard_size = "10GB"
shards, index = shard_checkpoint(
@ -994,6 +991,15 @@ def quantize(
f"index located at {save_index_file}."
)
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
config.quantization_config = {
"bits": bits,
"group_size": groupsize,
"damp_percent": percdamp,
"desc_act": act_order,
"static_groups": False,
"sym": sym,
"quant_method": "gptq",
}
config.save_pretrained(output_dir)
logger.info("Saved config")
logger.info("Saving tokenizer")