server quantize: store quantizer config in standard format (#2299)
- Create `quantization_config` option in the model config. - Don't store the quantizer config in tensors anymore.
This commit is contained in:
parent
0b95693fb8
commit
53aec27328
|
@ -17,7 +17,7 @@ from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from text_generation_server.layers.gptq.utils import torch_snr_error
|
from text_generation_server.layers.gptq.utils import torch_snr_error
|
||||||
|
|
||||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
|
||||||
|
|
||||||
DEV = torch.device("cuda:0")
|
DEV = torch.device("cuda:0")
|
||||||
|
|
||||||
|
@ -897,7 +897,7 @@ def quantize(
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
||||||
weights_loader=DefaultWeightsLoader(),
|
weights_loader=DefaultWeightsLoader(UnquantizedWeight),
|
||||||
)
|
)
|
||||||
hooks = []
|
hooks = []
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
@ -960,9 +960,6 @@ def quantize(
|
||||||
|
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||||
state_dict["gptq_bits"] = torch.LongTensor([bits])
|
|
||||||
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
|
|
||||||
state_dict["gptq_sym"] = torch.BoolTensor([sym])
|
|
||||||
|
|
||||||
max_shard_size = "10GB"
|
max_shard_size = "10GB"
|
||||||
shards, index = shard_checkpoint(
|
shards, index = shard_checkpoint(
|
||||||
|
@ -994,6 +991,15 @@ def quantize(
|
||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||||
|
config.quantization_config = {
|
||||||
|
"bits": bits,
|
||||||
|
"group_size": groupsize,
|
||||||
|
"damp_percent": percdamp,
|
||||||
|
"desc_act": act_order,
|
||||||
|
"static_groups": False,
|
||||||
|
"sym": sym,
|
||||||
|
"quant_method": "gptq",
|
||||||
|
}
|
||||||
config.save_pretrained(output_dir)
|
config.save_pretrained(output_dir)
|
||||||
logger.info("Saved config")
|
logger.info("Saved config")
|
||||||
logger.info("Saving tokenizer")
|
logger.info("Saving tokenizer")
|
||||||
|
|
Loading…
Reference in New Issue