feat(server): allow the server to use a local weight cache (#49)

This commit is contained in:
OlivierDehaene 2023-02-01 16:22:10 +01:00 committed by GitHub
parent 313194f6d7
commit 775115e3a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 0 deletions

View File

@ -313,6 +313,12 @@ fn shard_manager(
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") {
env.push(("WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into()));
};
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into()));

View File

@ -25,6 +25,7 @@ from transformers.generation.logits_process import (
from text_generation.pb import generate_pb2
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
@ -230,6 +231,9 @@ def try_to_load_from_cache(model_name, revision, filename):
def weight_files(model_name, revision=None, extension=".safetensors"):
"""Get the local safetensors filenames"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_name, revision, extension)
files = []
for filename in filenames:
@ -249,6 +253,9 @@ def weight_files(model_name, revision=None, extension=".safetensors"):
def download_weights(model_name, revision=None, extension=".safetensors"):
"""Download the safetensors files from the hub"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_name, revision, extension)
download_function = partial(