feat(server): allow the server to use a local weight cache (#49)
This commit is contained in:
parent
313194f6d7
commit
775115e3a5
|
@ -313,6 +313,12 @@ fn shard_manager(
|
|||
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||
};
|
||||
|
||||
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard
|
||||
// Useful when running inside a HuggingFace Inference Endpoint
|
||||
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") {
|
||||
env.push(("WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into()));
|
||||
};
|
||||
|
||||
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
|
||||
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
||||
env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into()));
|
||||
|
|
|
@ -25,6 +25,7 @@ from transformers.generation.logits_process import (
|
|||
|
||||
from text_generation.pb import generate_pb2
|
||||
|
||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||
|
||||
class Sampling:
|
||||
def __init__(self, seed: int, device: str = "cpu"):
|
||||
|
@ -230,6 +231,9 @@ def try_to_load_from_cache(model_name, revision, filename):
|
|||
|
||||
def weight_files(model_name, revision=None, extension=".safetensors"):
|
||||
"""Get the local safetensors filenames"""
|
||||
if WEIGHTS_CACHE_OVERRIDE is not None:
|
||||
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
||||
|
||||
filenames = weight_hub_files(model_name, revision, extension)
|
||||
files = []
|
||||
for filename in filenames:
|
||||
|
@ -249,6 +253,9 @@ def weight_files(model_name, revision=None, extension=".safetensors"):
|
|||
|
||||
def download_weights(model_name, revision=None, extension=".safetensors"):
|
||||
"""Download the safetensors files from the hub"""
|
||||
if WEIGHTS_CACHE_OVERRIDE is not None:
|
||||
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
||||
|
||||
filenames = weight_hub_files(model_name, revision, extension)
|
||||
|
||||
download_function = partial(
|
||||
|
|
Loading…
Reference in New Issue