From 0fbc69194694b60badae3bf643bc76985f69c0f4 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 14 Feb 2023 13:02:16 +0100 Subject: [PATCH] feat: add safetensors conversion (#63) --- README.md | 8 +- launcher/src/main.rs | 150 +++++++++- server/tests/utils/test_convert.py | 17 ++ server/tests/utils/test_hub.py | 40 +++ .../{test_utils.py => utils/test_tokens.py} | 38 +-- server/text_generation/cli.py | 49 ++- server/text_generation/models/__init__.py | 36 ++- server/text_generation/models/bloom.py | 7 - server/text_generation/models/galactica.py | 7 - server/text_generation/models/gpt_neox.py | 7 - server/text_generation/models/santacoder.py | 2 +- server/text_generation/models/t5.py | 7 - server/text_generation/utils.py | 283 ------------------ server/text_generation/utils/__init__.py | 36 +++ server/text_generation/utils/convert.py | 96 ++++++ server/text_generation/utils/dist.py | 35 +++ server/text_generation/utils/hub.py | 169 +++++++++++ server/text_generation/utils/tokens.py | 142 +++++++++ 18 files changed, 749 insertions(+), 380 deletions(-) create mode 100644 server/tests/utils/test_convert.py create mode 100644 server/tests/utils/test_hub.py rename server/tests/{test_utils.py => utils/test_tokens.py} (52%) delete mode 100644 server/text_generation/utils.py create mode 100644 server/text_generation/utils/__init__.py create mode 100644 server/text_generation/utils/convert.py create mode 100644 server/text_generation/utils/dist.py create mode 100644 server/text_generation/utils/hub.py create mode 100644 server/text_generation/utils/tokens.py diff --git a/README.md b/README.md index 39cec857..5234844b 100644 --- a/README.md +++ b/README.md @@ -49,17 +49,17 @@ to power LLMs api-inference widgets. - Log probabilities - Distributed tracing with Open Telemetry -## Officially supported models +## Officially supported architectures - [BLOOM](https://huggingface.co/bigscience/bloom) - [BLOOMZ](https://huggingface.co/bigscience/bloomz) - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) -- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated) +- [Galactica](https://huggingface.co/facebook/galactica-120b) - [SantaCoder](https://huggingface.co/bigcode/santacoder) - [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b) - [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl) -Other models are supported on a best effort basis using: +Other architectures are supported on a best effort basis using: `AutoModelForCausalLM.from_pretrained(, device_map="auto")` @@ -191,7 +191,7 @@ Be aware that the official Docker image has them enabled by default. ### Download -First you need to download the weights: +It is advised to download the weights ahead of time with the following command: ```shell make download-bloom diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 3c8d9fcc..0848dd9a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -12,7 +12,7 @@ use std::thread; use std::thread::sleep; use std::time::{Duration, Instant}; use std::{fs, io}; -use subprocess::{Popen, PopenConfig, PopenError, Redirection}; +use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection}; /// App Configuration #[derive(Parser, Debug)] @@ -43,6 +43,10 @@ struct Args { #[clap(default_value = "29500", long, env)] master_port: usize, #[clap(long, env)] + huggingface_hub_cache: Option, + #[clap(long, env)] + weights_cache_override: Option, + #[clap(long, env)] json_output: bool, #[clap(long, env)] otlp_endpoint: Option, @@ -63,6 +67,8 @@ fn main() -> ExitCode { shard_uds_path, master_addr, master_port, + huggingface_hub_cache, + weights_cache_override, json_output, otlp_endpoint, } = Args::parse(); @@ -84,6 +90,124 @@ fn main() -> ExitCode { }) .expect("Error setting Ctrl-C handler"); + // Download weights + if weights_cache_override.is_none() { + let mut download_argv = vec![ + "text-generation-server".to_string(), + "download-weights".to_string(), + model_id.clone(), + "--logger-level".to_string(), + "INFO".to_string(), + "--json-output".to_string(), + ]; + if num_shard == 1 { + download_argv.push("--extension".to_string()); + download_argv.push(".bin".to_string()); + } else { + download_argv.push("--extension".to_string()); + download_argv.push(".safetensors".to_string()); + } + + // Model optional revision + if let Some(ref revision) = revision { + download_argv.push("--revision".to_string()); + download_argv.push(revision.to_string()) + } + + let mut env = Vec::new(); + + // If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard + // Useful when running inside a docker container + if let Some(ref huggingface_hub_cache) = huggingface_hub_cache { + env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); + }; + + // Start process + tracing::info!("Starting download"); + let mut download_process = match Popen::create( + &download_argv, + PopenConfig { + stdout: Redirection::Pipe, + stderr: Redirection::Pipe, + // Needed for the shutdown procedure + setpgid: true, + env: Some(env), + ..Default::default() + }, + ) { + Ok(p) => p, + Err(err) => { + if let PopenError::IoError(ref err) = err { + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-server not found in PATH"); + tracing::error!("Please install it with `make install-server`") + } + } + return ExitCode::FAILURE; + } + }; + + // Redirect STDOUT to the console + let download_stdout = download_process.stdout.take().unwrap(); + thread::spawn(move || { + // Enter download tracing span + let stdout = BufReader::new(download_stdout); + let _span = tracing::span!(tracing::Level::INFO, "download").entered(); + for line in stdout.lines() { + // Parse loguru logs + if let Ok(value) = serde_json::from_str::(&line.unwrap()) { + if let Some(text) = value.get("text") { + // Format escaped newlines + tracing::info!("{}", text.to_string().replace("\\n", "")); + } + } + } + }); + + loop { + if let Some(status) = download_process.poll() { + match status { + ExitStatus::Exited(exit_code) => { + if exit_code == 0 { + tracing::info!("Successfully downloaded weights."); + break; + } else { + let mut err = String::new(); + download_process + .stderr + .take() + .unwrap() + .read_to_string(&mut err) + .unwrap(); + tracing::error!("Download encountered an error: {err}"); + return ExitCode::FAILURE; + } + } + _ => { + tracing::error!("Download process exited with an unkown status."); + return ExitCode::FAILURE; + } + } + } + if !running.load(Ordering::SeqCst) { + download_process.terminate().unwrap(); + tracing::info!("Waiting for download process to gracefully shutdown"); + download_process + .wait_timeout(Duration::from_secs(90)) + .unwrap(); + tracing::info!("Download process terminated"); + return ExitCode::SUCCESS; + } + sleep(Duration::from_millis(100)); + } + } else { + tracing::info!( + "weights_cache_override is set to {:?}.", + weights_cache_override + ); + tracing::info!("Skipping download.") + } + // Shared shutdown bool let shutdown = Arc::new(Mutex::new(false)); // Shared shutdown channel @@ -99,6 +223,8 @@ fn main() -> ExitCode { let revision = revision.clone(); let uds_path = shard_uds_path.clone(); let master_addr = master_addr.clone(); + let huggingface_hub_cache = huggingface_hub_cache.clone(); + let weights_cache_override = weights_cache_override.clone(); let status_sender = status_sender.clone(); let shutdown = shutdown.clone(); let shutdown_sender = shutdown_sender.clone(); @@ -113,6 +239,8 @@ fn main() -> ExitCode { num_shard, master_addr, master_port, + huggingface_hub_cache, + weights_cache_override, otlp_endpoint, status_sender, shutdown, @@ -232,7 +360,7 @@ fn main() -> ExitCode { while running.load(Ordering::SeqCst) { if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { - tracing::error!("Shard {} failed:\n{}", rank, err); + tracing::error!("Shard {rank} failed:\n{err}"); exit_code = ExitCode::FAILURE; break; }; @@ -275,6 +403,8 @@ fn shard_manager( world_size: usize, master_addr: String, master_port: usize, + huggingface_hub_cache: Option, + weights_cache_override: Option, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc>, @@ -328,15 +458,15 @@ fn shard_manager( ("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()), ]; - // If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard + // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container - if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") { + if let Some(huggingface_hub_cache) = huggingface_hub_cache { env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); }; - // If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard + // If weights_cache_override is some, pass it to the shard // Useful when running inside a HuggingFace Inference Endpoint - if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") { + if let Some(weights_cache_override) = weights_cache_override { env.push(( "WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into(), @@ -355,7 +485,7 @@ fn shard_manager( }; // Start process - tracing::info!("Starting shard {}", rank); + tracing::info!("Starting shard {rank}"); let mut p = match Popen::create( &shard_argv, PopenConfig { @@ -419,17 +549,17 @@ fn shard_manager( if *shutdown.lock().unwrap() { p.terminate().unwrap(); let _ = p.wait_timeout(Duration::from_secs(90)); - tracing::info!("Shard {} terminated", rank); + tracing::info!("Shard {rank} terminated"); return; } // Shard is ready if uds.exists() && !ready { - tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed()); + tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed()); status_sender.send(ShardStatus::Ready).unwrap(); ready = true; } else if !ready && wait_time.elapsed() > Duration::from_secs(10) { - tracing::info!("Waiting for shard {} to be ready...", rank); + tracing::info!("Waiting for shard {rank} to be ready..."); wait_time = Instant::now(); } sleep(Duration::from_millis(100)); diff --git a/server/tests/utils/test_convert.py b/server/tests/utils/test_convert.py new file mode 100644 index 00000000..5f284be5 --- /dev/null +++ b/server/tests/utils/test_convert.py @@ -0,0 +1,17 @@ +from text_generation.utils.hub import download_weights, weight_hub_files, weight_files + +from text_generation.utils.convert import convert_files + + +def test_convert_files(): + model_id = "bigscience/bloom-560m" + pt_filenames = weight_hub_files(model_id, extension=".bin") + local_pt_files = download_weights(pt_filenames, model_id) + local_st_files = [ + p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files + ] + convert_files(local_pt_files, local_st_files) + + found_st_files = weight_files(model_id) + + assert all([p in found_st_files for p in local_st_files]) diff --git a/server/tests/utils/test_hub.py b/server/tests/utils/test_hub.py new file mode 100644 index 00000000..b3120160 --- /dev/null +++ b/server/tests/utils/test_hub.py @@ -0,0 +1,40 @@ +import pytest + +from text_generation.utils.hub import ( + weight_hub_files, + download_weights, + weight_files, + EntryNotFoundError, + LocalEntryNotFoundError, + RevisionNotFoundError, +) + + +def test_weight_hub_files(): + filenames = weight_hub_files("bigscience/bloom-560m") + assert filenames == ["model.safetensors"] + + +def test_weight_hub_files_llm(): + filenames = weight_hub_files("bigscience/bloom") + assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)] + + +def test_weight_hub_files_empty(): + with pytest.raises(EntryNotFoundError): + weight_hub_files("bigscience/bloom", extension=".errors") + + +def test_download_weights(): + model_id = "bigscience/bloom-560m" + filenames = weight_hub_files(model_id) + files = download_weights(filenames, model_id) + local_files = weight_files("bigscience/bloom-560m") + assert files == local_files + + +def test_weight_files_error(): + with pytest.raises(RevisionNotFoundError): + weight_files("bigscience/bloom-560m", revision="error") + with pytest.raises(LocalEntryNotFoundError): + weight_files("bert-base-uncased") diff --git a/server/tests/test_utils.py b/server/tests/utils/test_tokens.py similarity index 52% rename from server/tests/test_utils.py rename to server/tests/utils/test_tokens.py index ffe9be65..7eca482f 100644 --- a/server/tests/test_utils.py +++ b/server/tests/utils/test_tokens.py @@ -1,14 +1,6 @@ -import pytest - -from huggingface_hub.utils import RevisionNotFoundError - -from text_generation.utils import ( - weight_hub_files, - download_weights, - weight_files, +from text_generation.utils.tokens import ( StopSequenceCriteria, StoppingCriteria, - LocalEntryNotFoundError, FinishReason, ) @@ -41,31 +33,3 @@ def test_stopping_criteria_max(): assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) - - -def test_weight_hub_files(): - filenames = weight_hub_files("bigscience/bloom-560m") - assert filenames == ["model.safetensors"] - - -def test_weight_hub_files_llm(): - filenames = weight_hub_files("bigscience/bloom") - assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)] - - -def test_weight_hub_files_empty(): - filenames = weight_hub_files("bigscience/bloom", extension=".errors") - assert filenames == [] - - -def test_download_weights(): - files = download_weights("bigscience/bloom-560m") - local_files = weight_files("bigscience/bloom-560m") - assert files == local_files - - -def test_weight_files_error(): - with pytest.raises(RevisionNotFoundError): - weight_files("bigscience/bloom-560m", revision="error") - with pytest.raises(LocalEntryNotFoundError): - weight_files("bert-base-uncased") diff --git a/server/text_generation/cli.py b/server/text_generation/cli.py index e9c8ea92..678dce16 100644 --- a/server/text_generation/cli.py +++ b/server/text_generation/cli.py @@ -60,8 +60,55 @@ def download_weights( model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", + logger_level: str = "INFO", + json_output: bool = False, ): - utils.download_weights(model_id, revision, extension) + # Remove default handler + logger.remove() + logger.add( + sys.stdout, + format="{message}", + filter="text_generation", + level=logger_level, + serialize=json_output, + backtrace=True, + diagnose=False, + ) + + # Test if files were already download + try: + utils.weight_files(model_id, revision, extension) + logger.info( + "Files are already present in the local cache. " "Skipping download." + ) + return + # Local files not found + except utils.LocalEntryNotFoundError: + pass + + # Download weights directly + try: + filenames = utils.weight_hub_files(model_id, revision, extension) + utils.download_weights(filenames, model_id, revision) + except utils.EntryNotFoundError as e: + if not extension == ".safetensors": + raise e + + logger.warning( + f"No safetensors weights found for model {model_id} at revision {revision}. " + f"Converting PyTorch weights instead." + ) + + # Try to see if there are pytorch weights + pt_filenames = utils.weight_hub_files(model_id, revision, ".bin") + # Download pytorch weights + local_pt_files = utils.download_weights(pt_filenames, model_id, revision) + local_st_files = [ + p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" + for p in local_pt_files + ] + # Convert pytorch weights to safetensors + utils.convert_files(local_pt_files, local_st_files) if __name__ == "__main__": diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index 7445b427..908b144c 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -41,6 +41,15 @@ torch.set_grad_enabled(False) def get_model( model_id: str, revision: Optional[str], sharded: bool, quantize: bool ) -> Model: + if model_id.startswith("facebook/galactica"): + if sharded: + return GalacticaSharded(model_id, revision, quantize=quantize) + else: + return Galactica(model_id, revision, quantize=quantize) + + if "santacoder" in model_id: + return SantaCoder(model_id, revision, quantize) + config = AutoConfig.from_pretrained(model_id, revision=revision) if config.model_type == "bloom": @@ -48,27 +57,22 @@ def get_model( return BLOOMSharded(model_id, revision, quantize=quantize) else: return BLOOM(model_id, revision, quantize=quantize) - elif config.model_type == "gpt_neox": + + if config.model_type == "gpt_neox": if sharded: return GPTNeoxSharded(model_id, revision, quantize=quantize) else: return GPTNeox(model_id, revision, quantize=quantize) - elif config.model_type == "t5": + + if config.model_type == "t5": if sharded: return T5Sharded(model_id, revision, quantize=quantize) else: return Seq2SeqLM(model_id, revision, quantize=quantize) - elif model_id.startswith("facebook/galactica"): - if sharded: - return GalacticaSharded(model_id, revision, quantize=quantize) - else: - return Galactica(model_id, revision, quantize=quantize) - elif "santacoder" in model_id: - return SantaCoder(model_id, revision, quantize) - else: - if sharded: - raise ValueError("sharded is not supported for AutoModel") - try: - return CausalLM(model_id, revision, quantize=quantize) - except Exception: - return Seq2SeqLM(model_id, revision, quantize=quantize) + + if sharded: + raise ValueError("sharded is not supported for AutoModel") + try: + return CausalLM(model_id, revision, quantize=quantize) + except Exception: + return Seq2SeqLM(model_id, revision, quantize=quantize) diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 992d7b5b..08c3ac94 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -23,7 +23,6 @@ from text_generation.pb import generate_pb2 from text_generation.utils import ( initialize_torch_distributed, weight_files, - download_weights, ) HAS_BITS_AND_BYTES = True @@ -80,14 +79,8 @@ class BLOOMSharded(BLOOM): ) config.pad_token_id = 3 - # Only download weights for small models - if self.master and model_id == "bigscience/bloom-560m": - download_weights(model_id, revision=revision, extension=".safetensors") - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - if not filenames: - raise ValueError("No safetensors weights found") with init_empty_weights(): model = AutoModelForCausalLM.from_config(config) diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index f1dc8a30..780a94f1 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -26,7 +26,6 @@ from text_generation.utils import ( StoppingCriteria, initialize_torch_distributed, weight_files, - download_weights, ) HAS_BITS_AND_BYTES = True @@ -172,14 +171,8 @@ class GalacticaSharded(Galactica): ) tokenizer.pad_token_id = config.pad_token_id - # Only download weights for small models - if self.master and model_id == "facebook/galactica-125m": - download_weights(model_id, revision=revision, extension=".safetensors") - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - if not filenames: - raise ValueError("No safetensors weights found") with init_empty_weights(): model = AutoModelForCausalLM.from_config(config) diff --git a/server/text_generation/models/gpt_neox.py b/server/text_generation/models/gpt_neox.py index 2d467f4c..0197f976 100644 --- a/server/text_generation/models/gpt_neox.py +++ b/server/text_generation/models/gpt_neox.py @@ -20,7 +20,6 @@ from text_generation.models import CausalLM from text_generation.utils import ( initialize_torch_distributed, weight_files, - download_weights, ) HAS_BITS_AND_BYTES = True @@ -69,14 +68,8 @@ class GPTNeoxSharded(GPTNeox): model_id, revision=revision, tp_parallel=True ) - # Only master download weights - if self.master: - download_weights(model_id, revision=revision, extension=".safetensors") - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - if not filenames: - raise ValueError("No safetensors weights found") with init_empty_weights(): model = AutoModelForCausalLM.from_config(config) diff --git a/server/text_generation/models/santacoder.py b/server/text_generation/models/santacoder.py index fb496197..5d271c85 100644 --- a/server/text_generation/models/santacoder.py +++ b/server/text_generation/models/santacoder.py @@ -1,7 +1,7 @@ import torch import torch.distributed -from typing import Optional, List, Tuple +from typing import Optional, List from transformers import AutoTokenizer, AutoModelForCausalLM from text_generation.models import CausalLM diff --git a/server/text_generation/models/t5.py b/server/text_generation/models/t5.py index d7241c81..536ebda3 100644 --- a/server/text_generation/models/t5.py +++ b/server/text_generation/models/t5.py @@ -20,7 +20,6 @@ from text_generation.models import Seq2SeqLM from text_generation.utils import ( initialize_torch_distributed, weight_files, - download_weights, ) HAS_BITS_AND_BYTES = True @@ -53,14 +52,8 @@ class T5Sharded(Seq2SeqLM): ) tokenizer.bos_token_id = config.decoder_start_token_id - # Only master download weights - if self.master: - download_weights(model_id, revision=revision, extension=".safetensors") - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - if not filenames: - raise ValueError("No safetensors weights found") with init_empty_weights(): model = AutoModelForSeq2SeqLM.from_config(config) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py deleted file mode 100644 index 3b3f08c7..00000000 --- a/server/text_generation/utils.py +++ /dev/null @@ -1,283 +0,0 @@ -import concurrent -import os -import re -import torch -import torch.distributed - -from datetime import timedelta - -from concurrent.futures import ThreadPoolExecutor -from functools import partial -from pathlib import Path -from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE -from huggingface_hub.utils import LocalEntryNotFoundError -from tqdm import tqdm -from typing import List, Optional, Tuple -from transformers import PreTrainedTokenizerBase -from transformers.generation.logits_process import ( - LogitsProcessorList, - RepetitionPenaltyLogitsProcessor, - TemperatureLogitsWarper, - TopPLogitsWarper, - TopKLogitsWarper, -) - -from text_generation.pb import generate_pb2 -from text_generation.pb.generate_pb2 import FinishReason - -WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) - - -class Sampling: - def __init__(self, seed: int, device: str = "cpu"): - self.generator = torch.Generator(device) - self.generator.manual_seed(seed) - self.seed = seed - - def __call__(self, logits): - probs = torch.nn.functional.softmax(logits) - next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator) - return next_tokens - - -class Greedy: - def __call__(self, logits): - return logits.argmax() - - -class NextTokenChooser: - def __init__( - self, - temperature=1.0, - repetition_penalty=1.0, - top_k=None, - top_p=None, - do_sample=False, - seed=0, - device="cpu", - ): - warpers = LogitsProcessorList() - # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files - # all samplers can be found in `generation_utils_samplers.py` - sampling = do_sample - if temperature is not None and temperature != 1.0: - temperature = float(temperature) - warpers.append(TemperatureLogitsWarper(temperature)) - sampling = True - if top_k is not None and top_k != 0: - warpers.append(TopKLogitsWarper(top_k=top_k)) - sampling = True - if top_p is not None and top_p < 1.0: - warpers.append(TopPLogitsWarper(top_p=top_p)) - sampling = True - if repetition_penalty is not None and repetition_penalty != 1.0: - warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) - - self.warpers = warpers - self.choice = Sampling(seed, device) if sampling else Greedy() - - def __call__(self, input_ids, scores): - # Warp logits - scores = self.warpers(input_ids, scores) - - # Compute logprobs - logprobs = torch.log_softmax(scores, -1) - - # Choose tokens - next_id = self.choice(scores[-1]) - - return next_id.view(1, 1), logprobs - - @classmethod - def from_pb( - cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device - ) -> "NextTokenChooser": - return NextTokenChooser( - temperature=pb.temperature, - repetition_penalty=pb.repetition_penalty, - top_k=pb.top_k, - top_p=pb.top_p, - do_sample=pb.do_sample, - seed=pb.seed, - device=device, - ) - - -class StopSequenceCriteria: - def __init__(self, stop_sequence: str): - self.regex = re.compile(f".*{stop_sequence}$") - - def __call__(self, output: str) -> bool: - if self.regex.findall(output): - return True - return False - - -class StoppingCriteria: - def __init__( - self, - eos_token_id: int, - stop_sequence_criterias: List[StopSequenceCriteria], - max_new_tokens=20, - ): - self.eos_token_id = eos_token_id - self.stop_sequence_criterias = stop_sequence_criterias - self.max_new_tokens = max_new_tokens - self.current_tokens = 0 - self.current_output = "" - - def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: - self.current_tokens += 1 - if self.current_tokens >= self.max_new_tokens: - return True, FinishReason.FINISH_REASON_LENGTH - - if last_token == self.eos_token_id: - return True, FinishReason.FINISH_REASON_EOS_TOKEN - - self.current_output += last_output - for stop_sequence_criteria in self.stop_sequence_criterias: - if stop_sequence_criteria(self.current_output): - return True, FinishReason.FINISH_REASON_STOP_SEQUENCE - - return False, None - - @classmethod - def from_pb( - cls, - pb: generate_pb2.StoppingCriteriaParameters, - tokenizer: PreTrainedTokenizerBase, - ) -> "StoppingCriteria": - stop_sequence_criterias = [ - StopSequenceCriteria(sequence) for sequence in pb.stop_sequences - ] - return StoppingCriteria( - tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens - ) - - -def initialize_torch_distributed(): - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - - if torch.cuda.is_available(): - from torch.distributed import ProcessGroupNCCL - - # Set the device id. - assert world_size <= torch.cuda.device_count(), "Each process is one gpu" - device = rank % torch.cuda.device_count() - torch.cuda.set_device(device) - backend = "nccl" - options = ProcessGroupNCCL.Options() - options.is_high_priority_stream = True - options._timeout = timedelta(seconds=60) - else: - backend = "gloo" - options = None - - # Call the init process. - torch.distributed.init_process_group( - backend=backend, - world_size=world_size, - rank=rank, - timeout=timedelta(seconds=60), - pg_options=options, - ) - - return torch.distributed.group.WORLD, rank, world_size - - -def weight_hub_files(model_id, revision=None, extension=".safetensors"): - """Get the safetensors filenames on the hub""" - api = HfApi() - info = api.model_info(model_id, revision=revision) - filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)] - return filenames - - -def try_to_load_from_cache(model_id, revision, filename): - """Try to load a file from the Hugging Face cache""" - if revision is None: - revision = "main" - - object_id = model_id.replace("/", "--") - repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}" - - if not repo_cache.is_dir(): - # No cache for this model - return None - - refs_dir = repo_cache / "refs" - snapshots_dir = repo_cache / "snapshots" - no_exist_dir = repo_cache / ".no_exist" - - # Resolve refs (for instance to convert main to the associated commit sha) - if refs_dir.is_dir(): - revision_file = refs_dir / revision - if revision_file.exists(): - with revision_file.open() as f: - revision = f.read() - - # Check if file is cached as "no_exist" - if (no_exist_dir / revision / filename).is_file(): - return _CACHED_NO_EXIST - - # Check if revision folder exists - if not snapshots_dir.exists(): - return None - cached_shas = os.listdir(snapshots_dir) - if revision not in cached_shas: - # No cache for this revision and we won't try to return a random revision - return None - - # Check if file exists in cache - cached_file = snapshots_dir / revision / filename - return str(cached_file) if cached_file.is_file() else None - - -def weight_files(model_id, revision=None, extension=".safetensors"): - """Get the local safetensors filenames""" - if WEIGHTS_CACHE_OVERRIDE is not None: - return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}")) - - filenames = weight_hub_files(model_id, revision, extension) - files = [] - for filename in filenames: - cache_file = try_to_load_from_cache( - model_id, revision=revision, filename=filename - ) - if cache_file is None: - raise LocalEntryNotFoundError( - f"File {filename} of model {model_id} not found in " - f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " - f"Please run `text-generation-server download-weights {model_id}` first." - ) - files.append(cache_file) - - return files - - -def download_weights(model_id, revision=None, extension=".safetensors"): - """Download the safetensors files from the hub""" - if WEIGHTS_CACHE_OVERRIDE is not None: - return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}")) - - filenames = weight_hub_files(model_id, revision, extension) - - download_function = partial( - hf_hub_download, - repo_id=model_id, - local_files_only=False, - ) - - executor = ThreadPoolExecutor(max_workers=5) - futures = [ - executor.submit(download_function, filename=filename, revision=revision) - for filename in filenames - ] - files = [ - future.result() - for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)) - ] - - return files diff --git a/server/text_generation/utils/__init__.py b/server/text_generation/utils/__init__.py new file mode 100644 index 00000000..a390b710 --- /dev/null +++ b/server/text_generation/utils/__init__.py @@ -0,0 +1,36 @@ +from text_generation.utils.convert import convert_file, convert_files +from text_generation.utils.dist import initialize_torch_distributed +from text_generation.utils.hub import ( + weight_files, + weight_hub_files, + download_weights, + EntryNotFoundError, + LocalEntryNotFoundError, + RevisionNotFoundError, +) +from text_generation.utils.tokens import ( + Greedy, + NextTokenChooser, + Sampling, + StoppingCriteria, + StopSequenceCriteria, + FinishReason, +) + +__all__ = [ + "convert_file", + "convert_files", + "initialize_torch_distributed", + "weight_files", + "weight_hub_files", + "download_weights", + "EntryNotFoundError", + "LocalEntryNotFoundError", + "RevisionNotFoundError", + "Greedy", + "NextTokenChooser", + "Sampling", + "StoppingCriteria", + "StopSequenceCriteria", + "FinishReason", +] diff --git a/server/text_generation/utils/convert.py b/server/text_generation/utils/convert.py new file mode 100644 index 00000000..e7f9660c --- /dev/null +++ b/server/text_generation/utils/convert.py @@ -0,0 +1,96 @@ +import concurrent +import time +import torch + +from concurrent.futures import ThreadPoolExecutor +from collections import defaultdict +from datetime import timedelta +from loguru import logger +from pathlib import Path +from safetensors.torch import load_file, save_file +from typing import Dict, List + + +def check_file_size(source_file: Path, target_file: Path): + """ + Check that two files are close in size + """ + source_file_size = source_file.stat().st_size + target_file_size = target_file.stat().st_size + + if (source_file_size - target_file_size) / source_file_size > 0.01: + raise RuntimeError( + f"""The file size different is more than 1%: + - {source_file}: {source_file_size} + - {target_file}: {target_file_size} + """ + ) + + +def remove_shared_pointers(tensors: Dict[str, torch.Tensor]): + """ + For a Dict of tensors, check if two or more tensors point to the same underlying memory and + remove them + """ + ptrs = defaultdict(list) + for k, v in tensors.items(): + ptrs[v.data_ptr()].append(k) + + # Iterate over all found memory addresses + for ptr, names in ptrs.items(): + if len(names) > 1: + # Multiple tensors are point to the same memory + # Only keep the first tensor + for name in names[1:]: + tensors.pop(name) + + +def convert_file(pt_file: Path, st_file: Path): + """ + Convert a pytorch file to a safetensors file + """ + pt_state = torch.load(pt_file, map_location="cpu") + if "state_dict" in pt_state: + pt_state = pt_state["state_dict"] + + remove_shared_pointers(pt_state) + + # Tensors need to be contiguous + pt_state = {k: v.contiguous() for k, v in pt_state.items()} + + st_file.parent.mkdir(parents=True, exist_ok=True) + save_file(pt_state, str(st_file), metadata={"format": "pt"}) + + # Check that both files are close in size + check_file_size(pt_file, st_file) + + # Load safetensors state + st_state = load_file(str(st_file)) + for k in st_state: + pt_tensor = pt_state[k] + st_tensor = st_state[k] + if not torch.equal(pt_tensor, st_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +def convert_files(pt_files: List[Path], st_files: List[Path]): + assert len(pt_files) == len(st_files) + + executor = ThreadPoolExecutor(max_workers=5) + futures = [ + executor.submit(convert_file, pt_file=pt_file, st_file=st_file) + for pt_file, st_file in zip(pt_files, st_files) + ] + + # We do this instead of using tqdm because we want to parse the logs with the launcher + logger.info("Converting weights...") + start_time = time.time() + for i, future in enumerate(concurrent.futures.as_completed(futures)): + elapsed = timedelta(seconds=int(time.time() - start_time)) + remaining = len(futures) - (i + 1) + if remaining != 0: + eta = (elapsed / (i + 1)) * remaining + else: + eta = 0 + + logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}") diff --git a/server/text_generation/utils/dist.py b/server/text_generation/utils/dist.py new file mode 100644 index 00000000..9785493e --- /dev/null +++ b/server/text_generation/utils/dist.py @@ -0,0 +1,35 @@ +import os +import torch + +from datetime import timedelta + + +def initialize_torch_distributed(): + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + + if torch.cuda.is_available(): + from torch.distributed import ProcessGroupNCCL + + # Set the device id. + assert world_size <= torch.cuda.device_count(), "Each process is one gpu" + device = rank % torch.cuda.device_count() + torch.cuda.set_device(device) + backend = "nccl" + options = ProcessGroupNCCL.Options() + options.is_high_priority_stream = True + options._timeout = timedelta(seconds=60) + else: + backend = "gloo" + options = None + + # Call the init process. + torch.distributed.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + timeout=timedelta(seconds=60), + pg_options=options, + ) + + return torch.distributed.group.WORLD, rank, world_size diff --git a/server/text_generation/utils/hub.py b/server/text_generation/utils/hub.py new file mode 100644 index 00000000..60072a20 --- /dev/null +++ b/server/text_generation/utils/hub.py @@ -0,0 +1,169 @@ +import time +import concurrent +import os + +from concurrent.futures import ThreadPoolExecutor +from datetime import timedelta +from loguru import logger +from pathlib import Path +from typing import Optional, List + +from huggingface_hub import HfApi, hf_hub_download +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from huggingface_hub.utils import ( + LocalEntryNotFoundError, + EntryNotFoundError, + RevisionNotFoundError, # Import here to ease try/except in other part of the lib +) + +WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) + + +def weight_hub_files( + model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" +) -> List[str]: + """Get the weights filenames on the hub""" + api = HfApi() + info = api.model_info(model_id, revision=revision) + filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)] + + if not filenames: + raise EntryNotFoundError( + f"No {extension} weights found for model {model_id} and revision {revision}.", + None, + ) + + return filenames + + +def try_to_load_from_cache( + model_id: str, revision: Optional[str], filename: str +) -> Optional[Path]: + """Try to load a file from the Hugging Face cache""" + if revision is None: + revision = "main" + + object_id = model_id.replace("/", "--") + repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}" + + if not repo_cache.is_dir(): + # No cache for this model + return None + + refs_dir = repo_cache / "refs" + snapshots_dir = repo_cache / "snapshots" + no_exist_dir = repo_cache / ".no_exist" + + # Resolve refs (for instance to convert main to the associated commit sha) + if refs_dir.is_dir(): + revision_file = refs_dir / revision + if revision_file.exists(): + with revision_file.open() as f: + revision = f.read() + + # Check if file is cached as "no_exist" + if (no_exist_dir / revision / filename).is_file(): + return None + + # Check if revision folder exists + if not snapshots_dir.exists(): + return None + cached_shas = os.listdir(snapshots_dir) + if revision not in cached_shas: + # No cache for this revision and we won't try to return a random revision + return None + + # Check if file exists in cache + cached_file = snapshots_dir / revision / filename + return cached_file if cached_file.is_file() else None + + +def weight_files( + model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" +) -> List[Path]: + """Get the local files""" + try: + filenames = weight_hub_files(model_id, revision, extension) + except EntryNotFoundError as e: + if extension != ".safetensors": + raise e + # Try to see if there are pytorch weights + pt_filenames = weight_hub_files(model_id, revision, extension=".bin") + # Change pytorch extension to safetensors extension + # It is possible that we have safetensors weights locally even though they are not on the + # hub if we converted weights locally without pushing them + filenames = [ + f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames + ] + + if WEIGHTS_CACHE_OVERRIDE is not None: + files = [] + for filename in filenames: + p = Path(WEIGHTS_CACHE_OVERRIDE) / filename + if not p.exists(): + raise LocalEntryNotFoundError( + f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}." + ) + files.append(p) + return files + + files = [] + for filename in filenames: + cache_file = try_to_load_from_cache( + model_id, revision=revision, filename=filename + ) + if cache_file is None: + raise LocalEntryNotFoundError( + f"File {filename} of model {model_id} not found in " + f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " + f"Please run `text-generation-server download-weights {model_id}` first." + ) + files.append(cache_file) + + return files + + +def download_weights( + filenames: List[str], model_id: str, revision: Optional[str] = None +) -> List[Path]: + """Download the safetensors files from the hub""" + + def download_file(filename): + local_file = try_to_load_from_cache(model_id, revision, filename) + if local_file is not None: + logger.info(f"File {filename} already present in cache.") + return local_file + + start_time = time.time() + local_file = hf_hub_download( + filename=filename, + repo_id=model_id, + revision=revision, + local_files_only=False, + ) + logger.info( + f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}." + ) + return local_file + + executor = ThreadPoolExecutor(max_workers=5) + futures = [ + executor.submit(download_file, filename=filename) for filename in filenames + ] + + # We do this instead of using tqdm because we want to parse the logs with the launcher + logger.info("Downloading weights...") + start_time = time.time() + files = [] + for i, future in enumerate(concurrent.futures.as_completed(futures)): + elapsed = timedelta(seconds=int(time.time() - start_time)) + remaining = len(futures) - (i + 1) + if remaining != 0: + eta = (elapsed / (i + 1)) * remaining + else: + eta = 0 + + logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}") + files.append(Path(future.result())) + + return [Path(p) for p in files] diff --git a/server/text_generation/utils/tokens.py b/server/text_generation/utils/tokens.py new file mode 100644 index 00000000..cc0e6c35 --- /dev/null +++ b/server/text_generation/utils/tokens.py @@ -0,0 +1,142 @@ +import re +import torch + +from transformers import ( + LogitsProcessorList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + RepetitionPenaltyLogitsProcessor, + PreTrainedTokenizerBase, +) +from typing import List, Tuple, Optional + +from text_generation.pb import generate_pb2 +from text_generation.pb.generate_pb2 import FinishReason + + +class Sampling: + def __init__(self, seed: int, device: str = "cpu"): + self.generator = torch.Generator(device) + self.generator.manual_seed(seed) + self.seed = seed + + def __call__(self, logits): + probs = torch.nn.functional.softmax(logits) + next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator) + return next_tokens + + +class Greedy: + def __call__(self, logits): + return logits.argmax() + + +class NextTokenChooser: + def __init__( + self, + temperature=1.0, + repetition_penalty=1.0, + top_k=None, + top_p=None, + do_sample=False, + seed=0, + device="cpu", + ): + warpers = LogitsProcessorList() + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + sampling = do_sample + if temperature is not None and temperature != 1.0: + temperature = float(temperature) + warpers.append(TemperatureLogitsWarper(temperature)) + sampling = True + if top_k is not None and top_k != 0: + warpers.append(TopKLogitsWarper(top_k=top_k)) + sampling = True + if top_p is not None and top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=top_p)) + sampling = True + if repetition_penalty is not None and repetition_penalty != 1.0: + warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) + + self.warpers = warpers + self.choice = Sampling(seed, device) if sampling else Greedy() + + def __call__(self, input_ids, scores): + # Warp logits + scores = self.warpers(input_ids, scores) + + # Compute logprobs + logprobs = torch.log_softmax(scores, -1) + + # Choose tokens + next_id = self.choice(scores[-1]) + + return next_id.view(1, 1), logprobs + + @classmethod + def from_pb( + cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device + ) -> "NextTokenChooser": + return NextTokenChooser( + temperature=pb.temperature, + repetition_penalty=pb.repetition_penalty, + top_k=pb.top_k, + top_p=pb.top_p, + do_sample=pb.do_sample, + seed=pb.seed, + device=device, + ) + + +class StopSequenceCriteria: + def __init__(self, stop_sequence: str): + self.regex = re.compile(f".*{stop_sequence}$") + + def __call__(self, output: str) -> bool: + if self.regex.findall(output): + return True + return False + + +class StoppingCriteria: + def __init__( + self, + eos_token_id: int, + stop_sequence_criterias: List[StopSequenceCriteria], + max_new_tokens=20, + ): + self.eos_token_id = eos_token_id + self.stop_sequence_criterias = stop_sequence_criterias + self.max_new_tokens = max_new_tokens + self.current_tokens = 0 + self.current_output = "" + + def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: + self.current_tokens += 1 + if self.current_tokens >= self.max_new_tokens: + return True, FinishReason.FINISH_REASON_LENGTH + + if last_token == self.eos_token_id: + return True, FinishReason.FINISH_REASON_EOS_TOKEN + + self.current_output += last_output + for stop_sequence_criteria in self.stop_sequence_criterias: + if stop_sequence_criteria(self.current_output): + return True, FinishReason.FINISH_REASON_STOP_SEQUENCE + + return False, None + + @classmethod + def from_pb( + cls, + pb: generate_pb2.StoppingCriteriaParameters, + tokenizer: PreTrainedTokenizerBase, + ) -> "StoppingCriteria": + stop_sequence_criterias = [ + StopSequenceCriteria(sequence) for sequence in pb.stop_sequences + ] + return StoppingCriteria( + tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens + )