fix(server): Handle loading from local files for MPT (#534)

This PR allows the MPT model to be loaded from local files. Without this
change, an exception will be thrown by `hf_hub_download` function if
`model_id` is a local path.
This commit is contained in:
Antoni Baum 2023-07-04 09:37:25 -07:00 committed by GitHub
parent e6888d0e87
commit 2a101207d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 1 deletions

View File

@ -1,6 +1,7 @@
import torch
import torch.distributed
from pathlib import Path
from typing import Optional, Type
from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
@ -60,7 +61,12 @@ class MPTSharded(CausalLM):
)
tokenizer.pad_token = tokenizer.eos_token
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
# If model_id is a local path, load the file directly
local_path = Path(model_id, "config.json")
if local_path.exists():
filename = str(local_path.resolve())
else:
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
with open(filename, "r") as f:
config = json.load(f)
config = PretrainedConfig(**config)