fix(server): Handle loading from local files for MPT (#534)
This PR allows the MPT model to be loaded from local files. Without this change, an exception will be thrown by `hf_hub_download` function if `model_id` is a local path.
This commit is contained in:
parent
e6888d0e87
commit
2a101207d4
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
|
||||
|
@ -60,7 +61,12 @@ class MPTSharded(CausalLM):
|
|||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
|
||||
# If model_id is a local path, load the file directly
|
||||
local_path = Path(model_id, "config.json")
|
||||
if local_path.exists():
|
||||
filename = str(local_path.resolve())
|
||||
else:
|
||||
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
|
||||
with open(filename, "r") as f:
|
||||
config = json.load(f)
|
||||
config = PretrainedConfig(**config)
|
||||
|
|
Loading…
Reference in New Issue