From 49b4b33e805d0ffee62688fe2607120b0c759e3d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 23 Jun 2023 12:40:46 +0200 Subject: [PATCH] feat(server): Update convert logic. (#483) Should be more robust to shared tensors (ok when using `from_pretrained). But forcing us to add new checks in our loading code (since the chosen key to keep might be different from `transformers`). --------- Co-authored-by: Ubuntu --- .../models/flash_santacoder.py | 3 +- .../text_generation_server/utils/convert.py | 89 ++++++------------- .../text_generation_server/utils/weights.py | 22 +++-- 3 files changed, 46 insertions(+), 68 deletions(-) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 54634e4a..a71c0061 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -52,7 +52,8 @@ class FlashSantacoderSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, device=device, dtype=dtype, process_group=self.process_group, + aliases = {"transformer.wte.weight": ["lm_head.weight"]} ) model = FlashSantacoderForCausalLM(config, weights) diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index c4e79432..0e4adaba 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -1,76 +1,45 @@ import datetime import torch +import os -from collections import defaultdict from loguru import logger from pathlib import Path -from safetensors.torch import save_file -from safetensors import safe_open -from typing import Dict, List - - -def check_file_size(source_file: Path, target_file: Path): - """ - Check that two files are close in size - """ - source_file_size = source_file.stat().st_size - target_file_size = target_file.stat().st_size - - if (source_file_size - target_file_size) / source_file_size > 0.05: - raise RuntimeError( - f"""The file size different is more than 5%: - - {source_file}: {source_file_size} - - {target_file}: {target_file_size} - """ - ) - - -def remove_shared_pointers(tensors: Dict[str, torch.Tensor]): - """ - For a Dict of tensors, check if two or more tensors point to the same underlying memory and - remove them - """ - ptrs = defaultdict(list) - for k, v in tensors.items(): - ptrs[v.data_ptr()].append(k) - - # Iterate over all found memory addresses - for ptr, names in ptrs.items(): - if len(names) > 1: - # Multiple tensors are point to the same memory - # Only keep the first tensor - for name in names[1:]: - tensors.pop(name) +from safetensors.torch import save_file, _remove_duplicate_names, load_file +from typing import List def convert_file(pt_file: Path, sf_file: Path): """ Convert a pytorch file to a safetensors file + This will remove duplicate tensors from the file. + + Unfortunately, this might not respect *transformers* convention. + Forcing us to check for potentially different keys during load when looking + for specific tensors (making tensor sharing explicit). """ - logger.info(f"Convert {pt_file} to {sf_file}.") + loaded = torch.load(pt_file, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + to_removes = _remove_duplicate_names(loaded) - pt_state = torch.load(pt_file, map_location="cpu") - if "state_dict" in pt_state: - pt_state = pt_state["state_dict"] + metadata = {"format": "pt"} + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if to_remove not in metadata: + metadata[to_remove] = kept_name + del loaded[to_remove] + # Force tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} - remove_shared_pointers(pt_state) - - # Tensors need to be contiguous - pt_state = {k: v.contiguous() for k, v in pt_state.items()} - - sf_file.parent.mkdir(parents=True, exist_ok=True) - save_file(pt_state, str(sf_file), metadata={"format": "pt"}) - - # Check that both files are close in size - check_file_size(pt_file, sf_file) - - # Load safetensors state - for k in pt_state: - pt_tensor = pt_state[k] - with safe_open(sf_file, framework="pt") as f: - sf_tensor = f.get_tensor(k) - if not torch.equal(pt_tensor, sf_tensor): - raise RuntimeError(f"The output tensors do not match for key {k}") + dirname = os.path.dirname(sf_file) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_file, metadata=metadata) + reloaded = load_file(sf_file) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") def convert_files(pt_files: List[Path], sf_files: List[Path]): diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 76a4f65a..88347a6a 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,10 +1,10 @@ from pathlib import Path -from typing import List +from typing import List, Dict, Optional from safetensors import safe_open class Weights: - def __init__(self, filenames: List[Path], device, dtype, process_group): + def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None): routing = {} for filename in filenames: with safe_open(filename, framework="pytorch") as f: @@ -14,6 +14,9 @@ class Weights: f"Key {k} was found in multiple files: {filename} and {routing[k]}" ) routing[k] = filename + if aliases is None: + aliases = {} + self.aliases = aliases self.routing = routing self.device = device self.dtype = dtype @@ -27,14 +30,19 @@ class Weights: return self._handles[filename] - def get_filename(self, tensor_name: str) -> str: + def get_filename(self, tensor_name: str) -> (str, str): filename = self.routing.get(tensor_name, None) if filename is None: + aliases = self.aliases.get(tensor_name, []) + for alias in aliases: + filename = self.routing.get(alias, None) + if filename is not None: + return str(filename), alias raise RuntimeError(f"weight {tensor_name} does not exist") - return str(filename) + return str(filename), tensor_name def _get_slice(self, tensor_name: str): - filename = self.get_filename(tensor_name) + filename, tensor_name= self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) return slice_ @@ -43,7 +51,7 @@ class Weights: return self._get_slice(tensor_name).get_shape() def get_tensor(self, tensor_name: str): - filename = self.get_filename(tensor_name) + filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) tensor = tensor.to(dtype=self.dtype) @@ -51,7 +59,7 @@ class Weights: return tensor def get_sharded(self, tensor_name: str, dim: int): - filename = self.get_filename(tensor_name) + filename, tensor_name = self.get_filename(tensor_name) world_size = self.process_group.size() rank = self.process_group.rank()