Micro cleanup. (#2555)
This commit is contained in:
parent
d31a6f75cc
commit
74d3ce106e
|
@ -75,7 +75,6 @@ def load_and_merge_adapters(
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
|
|
||||||
if len(adapter_parameters.adapter_info) == 1:
|
if len(adapter_parameters.adapter_info) == 1:
|
||||||
adapter = next(iter(adapter_parameters.adapter_info))
|
adapter = next(iter(adapter_parameters.adapter_info))
|
||||||
return load_module_map(
|
return load_module_map(
|
||||||
|
@ -191,16 +190,15 @@ def load_module_map(
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
|
|
||||||
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
||||||
|
|
||||||
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
|
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
|
||||||
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
||||||
|
|
||||||
adapter_filenames = (
|
adapter_filenames = (
|
||||||
hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors")
|
hub._weight_files_from_dir(adapter_path, extension=".safetensors")
|
||||||
if adapter_path
|
if adapter_path
|
||||||
else hub._cached_adapter_weight_files(
|
else hub._cached_weight_files(
|
||||||
adapter_id, revision=revision, extension=".safetensors"
|
adapter_id, revision=revision, extension=".safetensors"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,17 +18,6 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||||
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
||||||
|
|
||||||
|
|
||||||
def _cached_adapter_weight_files(
|
|
||||||
adapter_id: str, revision: Optional[str], extension: str
|
|
||||||
) -> List[str]:
|
|
||||||
"""Guess weight files from the cached revision snapshot directory"""
|
|
||||||
d = _get_cached_revision_directory(adapter_id, revision)
|
|
||||||
if not d:
|
|
||||||
return []
|
|
||||||
filenames = _adapter_weight_files_from_dir(d, extension)
|
|
||||||
return filenames
|
|
||||||
|
|
||||||
|
|
||||||
def _cached_weight_files(
|
def _cached_weight_files(
|
||||||
model_id: str, revision: Optional[str], extension: str
|
model_id: str, revision: Optional[str], extension: str
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
@ -65,39 +54,11 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
||||||
if f.endswith(extension)
|
if f.endswith(extension)
|
||||||
and "arguments" not in f
|
and "arguments" not in f
|
||||||
and "args" not in f
|
and "args" not in f
|
||||||
and "adapter" not in f
|
|
||||||
and "training" not in f
|
and "training" not in f
|
||||||
]
|
]
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|
||||||
# os.walk: do not iterate, just scan for depth 1, not recursively
|
|
||||||
# see _weight_files_from_dir, that's also what is done there
|
|
||||||
root, _, files = next(os.walk(str(d)))
|
|
||||||
filenames = [
|
|
||||||
os.path.join(root, f)
|
|
||||||
for f in files
|
|
||||||
if f.endswith(extension)
|
|
||||||
and "arguments" not in f
|
|
||||||
and "args" not in f
|
|
||||||
and "training" not in f
|
|
||||||
]
|
|
||||||
return filenames
|
|
||||||
|
|
||||||
|
|
||||||
def _adapter_config_files_from_dir(d: Path) -> List[str]:
|
|
||||||
# os.walk: do not iterate, just scan for depth 1, not recursively
|
|
||||||
# see _weight_files_from_dir, that's also what is done there
|
|
||||||
root, _, files = next(os.walk(str(d)))
|
|
||||||
filenames = [
|
|
||||||
os.path.join(root, f)
|
|
||||||
for f in files
|
|
||||||
if f.endswith(".json") and "arguments" not in f and "args" not in f
|
|
||||||
]
|
|
||||||
return filenames
|
|
||||||
|
|
||||||
|
|
||||||
def _get_cached_revision_directory(
|
def _get_cached_revision_directory(
|
||||||
model_id: str, revision: Optional[str]
|
model_id: str, revision: Optional[str]
|
||||||
) -> Optional[Path]:
|
) -> Optional[Path]:
|
||||||
|
|
Loading…
Reference in New Issue