feat(router): add device and dtype info (#215)

This commit is contained in:
OlivierDehaene 2023-04-21 15:36:29 +02:00 committed by GitHub
parent ac8c0f6fe4
commit 343437c7b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 120 additions and 15 deletions

View File

@ -3,6 +3,8 @@ syntax = "proto3";
package generate.v1; package generate.v1;
service TextGenerationService { service TextGenerationService {
/// Model Info
rpc Info (InfoRequest) returns (InfoResponse) {}
/// Service discovery /// Service discovery
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
/// Empties batch cache /// Empties batch cache
@ -13,6 +15,15 @@ service TextGenerationService {
rpc Decode (DecodeRequest) returns (DecodeResponse); rpc Decode (DecodeRequest) returns (DecodeResponse);
} }
/// Empty request
message InfoRequest {}
message InfoResponse {
bool requires_padding = 1;
string dtype = 2;
string device_type = 3;
}
/// Empty request /// Empty request
message ServiceDiscoveryRequest {} message ServiceDiscoveryRequest {}

View File

@ -54,6 +54,14 @@ impl Client {
Ok(urls) Ok(urls)
} }
/// Get model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<InfoResponse> {
let request = tonic::Request::new(InfoRequest {}).inject_context();
let response = self.stub.info(request).await?.into_inner();
Ok(response)
}
/// Clear the past generations cache /// Clear the past generations cache
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> { pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {

View File

@ -6,6 +6,7 @@ mod pb;
mod sharded_client; mod sharded_client;
pub use client::Client; pub use client::Client;
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{ pub use pb::generate::v1::{
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
Request, StoppingCriteriaParameters, Request, StoppingCriteriaParameters,

View File

@ -1,6 +1,6 @@
/// Multi shard Client /// Multi shard Client
use crate::Result; use crate::Result;
use crate::{Batch, Client, Generation}; use crate::{Batch, Client, Generation, ShardInfo};
use futures::future::join_all; use futures::future::join_all;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
@ -37,6 +37,17 @@ impl ShardedClient {
Self::from_master_client(master_client).await Self::from_master_client(master_client).await
} }
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap()
}
/// Clear the past generations cache /// Clear the past generations cache
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> { pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {

View File

@ -12,7 +12,7 @@ use validation::Validation;
/// Hub type /// Hub type
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct ModelInfo { pub struct HubModelInfo {
#[serde(rename(deserialize = "id"))] #[serde(rename(deserialize = "id"))]
pub model_id: String, pub model_id: String,
pub sha: Option<String>, pub sha: Option<String>,
@ -25,6 +25,10 @@ pub struct Info {
pub model_id: String, pub model_id: String,
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")] #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
pub model_sha: Option<String>, pub model_sha: Option<String>,
#[schema(example = "torch.float16")]
pub model_dtype: String,
#[schema(example = "cuda")]
pub model_device_type: String,
#[schema(nullable = true, example = "text-generation")] #[schema(nullable = true, example = "text-generation")]
pub model_pipeline_tag: Option<String>, pub model_pipeline_tag: Option<String>,
#[schema(example = "0.5.0")] #[schema(example = "0.5.0")]

View File

@ -10,7 +10,7 @@ use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path; use std::path::Path;
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use text_generation_router::{server, ModelInfo}; use text_generation_router::{server, HubModelInfo};
use tokenizers::{FromPretrainedParameters, Tokenizer}; use tokenizers::{FromPretrainedParameters, Tokenizer};
use tower_http::cors::AllowOrigin; use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
@ -128,7 +128,7 @@ fn main() -> Result<(), std::io::Error> {
// Get Model info // Get Model info
let model_info = match local_model { let model_info = match local_model {
true => ModelInfo { true => HubModelInfo {
model_id: tokenizer_name.clone(), model_id: tokenizer_name.clone(),
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
@ -154,6 +154,11 @@ fn main() -> Result<(), std::io::Error> {
.clear_cache(None) .clear_cache(None)
.await .await
.expect("Unable to clear cache"); .expect("Unable to clear cache");
// Get info from the shard
let shard_info = sharded_client
.info()
.await
.expect("Unable to get shard info");
tracing::info!("Connected"); tracing::info!("Connected");
// Binds on localhost // Binds on localhost
@ -162,6 +167,7 @@ fn main() -> Result<(), std::io::Error> {
// Run server // Run server
server::run( server::run(
model_info, model_info,
shard_info,
compat_return_full_text, compat_return_full_text,
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
@ -237,7 +243,7 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
} }
/// get model info from the Huggingface Hub /// get model info from the Huggingface Hub
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> ModelInfo { pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> HubModelInfo {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let mut builder = client.get(format!( let mut builder = client.get(format!(
"https://huggingface.co/api/models/{model_id}/revision/{revision}" "https://huggingface.co/api/models/{model_id}/revision/{revision}"

View File

@ -3,7 +3,7 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, Infer, Info, ModelInfo, PrefillToken, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
StreamDetails, StreamResponse, Token, Validation, StreamDetails, StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
@ -18,7 +18,7 @@ use futures::Stream;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use text_generation_client::ShardedClient; use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::time::Instant; use tokio::time::Instant;
@ -78,13 +78,19 @@ async fn compat_generate(
responses((status = 200, description = "Served model info", body = Info)) responses((status = 200, description = "Served model info", body = Info))
)] )]
#[instrument] #[instrument]
async fn get_model_info(model_info: Extension<ModelInfo>) -> Json<Info> { async fn get_model_info(
model_info: Extension<HubModelInfo>,
shard_info: Extension<ShardInfo>,
) -> Json<Info> {
let model_info = model_info.0; let model_info = model_info.0;
let shard_info = shard_info.0;
let info = Info { let info = Info {
version: env!("CARGO_PKG_VERSION"), version: env!("CARGO_PKG_VERSION"),
sha: option_env!("VERGEN_GIT_SHA"), sha: option_env!("VERGEN_GIT_SHA"),
model_id: model_info.model_id, model_id: model_info.model_id,
model_sha: model_info.sha, model_sha: model_info.sha,
model_dtype: shard_info.dtype,
model_device_type: shard_info.device_type,
model_pipeline_tag: model_info.pipeline_tag, model_pipeline_tag: model_info.pipeline_tag,
}; };
Json(info) Json(info)
@ -497,7 +503,8 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
/// Serving method /// Serving method
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn run( pub async fn run(
model_info: ModelInfo, model_info: HubModelInfo,
shard_info: ShardInfo,
compat_return_full_text: bool, compat_return_full_text: bool,
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_best_of: usize, max_best_of: usize,
@ -641,6 +648,7 @@ pub async fn run(
// Prometheus metrics route // Prometheus metrics route
.route("/metrics", get(metrics)) .route("/metrics", get(metrics))
.layer(Extension(model_info)) .layer(Extension(model_info))
.layer(Extension(shard_info))
.layer(Extension(compat_return_full_text)) .layer(Extension(compat_return_full_text))
.layer(Extension(infer)) .layer(Extension(infer))
.layer(Extension(prom_handle)) .layer(Extension(prom_handle))

View File

@ -100,7 +100,11 @@ class BLOOMSharded(BLOOM):
self.model = model.eval() self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, device=device, decode_buffer=1 tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
decode_buffer=1,
) )
@staticmethod @staticmethod

View File

@ -400,7 +400,11 @@ class CausalLM(Model):
) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
decode_buffer=decode_buffer,
) )
@property @property

View File

@ -343,7 +343,11 @@ class FlashCausalLM(Model):
) )
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
decode_buffer=decode_buffer,
) )
@property @property

View File

@ -63,6 +63,8 @@ class FlashLlama(FlashCausalLM):
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device, device=device,
) )
@ -184,6 +186,8 @@ class FlashLlamaSharded(FlashLlama):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device, device=device,
) )

View File

@ -70,6 +70,8 @@ class FlashNeoXSharded(FlashNeoX):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device, device=device,
) )

View File

@ -65,7 +65,11 @@ class FlashSantacoder(FlashCausalLM):
self.model = model.eval().to(device) self.model = model.eval().to(device)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, device=device, decode_buffer=1 tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
decode_buffer=1,
) )
@staticmethod @staticmethod
@ -206,6 +210,8 @@ class FlashSantacoderSharded(FlashSantacoder):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device, device=device,
) )

View File

@ -228,6 +228,8 @@ class GalacticaSharded(Galactica):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device, device=device,
) )

View File

@ -72,6 +72,8 @@ class GPTNeoxSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device, device=device,
) )

