From cd5961b5dad560d63f4dd42d08d6ee3877b82003 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 6 Mar 2023 14:39:36 +0100 Subject: [PATCH] feat: allow local models (#101) closes #99 --- launcher/src/main.rs | 6 +++++- router/src/main.rs | 17 +++++++++++++---- router/src/queue.rs | 2 +- server/text_generation/models/__init__.py | 2 +- server/text_generation/models/bloom.py | 3 --- server/text_generation/models/galactica.py | 3 --- server/text_generation/utils/hub.py | 4 ++++ 7 files changed, 24 insertions(+), 13 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 7a32ad2c..66dcb2db 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -110,8 +110,12 @@ fn main() -> ExitCode { }) .expect("Error setting Ctrl-C handler"); + // Check if model_id is a local model + let local_path = Path::new(&model_id); + let is_local_model = local_path.exists() && local_path.is_dir(); + // Download weights for sharded models - if weights_cache_override.is_none() && num_shard > 1 { + if !is_local_model && weights_cache_override.is_none() && num_shard > 1 { let mut download_argv = vec![ "text-generation-server".to_string(), "download-weights".to_string(), diff --git a/router/src/main.rs b/router/src/main.rs index 2baf9e72..a51d3168 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -8,6 +8,7 @@ use opentelemetry::sdk::Resource; use opentelemetry::{global, KeyValue}; use opentelemetry_otlp::WithExportConfig; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::path::Path; use text_generation_client::ShardedClient; use text_generation_router::server; use tokenizers::Tokenizer; @@ -83,11 +84,19 @@ fn main() -> Result<(), std::io::Error> { ) }); - // Download and instantiate tokenizer + // Tokenizer instance // This will only be used to validate payloads - // - // We need to download it outside of the Tokio runtime - let tokenizer = Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap(); + let local_path = Path::new(&tokenizer_name); + let tokenizer = + if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists() + { + // Load local tokenizer + Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap() + } else { + // Download and instantiate tokenizer + // We need to download it outside of the Tokio runtime + Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap() + }; // Launch Tokio runtime tokio::runtime::Builder::new_multi_thread() diff --git a/router/src/queue.rs b/router/src/queue.rs index 088bdd3c..0ebfed9b 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -234,7 +234,7 @@ mod tests { do_sample: false, seed: 0, repetition_penalty: 0.0, - watermark: false + watermark: false, }, stopping_parameters: StoppingCriteriaParameters { max_new_tokens: 0, diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index 908b144c..386b7dc9 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -41,7 +41,7 @@ torch.set_grad_enabled(False) def get_model( model_id: str, revision: Optional[str], sharded: bool, quantize: bool ) -> Model: - if model_id.startswith("facebook/galactica"): + if "facebook/galactica" in model_id: if sharded: return GalacticaSharded(model_id, revision, quantize=quantize) else: diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 08c3ac94..83a0d63e 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -58,9 +58,6 @@ class BLOOMSharded(BLOOM): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): - if not model_id.startswith("bigscience/bloom"): - raise ValueError(f"Model {model_id} is not supported") - self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 if torch.cuda.is_available(): diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index f3a76459..9a71c5d3 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -164,9 +164,6 @@ class GalacticaSharded(Galactica): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): - if not model_id.startswith("facebook/galactica"): - raise ValueError(f"Model {model_id} is not supported") - self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 if torch.cuda.is_available(): diff --git a/server/text_generation/utils/hub.py b/server/text_generation/utils/hub.py index 7166b4cb..d338fb29 100644 --- a/server/text_generation/utils/hub.py +++ b/server/text_generation/utils/hub.py @@ -80,6 +80,10 @@ def weight_files( model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" ) -> List[Path]: """Get the local files""" + # Local model + if Path(model_id).exists() and Path(model_id).is_dir(): + return list(Path(model_id).glob(f"*{extension}")) + try: filenames = weight_hub_files(model_id, revision, extension) except EntryNotFoundError as e: