fix: fix gpt-q params loading

This commit is contained in:
OlivierDehaene 2023-12-14 11:02:16 +01:00
parent 28821bfd5d
commit 44b267ab22
11 changed files with 20 additions and 14 deletions

View File

@ -81,7 +81,7 @@ class BLOOMSharded(CausalLM):
prefix="transformer", prefix="transformer",
) )
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id, revision)
model = BloomForCausalLM(config, weights) model = BloomForCausalLM(config, weights)

View File

@ -64,7 +64,7 @@ class FlashLlama(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id, revision)
model = FlashLlamaForCausalLM(config, weights) model = FlashLlamaForCausalLM(config, weights)
if use_medusa: if use_medusa:

View File

@ -328,7 +328,7 @@ class BaseFlashMistral(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id, revision)
model = model_cls(config, weights) model = model_cls(config, weights)

View File

@ -53,7 +53,7 @@ class FlashNeoXSharded(FlashCausalLM):
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group
) )
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id, revision)
model = FlashGPTNeoXForCausalLM(config, weights) model = FlashGPTNeoXForCausalLM(config, weights)

View File

@ -62,7 +62,7 @@ class FlashRWSharded(FlashCausalLM):
config.quantize = quantize config.quantize = quantize
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id, revision)
model = FlashRWForCausalLM(config, weights) model = FlashRWForCausalLM(config, weights)

View File

@ -63,7 +63,7 @@ class FlashSantacoderSharded(FlashCausalLM):
aliases={"transformer.wte.weight": ["lm_head.weight"]}, aliases={"transformer.wte.weight": ["lm_head.weight"]},
) )
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id, revision)
model = FlashSantacoderForCausalLM(config, weights) model = FlashSantacoderForCausalLM(config, weights)

View File

@ -199,7 +199,7 @@ class GalacticaSharded(CausalLM):
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group
) )
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights) model = OPTForCausalLM(config, weights)

View File

@ -57,7 +57,7 @@ class GPTNeoxSharded(CausalLM):
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group
) )
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id, revision)
model = GPTNeoxForCausalLM(config, weights) model = GPTNeoxForCausalLM(config, weights)

View File

@ -81,7 +81,7 @@ class MPTSharded(CausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id, revision)
config.quantize = quantize config.quantize = quantize
model = MPTForCausalLM(config, weights) model = MPTForCausalLM(config, weights)

View File

@ -55,7 +55,7 @@ class OPTSharded(CausalLM):
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group
) )
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights) model = OPTForCausalLM(config, weights)

View File

@ -327,13 +327,15 @@ class Weights:
return bits, groupsize return bits, groupsize
def _set_gptq_params(self, model_id): def _set_gptq_params(self, model_id, revision):
filename = "config.json" filename = "config.json"
try: try:
if os.path.exists(os.path.join(model_id, filename)): if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename) filename = os.path.join(model_id, filename)
else: else:
filename = hf_hub_download(model_id, filename=filename) filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f: with open(filename, "r") as f:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"] self.gptq_bits = data["quantization_config"]["bits"]
@ -344,7 +346,9 @@ class Weights:
if os.path.exists(os.path.join(model_id, filename)): if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename) filename = os.path.join(model_id, filename)
else: else:
filename = hf_hub_download(model_id, filename=filename) filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f: with open(filename, "r") as f:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["bits"] self.gptq_bits = data["bits"]
@ -355,7 +359,9 @@ class Weights:
if os.path.exists(os.path.join(model_id, filename)): if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename) filename = os.path.join(model_id, filename)
else: else:
filename = hf_hub_download(model_id, filename=filename) filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f: with open(filename, "r") as f:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["w_bit"] self.gptq_bits = data["w_bit"]