View File

@ -5,6 +5,7 @@ from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from text_generation_server.models.types import Batch, GeneratedText from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
@ -13,6 +14,8 @@ class Model(ABC):
def __init__( def __init__(
self, self,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
requires_padding: bool,
dtype: torch.dtype,
device: torch.device, device: torch.device,
decode_buffer: int = 3, decode_buffer: int = 3,
): ):
@ -21,9 +24,19 @@ class Model(ABC):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)
self.requires_padding = requires_padding
self.dtype = dtype
self.device = device self.device = device
self.decode_buffer = decode_buffer self.decode_buffer = decode_buffer
@property
def info(self) -> InfoResponse:
return InfoResponse(
requires_padding=self.requires_padding,
dtype=str(self.dtype),
device_type=self.device.type,
)
@property @property
@abstractmethod @abstractmethod
def batch_type(self) -> Type[B]: def batch_type(self) -> Type[B]:

View File

@ -88,6 +88,8 @@ class OPTSharded(OPT):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device, device=device,
) )

View File

@ -54,7 +54,11 @@ class SantaCoder(CausalLM):
) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, device=device, decode_buffer=1 tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
decode_buffer=1,
) )
def decode(self, generated_ids: List[int]) -> str: def decode(self, generated_ids: List[int]) -> str:

View File

@ -460,7 +460,11 @@ class Seq2SeqLM(Model):
tokenizer.bos_token_id = self.model.config.decoder_start_token_id tokenizer.bos_token_id = self.model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
decode_buffer=decode_buffer,
) )
@property @property

View File

@ -72,6 +72,8 @@ class T5Sharded(Seq2SeqLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device, device=device,
) )

View File

@ -26,6 +26,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
# Force inference mode for the lifetime of TextGenerationService # Force inference mode for the lifetime of TextGenerationService
self._inference_mode_raii_guard = torch._C._InferenceMode(True) self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def Info(self, request, context):
return self.model.info
async def ServiceDiscovery(self, request, context): async def ServiceDiscovery(self, request, context):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)