Micro cleanup. (#2555)

This commit is contained in:
Nicolas Patry 2024-09-24 11:19:24 +02:00 committed by GitHub
parent d31a6f75cc
commit 74d3ce106e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 43 deletions

View File

@ -75,7 +75,6 @@ def load_and_merge_adapters(
weight_names: Tuple[str], weight_names: Tuple[str],
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_info) == 1: if len(adapter_parameters.adapter_info) == 1:
adapter = next(iter(adapter_parameters.adapter_info)) adapter = next(iter(adapter_parameters.adapter_info))
return load_module_map( return load_module_map(
@ -191,16 +190,15 @@ def load_module_map(
weight_names: Tuple[str], weight_names: Tuple[str],
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
adapter_config = LoraConfig.load(adapter_path or adapter_id, None) adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
if not adapter_path and adapter_config.base_model_name_or_path != model_id: if not adapter_path and adapter_config.base_model_name_or_path != model_id:
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
adapter_filenames = ( adapter_filenames = (
hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors") hub._weight_files_from_dir(adapter_path, extension=".safetensors")
if adapter_path if adapter_path
else hub._cached_adapter_weight_files( else hub._cached_weight_files(
adapter_id, revision=revision, extension=".safetensors" adapter_id, revision=revision, extension=".safetensors"
) )
) )

View File

@ -18,17 +18,6 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"] HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
def _cached_adapter_weight_files(
adapter_id: str, revision: Optional[str], extension: str
) -> List[str]:
"""Guess weight files from the cached revision snapshot directory"""
d = _get_cached_revision_directory(adapter_id, revision)
if not d:
return []
filenames = _adapter_weight_files_from_dir(d, extension)
return filenames
def _cached_weight_files( def _cached_weight_files(
model_id: str, revision: Optional[str], extension: str model_id: str, revision: Optional[str], extension: str
) -> List[str]: ) -> List[str]:
@ -65,39 +54,11 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
if f.endswith(extension) if f.endswith(extension)
and "arguments" not in f and "arguments" not in f
and "args" not in f and "args" not in f
and "adapter" not in f
and "training" not in f and "training" not in f
] ]
return filenames return filenames
def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_files_from_dir, that's also what is done there
root, _, files = next(os.walk(str(d)))
filenames = [
os.path.join(root, f)
for f in files
if f.endswith(extension)
and "arguments" not in f
and "args" not in f
and "training" not in f
]
return filenames
def _adapter_config_files_from_dir(d: Path) -> List[str]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_files_from_dir, that's also what is done there
root, _, files = next(os.walk(str(d)))
filenames = [
os.path.join(root, f)
for f in files
if f.endswith(".json") and "arguments" not in f and "args" not in f
]
return filenames
def _get_cached_revision_directory( def _get_cached_revision_directory(
model_id: str, revision: Optional[str] model_id: str, revision: Optional[str]
) -> Optional[Path]: ) -> Optional[Path]: