feat(server): Support GPT-Neox (#39)

This commit is contained in:
OlivierDehaene 2023-01-31 18:53:56 +01:00 committed by GitHub
parent c6e8b9442b
commit f830706b21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 386 additions and 47 deletions

View File

@ -26,6 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated) - ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
- [SantaCoder](https://huggingface.co/bigcode/santacoder) - [SantaCoder](https://huggingface.co/bigcode/santacoder)
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision refs/pr/13`
Other models are supported on a best effort basis using: Other models are supported on a best effort basis using:

View File

@ -21,6 +21,8 @@ struct Args {
#[clap(default_value = "bigscience/bloom-560m", long, env)] #[clap(default_value = "bigscience/bloom-560m", long, env)]
model_name: String, model_name: String,
#[clap(long, env)] #[clap(long, env)]
revision: Option<String>,
#[clap(long, env)]
num_shard: Option<usize>, num_shard: Option<usize>,
#[clap(long, env)] #[clap(long, env)]
quantize: bool, quantize: bool,
@ -48,6 +50,7 @@ fn main() -> ExitCode {
// Pattern match configuration // Pattern match configuration
let Args { let Args {
model_name, model_name,
revision,
num_shard, num_shard,
quantize, quantize,
max_concurrent_requests, max_concurrent_requests,
@ -90,6 +93,7 @@ fn main() -> ExitCode {
// Start shard processes // Start shard processes
for rank in 0..num_shard { for rank in 0..num_shard {
let model_name = model_name.clone(); let model_name = model_name.clone();
let revision = revision.clone();
let uds_path = shard_uds_path.clone(); let uds_path = shard_uds_path.clone();
let master_addr = master_addr.clone(); let master_addr = master_addr.clone();
let status_sender = status_sender.clone(); let status_sender = status_sender.clone();
@ -98,6 +102,7 @@ fn main() -> ExitCode {
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_name, model_name,
revision,
quantize, quantize,
uds_path, uds_path,
rank, rank,
@ -252,6 +257,7 @@ enum ShardStatus {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn shard_manager( fn shard_manager(
model_name: String, model_name: String,
revision: Option<String>,
quantize: bool, quantize: bool,
uds_path: String, uds_path: String,
rank: usize, rank: usize,
@ -288,6 +294,11 @@ fn shard_manager(
shard_argv.push("--quantize".to_string()) shard_argv.push("--quantize".to_string())
} }
if let Some(revision) = revision {
shard_argv.push("--revision".to_string());
shard_argv.push(revision)
}
let mut env = vec![ let mut env = vec![
("RANK".into(), rank.to_string().into()), ("RANK".into(), rank.to_string().into()),
("WORLD_SIZE".into(), world_size.to_string().into()), ("WORLD_SIZE".into(), world_size.to_string().into()),

View File

@ -1,5 +1,7 @@
import pytest import pytest
from huggingface_hub.utils import RevisionNotFoundError
from text_generation.utils import ( from text_generation.utils import (
weight_hub_files, weight_hub_files,
download_weights, download_weights,
@ -51,7 +53,7 @@ def test_weight_hub_files_llm():
def test_weight_hub_files_empty(): def test_weight_hub_files_empty():
filenames = weight_hub_files("bigscience/bloom", ".errors") filenames = weight_hub_files("bigscience/bloom", extension=".errors")
assert filenames == [] assert filenames == []
@ -62,5 +64,7 @@ def test_download_weights():
def test_weight_files_error(): def test_weight_files_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
with pytest.raises(LocalEntryNotFoundError): with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased") weight_files("bert-base-uncased")

View File

@ -4,6 +4,7 @@ import typer
from pathlib import Path from pathlib import Path
from loguru import logger from loguru import logger
from typing import Optional
from text_generation import server, utils from text_generation import server, utils
@ -13,6 +14,7 @@ app = typer.Typer()
@app.command() @app.command()
def serve( def serve(
model_name: str, model_name: str,
revision: Optional[str] = None,
sharded: bool = False, sharded: bool = False,
quantize: bool = False, quantize: bool = False,
uds_path: Path = "/tmp/text-generation", uds_path: Path = "/tmp/text-generation",
@ -44,15 +46,16 @@ def serve(
os.getenv("MASTER_PORT", None) is not None os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True" ), "MASTER_PORT must be set when sharded is True"
server.serve(model_name, sharded, quantize, uds_path) server.serve(model_name, revision, sharded, quantize, uds_path)
@app.command() @app.command()
def download_weights( def download_weights(
model_name: str, model_name: str,
revision: Optional[str] = None,
extension: str = ".safetensors", extension: str = ".safetensors",
): ):
utils.download_weights(model_name, extension) utils.download_weights(model_name, revision, extension)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,11 +1,15 @@
import torch import torch
from transformers import AutoConfig
from typing import Optional
from text_generation.models.model import Model from text_generation.models.model import Model
from text_generation.models.causal_lm import CausalLM from text_generation.models.causal_lm import CausalLM
from text_generation.models.bloom import BLOOM, BLOOMSharded from text_generation.models.bloom import BLOOM, BLOOMSharded
from text_generation.models.seq2seq_lm import Seq2SeqLM from text_generation.models.seq2seq_lm import Seq2SeqLM
from text_generation.models.galactica import Galactica, GalacticaSharded from text_generation.models.galactica import Galactica, GalacticaSharded
from text_generation.models.santacoder import SantaCoder from text_generation.models.santacoder import SantaCoder
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded
__all__ = [ __all__ = [
"Model", "Model",
@ -25,23 +29,32 @@ torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: def get_model(
if model_name.startswith("bigscience/bloom"): model_name: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model:
config = AutoConfig.from_pretrained(model_name)
if config.model_type == "bloom":
if sharded: if sharded:
return BLOOMSharded(model_name, quantize=quantize) return BLOOMSharded(model_name, revision, quantize=quantize)
else: else:
return BLOOM(model_name, quantize=quantize) return BLOOM(model_name, revision, quantize=quantize)
elif config.model_type == "gpt_neox":
if sharded:
return GPTNeoxSharded(model_name, revision, quantize=quantize)
else:
return GPTNeox(model_name, revision, quantize=quantize)
elif model_name.startswith("facebook/galactica"): elif model_name.startswith("facebook/galactica"):
if sharded: if sharded:
return GalacticaSharded(model_name, quantize=quantize) return GalacticaSharded(model_name, revision, quantize=quantize)
else: else:
return Galactica(model_name, quantize=quantize) return Galactica(model_name, revision, quantize=quantize)
elif "santacoder" in model_name: elif "santacoder" in model_name:
return SantaCoder(model_name, quantize) return SantaCoder(model_name, revision, quantize)
else: else:
if sharded: if sharded:
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")
try: try:
return CausalLM(model_name, quantize=quantize) return CausalLM(model_name, revision, quantize=quantize)
except Exception: except Exception:
return Seq2SeqLM(model_name, quantize=quantize) return Seq2SeqLM(model_name, revision, quantize=quantize)

View File

@ -56,7 +56,9 @@ class BLOOM(CausalLM):
class BLOOMSharded(BLOOM): class BLOOMSharded(BLOOM):
def __init__(self, model_name: str, quantize: bool = False): def __init__(
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
):
if not model_name.startswith("bigscience/bloom"): if not model_name.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_name} is not supported") raise ValueError(f"Model {model_name} is not supported")
@ -69,19 +71,23 @@ class BLOOMSharded(BLOOM):
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left"
)
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_name, slow_but_exact=False, tp_parallel=True model_name, revision=revision, slow_but_exact=False, tp_parallel=True
) )
config.pad_token_id = 3 config.pad_token_id = 3
# Only download weights for small models # Only download weights for small models
if self.master and model_name == "bigscience/bloom-560m": if self.master and model_name == "bigscience/bloom-560m":
download_weights(model_name, extension=".safetensors") download_weights(model_name, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_name, extension=".safetensors") filenames = weight_files(
model_name, revision=revision, extension=".safetensors"
)
if not filenames: if not filenames:
raise ValueError("No safetensors weights found") raise ValueError("No safetensors weights found")

View File

@ -232,7 +232,7 @@ class CausalLMBatch(Batch):
class CausalLM(Model): class CausalLM(Model):
def __init__(self, model_name: str, quantize=False): def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -243,9 +243,12 @@ class CausalLM(Model):
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left"
)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name, model_name,
revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize, load_in_8bit=quantize,

View File

@ -148,7 +148,9 @@ class Galactica(CausalLM):
class GalacticaSharded(Galactica): class GalacticaSharded(Galactica):
def __init__(self, model_name: str, quantize: bool = False): def __init__(
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
):
if not model_name.startswith("facebook/galactica"): if not model_name.startswith("facebook/galactica"):
raise ValueError(f"Model {model_name} is not supported") raise ValueError(f"Model {model_name} is not supported")
@ -161,24 +163,23 @@ class GalacticaSharded(Galactica):
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left"
)
config = AutoConfig.from_pretrained(model_name, tp_parallel=True) config = AutoConfig.from_pretrained(
model_name, revision=revision, tp_parallel=True
)
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
# Only download weights for small models # Only download weights for small models
if self.master and model_name == "facebook/galactica-125m": if self.master and model_name == "facebook/galactica-125m":
download_weights(model_name, extension=".safetensors") download_weights(model_name, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_name, extension=".safetensors") filenames = weight_files(
model_name, revision=revision, extension=".safetensors"
)
if not filenames: if not filenames:
raise ValueError("No safetensors weights found") raise ValueError("No safetensors weights found")

View File

@ -0,0 +1,244 @@
import torch
import torch.distributed
from typing import List, Optional, Tuple
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
)
from transformers.models.gpt_neox.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation.models import CausalLM
from text_generation.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
class GPTNeox(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class GPTNeoxSharded(GPTNeox):
def __init__(
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16
else:
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left"
)
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(
model_name, revision=revision, tp_parallel=True
)
# Only master download weights
if self.master:
download_weights(model_name, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(
model_name, revision=revision, extension=".safetensors"
)
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
rank=self.rank,
world_size=self.world_size,
)
self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
tokenizer=tokenizer,
device=device,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight":
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
if quantize:
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values

View File

@ -14,7 +14,7 @@ EOD = "<|endoftext|>"
class SantaCoder(CausalLM): class SantaCoder(CausalLM):
def __init__(self, model_name: str, quantize=False): def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -25,7 +25,9 @@ class SantaCoder(CausalLM):
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left"
)
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{ {
"additional_special_tokens": [ "additional_special_tokens": [
@ -42,6 +44,7 @@ class SantaCoder(CausalLM):
self.model = ( self.model = (
AutoModelForCausalLM.from_pretrained( AutoModelForCausalLM.from_pretrained(
model_name, model_name,
revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
load_in_8bit=quantize, load_in_8bit=quantize,
trust_remote_code=True, # required trust_remote_code=True, # required

View File

@ -289,7 +289,7 @@ class Seq2SeqLMBatch(Batch):
class Seq2SeqLM(Model): class Seq2SeqLM(Model):
def __init__(self, model_name: str, quantize=False): def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -302,11 +302,14 @@ class Seq2SeqLM(Model):
self.model = AutoModelForSeq2SeqLM.from_pretrained( self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, model_name,
revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize, load_in_8bit=quantize,
).eval() ).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left"
)
tokenizer.bos_token_id = self.model.config.decoder_start_token_id tokenizer.bos_token_id = self.model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(

View File

@ -6,7 +6,7 @@ from loguru import logger
from grpc_reflection.v1alpha import reflection from grpc_reflection.v1alpha import reflection
from pathlib import Path from pathlib import Path
from typing import List from typing import List, Optional
from text_generation.cache import Cache from text_generation.cache import Cache
from text_generation.interceptor import ExceptionInterceptor from text_generation.interceptor import ExceptionInterceptor
@ -67,12 +67,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve( def serve(
model_name: str, model_name: str,
revision: Optional[str],
sharded: bool, sharded: bool,
quantize: bool, quantize: bool,
uds_path: Path, uds_path: Path,
): ):
async def serve_inner( async def serve_inner(
model_name: str, model_name: str,
revision: Optional[str],
sharded: bool = False, sharded: bool = False,
quantize: bool = False, quantize: bool = False,
): ):
@ -87,7 +89,7 @@ def serve(
local_url = unix_socket_template.format(uds_path, 0) local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url] server_urls = [local_url]
model = get_model(model_name, sharded, quantize) model = get_model(model_name, revision, sharded, quantize)
server = aio.server(interceptors=[ExceptionInterceptor()]) server = aio.server(interceptors=[ExceptionInterceptor()])
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
@ -107,4 +109,4 @@ def serve(
logger.info("Signal received. Shutting down") logger.info("Signal received. Shutting down")
await server.stop(0) await server.stop(0)
asyncio.run(serve_inner(model_name, sharded, quantize)) asyncio.run(serve_inner(model_name, revision, sharded, quantize))

View File

@ -8,7 +8,9 @@ from datetime import timedelta
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import LocalEntryNotFoundError from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm from tqdm import tqdm
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@ -170,20 +172,62 @@ def initialize_torch_distributed():
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
def weight_hub_files(model_name, extension=".safetensors"): def weight_hub_files(model_name, revision=None, extension=".safetensors"):
"""Get the safetensors filenames on the hub""" """Get the safetensors filenames on the hub"""
api = HfApi() api = HfApi()
info = api.model_info(model_name) info = api.model_info(model_name, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)] filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
return filenames return filenames
def weight_files(model_name, extension=".safetensors"): def try_to_load_from_cache(model_name, revision, filename):
"""Try to load a file from the Hugging Face cache"""
if revision is None:
revision = "main"
object_id = model_name.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir():
# No cache for this model
return None
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
no_exist_dir = repo_cache / ".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return _CACHED_NO_EXIST
# Check if revision folder exists
if not snapshots_dir.exists():
return None
cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
return str(cached_file) if cached_file.is_file() else None
def weight_files(model_name, revision=None, extension=".safetensors"):
"""Get the local safetensors filenames""" """Get the local safetensors filenames"""
filenames = weight_hub_files(model_name, extension) filenames = weight_hub_files(model_name, revision, extension)
files = [] files = []
for filename in filenames: for filename in filenames:
cache_file = try_to_load_from_cache(model_name, filename=filename) cache_file = try_to_load_from_cache(
model_name, revision=revision, filename=filename
)
if cache_file is None: if cache_file is None:
raise LocalEntryNotFoundError( raise LocalEntryNotFoundError(
f"File {filename} of model {model_name} not found in " f"File {filename} of model {model_name} not found in "
@ -195,9 +239,9 @@ def weight_files(model_name, extension=".safetensors"):
return files return files
def download_weights(model_name, extension=".safetensors"): def download_weights(model_name, revision=None, extension=".safetensors"):
"""Download the safetensors files from the hub""" """Download the safetensors files from the hub"""
filenames = weight_hub_files(model_name, extension) filenames = weight_hub_files(model_name, revision, extension)
download_function = partial( download_function = partial(
hf_hub_download, hf_hub_download,
@ -207,7 +251,8 @@ def download_weights(model_name, extension=".safetensors"):
executor = ThreadPoolExecutor(max_workers=5) executor = ThreadPoolExecutor(max_workers=5)
futures = [ futures = [
executor.submit(download_function, filename=filename) for filename in filenames executor.submit(download_function, filename=filename, revision=revision)
for filename in filenames
] ]
files = [ files = [
future.result() future.result()