parent
9b205d33cc
commit
cd5961b5da
|
@ -110,8 +110,12 @@ fn main() -> ExitCode {
|
||||||
})
|
})
|
||||||
.expect("Error setting Ctrl-C handler");
|
.expect("Error setting Ctrl-C handler");
|
||||||
|
|
||||||
|
// Check if model_id is a local model
|
||||||
|
let local_path = Path::new(&model_id);
|
||||||
|
let is_local_model = local_path.exists() && local_path.is_dir();
|
||||||
|
|
||||||
// Download weights for sharded models
|
// Download weights for sharded models
|
||||||
if weights_cache_override.is_none() && num_shard > 1 {
|
if !is_local_model && weights_cache_override.is_none() && num_shard > 1 {
|
||||||
let mut download_argv = vec![
|
let mut download_argv = vec![
|
||||||
"text-generation-server".to_string(),
|
"text-generation-server".to_string(),
|
||||||
"download-weights".to_string(),
|
"download-weights".to_string(),
|
||||||
|
|
|
@ -8,6 +8,7 @@ use opentelemetry::sdk::Resource;
|
||||||
use opentelemetry::{global, KeyValue};
|
use opentelemetry::{global, KeyValue};
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
use std::path::Path;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use text_generation_router::server;
|
use text_generation_router::server;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
@ -83,11 +84,19 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Download and instantiate tokenizer
|
// Tokenizer instance
|
||||||
// This will only be used to validate payloads
|
// This will only be used to validate payloads
|
||||||
//
|
let local_path = Path::new(&tokenizer_name);
|
||||||
// We need to download it outside of the Tokio runtime
|
let tokenizer =
|
||||||
let tokenizer = Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap();
|
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
||||||
|
{
|
||||||
|
// Load local tokenizer
|
||||||
|
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
|
||||||
|
} else {
|
||||||
|
// Download and instantiate tokenizer
|
||||||
|
// We need to download it outside of the Tokio runtime
|
||||||
|
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
|
|
|
@ -234,7 +234,7 @@ mod tests {
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 0.0,
|
repetition_penalty: 0.0,
|
||||||
watermark: false
|
watermark: false,
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
max_new_tokens: 0,
|
max_new_tokens: 0,
|
||||||
|
|
|
@ -41,7 +41,7 @@ torch.set_grad_enabled(False)
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if model_id.startswith("facebook/galactica"):
|
if "facebook/galactica" in model_id:
|
||||||
if sharded:
|
if sharded:
|
||||||
return GalacticaSharded(model_id, revision, quantize=quantize)
|
return GalacticaSharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -58,9 +58,6 @@ class BLOOMSharded(BLOOM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
if not model_id.startswith("bigscience/bloom"):
|
|
||||||
raise ValueError(f"Model {model_id} is not supported")
|
|
||||||
|
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|
|
@ -164,9 +164,6 @@ class GalacticaSharded(Galactica):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
if not model_id.startswith("facebook/galactica"):
|
|
||||||
raise ValueError(f"Model {model_id} is not supported")
|
|
||||||
|
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|
|
@ -80,6 +80,10 @@ def weight_files(
|
||||||
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||||
) -> List[Path]:
|
) -> List[Path]:
|
||||||
"""Get the local files"""
|
"""Get the local files"""
|
||||||
|
# Local model
|
||||||
|
if Path(model_id).exists() and Path(model_id).is_dir():
|
||||||
|
return list(Path(model_id).glob(f"*{extension}"))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
filenames = weight_hub_files(model_id, revision, extension)
|
filenames = weight_hub_files(model_id, revision, extension)
|
||||||
except EntryNotFoundError as e:
|
except EntryNotFoundError as e:
|
||||||
|
|
Loading…
Reference in New Issue