feat: allow local models (#101)

closes #99
This commit is contained in:
OlivierDehaene 2023-03-06 14:39:36 +01:00 committed by GitHub
parent 9b205d33cc
commit cd5961b5da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 24 additions and 13 deletions

View File

@ -110,8 +110,12 @@ fn main() -> ExitCode {
}) })
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
// Check if model_id is a local model
let local_path = Path::new(&model_id);
let is_local_model = local_path.exists() && local_path.is_dir();
// Download weights for sharded models // Download weights for sharded models
if weights_cache_override.is_none() && num_shard > 1 { if !is_local_model && weights_cache_override.is_none() && num_shard > 1 {
let mut download_argv = vec![ let mut download_argv = vec![
"text-generation-server".to_string(), "text-generation-server".to_string(),
"download-weights".to_string(), "download-weights".to_string(),

View File

@ -8,6 +8,7 @@ use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue}; use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path;
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use text_generation_router::server; use text_generation_router::server;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -83,11 +84,19 @@ fn main() -> Result<(), std::io::Error> {
) )
}); });
// Download and instantiate tokenizer // Tokenizer instance
// This will only be used to validate payloads // This will only be used to validate payloads
// let local_path = Path::new(&tokenizer_name);
// We need to download it outside of the Tokio runtime let tokenizer =
let tokenizer = Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap(); if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
{
// Load local tokenizer
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
} else {
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
};
// Launch Tokio runtime // Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()

View File

@ -234,7 +234,7 @@ mod tests {
do_sample: false, do_sample: false,
seed: 0, seed: 0,
repetition_penalty: 0.0, repetition_penalty: 0.0,
watermark: false watermark: false,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
max_new_tokens: 0, max_new_tokens: 0,

View File

@ -41,7 +41,7 @@ torch.set_grad_enabled(False)
def get_model( def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model: ) -> Model:
if model_id.startswith("facebook/galactica"): if "facebook/galactica" in model_id:
if sharded: if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize) return GalacticaSharded(model_id, revision, quantize=quantize)
else: else:

View File

@ -58,9 +58,6 @@ class BLOOMSharded(BLOOM):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
if not model_id.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():

View File

@ -164,9 +164,6 @@ class GalacticaSharded(Galactica):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
if not model_id.startswith("facebook/galactica"):
raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():

View File

@ -80,6 +80,10 @@ def weight_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[Path]: ) -> List[Path]:
"""Get the local files""" """Get the local files"""
# Local model
if Path(model_id).exists() and Path(model_id).is_dir():
return list(Path(model_id).glob(f"*{extension}"))
try: try:
filenames = weight_hub_files(model_id, revision, extension) filenames = weight_hub_files(model_id, revision, extension)
except EntryNotFoundError as e: except EntryNotFoundError as e: