Micro cleanup. (#2555)
This commit is contained in:
parent
d31a6f75cc
commit
74d3ce106e
|
@ -75,7 +75,6 @@ def load_and_merge_adapters(
|
|||
weight_names: Tuple[str],
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
|
||||
if len(adapter_parameters.adapter_info) == 1:
|
||||
adapter = next(iter(adapter_parameters.adapter_info))
|
||||
return load_module_map(
|
||||
|
@ -191,16 +190,15 @@ def load_module_map(
|
|||
weight_names: Tuple[str],
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
|
||||
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
||||
|
||||
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
|
||||
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
||||
|
||||
adapter_filenames = (
|
||||
hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors")
|
||||
hub._weight_files_from_dir(adapter_path, extension=".safetensors")
|
||||
if adapter_path
|
||||
else hub._cached_adapter_weight_files(
|
||||
else hub._cached_weight_files(
|
||||
adapter_id, revision=revision, extension=".safetensors"
|
||||
)
|
||||
)
|
||||
|
|
|
@ -18,17 +18,6 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
|||
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
||||
|
||||
|
||||
def _cached_adapter_weight_files(
|
||||
adapter_id: str, revision: Optional[str], extension: str
|
||||
) -> List[str]:
|
||||
"""Guess weight files from the cached revision snapshot directory"""
|
||||
d = _get_cached_revision_directory(adapter_id, revision)
|
||||
if not d:
|
||||
return []
|
||||
filenames = _adapter_weight_files_from_dir(d, extension)
|
||||
return filenames
|
||||
|
||||
|
||||
def _cached_weight_files(
|
||||
model_id: str, revision: Optional[str], extension: str
|
||||
) -> List[str]:
|
||||
|
@ -65,39 +54,11 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|||
if f.endswith(extension)
|
||||
and "arguments" not in f
|
||||
and "args" not in f
|
||||
and "adapter" not in f
|
||||
and "training" not in f
|
||||
]
|
||||
return filenames
|
||||
|
||||
|
||||
def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
||||
# os.walk: do not iterate, just scan for depth 1, not recursively
|
||||
# see _weight_files_from_dir, that's also what is done there
|
||||
root, _, files = next(os.walk(str(d)))
|
||||
filenames = [
|
||||
os.path.join(root, f)
|
||||
for f in files
|
||||
if f.endswith(extension)
|
||||
and "arguments" not in f
|
||||
and "args" not in f
|
||||
and "training" not in f
|
||||
]
|
||||
return filenames
|
||||
|
||||
|
||||
def _adapter_config_files_from_dir(d: Path) -> List[str]:
|
||||
# os.walk: do not iterate, just scan for depth 1, not recursively
|
||||
# see _weight_files_from_dir, that's also what is done there
|
||||
root, _, files = next(os.walk(str(d)))
|
||||
filenames = [
|
||||
os.path.join(root, f)
|
||||
for f in files
|
||||
if f.endswith(".json") and "arguments" not in f and "args" not in f
|
||||
]
|
||||
return filenames
|
||||
|
||||
|
||||
def _get_cached_revision_directory(
|
||||
model_id: str, revision: Optional[str]
|
||||
) -> Optional[Path]:
|
||||
|
|
Loading…
Reference in New Issue