diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 0151b017..8e8daad3 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -74,7 +74,7 @@ class BLOOMSharded(CausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer", ) if config.quantize == "gptq": weights._set_gptq_params(model_id) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 8a19fd9f..4bae8cc0 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -16,6 +16,7 @@ class Weights: dtype, process_group, aliases: Optional[Dict[str, List[str]]] = None, + prefix: Optional[str] = None ): routing = {} for filename in filenames: @@ -33,6 +34,7 @@ class Weights: self.device = device self.dtype = dtype self.process_group = process_group + self.prefix = prefix self._handles = {} def _get_handle(self, filename): @@ -43,15 +45,22 @@ class Weights: return self._handles[filename] def get_filename(self, tensor_name: str) -> (str, str): - filename = self.routing.get(tensor_name, None) - if filename is None: - aliases = self.aliases.get(tensor_name, []) + + names = [tensor_name] + if self.prefix is not None: + prefixed = f"{self.prefix}.{tensor_name}" + names.append(prefixed) + for name in names: + filename = self.routing.get(name, None) + if filename is not None: + return str(filename), name + + aliases = self.aliases.get(name, []) for alias in aliases: filename = self.routing.get(alias, None) if filename is not None: return str(filename), alias - raise RuntimeError(f"weight {tensor_name} does not exist") - return str(filename), tensor_name + raise RuntimeError(f"weight {tensor_name} does not exist") def _get_slice(self, tensor_name: str): filename, tensor_name = self.get_filename(tensor_name)