feat(server): load santacoder/starcoder models with safetensors (#393)
Fix #366
This commit is contained in:
parent
c0928e6f26
commit
95d3546976
|
@ -546,11 +546,7 @@ enum LauncherError {
|
|||
WebserverCannotStart,
|
||||
}
|
||||
|
||||
fn download_convert_model(
|
||||
args: &Args,
|
||||
auto_convert: bool,
|
||||
running: Arc<AtomicBool>,
|
||||
) -> Result<(), LauncherError> {
|
||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||
let mut download_argv = vec![
|
||||
"text-generation-server".to_string(),
|
||||
"download-weights".to_string(),
|
||||
|
@ -562,11 +558,6 @@ fn download_convert_model(
|
|||
"--json-output".to_string(),
|
||||
];
|
||||
|
||||
// Auto convert weights to safetensors
|
||||
if auto_convert {
|
||||
download_argv.push("--auto-convert".to_string());
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(revision) = &args.revision {
|
||||
download_argv.push("--revision".to_string());
|
||||
|
@ -932,11 +923,8 @@ fn main() -> Result<(), LauncherError> {
|
|||
})
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
// auto_convert is only needed for sharded models as we do not require safetensors in
|
||||
// single shard mode
|
||||
let auto_convert = num_shard > 1;
|
||||
// Download and convert model weights
|
||||
download_convert_model(&args, auto_convert, running.clone())?;
|
||||
download_convert_model(&args, running.clone())?;
|
||||
|
||||
// Shared shutdown bool
|
||||
let shutdown = Arc::new(Mutex::new(false));
|
||||
|
|
|
@ -54,12 +54,7 @@ class FlashSantacoder(FlashCausalLM):
|
|||
)
|
||||
|
||||
# We do not use from_pretrained as we modified the model internal module layout
|
||||
try:
|
||||
filenames = weight_files(model_id, revision, ".bin")
|
||||
# Local files not found
|
||||
except LocalEntryNotFoundError:
|
||||
hub_files = weight_hub_files(model_id, revision, ".bin")
|
||||
filenames = download_weights(hub_files, model_id, revision)
|
||||
filenames = weight_files(model_id, revision, ".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = FlashSantacoderForCausalLM(config)
|
||||
|
@ -91,85 +86,100 @@ class FlashSantacoder(FlashCausalLM):
|
|||
transpose: bool,
|
||||
):
|
||||
for filename in filenames:
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
for key, value in state_dict.items():
|
||||
value = value.to(device if quantize is None else "cpu").to(dtype)
|
||||
with safe_open(
|
||||
filename, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
value = f.get_tensor(key)
|
||||
value = value.to(device if quantize is None else "cpu").to(dtype)
|
||||
|
||||
layer_name = ".".join(key.split(".")[:4])
|
||||
layer_name = ".".join(key.split(".")[:4])
|
||||
|
||||
# Fused qkv
|
||||
if "q_attn.weight" in key or "kv_attn.weight" in key:
|
||||
final_key = layer_name + ".c_attn.weight"
|
||||
elif "q_attn.bias" in key or "kv_attn.bias" in key:
|
||||
final_key = layer_name + ".c_attn.bias"
|
||||
# Fused qkv
|
||||
if "q_attn.weight" in key or "kv_attn.weight" in key:
|
||||
final_key = layer_name + ".c_attn.weight"
|
||||
elif "q_attn.bias" in key or "kv_attn.bias" in key:
|
||||
final_key = layer_name + ".c_attn.bias"
|
||||
|
||||
else:
|
||||
final_key = key
|
||||
|
||||
module_name, param_name = final_key.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
try:
|
||||
current_parameter_tensor = module._parameters[param_name]
|
||||
except KeyError:
|
||||
current_parameter_tensor = None
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
if transpose and (
|
||||
"c_fc.weight" in key
|
||||
or "c_proj.weight" in key
|
||||
or "q_attn.weight" in key
|
||||
or "kv_attn.weight" in key
|
||||
or "c_attn.weight" in key
|
||||
):
|
||||
# Tranpose as we use nn.Linear instead of Conv1D
|
||||
value = value.T
|
||||
|
||||
if current_parameter_tensor.device == torch.device("meta"):
|
||||
# Init qkv
|
||||
if "c_attn.weight" in final_key:
|
||||
module._parameters[param_name] = value.new_empty(
|
||||
(
|
||||
model.transformer.head_size
|
||||
* (model.transformer.num_heads + 2),
|
||||
value.shape[1],
|
||||
)
|
||||
)
|
||||
elif "c_attn.bias" in final_key:
|
||||
module._parameters[param_name] = value.new_empty(
|
||||
(
|
||||
model.transformer.head_size
|
||||
* (model.transformer.num_heads + 2)
|
||||
)
|
||||
)
|
||||
|
||||
# Copy to correct slice
|
||||
if "q_attn.weight" in key:
|
||||
module._parameters[param_name][: value.shape[0]] = value
|
||||
elif "q_attn.bias" in key:
|
||||
module._parameters[param_name][: value.shape[0]] = value
|
||||
elif "kv_attn.weight" in key:
|
||||
module._parameters[param_name][
|
||||
model.transformer.head_size * model.transformer.num_heads :
|
||||
] = value
|
||||
elif "kv_attn.bias" in key:
|
||||
module._parameters[param_name][
|
||||
model.transformer.head_size * model.transformer.num_heads :
|
||||
] = value
|
||||
else:
|
||||
if current_parameter_tensor.shape != value.shape:
|
||||
raise ValueError(
|
||||
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
|
||||
)
|
||||
module._parameters[param_name] = value
|
||||
else:
|
||||
module._buffers[param_name] = value
|
||||
final_key = key
|
||||
|
||||
del value
|
||||
module_name, param_name = final_key.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
try:
|
||||
current_parameter_tensor = module._parameters[param_name]
|
||||
except KeyError:
|
||||
current_parameter_tensor = None
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
if transpose and (
|
||||
"c_fc.weight" in key
|
||||
or "c_proj.weight" in key
|
||||
or "q_attn.weight" in key
|
||||
or "kv_attn.weight" in key
|
||||
or "c_attn.weight" in key
|
||||
):
|
||||
# Tranpose as we use nn.Linear instead of Conv1D
|
||||
value = value.T
|
||||
|
||||
if current_parameter_tensor.device == torch.device("meta"):
|
||||
# Init qkv
|
||||
if "c_attn.weight" in final_key:
|
||||
module._parameters[param_name] = value.new_empty(
|
||||
(
|
||||
model.transformer.head_size
|
||||
* (model.transformer.num_heads + 2),
|
||||
value.shape[1],
|
||||
)
|
||||
)
|
||||
elif "c_attn.bias" in final_key:
|
||||
module._parameters[param_name] = value.new_empty(
|
||||
(
|
||||
model.transformer.head_size
|
||||
* (model.transformer.num_heads + 2)
|
||||
)
|
||||
)
|
||||
|
||||
# Copy to correct slice
|
||||
if "q_attn.weight" in key:
|
||||
module._parameters[param_name][: value.shape[0]] = value
|
||||
elif "q_attn.bias" in key:
|
||||
module._parameters[param_name][: value.shape[0]] = value
|
||||
elif "kv_attn.weight" in key:
|
||||
module._parameters[param_name][
|
||||
model.transformer.head_size * model.transformer.num_heads :
|
||||
] = value
|
||||
elif "kv_attn.bias" in key:
|
||||
module._parameters[param_name][
|
||||
model.transformer.head_size * model.transformer.num_heads :
|
||||
] = value
|
||||
else:
|
||||
if current_parameter_tensor.shape != value.shape:
|
||||
raise ValueError(
|
||||
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
|
||||
)
|
||||
module._parameters[param_name] = value
|
||||
else:
|
||||
module._buffers[param_name] = value
|
||||
|
||||
del value
|
||||
|
||||
if model.lm_head.weight.device == torch.device("meta"):
|
||||
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
||||
uninitialized_parameters = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model : {uninitialized_parameters}"
|
||||
)
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
return self.tokenizer.decode(
|
||||
|
@ -389,6 +399,8 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||
if model.lm_head.weight.device == torch.device("meta"):
|
||||
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
|
Loading…
Reference in New Issue