fix: pass missing revision arg for lora adapter when loading multiple… (#2510)

fix: pass missing revision arg for lora adapter when loading multiple adapters
This commit is contained in:
drbh 2024-09-12 17:04:52 +02:00 committed by GitHub
parent d95c670ada
commit 628334d336
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 6 deletions

View File

@ -77,12 +77,12 @@ def load_and_merge_adapters(
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_info) == 1:
adapter_info = next(iter(adapter_parameters.adapter_info))
adapter = next(iter(adapter_parameters.adapter_info))
return load_module_map(
model_id,
adapter_info.revision,
adapter_info.id,
adapter_info.path,
adapter.revision,
adapter.id,
adapter.path,
weight_names,
trust_remote_code,
)
@ -90,7 +90,6 @@ def load_and_merge_adapters(
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
return _load_and_merge(
model_id,
adapter_params.revision,
adapter_params,
weight_names,
trust_remote_code,
@ -109,7 +108,6 @@ class AdapterParametersContainer:
@lru_cache(maxsize=32)
def _load_and_merge(
model_id: str,
revision: str,
adapter_params: AdapterParametersContainer,
weight_names: Tuple[str],
trust_remote_code: bool = False,
@ -126,6 +124,7 @@ def _load_and_merge(
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
load_module_map(
model_id,
adapter.revision,
adapter.id,
adapter.path,
weight_names,