fix: fix gpt-q params loading
This commit is contained in:
parent
28821bfd5d
commit
44b267ab22
|
@ -81,7 +81,7 @@ class BLOOMSharded(CausalLM):
|
|||
prefix="transformer",
|
||||
)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = BloomForCausalLM(config, weights)
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ class FlashLlama(FlashCausalLM):
|
|||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id)
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
if use_medusa:
|
||||
|
|
|
@ -328,7 +328,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id)
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = model_cls(config, weights)
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashGPTNeoXForCausalLM(config, weights)
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||
|
||||
config.quantize = quantize
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashRWForCausalLM(config, weights)
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||
)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashSantacoderForCausalLM(config, weights)
|
||||
|
||||
|
|
|
@ -199,7 +199,7 @@ class GalacticaSharded(CausalLM):
|
|||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = OPTForCausalLM(config, weights)
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ class GPTNeoxSharded(CausalLM):
|
|||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = GPTNeoxForCausalLM(config, weights)
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ class MPTSharded(CausalLM):
|
|||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
config.quantize = quantize
|
||||
model = MPTForCausalLM(config, weights)
|
||||
|
|
|
@ -55,7 +55,7 @@ class OPTSharded(CausalLM):
|
|||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = OPTForCausalLM(config, weights)
|
||||
|
||||
|
|
|
@ -327,13 +327,15 @@ class Weights:
|
|||
|
||||
return bits, groupsize
|
||||
|
||||
def _set_gptq_params(self, model_id):
|
||||
def _set_gptq_params(self, model_id, revision):
|
||||
filename = "config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename)
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["quantization_config"]["bits"]
|
||||
|
@ -344,7 +346,9 @@ class Weights:
|
|||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename)
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["bits"]
|
||||
|
@ -355,7 +359,9 @@ class Weights:
|
|||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename)
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["w_bit"]
|
||||
|
|
Loading…
Reference in New Issue