Fixing few things

This commit is contained in:
Ubuntu 2023-06-13 18:58:09 +00:00 committed by Nicolas Patry
parent dadbbc27d5
commit ffe8fc4699
3 changed files with 18 additions and 18 deletions

View File

@ -150,7 +150,6 @@ def download_weights(
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files)
@app.command()
def quantize(
model_id: str,
@ -158,8 +157,9 @@ def quantize(
revision: Optional[str] = None,
logger_level: str = "INFO",
json_output: bool = False,
trust_remote_code: bool = False,
):
extension: str = (".safetensors",)
extension: str = ".safetensors",
# Remove default handler
logger.remove()
logger.add(
@ -171,15 +171,12 @@ def quantize(
backtrace=True,
diagnose=False,
)
download_weights(
model_id=model_id,
revision=revision,
logger_level=logger_level,
json_output=json_output,
)
download_weights(model_id=model_id, revision=revision, logger_level=logger_level, json_output=json_output)
from text_generation_server.utils.gptq.quantize import quantize
quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir, trust_remote_code=trust_remote_code)
quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir)
if __name__ == "__main__":

View File

@ -304,11 +304,14 @@ class QuantLinearFunction(torch.autograd.Function):
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.g_idx = g_idx
self.bias = bias
self.qweight = self.register_buffer("qweight", qweight)
self.qzeros = self.register_buffer("qzeros", qzeros)
self.scales = self.register_buffer("scales", scales)
self.g_idx = self.register_buffer("g_idx", g_idx)
if bias is not None:
self.bias = self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
self.bits = bits

View File

@ -937,9 +937,9 @@ def pack(model, quantizers, bits, groupsize):
# print('max memory(MiB):', max_memory)
def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
def quantize(model_id: str, bits: int, groupsize: int, output_dir: str, trust_remote_code: bool):
print("loading model")
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0")
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0", trust_remote_code=trust_remote_code)
print("LOADED model")
model.seqlen = 2048
@ -1002,8 +1002,8 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
from transformers.modeling_utils import shard_checkpoint
state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
state_dict["gptq_bits"] = torch.LongTensor(bits)
state_dict["gptq_groupsize"] = torch.LongTensor(groupsize)
state_dict["gptq_bits"] = torch.LongTensor([bits])
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
max_shard_size = "10GB"
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors")