parent
9b205d33cc
commit
cd5961b5da
|
@ -110,8 +110,12 @@ fn main() -> ExitCode {
|
|||
})
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
// Check if model_id is a local model
|
||||
let local_path = Path::new(&model_id);
|
||||
let is_local_model = local_path.exists() && local_path.is_dir();
|
||||
|
||||
// Download weights for sharded models
|
||||
if weights_cache_override.is_none() && num_shard > 1 {
|
||||
if !is_local_model && weights_cache_override.is_none() && num_shard > 1 {
|
||||
let mut download_argv = vec![
|
||||
"text-generation-server".to_string(),
|
||||
"download-weights".to_string(),
|
||||
|
|
|
@ -8,6 +8,7 @@ use opentelemetry::sdk::Resource;
|
|||
use opentelemetry::{global, KeyValue};
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use text_generation_client::ShardedClient;
|
||||
use text_generation_router::server;
|
||||
use tokenizers::Tokenizer;
|
||||
|
@ -83,11 +84,19 @@ fn main() -> Result<(), std::io::Error> {
|
|||
)
|
||||
});
|
||||
|
||||
// Download and instantiate tokenizer
|
||||
// Tokenizer instance
|
||||
// This will only be used to validate payloads
|
||||
//
|
||||
// We need to download it outside of the Tokio runtime
|
||||
let tokenizer = Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap();
|
||||
let local_path = Path::new(&tokenizer_name);
|
||||
let tokenizer =
|
||||
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
||||
{
|
||||
// Load local tokenizer
|
||||
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
|
||||
} else {
|
||||
// Download and instantiate tokenizer
|
||||
// We need to download it outside of the Tokio runtime
|
||||
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
|
||||
};
|
||||
|
||||
// Launch Tokio runtime
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
|
|
|
@ -234,7 +234,7 @@ mod tests {
|
|||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 0.0,
|
||||
watermark: false
|
||||
watermark: false,
|
||||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
max_new_tokens: 0,
|
||||
|
|
|
@ -41,7 +41,7 @@ torch.set_grad_enabled(False)
|
|||
def get_model(
|
||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||
) -> Model:
|
||||
if model_id.startswith("facebook/galactica"):
|
||||
if "facebook/galactica" in model_id:
|
||||
if sharded:
|
||||
return GalacticaSharded(model_id, revision, quantize=quantize)
|
||||
else:
|
||||
|
|
|
@ -58,9 +58,6 @@ class BLOOMSharded(BLOOM):
|
|||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
if not model_id.startswith("bigscience/bloom"):
|
||||
raise ValueError(f"Model {model_id} is not supported")
|
||||
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
if torch.cuda.is_available():
|
||||
|
|
|
@ -164,9 +164,6 @@ class GalacticaSharded(Galactica):
|
|||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
if not model_id.startswith("facebook/galactica"):
|
||||
raise ValueError(f"Model {model_id} is not supported")
|
||||
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
if torch.cuda.is_available():
|
||||
|
|
|
@ -80,6 +80,10 @@ def weight_files(
|
|||
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||
) -> List[Path]:
|
||||
"""Get the local files"""
|
||||
# Local model
|
||||
if Path(model_id).exists() and Path(model_id).is_dir():
|
||||
return list(Path(model_id).glob(f"*{extension}"))
|
||||
|
||||
try:
|
||||
filenames = weight_hub_files(model_id, revision, extension)
|
||||
except EntryNotFoundError as e:
|
||||
|
|
Loading…
Reference in New Issue