fix(server): fix bnb quantization for CausalLM models (#385)
This commit is contained in:
parent
87dc034b59
commit
337afb2842
|
@ -245,6 +245,8 @@ class BLOOMSharded(BLOOM):
|
||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
|
else:
|
||||||
|
tensor = tensor.to(device)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
raise NotImplementedError("`gptq` is not implemented for now")
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
elif quantize is None:
|
elif quantize is None:
|
||||||
|
|
|
@ -364,6 +364,8 @@ class GalacticaSharded(Galactica):
|
||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
|
else:
|
||||||
|
tensor = tensor.to(device)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
raise NotImplementedError("`gptq` is not implemented for now")
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
elif quantize is None:
|
elif quantize is None:
|
||||||
|
|
|
@ -210,6 +210,8 @@ class GPTNeoxSharded(CausalLM):
|
||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
|
else:
|
||||||
|
tensor = tensor.to(device)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
raise NotImplementedError("`gptq` is not implemented for now")
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
elif quantize is None:
|
elif quantize is None:
|
||||||
|
|
|
@ -166,7 +166,7 @@ class OPTSharded(OPT):
|
||||||
|
|
||||||
tensor = tensor.contiguous().to(dtype)
|
tensor = tensor.contiguous().to(dtype)
|
||||||
|
|
||||||
if quantize:
|
if quantize == "bitsandbytes":
|
||||||
if not HAS_BITS_AND_BYTES:
|
if not HAS_BITS_AND_BYTES:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"bitsandbytes is not available on your machine either because it is not installed "
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||||||
|
@ -216,9 +216,14 @@ class OPTSharded(OPT):
|
||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
tensor = tensor.to(device)
|
tensor = tensor.to(device)
|
||||||
|
elif quantize == "gptq":
|
||||||
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
|
elif quantize is None:
|
||||||
|
tensor = tensor.to(device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
|
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
if name == "model.decoder.embed_tokens.weight":
|
if name == "model.decoder.embed_tokens.weight":
|
||||||
|
|
|
@ -222,7 +222,8 @@ class T5Sharded(Seq2SeqLM):
|
||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
|
else:
|
||||||
|
tensor = tensor.to(device)
|
||||||
elif quantize == "gptq" and not module_name.endswith("wo"):
|
elif quantize == "gptq" and not module_name.endswith("wo"):
|
||||||
raise NotImplementedError("`gptq` is not implemented for now")
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
elif quantize is None or module_name.endswith("wo"):
|
elif quantize is None or module_name.endswith("wo"):
|
||||||
|
|
Loading…
Reference in New Issue