fix: pass missing revision arg for lora adapter when loading multiple… (#2510)

fix: pass missing revision arg for lora adapter when loading multiple adapters
This commit is contained in:
drbh 2024-09-12 17:04:52 +02:00 committed by GitHub
parent d95c670ada
commit 628334d336
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 6 deletions

View File

@ -77,12 +77,12 @@ def load_and_merge_adapters(
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_info) == 1: if len(adapter_parameters.adapter_info) == 1:
adapter_info = next(iter(adapter_parameters.adapter_info)) adapter = next(iter(adapter_parameters.adapter_info))
return load_module_map( return load_module_map(
model_id, model_id,
adapter_info.revision, adapter.revision,
adapter_info.id, adapter.id,
adapter_info.path, adapter.path,
weight_names, weight_names,
trust_remote_code, trust_remote_code,
) )
@ -90,7 +90,6 @@ def load_and_merge_adapters(
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index) adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
return _load_and_merge( return _load_and_merge(
model_id, model_id,
adapter_params.revision,
adapter_params, adapter_params,
weight_names, weight_names,
trust_remote_code, trust_remote_code,
@ -109,7 +108,6 @@ class AdapterParametersContainer:
@lru_cache(maxsize=32) @lru_cache(maxsize=32)
def _load_and_merge( def _load_and_merge(
model_id: str, model_id: str,
revision: str,
adapter_params: AdapterParametersContainer, adapter_params: AdapterParametersContainer,
weight_names: Tuple[str], weight_names: Tuple[str],
trust_remote_code: bool = False, trust_remote_code: bool = False,
@ -126,6 +124,7 @@ def _load_and_merge(
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
load_module_map( load_module_map(
model_id, model_id,
adapter.revision,
adapter.id, adapter.id,
adapter.path, adapter.path,
weight_names, weight_names,