diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 31927b91..2b61f9bb 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -75,7 +75,6 @@ def load_and_merge_adapters( weight_names: Tuple[str], trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: - if len(adapter_parameters.adapter_info) == 1: adapter = next(iter(adapter_parameters.adapter_info)) return load_module_map( @@ -191,16 +190,15 @@ def load_module_map( weight_names: Tuple[str], trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: - adapter_config = LoraConfig.load(adapter_path or adapter_id, None) if not adapter_path and adapter_config.base_model_name_or_path != model_id: check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) adapter_filenames = ( - hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors") + hub._weight_files_from_dir(adapter_path, extension=".safetensors") if adapter_path - else hub._cached_adapter_weight_files( + else hub._cached_weight_files( adapter_id, revision=revision, extension=".safetensors" ) ) diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index db412aeb..f9c476ac 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -18,17 +18,6 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"] -def _cached_adapter_weight_files( - adapter_id: str, revision: Optional[str], extension: str -) -> List[str]: - """Guess weight files from the cached revision snapshot directory""" - d = _get_cached_revision_directory(adapter_id, revision) - if not d: - return [] - filenames = _adapter_weight_files_from_dir(d, extension) - return filenames - - def _cached_weight_files( model_id: str, revision: Optional[str], extension: str ) -> List[str]: @@ -65,39 +54,11 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: if f.endswith(extension) and "arguments" not in f and "args" not in f - and "adapter" not in f and "training" not in f ] return filenames -def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]: - # os.walk: do not iterate, just scan for depth 1, not recursively - # see _weight_files_from_dir, that's also what is done there - root, _, files = next(os.walk(str(d))) - filenames = [ - os.path.join(root, f) - for f in files - if f.endswith(extension) - and "arguments" not in f - and "args" not in f - and "training" not in f - ] - return filenames - - -def _adapter_config_files_from_dir(d: Path) -> List[str]: - # os.walk: do not iterate, just scan for depth 1, not recursively - # see _weight_files_from_dir, that's also what is done there - root, _, files = next(os.walk(str(d))) - filenames = [ - os.path.join(root, f) - for f in files - if f.endswith(".json") and "arguments" not in f and "args" not in f - ] - return filenames - - def _get_cached_revision_directory( model_id: str, revision: Optional[str] ) -> Optional[Path]: