feat(server): allow the server to use a local weight cache (#49)
This commit is contained in:
parent
313194f6d7
commit
775115e3a5
|
@ -313,6 +313,12 @@ fn shard_manager(
|
||||||
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard
|
||||||
|
// Useful when running inside a HuggingFace Inference Endpoint
|
||||||
|
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") {
|
||||||
|
env.push(("WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into()));
|
||||||
|
};
|
||||||
|
|
||||||
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
|
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
|
||||||
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
||||||
env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into()));
|
env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into()));
|
||||||
|
|
|
@ -25,6 +25,7 @@ from transformers.generation.logits_process import (
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
|
|
||||||
|
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||||
|
|
||||||
class Sampling:
|
class Sampling:
|
||||||
def __init__(self, seed: int, device: str = "cpu"):
|
def __init__(self, seed: int, device: str = "cpu"):
|
||||||
|
@ -230,6 +231,9 @@ def try_to_load_from_cache(model_name, revision, filename):
|
||||||
|
|
||||||
def weight_files(model_name, revision=None, extension=".safetensors"):
|
def weight_files(model_name, revision=None, extension=".safetensors"):
|
||||||
"""Get the local safetensors filenames"""
|
"""Get the local safetensors filenames"""
|
||||||
|
if WEIGHTS_CACHE_OVERRIDE is not None:
|
||||||
|
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
||||||
|
|
||||||
filenames = weight_hub_files(model_name, revision, extension)
|
filenames = weight_hub_files(model_name, revision, extension)
|
||||||
files = []
|
files = []
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
|
@ -249,6 +253,9 @@ def weight_files(model_name, revision=None, extension=".safetensors"):
|
||||||
|
|
||||||
def download_weights(model_name, revision=None, extension=".safetensors"):
|
def download_weights(model_name, revision=None, extension=".safetensors"):
|
||||||
"""Download the safetensors files from the hub"""
|
"""Download the safetensors files from the hub"""
|
||||||
|
if WEIGHTS_CACHE_OVERRIDE is not None:
|
||||||
|
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
||||||
|
|
||||||
filenames = weight_hub_files(model_name, revision, extension)
|
filenames = weight_hub_files(model_name, revision, extension)
|
||||||
|
|
||||||
download_function = partial(
|
download_function = partial(
|
||||||
|
|
Loading…
Reference in New Issue