diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0810d979..7ee8bf1b 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -546,11 +546,7 @@ enum LauncherError { WebserverCannotStart, } -fn download_convert_model( - args: &Args, - auto_convert: bool, - running: Arc, -) -> Result<(), LauncherError> { +fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { let mut download_argv = vec![ "text-generation-server".to_string(), "download-weights".to_string(), @@ -562,11 +558,6 @@ fn download_convert_model( "--json-output".to_string(), ]; - // Auto convert weights to safetensors - if auto_convert { - download_argv.push("--auto-convert".to_string()); - } - // Model optional revision if let Some(revision) = &args.revision { download_argv.push("--revision".to_string()); @@ -932,11 +923,8 @@ fn main() -> Result<(), LauncherError> { }) .expect("Error setting Ctrl-C handler"); - // auto_convert is only needed for sharded models as we do not require safetensors in - // single shard mode - let auto_convert = num_shard > 1; // Download and convert model weights - download_convert_model(&args, auto_convert, running.clone())?; + download_convert_model(&args, running.clone())?; // Shared shutdown bool let shutdown = Arc::new(Mutex::new(false)); diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 482e0f54..7907e2cc 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -54,12 +54,7 @@ class FlashSantacoder(FlashCausalLM): ) # We do not use from_pretrained as we modified the model internal module layout - try: - filenames = weight_files(model_id, revision, ".bin") - # Local files not found - except LocalEntryNotFoundError: - hub_files = weight_hub_files(model_id, revision, ".bin") - filenames = download_weights(hub_files, model_id, revision) + filenames = weight_files(model_id, revision, ".safetensors") with init_empty_weights(): model = FlashSantacoderForCausalLM(config) @@ -91,85 +86,100 @@ class FlashSantacoder(FlashCausalLM): transpose: bool, ): for filename in filenames: - state_dict = torch.load(filename, map_location="cpu") - for key, value in state_dict.items(): - value = value.to(device if quantize is None else "cpu").to(dtype) + with safe_open( + filename, framework="pt", device=str(device) if quantize is None else "cpu" + ) as f: + for key in f.keys(): + value = f.get_tensor(key) + value = value.to(device if quantize is None else "cpu").to(dtype) - layer_name = ".".join(key.split(".")[:4]) + layer_name = ".".join(key.split(".")[:4]) - # Fused qkv - if "q_attn.weight" in key or "kv_attn.weight" in key: - final_key = layer_name + ".c_attn.weight" - elif "q_attn.bias" in key or "kv_attn.bias" in key: - final_key = layer_name + ".c_attn.bias" + # Fused qkv + if "q_attn.weight" in key or "kv_attn.weight" in key: + final_key = layer_name + ".c_attn.weight" + elif "q_attn.bias" in key or "kv_attn.bias" in key: + final_key = layer_name + ".c_attn.bias" - else: - final_key = key - - module_name, param_name = final_key.rsplit(".", 1) - module = model.get_submodule(module_name) - - try: - current_parameter_tensor = module._parameters[param_name] - except KeyError: - current_parameter_tensor = None - - if current_parameter_tensor is not None: - if transpose and ( - "c_fc.weight" in key - or "c_proj.weight" in key - or "q_attn.weight" in key - or "kv_attn.weight" in key - or "c_attn.weight" in key - ): - # Tranpose as we use nn.Linear instead of Conv1D - value = value.T - - if current_parameter_tensor.device == torch.device("meta"): - # Init qkv - if "c_attn.weight" in final_key: - module._parameters[param_name] = value.new_empty( - ( - model.transformer.head_size - * (model.transformer.num_heads + 2), - value.shape[1], - ) - ) - elif "c_attn.bias" in final_key: - module._parameters[param_name] = value.new_empty( - ( - model.transformer.head_size - * (model.transformer.num_heads + 2) - ) - ) - - # Copy to correct slice - if "q_attn.weight" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "q_attn.bias" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "kv_attn.weight" in key: - module._parameters[param_name][ - model.transformer.head_size * model.transformer.num_heads : - ] = value - elif "kv_attn.bias" in key: - module._parameters[param_name][ - model.transformer.head_size * model.transformer.num_heads : - ] = value else: - if current_parameter_tensor.shape != value.shape: - raise ValueError( - f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" - ) - module._parameters[param_name] = value - else: - module._buffers[param_name] = value + final_key = key - del value + module_name, param_name = final_key.rsplit(".", 1) + module = model.get_submodule(module_name) + + try: + current_parameter_tensor = module._parameters[param_name] + except KeyError: + current_parameter_tensor = None + + if current_parameter_tensor is not None: + if transpose and ( + "c_fc.weight" in key + or "c_proj.weight" in key + or "q_attn.weight" in key + or "kv_attn.weight" in key + or "c_attn.weight" in key + ): + # Tranpose as we use nn.Linear instead of Conv1D + value = value.T + + if current_parameter_tensor.device == torch.device("meta"): + # Init qkv + if "c_attn.weight" in final_key: + module._parameters[param_name] = value.new_empty( + ( + model.transformer.head_size + * (model.transformer.num_heads + 2), + value.shape[1], + ) + ) + elif "c_attn.bias" in final_key: + module._parameters[param_name] = value.new_empty( + ( + model.transformer.head_size + * (model.transformer.num_heads + 2) + ) + ) + + # Copy to correct slice + if "q_attn.weight" in key: + module._parameters[param_name][: value.shape[0]] = value + elif "q_attn.bias" in key: + module._parameters[param_name][: value.shape[0]] = value + elif "kv_attn.weight" in key: + module._parameters[param_name][ + model.transformer.head_size * model.transformer.num_heads : + ] = value + elif "kv_attn.bias" in key: + module._parameters[param_name][ + model.transformer.head_size * model.transformer.num_heads : + ] = value + else: + if current_parameter_tensor.shape != value.shape: + raise ValueError( + f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" + ) + module._parameters[param_name] = value + else: + module._buffers[param_name] = value + + del value + + if model.lm_head.weight.device == torch.device("meta"): + model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) torch.cuda.empty_cache() model.post_load_weights(quantize) + uninitialized_parameters = [] + for n, p in model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model : {uninitialized_parameters}" + ) + def decode(self, generated_ids: List[int]) -> str: # Do not skip special tokens as they are used for custom parsing rules of the generated text return self.tokenizer.decode( @@ -389,6 +399,8 @@ class FlashSantacoderSharded(FlashSantacoder): else: module._buffers[param_name] = tensor - model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) + if model.lm_head.weight.device == torch.device("meta"): + model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) + torch.cuda.empty_cache() model.post_load_weights(quantize)