fix: fix gpt-q params loading
This commit is contained in:
parent
28821bfd5d
commit
44b267ab22
|
@ -81,7 +81,7 @@ class BLOOMSharded(CausalLM):
|
||||||
prefix="transformer",
|
prefix="transformer",
|
||||||
)
|
)
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = BloomForCausalLM(config, weights)
|
model = BloomForCausalLM(config, weights)
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,7 @@ class FlashLlama(FlashCausalLM):
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
if config.quantize in ["gptq", "awq"]:
|
if config.quantize in ["gptq", "awq"]:
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = FlashLlamaForCausalLM(config, weights)
|
model = FlashLlamaForCausalLM(config, weights)
|
||||||
if use_medusa:
|
if use_medusa:
|
||||||
|
|
|
@ -328,7 +328,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
if config.quantize in ["gptq", "awq"]:
|
if config.quantize in ["gptq", "awq"]:
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = model_cls(config, weights)
|
model = model_cls(config, weights)
|
||||||
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||||
)
|
)
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = FlashGPTNeoXForCausalLM(config, weights)
|
model = FlashGPTNeoXForCausalLM(config, weights)
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@ class FlashRWSharded(FlashCausalLM):
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = FlashRWForCausalLM(config, weights)
|
model = FlashRWForCausalLM(config, weights)
|
||||||
|
|
||||||
|
|
|
@ -63,7 +63,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||||
)
|
)
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = FlashSantacoderForCausalLM(config, weights)
|
model = FlashSantacoderForCausalLM(config, weights)
|
||||||
|
|
||||||
|
|
|
@ -199,7 +199,7 @@ class GalacticaSharded(CausalLM):
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||||
)
|
)
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = OPTForCausalLM(config, weights)
|
model = OPTForCausalLM(config, weights)
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ class GPTNeoxSharded(CausalLM):
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||||
)
|
)
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = GPTNeoxForCausalLM(config, weights)
|
model = GPTNeoxForCausalLM(config, weights)
|
||||||
|
|
||||||
|
|
|
@ -81,7 +81,7 @@ class MPTSharded(CausalLM):
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
model = MPTForCausalLM(config, weights)
|
model = MPTForCausalLM(config, weights)
|
||||||
|
|
|
@ -55,7 +55,7 @@ class OPTSharded(CausalLM):
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||||
)
|
)
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = OPTForCausalLM(config, weights)
|
model = OPTForCausalLM(config, weights)
|
||||||
|
|
||||||
|
|
|
@ -327,13 +327,15 @@ class Weights:
|
||||||
|
|
||||||
return bits, groupsize
|
return bits, groupsize
|
||||||
|
|
||||||
def _set_gptq_params(self, model_id):
|
def _set_gptq_params(self, model_id, revision):
|
||||||
filename = "config.json"
|
filename = "config.json"
|
||||||
try:
|
try:
|
||||||
if os.path.exists(os.path.join(model_id, filename)):
|
if os.path.exists(os.path.join(model_id, filename)):
|
||||||
filename = os.path.join(model_id, filename)
|
filename = os.path.join(model_id, filename)
|
||||||
else:
|
else:
|
||||||
filename = hf_hub_download(model_id, filename=filename)
|
filename = hf_hub_download(
|
||||||
|
model_id, filename=filename, revision=revision
|
||||||
|
)
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["quantization_config"]["bits"]
|
self.gptq_bits = data["quantization_config"]["bits"]
|
||||||
|
@ -344,7 +346,9 @@ class Weights:
|
||||||
if os.path.exists(os.path.join(model_id, filename)):
|
if os.path.exists(os.path.join(model_id, filename)):
|
||||||
filename = os.path.join(model_id, filename)
|
filename = os.path.join(model_id, filename)
|
||||||
else:
|
else:
|
||||||
filename = hf_hub_download(model_id, filename=filename)
|
filename = hf_hub_download(
|
||||||
|
model_id, filename=filename, revision=revision
|
||||||
|
)
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["bits"]
|
self.gptq_bits = data["bits"]
|
||||||
|
@ -355,7 +359,9 @@ class Weights:
|
||||||
if os.path.exists(os.path.join(model_id, filename)):
|
if os.path.exists(os.path.join(model_id, filename)):
|
||||||
filename = os.path.join(model_id, filename)
|
filename = os.path.join(model_id, filename)
|
||||||
else:
|
else:
|
||||||
filename = hf_hub_download(model_id, filename=filename)
|
filename = hf_hub_download(
|
||||||
|
model_id, filename=filename, revision=revision
|
||||||
|
)
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["w_bit"]
|
self.gptq_bits = data["w_bit"]
|
||||||
|
|
Loading…
Reference in New Issue