Fixing few things
This commit is contained in:
parent
dadbbc27d5
commit
ffe8fc4699
|
@ -150,7 +150,6 @@ def download_weights(
|
|||
# Convert pytorch weights to safetensors
|
||||
utils.convert_files(local_pt_files, local_st_files)
|
||||
|
||||
|
||||
@app.command()
|
||||
def quantize(
|
||||
model_id: str,
|
||||
|
@ -158,8 +157,9 @@ def quantize(
|
|||
revision: Optional[str] = None,
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
extension: str = (".safetensors",)
|
||||
extension: str = ".safetensors",
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
logger.add(
|
||||
|
@ -171,15 +171,12 @@ def quantize(
|
|||
backtrace=True,
|
||||
diagnose=False,
|
||||
)
|
||||
download_weights(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
logger_level=logger_level,
|
||||
json_output=json_output,
|
||||
)
|
||||
download_weights(model_id=model_id, revision=revision, logger_level=logger_level, json_output=json_output)
|
||||
from text_generation_server.utils.gptq.quantize import quantize
|
||||
quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir, trust_remote_code=trust_remote_code)
|
||||
|
||||
|
||||
|
||||
quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -304,11 +304,14 @@ class QuantLinearFunction(torch.autograd.Function):
|
|||
class QuantLinear(nn.Module):
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
super().__init__()
|
||||
self.qweight = qweight
|
||||
self.qzeros = qzeros
|
||||
self.scales = scales
|
||||
self.g_idx = g_idx
|
||||
self.bias = bias
|
||||
self.qweight = self.register_buffer("qweight", qweight)
|
||||
self.qzeros = self.register_buffer("qzeros", qzeros)
|
||||
self.scales = self.register_buffer("scales", scales)
|
||||
self.g_idx = self.register_buffer("g_idx", g_idx)
|
||||
if bias is not None:
|
||||
self.bias = self.register_buffer("bias", bias)
|
||||
else:
|
||||
self.bias = None
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
self.bits = bits
|
||||
|
|
|
@ -937,9 +937,9 @@ def pack(model, quantizers, bits, groupsize):
|
|||
# print('max memory(MiB):', max_memory)
|
||||
|
||||
|
||||
def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
|
||||
def quantize(model_id: str, bits: int, groupsize: int, output_dir: str, trust_remote_code: bool):
|
||||
print("loading model")
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0")
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0", trust_remote_code=trust_remote_code)
|
||||
print("LOADED model")
|
||||
model.seqlen = 2048
|
||||
|
||||
|
@ -1002,8 +1002,8 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
|
|||
from transformers.modeling_utils import shard_checkpoint
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||
state_dict["gptq_bits"] = torch.LongTensor(bits)
|
||||
state_dict["gptq_groupsize"] = torch.LongTensor(groupsize)
|
||||
state_dict["gptq_bits"] = torch.LongTensor([bits])
|
||||
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
|
||||
|
||||
max_shard_size = "10GB"
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors")
|
||||
|
|
Loading…
Reference in New Issue