diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 42c986b0..b0086ea0 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -17,7 +17,7 @@ from loguru import logger from typing import Optional from text_generation_server.layers.gptq.utils import torch_snr_error -from text_generation_server.utils.weights import DefaultWeightsLoader +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight DEV = torch.device("cuda:0") @@ -897,7 +897,7 @@ def quantize( dtype=torch.float16, process_group=process_group, aliases={"embed_tokens.weight": ["lm_head.weight"]}, - weights_loader=DefaultWeightsLoader(), + weights_loader=DefaultWeightsLoader(UnquantizedWeight), ) hooks = [] for name, module in model.named_modules(): @@ -960,9 +960,6 @@ def quantize( state_dict = model.state_dict() state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} - state_dict["gptq_bits"] = torch.LongTensor([bits]) - state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) - state_dict["gptq_sym"] = torch.BoolTensor([sym]) max_shard_size = "10GB" shards, index = shard_checkpoint( @@ -994,6 +991,15 @@ def quantize( f"index located at {save_index_file}." ) config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + config.quantization_config = { + "bits": bits, + "group_size": groupsize, + "damp_percent": percdamp, + "desc_act": act_order, + "static_groups": False, + "sym": sym, + "quant_method": "gptq", + } config.save_pretrained(output_dir) logger.info("Saved config") logger.info("Saving tokenizer")