diff --git a/proto/generate.proto b/proto/generate.proto index cc14cbf..2bf7385 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -3,6 +3,8 @@ syntax = "proto3"; package generate.v1; service TextGenerationService { + /// Model Info + rpc Info (InfoRequest) returns (InfoResponse) {} /// Service discovery rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} /// Empties batch cache @@ -13,6 +15,15 @@ service TextGenerationService { rpc Decode (DecodeRequest) returns (DecodeResponse); } +/// Empty request +message InfoRequest {} + +message InfoResponse { + bool requires_padding = 1; + string dtype = 2; + string device_type = 3; +} + /// Empty request message ServiceDiscoveryRequest {} diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 1b2086a..cccd500 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -54,6 +54,14 @@ impl Client { Ok(urls) } + /// Get model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let request = tonic::Request::new(InfoRequest {}).inject_context(); + let response = self.stub.info(request).await?.into_inner(); + Ok(response) + } + /// Clear the past generations cache #[instrument(skip(self))] pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 21fbc1e..6a00130 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -6,6 +6,7 @@ mod pb; mod sharded_client; pub use client::Client; +pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::{ Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request, StoppingCriteriaParameters, diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 31f7631..903c7a6 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,6 +1,6 @@ /// Multi shard Client use crate::Result; -use crate::{Batch, Client, Generation}; +use crate::{Batch, Client, Generation, ShardInfo}; use futures::future::join_all; use tonic::transport::Uri; use tracing::instrument; @@ -37,6 +37,17 @@ impl ShardedClient { Self::from_master_client(master_client).await } + /// Get the model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.info()) + .collect(); + join_all(futures).await.pop().unwrap() + } + /// Clear the past generations cache #[instrument(skip(self))] pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { diff --git a/router/src/lib.rs b/router/src/lib.rs index 7dc115f..2f93ec0 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -12,7 +12,7 @@ use validation::Validation; /// Hub type #[derive(Clone, Debug, Deserialize)] -pub struct ModelInfo { +pub struct HubModelInfo { #[serde(rename(deserialize = "id"))] pub model_id: String, pub sha: Option, @@ -25,6 +25,10 @@ pub struct Info { pub model_id: String, #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")] pub model_sha: Option, + #[schema(example = "torch.float16")] + pub model_dtype: String, + #[schema(example = "cuda")] + pub model_device_type: String, #[schema(nullable = true, example = "text-generation")] pub model_pipeline_tag: Option, #[schema(example = "0.5.0")] diff --git a/router/src/main.rs b/router/src/main.rs index 31783bb..712071b 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -10,7 +10,7 @@ use opentelemetry_otlp::WithExportConfig; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; use text_generation_client::ShardedClient; -use text_generation_router::{server, ModelInfo}; +use text_generation_router::{server, HubModelInfo}; use tokenizers::{FromPretrainedParameters, Tokenizer}; use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; @@ -128,7 +128,7 @@ fn main() -> Result<(), std::io::Error> { // Get Model info let model_info = match local_model { - true => ModelInfo { + true => HubModelInfo { model_id: tokenizer_name.clone(), sha: None, pipeline_tag: None, @@ -154,6 +154,11 @@ fn main() -> Result<(), std::io::Error> { .clear_cache(None) .await .expect("Unable to clear cache"); + // Get info from the shard + let shard_info = sharded_client + .info() + .await + .expect("Unable to get shard info"); tracing::info!("Connected"); // Binds on localhost @@ -162,6 +167,7 @@ fn main() -> Result<(), std::io::Error> { // Run server server::run( model_info, + shard_info, compat_return_full_text, max_concurrent_requests, max_best_of, @@ -237,7 +243,7 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { } /// get model info from the Huggingface Hub -pub async fn get_model_info(model_id: &str, revision: &str, token: Option) -> ModelInfo { +pub async fn get_model_info(model_id: &str, revision: &str, token: Option) -> HubModelInfo { let client = reqwest::Client::new(); let mut builder = client.get(format!( "https://huggingface.co/api/models/{model_id}/revision/{revision}" diff --git a/router/src/server.rs b/router/src/server.rs index fee748e..8891443 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -3,7 +3,7 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, - GenerateParameters, GenerateRequest, GenerateResponse, Infer, Info, ModelInfo, PrefillToken, + GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; @@ -18,7 +18,7 @@ use futures::Stream; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use std::convert::Infallible; use std::net::SocketAddr; -use text_generation_client::ShardedClient; +use text_generation_client::{ShardInfo, ShardedClient}; use tokenizers::Tokenizer; use tokio::signal; use tokio::time::Instant; @@ -78,13 +78,19 @@ async fn compat_generate( responses((status = 200, description = "Served model info", body = Info)) )] #[instrument] -async fn get_model_info(model_info: Extension) -> Json { +async fn get_model_info( + model_info: Extension, + shard_info: Extension, +) -> Json { let model_info = model_info.0; + let shard_info = shard_info.0; let info = Info { version: env!("CARGO_PKG_VERSION"), sha: option_env!("VERGEN_GIT_SHA"), model_id: model_info.model_id, model_sha: model_info.sha, + model_dtype: shard_info.dtype, + model_device_type: shard_info.device_type, model_pipeline_tag: model_info.pipeline_tag, }; Json(info) @@ -497,7 +503,8 @@ async fn metrics(prom_handle: Extension) -> String { /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( - model_info: ModelInfo, + model_info: HubModelInfo, + shard_info: ShardInfo, compat_return_full_text: bool, max_concurrent_requests: usize, max_best_of: usize, @@ -641,6 +648,7 @@ pub async fn run( // Prometheus metrics route .route("/metrics", get(metrics)) .layer(Extension(model_info)) + .layer(Extension(shard_info)) .layer(Extension(compat_return_full_text)) .layer(Extension(infer)) .layer(Extension(prom_handle)) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 731a985..e43a4b7 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -100,7 +100,11 @@ class BLOOMSharded(BLOOM): self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( - tokenizer=tokenizer, device=device, decode_buffer=1 + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + decode_buffer=1, ) @staticmethod diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 71eff43..9831325 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -400,7 +400,11 @@ class CausalLM(Model): ) super(CausalLM, self).__init__( - tokenizer=tokenizer, device=device, decode_buffer=decode_buffer + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + decode_buffer=decode_buffer, ) @property diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2843f27..0e2fbaa 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -343,7 +343,11 @@ class FlashCausalLM(Model): ) super(FlashCausalLM, self).__init__( - tokenizer=tokenizer, device=device, decode_buffer=decode_buffer + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + decode_buffer=decode_buffer, ) @property diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 9cbf1b5..764de2a 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -63,6 +63,8 @@ class FlashLlama(FlashCausalLM): super(FlashCausalLM, self).__init__( tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, device=device, ) @@ -184,6 +186,8 @@ class FlashLlamaSharded(FlashLlama): torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, device=device, ) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 0cda728..259fc20 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -70,6 +70,8 @@ class FlashNeoXSharded(FlashNeoX): torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, device=device, ) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index e3066c9..7dcd8b0 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -65,7 +65,11 @@ class FlashSantacoder(FlashCausalLM): self.model = model.eval().to(device) super(FlashCausalLM, self).__init__( - tokenizer=tokenizer, device=device, decode_buffer=1 + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + decode_buffer=1, ) @staticmethod @@ -206,6 +210,8 @@ class FlashSantacoderSharded(FlashSantacoder): torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, device=device, ) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 746e9e8..753e86e 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -228,6 +228,8 @@ class GalacticaSharded(Galactica): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, device=device, ) diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 489615e..3b5fe2c 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -72,6 +72,8 @@ class GPTNeoxSharded(CausalLM): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, device=device, ) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 08a4855..bfae829 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -5,6 +5,7 @@ from typing import List, Tuple, Optional, TypeVar, Type from transformers import PreTrainedTokenizerBase from text_generation_server.models.types import Batch, GeneratedText +from text_generation_server.pb.generate_pb2 import InfoResponse B = TypeVar("B", bound=Batch) @@ -13,6 +14,8 @@ class Model(ABC): def __init__( self, tokenizer: PreTrainedTokenizerBase, + requires_padding: bool, + dtype: torch.dtype, device: torch.device, decode_buffer: int = 3, ): @@ -21,9 +24,19 @@ class Model(ABC): self.tokenizer = tokenizer self.all_special_ids = set(tokenizer.all_special_ids) + self.requires_padding = requires_padding + self.dtype = dtype self.device = device self.decode_buffer = decode_buffer + @property + def info(self) -> InfoResponse: + return InfoResponse( + requires_padding=self.requires_padding, + dtype=str(self.dtype), + device_type=self.device.type, + ) + @property @abstractmethod def batch_type(self) -> Type[B]: diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 8e5527c..1a21186 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -88,6 +88,8 @@ class OPTSharded(OPT): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, device=device, ) diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 8646a4e..796c33e 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -54,7 +54,11 @@ class SantaCoder(CausalLM): ) super(CausalLM, self).__init__( - tokenizer=tokenizer, device=device, decode_buffer=1 + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + decode_buffer=1, ) def decode(self, generated_ids: List[int]) -> str: diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index dd2f999..aa452c7 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -460,7 +460,11 @@ class Seq2SeqLM(Model): tokenizer.bos_token_id = self.model.config.decoder_start_token_id super(Seq2SeqLM, self).__init__( - tokenizer=tokenizer, device=device, decode_buffer=decode_buffer + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + decode_buffer=decode_buffer, ) @property diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index b9f7701..487a598 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -72,6 +72,8 @@ class T5Sharded(Seq2SeqLM): torch.distributed.barrier(group=self.process_group) super(Seq2SeqLM, self).__init__( tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, device=device, ) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 3caee80..95b431c 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -26,6 +26,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): # Force inference mode for the lifetime of TextGenerationService self._inference_mode_raii_guard = torch._C._InferenceMode(True) + async def Info(self, request, context): + return self.model.info + async def ServiceDiscovery(self, request, context): return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)