feat: add safetensors conversion (#63)

This commit is contained in:
OlivierDehaene 2023-02-14 13:02:16 +01:00 committed by GitHub
parent 9af454142a
commit 0fbc691946
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 749 additions and 380 deletions

View File

@ -49,17 +49,17 @@ to power LLMs api-inference widgets.
- Log probabilities
- Distributed tracing with Open Telemetry
## Officially supported models
## Officially supported architectures
- [BLOOM](https://huggingface.co/bigscience/bloom)
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
- [Galactica](https://huggingface.co/facebook/galactica-120b)
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl)
Other models are supported on a best effort basis using:
Other architectures are supported on a best effort basis using:
`AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")`
@ -191,7 +191,7 @@ Be aware that the official Docker image has them enabled by default.
### Download
First you need to download the weights:
It is advised to download the weights ahead of time with the following command:
```shell
make download-bloom

View File

@ -12,7 +12,7 @@ use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
use subprocess::{Popen, PopenConfig, PopenError, Redirection};
use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
/// App Configuration
#[derive(Parser, Debug)]
@ -43,6 +43,10 @@ struct Args {
#[clap(default_value = "29500", long, env)]
master_port: usize,
#[clap(long, env)]
huggingface_hub_cache: Option<String>,
#[clap(long, env)]
weights_cache_override: Option<String>,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
@ -63,6 +67,8 @@ fn main() -> ExitCode {
shard_uds_path,
master_addr,
master_port,
huggingface_hub_cache,
weights_cache_override,
json_output,
otlp_endpoint,
} = Args::parse();
@ -84,6 +90,124 @@ fn main() -> ExitCode {
})
.expect("Error setting Ctrl-C handler");
// Download weights
if weights_cache_override.is_none() {
let mut download_argv = vec![
"text-generation-server".to_string(),
"download-weights".to_string(),
model_id.clone(),
"--logger-level".to_string(),
"INFO".to_string(),
"--json-output".to_string(),
];
if num_shard == 1 {
download_argv.push("--extension".to_string());
download_argv.push(".bin".to_string());
} else {
download_argv.push("--extension".to_string());
download_argv.push(".safetensors".to_string());
}
// Model optional revision
if let Some(ref revision) = revision {
download_argv.push("--revision".to_string());
download_argv.push(revision.to_string())
}
let mut env = Vec::new();
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
// Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// Start process
tracing::info!("Starting download");
let mut download_process = match Popen::create(
&download_argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
env: Some(env),
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
}
}
return ExitCode::FAILURE;
}
};
// Redirect STDOUT to the console
let download_stdout = download_process.stdout.take().unwrap();
thread::spawn(move || {
// Enter download tracing span
let stdout = BufReader::new(download_stdout);
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
for line in stdout.lines() {
// Parse loguru logs
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) {
if let Some(text) = value.get("text") {
// Format escaped newlines
tracing::info!("{}", text.to_string().replace("\\n", ""));
}
}
}
});
loop {
if let Some(status) = download_process.poll() {
match status {
ExitStatus::Exited(exit_code) => {
if exit_code == 0 {
tracing::info!("Successfully downloaded weights.");
break;
} else {
let mut err = String::new();
download_process
.stderr
.take()
.unwrap()
.read_to_string(&mut err)
.unwrap();
tracing::error!("Download encountered an error: {err}");
return ExitCode::FAILURE;
}
}
_ => {
tracing::error!("Download process exited with an unkown status.");
return ExitCode::FAILURE;
}
}
}
if !running.load(Ordering::SeqCst) {
download_process.terminate().unwrap();
tracing::info!("Waiting for download process to gracefully shutdown");
download_process
.wait_timeout(Duration::from_secs(90))
.unwrap();
tracing::info!("Download process terminated");
return ExitCode::SUCCESS;
}
sleep(Duration::from_millis(100));
}
} else {
tracing::info!(
"weights_cache_override is set to {:?}.",
weights_cache_override
);
tracing::info!("Skipping download.")
}
// Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false));
// Shared shutdown channel
@ -99,6 +223,8 @@ fn main() -> ExitCode {
let revision = revision.clone();
let uds_path = shard_uds_path.clone();
let master_addr = master_addr.clone();
let huggingface_hub_cache = huggingface_hub_cache.clone();
let weights_cache_override = weights_cache_override.clone();
let status_sender = status_sender.clone();
let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone();
@ -113,6 +239,8 @@ fn main() -> ExitCode {
num_shard,
master_addr,
master_port,
huggingface_hub_cache,
weights_cache_override,
otlp_endpoint,
status_sender,
shutdown,
@ -232,7 +360,7 @@ fn main() -> ExitCode {
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {} failed:\n{}", rank, err);
tracing::error!("Shard {rank} failed:\n{err}");
exit_code = ExitCode::FAILURE;
break;
};
@ -275,6 +403,8 @@ fn shard_manager(
world_size: usize,
master_addr: String,
master_port: usize,
huggingface_hub_cache: Option<String>,
weights_cache_override: Option<String>,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>,
@ -328,15 +458,15 @@ fn shard_manager(
("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()),
];
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard
// If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") {
if let Some(weights_cache_override) = weights_cache_override {
env.push((
"WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(),
@ -355,7 +485,7 @@ fn shard_manager(
};
// Start process
tracing::info!("Starting shard {}", rank);
tracing::info!("Starting shard {rank}");
let mut p = match Popen::create(
&shard_argv,
PopenConfig {
@ -419,17 +549,17 @@ fn shard_manager(
if *shutdown.lock().unwrap() {
p.terminate().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {} terminated", rank);
tracing::info!("Shard {rank} terminated");
return;
}
// Shard is ready
if uds.exists() && !ready {
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed());
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap();
ready = true;
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard {} to be ready...", rank);
tracing::info!("Waiting for shard {rank} to be ready...");
wait_time = Instant::now();
}
sleep(Duration::from_millis(100));

View File

@ -0,0 +1,17 @@
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
from text_generation.utils.convert import convert_files
def test_convert_files():
model_id = "bigscience/bloom-560m"
pt_filenames = weight_hub_files(model_id, extension=".bin")
local_pt_files = download_weights(pt_filenames, model_id)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
]
convert_files(local_pt_files, local_st_files)
found_st_files = weight_files(model_id)
assert all([p in found_st_files for p in local_st_files])

View File

@ -0,0 +1,40 @@
import pytest
from text_generation.utils.hub import (
weight_hub_files,
download_weights,
weight_files,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
with pytest.raises(EntryNotFoundError):
weight_hub_files("bigscience/bloom", extension=".errors")
def test_download_weights():
model_id = "bigscience/bloom-560m"
filenames = weight_hub_files(model_id)
files = download_weights(filenames, model_id)
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -1,14 +1,6 @@
import pytest
from huggingface_hub.utils import RevisionNotFoundError
from text_generation.utils import (
weight_hub_files,
download_weights,
weight_files,
from text_generation.utils.tokens import (
StopSequenceCriteria,
StoppingCriteria,
LocalEntryNotFoundError,
FinishReason,
)
@ -41,31 +33,3 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
filenames = weight_hub_files("bigscience/bloom", extension=".errors")
assert filenames == []
def test_download_weights():
files = download_weights("bigscience/bloom-560m")
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -60,8 +60,55 @@ def download_weights(
model_id: str,
revision: Optional[str] = None,
extension: str = ".safetensors",
logger_level: str = "INFO",
json_output: bool = False,
):
utils.download_weights(model_id, revision, extension)
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Test if files were already download
try:
utils.weight_files(model_id, revision, extension)
logger.info(
"Files are already present in the local cache. " "Skipping download."
)
return
# Local files not found
except utils.LocalEntryNotFoundError:
pass
# Download weights directly
try:
filenames = utils.weight_hub_files(model_id, revision, extension)
utils.download_weights(filenames, model_id, revision)
except utils.EntryNotFoundError as e:
if not extension == ".safetensors":
raise e
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights instead."
)
# Try to see if there are pytorch weights
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
# Download pytorch weights
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files
]
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files)
if __name__ == "__main__":

View File

@ -41,6 +41,15 @@ torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model:
if model_id.startswith("facebook/galactica"):
if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize)
else:
return Galactica(model_id, revision, quantize=quantize)
if "santacoder" in model_id:
return SantaCoder(model_id, revision, quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision)
if config.model_type == "bloom":
@ -48,27 +57,22 @@ def get_model(
return BLOOMSharded(model_id, revision, quantize=quantize)
else:
return BLOOM(model_id, revision, quantize=quantize)
elif config.model_type == "gpt_neox":
if config.model_type == "gpt_neox":
if sharded:
return GPTNeoxSharded(model_id, revision, quantize=quantize)
else:
return GPTNeox(model_id, revision, quantize=quantize)
elif config.model_type == "t5":
if config.model_type == "t5":
if sharded:
return T5Sharded(model_id, revision, quantize=quantize)
else:
return Seq2SeqLM(model_id, revision, quantize=quantize)
elif model_id.startswith("facebook/galactica"):
if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize)
else:
return Galactica(model_id, revision, quantize=quantize)
elif "santacoder" in model_id:
return SantaCoder(model_id, revision, quantize)
else:
if sharded:
raise ValueError("sharded is not supported for AutoModel")
try:
return CausalLM(model_id, revision, quantize=quantize)
except Exception:
return Seq2SeqLM(model_id, revision, quantize=quantize)
if sharded:
raise ValueError("sharded is not supported for AutoModel")
try:
return CausalLM(model_id, revision, quantize=quantize)
except Exception:
return Seq2SeqLM(model_id, revision, quantize=quantize)

View File

@ -23,7 +23,6 @@ from text_generation.pb import generate_pb2
from text_generation.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
@ -80,14 +79,8 @@ class BLOOMSharded(BLOOM):
)
config.pad_token_id = 3
# Only download weights for small models
if self.master and model_id == "bigscience/bloom-560m":
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)

View File

@ -26,7 +26,6 @@ from text_generation.utils import (
StoppingCriteria,
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
@ -172,14 +171,8 @@ class GalacticaSharded(Galactica):
)
tokenizer.pad_token_id = config.pad_token_id
# Only download weights for small models
if self.master and model_id == "facebook/galactica-125m":
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)

View File

@ -20,7 +20,6 @@ from text_generation.models import CausalLM
from text_generation.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
@ -69,14 +68,8 @@ class GPTNeoxSharded(GPTNeox):
model_id, revision=revision, tp_parallel=True
)
# Only master download weights
if self.master:
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)

View File

@ -1,7 +1,7 @@
import torch
import torch.distributed
from typing import Optional, List, Tuple
from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation.models import CausalLM

View File

@ -20,7 +20,6 @@ from text_generation.models import Seq2SeqLM
from text_generation.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
@ -53,14 +52,8 @@ class T5Sharded(Seq2SeqLM):
)
tokenizer.bos_token_id = config.decoder_start_token_id
# Only master download weights
if self.master:
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(config)

View File

@ -1,283 +0,0 @@
import concurrent
import os
import re
import torch
import torch.distributed
from datetime import timedelta
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizerBase
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopPLogitsWarper,
TopKLogitsWarper,
)
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax()
class NextTokenChooser:
def __init__(
self,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
do_sample=False,
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling = do_sample
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
warpers.append(TemperatureLogitsWarper(temperature))
sampling = True
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k))
sampling = True
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs
@classmethod
def from_pb(
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
) -> "NextTokenChooser":
return NextTokenChooser(
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
)
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
self.regex = re.compile(f".*{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
)
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
options = None
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
return torch.distributed.group.WORLD, rank, world_size
def weight_hub_files(model_id, revision=None, extension=".safetensors"):
"""Get the safetensors filenames on the hub"""
api = HfApi()
info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
return filenames
def try_to_load_from_cache(model_id, revision, filename):
"""Try to load a file from the Hugging Face cache"""
if revision is None:
revision = "main"
object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir():
# No cache for this model
return None
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
no_exist_dir = repo_cache / ".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return _CACHED_NO_EXIST
# Check if revision folder exists
if not snapshots_dir.exists():
return None
cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
return str(cached_file) if cached_file.is_file() else None
def weight_files(model_id, revision=None, extension=".safetensors"):
"""Get the local safetensors filenames"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_id, revision, extension)
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(
model_id, revision=revision, filename=filename
)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_id}` first."
)
files.append(cache_file)
return files
def download_weights(model_id, revision=None, extension=".safetensors"):
"""Download the safetensors files from the hub"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_id, revision, extension)
download_function = partial(
hf_hub_download,
repo_id=model_id,
local_files_only=False,
)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(download_function, filename=filename, revision=revision)
for filename in filenames
]
files = [
future.result()
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
]
return files

View File

@ -0,0 +1,36 @@
from text_generation.utils.convert import convert_file, convert_files
from text_generation.utils.dist import initialize_torch_distributed
from text_generation.utils.hub import (
weight_files,
weight_hub_files,
download_weights,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
from text_generation.utils.tokens import (
Greedy,
NextTokenChooser,
Sampling,
StoppingCriteria,
StopSequenceCriteria,
FinishReason,
)
__all__ = [
"convert_file",
"convert_files",
"initialize_torch_distributed",
"weight_files",
"weight_hub_files",
"download_weights",
"EntryNotFoundError",
"LocalEntryNotFoundError",
"RevisionNotFoundError",
"Greedy",
"NextTokenChooser",
"Sampling",
"StoppingCriteria",
"StopSequenceCriteria",
"FinishReason",
]

View File

@ -0,0 +1,96 @@
import concurrent
import time
import torch
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from datetime import timedelta
from loguru import logger
from pathlib import Path
from safetensors.torch import load_file, save_file
from typing import Dict, List
def check_file_size(source_file: Path, target_file: Path):
"""
Check that two files are close in size
"""
source_file_size = source_file.stat().st_size
target_file_size = target_file.stat().st_size
if (source_file_size - target_file_size) / source_file_size > 0.01:
raise RuntimeError(
f"""The file size different is more than 1%:
- {source_file}: {source_file_size}
- {target_file}: {target_file_size}
"""
)
def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
"""
For a Dict of tensors, check if two or more tensors point to the same underlying memory and
remove them
"""
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
# Iterate over all found memory addresses
for ptr, names in ptrs.items():
if len(names) > 1:
# Multiple tensors are point to the same memory
# Only keep the first tensor
for name in names[1:]:
tensors.pop(name)
def convert_file(pt_file: Path, st_file: Path):
"""
Convert a pytorch file to a safetensors file
"""
pt_state = torch.load(pt_file, map_location="cpu")
if "state_dict" in pt_state:
pt_state = pt_state["state_dict"]
remove_shared_pointers(pt_state)
# Tensors need to be contiguous
pt_state = {k: v.contiguous() for k, v in pt_state.items()}
st_file.parent.mkdir(parents=True, exist_ok=True)
save_file(pt_state, str(st_file), metadata={"format": "pt"})
# Check that both files are close in size
check_file_size(pt_file, st_file)
# Load safetensors state
st_state = load_file(str(st_file))
for k in st_state:
pt_tensor = pt_state[k]
st_tensor = st_state[k]
if not torch.equal(pt_tensor, st_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
def convert_files(pt_files: List[Path], st_files: List[Path]):
assert len(pt_files) == len(st_files)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(convert_file, pt_file=pt_file, st_file=st_file)
for pt_file, st_file in zip(pt_files, st_files)
]
# We do this instead of using tqdm because we want to parse the logs with the launcher
logger.info("Converting weights...")
start_time = time.time()
for i, future in enumerate(concurrent.futures.as_completed(futures)):
elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(futures) - (i + 1)
if remaining != 0:
eta = (elapsed / (i + 1)) * remaining
else:
eta = 0
logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}")

View File

@ -0,0 +1,35 @@
import os
import torch
from datetime import timedelta
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
options = None
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
return torch.distributed.group.WORLD, rank, world_size

View File

@ -0,0 +1,169 @@
import time
import concurrent
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from loguru import logger
from pathlib import Path
from typing import Optional, List
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import (
LocalEntryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
)
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
def weight_hub_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]:
"""Get the weights filenames on the hub"""
api = HfApi()
info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
if not filenames:
raise EntryNotFoundError(
f"No {extension} weights found for model {model_id} and revision {revision}.",
None,
)
return filenames
def try_to_load_from_cache(
model_id: str, revision: Optional[str], filename: str
) -> Optional[Path]:
"""Try to load a file from the Hugging Face cache"""
if revision is None:
revision = "main"
object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir():
# No cache for this model
return None
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
no_exist_dir = repo_cache / ".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return None
# Check if revision folder exists
if not snapshots_dir.exists():
return None
cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
return cached_file if cached_file.is_file() else None
def weight_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[Path]:
"""Get the local files"""
try:
filenames = weight_hub_files(model_id, revision, extension)
except EntryNotFoundError as e:
if extension != ".safetensors":
raise e
# Try to see if there are pytorch weights
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
# Change pytorch extension to safetensors extension
# It is possible that we have safetensors weights locally even though they are not on the
# hub if we converted weights locally without pushing them
filenames = [
f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames
]
if WEIGHTS_CACHE_OVERRIDE is not None:
files = []
for filename in filenames:
p = Path(WEIGHTS_CACHE_OVERRIDE) / filename
if not p.exists():
raise LocalEntryNotFoundError(
f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}."
)
files.append(p)
return files
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(
model_id, revision=revision, filename=filename
)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_id}` first."
)
files.append(cache_file)
return files
def download_weights(
filenames: List[str], model_id: str, revision: Optional[str] = None
) -> List[Path]:
"""Download the safetensors files from the hub"""
def download_file(filename):
local_file = try_to_load_from_cache(model_id, revision, filename)
if local_file is not None:
logger.info(f"File {filename} already present in cache.")
return local_file
start_time = time.time()
local_file = hf_hub_download(
filename=filename,
repo_id=model_id,
revision=revision,
local_files_only=False,
)
logger.info(
f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
)
return local_file
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(download_file, filename=filename) for filename in filenames
]
# We do this instead of using tqdm because we want to parse the logs with the launcher
logger.info("Downloading weights...")
start_time = time.time()
files = []
for i, future in enumerate(concurrent.futures.as_completed(futures)):
elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(futures) - (i + 1)
if remaining != 0:
eta = (elapsed / (i + 1)) * remaining
else:
eta = 0
logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}")
files.append(Path(future.result()))
return [Path(p) for p in files]

View File

@ -0,0 +1,142 @@
import re
import torch
from transformers import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
from typing import List, Tuple, Optional
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax()
class NextTokenChooser:
def __init__(
self,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
do_sample=False,
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling = do_sample
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
warpers.append(TemperatureLogitsWarper(temperature))
sampling = True
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k))
sampling = True
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs
@classmethod
def from_pb(
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
) -> "NextTokenChooser":
return NextTokenChooser(
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
)
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
self.regex = re.compile(f".*{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
)