From f830706b215a94e1cdedeb89766b2822d8ceac24 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 31 Jan 2023 18:53:56 +0100 Subject: [PATCH] feat(server): Support GPT-Neox (#39) --- README.md | 1 + launcher/src/main.rs | 11 + server/tests/test_utils.py | 6 +- server/text_generation/cli.py | 7 +- server/text_generation/models/__init__.py | 31 ++- server/text_generation/models/bloom.py | 16 +- server/text_generation/models/causal_lm.py | 7 +- server/text_generation/models/galactica.py | 25 +- server/text_generation/models/gpt_neox.py | 244 ++++++++++++++++++++ server/text_generation/models/santacoder.py | 7 +- server/text_generation/models/seq2seq_lm.py | 7 +- server/text_generation/server.py | 8 +- server/text_generation/utils.py | 63 ++++- 13 files changed, 386 insertions(+), 47 deletions(-) create mode 100644 server/text_generation/models/gpt_neox.py diff --git a/README.md b/README.md index 2635f641..67940ae8 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets. - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) - ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated) - [SantaCoder](https://huggingface.co/bigcode/santacoder) +- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision refs/pr/13` Other models are supported on a best effort basis using: diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f897d2cd..bd449e28 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -21,6 +21,8 @@ struct Args { #[clap(default_value = "bigscience/bloom-560m", long, env)] model_name: String, #[clap(long, env)] + revision: Option, + #[clap(long, env)] num_shard: Option, #[clap(long, env)] quantize: bool, @@ -48,6 +50,7 @@ fn main() -> ExitCode { // Pattern match configuration let Args { model_name, + revision, num_shard, quantize, max_concurrent_requests, @@ -90,6 +93,7 @@ fn main() -> ExitCode { // Start shard processes for rank in 0..num_shard { let model_name = model_name.clone(); + let revision = revision.clone(); let uds_path = shard_uds_path.clone(); let master_addr = master_addr.clone(); let status_sender = status_sender.clone(); @@ -98,6 +102,7 @@ fn main() -> ExitCode { thread::spawn(move || { shard_manager( model_name, + revision, quantize, uds_path, rank, @@ -252,6 +257,7 @@ enum ShardStatus { #[allow(clippy::too_many_arguments)] fn shard_manager( model_name: String, + revision: Option, quantize: bool, uds_path: String, rank: usize, @@ -288,6 +294,11 @@ fn shard_manager( shard_argv.push("--quantize".to_string()) } + if let Some(revision) = revision { + shard_argv.push("--revision".to_string()); + shard_argv.push(revision) + } + let mut env = vec![ ("RANK".into(), rank.to_string().into()), ("WORLD_SIZE".into(), world_size.to_string().into()), diff --git a/server/tests/test_utils.py b/server/tests/test_utils.py index 643cb834..1dc6801b 100644 --- a/server/tests/test_utils.py +++ b/server/tests/test_utils.py @@ -1,5 +1,7 @@ import pytest +from huggingface_hub.utils import RevisionNotFoundError + from text_generation.utils import ( weight_hub_files, download_weights, @@ -51,7 +53,7 @@ def test_weight_hub_files_llm(): def test_weight_hub_files_empty(): - filenames = weight_hub_files("bigscience/bloom", ".errors") + filenames = weight_hub_files("bigscience/bloom", extension=".errors") assert filenames == [] @@ -62,5 +64,7 @@ def test_download_weights(): def test_weight_files_error(): + with pytest.raises(RevisionNotFoundError): + weight_files("bigscience/bloom-560m", revision="error") with pytest.raises(LocalEntryNotFoundError): weight_files("bert-base-uncased") diff --git a/server/text_generation/cli.py b/server/text_generation/cli.py index 8101cea4..b133cb0a 100644 --- a/server/text_generation/cli.py +++ b/server/text_generation/cli.py @@ -4,6 +4,7 @@ import typer from pathlib import Path from loguru import logger +from typing import Optional from text_generation import server, utils @@ -13,6 +14,7 @@ app = typer.Typer() @app.command() def serve( model_name: str, + revision: Optional[str] = None, sharded: bool = False, quantize: bool = False, uds_path: Path = "/tmp/text-generation", @@ -44,15 +46,16 @@ def serve( os.getenv("MASTER_PORT", None) is not None ), "MASTER_PORT must be set when sharded is True" - server.serve(model_name, sharded, quantize, uds_path) + server.serve(model_name, revision, sharded, quantize, uds_path) @app.command() def download_weights( model_name: str, + revision: Optional[str] = None, extension: str = ".safetensors", ): - utils.download_weights(model_name, extension) + utils.download_weights(model_name, revision, extension) if __name__ == "__main__": diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index 41d73815..9309c887 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -1,11 +1,15 @@ import torch +from transformers import AutoConfig +from typing import Optional + from text_generation.models.model import Model from text_generation.models.causal_lm import CausalLM from text_generation.models.bloom import BLOOM, BLOOMSharded from text_generation.models.seq2seq_lm import Seq2SeqLM from text_generation.models.galactica import Galactica, GalacticaSharded from text_generation.models.santacoder import SantaCoder +from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded __all__ = [ "Model", @@ -25,23 +29,32 @@ torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True -def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: - if model_name.startswith("bigscience/bloom"): +def get_model( + model_name: str, revision: Optional[str], sharded: bool, quantize: bool +) -> Model: + config = AutoConfig.from_pretrained(model_name) + + if config.model_type == "bloom": if sharded: - return BLOOMSharded(model_name, quantize=quantize) + return BLOOMSharded(model_name, revision, quantize=quantize) else: - return BLOOM(model_name, quantize=quantize) + return BLOOM(model_name, revision, quantize=quantize) + elif config.model_type == "gpt_neox": + if sharded: + return GPTNeoxSharded(model_name, revision, quantize=quantize) + else: + return GPTNeox(model_name, revision, quantize=quantize) elif model_name.startswith("facebook/galactica"): if sharded: - return GalacticaSharded(model_name, quantize=quantize) + return GalacticaSharded(model_name, revision, quantize=quantize) else: - return Galactica(model_name, quantize=quantize) + return Galactica(model_name, revision, quantize=quantize) elif "santacoder" in model_name: - return SantaCoder(model_name, quantize) + return SantaCoder(model_name, revision, quantize) else: if sharded: raise ValueError("sharded is not supported for AutoModel") try: - return CausalLM(model_name, quantize=quantize) + return CausalLM(model_name, revision, quantize=quantize) except Exception: - return Seq2SeqLM(model_name, quantize=quantize) + return Seq2SeqLM(model_name, revision, quantize=quantize) diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 7708bb4a..4f55afc0 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -56,7 +56,9 @@ class BLOOM(CausalLM): class BLOOMSharded(BLOOM): - def __init__(self, model_name: str, quantize: bool = False): + def __init__( + self, model_name: str, revision: Optional[str] = None, quantize: bool = False + ): if not model_name.startswith("bigscience/bloom"): raise ValueError(f"Model {model_name} is not supported") @@ -69,19 +71,23 @@ class BLOOMSharded(BLOOM): device = torch.device("cpu") dtype = torch.float32 - tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + tokenizer = AutoTokenizer.from_pretrained( + model_name, revision=revision, padding_side="left" + ) config = AutoConfig.from_pretrained( - model_name, slow_but_exact=False, tp_parallel=True + model_name, revision=revision, slow_but_exact=False, tp_parallel=True ) config.pad_token_id = 3 # Only download weights for small models if self.master and model_name == "bigscience/bloom-560m": - download_weights(model_name, extension=".safetensors") + download_weights(model_name, revision=revision, extension=".safetensors") torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_name, extension=".safetensors") + filenames = weight_files( + model_name, revision=revision, extension=".safetensors" + ) if not filenames: raise ValueError("No safetensors weights found") diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 31996e06..4dc834b8 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -232,7 +232,7 @@ class CausalLMBatch(Batch): class CausalLM(Model): - def __init__(self, model_name: str, quantize=False): + def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 @@ -243,9 +243,12 @@ class CausalLM(Model): device = torch.device("cpu") dtype = torch.float32 - tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + tokenizer = AutoTokenizer.from_pretrained( + model_name, revision=revision, padding_side="left" + ) self.model = AutoModelForCausalLM.from_pretrained( model_name, + revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, load_in_8bit=quantize, diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 4722e1d8..be9b1699 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -148,7 +148,9 @@ class Galactica(CausalLM): class GalacticaSharded(Galactica): - def __init__(self, model_name: str, quantize: bool = False): + def __init__( + self, model_name: str, revision: Optional[str] = None, quantize: bool = False + ): if not model_name.startswith("facebook/galactica"): raise ValueError(f"Model {model_name} is not supported") @@ -161,24 +163,23 @@ class GalacticaSharded(Galactica): device = torch.device("cpu") dtype = torch.float32 - tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + tokenizer = AutoTokenizer.from_pretrained( + model_name, revision=revision, padding_side="left" + ) - config = AutoConfig.from_pretrained(model_name, tp_parallel=True) + config = AutoConfig.from_pretrained( + model_name, revision=revision, tp_parallel=True + ) tokenizer.pad_token_id = config.pad_token_id - # The flag below controls whether to allow TF32 on matmul. This flag defaults to False - # in PyTorch 1.12 and later. - torch.backends.cuda.matmul.allow_tf32 = True - - # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. - torch.backends.cudnn.allow_tf32 = True - # Only download weights for small models if self.master and model_name == "facebook/galactica-125m": - download_weights(model_name, extension=".safetensors") + download_weights(model_name, revision=revision, extension=".safetensors") torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_name, extension=".safetensors") + filenames = weight_files( + model_name, revision=revision, extension=".safetensors" + ) if not filenames: raise ValueError("No safetensors weights found") diff --git a/server/text_generation/models/gpt_neox.py b/server/text_generation/models/gpt_neox.py new file mode 100644 index 00000000..d901cae3 --- /dev/null +++ b/server/text_generation/models/gpt_neox.py @@ -0,0 +1,244 @@ +import torch +import torch.distributed + +from typing import List, Optional, Tuple + +from accelerate import init_empty_weights +from safetensors import safe_open +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + AutoConfig, +) +from transformers.models.gpt_neox.parallel_layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) + +from text_generation.models import CausalLM +from text_generation.utils import ( + initialize_torch_distributed, + weight_files, + download_weights, +) + +HAS_BITS_AND_BYTES = True +try: + import bitsandbytes as bnb + from bitsandbytes.nn import Int8Params +except Exception as e: + HAS_BITS_AND_BYTES = False + + +class GPTNeox(CausalLM): + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + """Overwrite forward to ignore position_ids""" + + # Model Forward + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + return outputs.logits, outputs.past_key_values + + +class GPTNeoxSharded(GPTNeox): + def __init__( + self, model_name: str, revision: Optional[str] = None, quantize: bool = False + ): + self.process_group, self.rank, self.world_size = initialize_torch_distributed() + self.master = self.rank == 0 + if torch.cuda.is_available(): + device = torch.device(f"cuda:{self.rank}") + dtype = torch.bfloat16 + else: + device = torch.device("cpu") + dtype = torch.float32 + + tokenizer = AutoTokenizer.from_pretrained( + model_name, revision=revision, padding_side="left" + ) + tokenizer.pad_token = tokenizer.eos_token + + config = AutoConfig.from_pretrained( + model_name, revision=revision, tp_parallel=True + ) + + # Only master download weights + if self.master: + download_weights(model_name, revision=revision, extension=".safetensors") + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files( + model_name, revision=revision, extension=".safetensors" + ) + if not filenames: + raise ValueError("No safetensors weights found") + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + quantize=quantize, + device=device, + rank=self.rank, + world_size=self.world_size, + ) + self.model = model.eval().to(dtype) + torch.distributed.barrier(group=self.process_group) + super(CausalLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @staticmethod + def load_weights( + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, + ): + parameters = dict(model.named_parameters()) + for file in filenames: + with safe_open( + file, framework="pt", device=str(device) if not quantize else "cpu" + ) as f: + for name in f.keys(): + module_name, param_name = name.rsplit(".", 1) + module = model.get_submodule(module_name) + + current_parameter_tensor = parameters.get(name, None) + + slice_ = f.get_slice(name) + + if isinstance(module, TensorParallelColumnLinear): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif name == "embed_out.weight": + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + try: + tensor = slice_[:] + except: + tensor = f.get_tensor(name) + + if ( + current_parameter_tensor is not None + and current_parameter_tensor.shape != tensor.shape + ): + raise ValueError( + f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + ) + + tensor = tensor.contiguous() + + if quantize: + if not HAS_BITS_AND_BYTES: + raise ImportError( + "bitsandbytes is not available on your machine either because it is not installed " + "or you don't have a GPU.\n" + "You can install it with `pip install bitsandbytes`." + ) + + if ( + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" + ): + tensor = Int8Params( + tensor, + has_fp16_weights=False, + requires_grad=False, + ).to(device) + state = bnb.MatmulLtState() + state.threshold = 6.0 + state.has_fp16_weights = False + state.memory_efficient_backward = False + state.use_pool = True + state.CB = tensor.CB + state.SCB = tensor.SCB + tensor.CB = None + tensor.SCB = None + + def replace_linear(state): + def linear(input, weight, bias): + out = bnb.matmul( + input, + weight, + state=state, + threshold=state.threshold, + bias=bias, + ) + + if state.CB is not None: + # we converted 8-bit row major to turing/ampere format + # in the first inference pass + # we no longer need the row-major weight + del state.CB + weight.data = state.CxB + + return out + + return linear + + module.linear = replace_linear(state) + + else: + tensor = tensor.to(device) + + if current_parameter_tensor is not None: + module._parameters[param_name] = tensor + else: + module._buffers[param_name] = tensor + + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ): + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + + # Logits are sharded, so we need to gather them + logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] + torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) + logits = torch.cat(logits, dim=2) + + return logits, outputs.past_key_values diff --git a/server/text_generation/models/santacoder.py b/server/text_generation/models/santacoder.py index 4b898ab9..6c1a250f 100644 --- a/server/text_generation/models/santacoder.py +++ b/server/text_generation/models/santacoder.py @@ -14,7 +14,7 @@ EOD = "<|endoftext|>" class SantaCoder(CausalLM): - def __init__(self, model_name: str, quantize=False): + def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 @@ -25,7 +25,9 @@ class SantaCoder(CausalLM): device = torch.device("cpu") dtype = torch.float32 - tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + tokenizer = AutoTokenizer.from_pretrained( + model_name, revision=revision, padding_side="left" + ) tokenizer.add_special_tokens( { "additional_special_tokens": [ @@ -42,6 +44,7 @@ class SantaCoder(CausalLM): self.model = ( AutoModelForCausalLM.from_pretrained( model_name, + revision=revision, torch_dtype=dtype, load_in_8bit=quantize, trust_remote_code=True, # required diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 6d5dc22e..29492dd7 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -289,7 +289,7 @@ class Seq2SeqLMBatch(Batch): class Seq2SeqLM(Model): - def __init__(self, model_name: str, quantize=False): + def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 @@ -302,11 +302,14 @@ class Seq2SeqLM(Model): self.model = AutoModelForSeq2SeqLM.from_pretrained( model_name, + revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, load_in_8bit=quantize, ).eval() - tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + tokenizer = AutoTokenizer.from_pretrained( + model_name, revision=revision, padding_side="left" + ) tokenizer.bos_token_id = self.model.config.decoder_start_token_id super(Seq2SeqLM, self).__init__( diff --git a/server/text_generation/server.py b/server/text_generation/server.py index a2bad8a7..852deebf 100644 --- a/server/text_generation/server.py +++ b/server/text_generation/server.py @@ -6,7 +6,7 @@ from loguru import logger from grpc_reflection.v1alpha import reflection from pathlib import Path -from typing import List +from typing import List, Optional from text_generation.cache import Cache from text_generation.interceptor import ExceptionInterceptor @@ -67,12 +67,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( model_name: str, + revision: Optional[str], sharded: bool, quantize: bool, uds_path: Path, ): async def serve_inner( model_name: str, + revision: Optional[str], sharded: bool = False, quantize: bool = False, ): @@ -87,7 +89,7 @@ def serve( local_url = unix_socket_template.format(uds_path, 0) server_urls = [local_url] - model = get_model(model_name, sharded, quantize) + model = get_model(model_name, revision, sharded, quantize) server = aio.server(interceptors=[ExceptionInterceptor()]) generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( @@ -107,4 +109,4 @@ def serve( logger.info("Signal received. Shutting down") await server.stop(0) - asyncio.run(serve_inner(model_name, sharded, quantize)) + asyncio.run(serve_inner(model_name, revision, sharded, quantize)) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index a2029911..62d60635 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -8,7 +8,9 @@ from datetime import timedelta from concurrent.futures import ThreadPoolExecutor from functools import partial -from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache +from pathlib import Path +from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.utils import LocalEntryNotFoundError from tqdm import tqdm from typing import List, Optional, Tuple @@ -170,20 +172,62 @@ def initialize_torch_distributed(): return torch.distributed.distributed_c10d._get_default_group(), rank, world_size -def weight_hub_files(model_name, extension=".safetensors"): +def weight_hub_files(model_name, revision=None, extension=".safetensors"): """Get the safetensors filenames on the hub""" api = HfApi() - info = api.model_info(model_name) + info = api.model_info(model_name, revision=revision) filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)] return filenames -def weight_files(model_name, extension=".safetensors"): +def try_to_load_from_cache(model_name, revision, filename): + """Try to load a file from the Hugging Face cache""" + if revision is None: + revision = "main" + + object_id = model_name.replace("/", "--") + repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}" + + if not repo_cache.is_dir(): + # No cache for this model + return None + + refs_dir = repo_cache / "refs" + snapshots_dir = repo_cache / "snapshots" + no_exist_dir = repo_cache / ".no_exist" + + # Resolve refs (for instance to convert main to the associated commit sha) + if refs_dir.is_dir(): + revision_file = refs_dir / revision + if revision_file.exists(): + with revision_file.open() as f: + revision = f.read() + + # Check if file is cached as "no_exist" + if (no_exist_dir / revision / filename).is_file(): + return _CACHED_NO_EXIST + + # Check if revision folder exists + if not snapshots_dir.exists(): + return None + cached_shas = os.listdir(snapshots_dir) + if revision not in cached_shas: + # No cache for this revision and we won't try to return a random revision + return None + + # Check if file exists in cache + cached_file = snapshots_dir / revision / filename + return str(cached_file) if cached_file.is_file() else None + + +def weight_files(model_name, revision=None, extension=".safetensors"): """Get the local safetensors filenames""" - filenames = weight_hub_files(model_name, extension) + filenames = weight_hub_files(model_name, revision, extension) files = [] for filename in filenames: - cache_file = try_to_load_from_cache(model_name, filename=filename) + cache_file = try_to_load_from_cache( + model_name, revision=revision, filename=filename + ) if cache_file is None: raise LocalEntryNotFoundError( f"File {filename} of model {model_name} not found in " @@ -195,9 +239,9 @@ def weight_files(model_name, extension=".safetensors"): return files -def download_weights(model_name, extension=".safetensors"): +def download_weights(model_name, revision=None, extension=".safetensors"): """Download the safetensors files from the hub""" - filenames = weight_hub_files(model_name, extension) + filenames = weight_hub_files(model_name, revision, extension) download_function = partial( hf_hub_download, @@ -207,7 +251,8 @@ def download_weights(model_name, extension=".safetensors"): executor = ThreadPoolExecutor(max_workers=5) futures = [ - executor.submit(download_function, filename=filename) for filename in filenames + executor.submit(download_function, filename=filename, revision=revision) + for filename in filenames ] files = [ future.result()