feat(router): add device and dtype info (#215)
This commit is contained in:
parent
ac8c0f6fe4
commit
343437c7b5
|
@ -3,6 +3,8 @@ syntax = "proto3";
|
|||
package generate.v1;
|
||||
|
||||
service TextGenerationService {
|
||||
/// Model Info
|
||||
rpc Info (InfoRequest) returns (InfoResponse) {}
|
||||
/// Service discovery
|
||||
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
|
||||
/// Empties batch cache
|
||||
|
@ -13,6 +15,15 @@ service TextGenerationService {
|
|||
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||
}
|
||||
|
||||
/// Empty request
|
||||
message InfoRequest {}
|
||||
|
||||
message InfoResponse {
|
||||
bool requires_padding = 1;
|
||||
string dtype = 2;
|
||||
string device_type = 3;
|
||||
}
|
||||
|
||||
/// Empty request
|
||||
message ServiceDiscoveryRequest {}
|
||||
|
||||
|
|
|
@ -54,6 +54,14 @@ impl Client {
|
|||
Ok(urls)
|
||||
}
|
||||
|
||||
/// Get model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||
let response = self.stub.info(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
|
|
|
@ -6,6 +6,7 @@ mod pb;
|
|||
mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
||||
pub use pb::generate::v1::{
|
||||
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
||||
Request, StoppingCriteriaParameters,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/// Multi shard Client
|
||||
use crate::Result;
|
||||
use crate::{Batch, Client, Generation};
|
||||
use crate::{Batch, Client, Generation, ShardInfo};
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
@ -37,6 +37,17 @@ impl ShardedClient {
|
|||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Get the model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.info())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
|
|
|
@ -12,7 +12,7 @@ use validation::Validation;
|
|||
|
||||
/// Hub type
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
pub struct HubModelInfo {
|
||||
#[serde(rename(deserialize = "id"))]
|
||||
pub model_id: String,
|
||||
pub sha: Option<String>,
|
||||
|
@ -25,6 +25,10 @@ pub struct Info {
|
|||
pub model_id: String,
|
||||
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
||||
pub model_sha: Option<String>,
|
||||
#[schema(example = "torch.float16")]
|
||||
pub model_dtype: String,
|
||||
#[schema(example = "cuda")]
|
||||
pub model_device_type: String,
|
||||
#[schema(nullable = true, example = "text-generation")]
|
||||
pub model_pipeline_tag: Option<String>,
|
||||
#[schema(example = "0.5.0")]
|
||||
|
|
|
@ -10,7 +10,7 @@ use opentelemetry_otlp::WithExportConfig;
|
|||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use text_generation_client::ShardedClient;
|
||||
use text_generation_router::{server, ModelInfo};
|
||||
use text_generation_router::{server, HubModelInfo};
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
use tower_http::cors::AllowOrigin;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
|
@ -128,7 +128,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
|
||||
// Get Model info
|
||||
let model_info = match local_model {
|
||||
true => ModelInfo {
|
||||
true => HubModelInfo {
|
||||
model_id: tokenizer_name.clone(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
|
@ -154,6 +154,11 @@ fn main() -> Result<(), std::io::Error> {
|
|||
.clear_cache(None)
|
||||
.await
|
||||
.expect("Unable to clear cache");
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client
|
||||
.info()
|
||||
.await
|
||||
.expect("Unable to get shard info");
|
||||
tracing::info!("Connected");
|
||||
|
||||
// Binds on localhost
|
||||
|
@ -162,6 +167,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
// Run server
|
||||
server::run(
|
||||
model_info,
|
||||
shard_info,
|
||||
compat_return_full_text,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
|
@ -237,7 +243,7 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
|||
}
|
||||
|
||||
/// get model info from the Huggingface Hub
|
||||
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> ModelInfo {
|
||||
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> HubModelInfo {
|
||||
let client = reqwest::Client::new();
|
||||
let mut builder = client.get(format!(
|
||||
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
|||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, Infer, Info, ModelInfo, PrefillToken,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
|
||||
StreamDetails, StreamResponse, Token, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
|
@ -18,7 +18,7 @@ use futures::Stream;
|
|||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use text_generation_client::ShardedClient;
|
||||
use text_generation_client::{ShardInfo, ShardedClient};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
use tokio::time::Instant;
|
||||
|
@ -78,13 +78,19 @@ async fn compat_generate(
|
|||
responses((status = 200, description = "Served model info", body = Info))
|
||||
)]
|
||||
#[instrument]
|
||||
async fn get_model_info(model_info: Extension<ModelInfo>) -> Json<Info> {
|
||||
async fn get_model_info(
|
||||
model_info: Extension<HubModelInfo>,
|
||||
shard_info: Extension<ShardInfo>,
|
||||
) -> Json<Info> {
|
||||
let model_info = model_info.0;
|
||||
let shard_info = shard_info.0;
|
||||
let info = Info {
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
sha: option_env!("VERGEN_GIT_SHA"),
|
||||
model_id: model_info.model_id,
|
||||
model_sha: model_info.sha,
|
||||
model_dtype: shard_info.dtype,
|
||||
model_device_type: shard_info.device_type,
|
||||
model_pipeline_tag: model_info.pipeline_tag,
|
||||
};
|
||||
Json(info)
|
||||
|
@ -497,7 +503,8 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
|||
/// Serving method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
model_info: ModelInfo,
|
||||
model_info: HubModelInfo,
|
||||
shard_info: ShardInfo,
|
||||
compat_return_full_text: bool,
|
||||
max_concurrent_requests: usize,
|
||||
max_best_of: usize,
|
||||
|
@ -641,6 +648,7 @@ pub async fn run(
|
|||
// Prometheus metrics route
|
||||
.route("/metrics", get(metrics))
|
||||
.layer(Extension(model_info))
|
||||
.layer(Extension(shard_info))
|
||||
.layer(Extension(compat_return_full_text))
|
||||
.layer(Extension(infer))
|
||||
.layer(Extension(prom_handle))
|
||||
|
|
|
@ -100,7 +100,11 @@ class BLOOMSharded(BLOOM):
|
|||
self.model = model.eval()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer, device=device, decode_buffer=1
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
decode_buffer=1,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -400,7 +400,11 @@ class CausalLM(Model):
|
|||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
decode_buffer=decode_buffer,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
@ -343,7 +343,11 @@ class FlashCausalLM(Model):
|
|||
)
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
decode_buffer=decode_buffer,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
@ -63,6 +63,8 @@ class FlashLlama(FlashCausalLM):
|
|||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
@ -184,6 +186,8 @@ class FlashLlamaSharded(FlashLlama):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
|
|
@ -70,6 +70,8 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
|
|
@ -65,7 +65,11 @@ class FlashSantacoder(FlashCausalLM):
|
|||
self.model = model.eval().to(device)
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer, device=device, decode_buffer=1
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
decode_buffer=1,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -206,6 +210,8 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
|
|
@ -228,6 +228,8 @@ class GalacticaSharded(Galactica):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
|
|
@ -72,6 +72,8 @@ class GPTNeoxSharded(CausalLM):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import List, Tuple, Optional, TypeVar, Type
|
|||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation_server.models.types import Batch, GeneratedText
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
||||
|
@ -13,6 +14,8 @@ class Model(ABC):
|
|||
def __init__(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
requires_padding: bool,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
decode_buffer: int = 3,
|
||||
):
|
||||
|
@ -21,9 +24,19 @@ class Model(ABC):
|
|||
|
||||
self.tokenizer = tokenizer
|
||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||
self.requires_padding = requires_padding
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.decode_buffer = decode_buffer
|
||||
|
||||
@property
|
||||
def info(self) -> InfoResponse:
|
||||
return InfoResponse(
|
||||
requires_padding=self.requires_padding,
|
||||
dtype=str(self.dtype),
|
||||
device_type=self.device.type,
|
||||
)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def batch_type(self) -> Type[B]:
|
||||
|
|
|
@ -88,6 +88,8 @@ class OPTSharded(OPT):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
|
|
@ -54,7 +54,11 @@ class SantaCoder(CausalLM):
|
|||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer, device=device, decode_buffer=1
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
decode_buffer=1,
|
||||
)
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
|
|
|
@ -460,7 +460,11 @@ class Seq2SeqLM(Model):
|
|||
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
||||
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
decode_buffer=decode_buffer,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
@ -72,6 +72,8 @@ class T5Sharded(Seq2SeqLM):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
|
|
@ -26,6 +26,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
# Force inference mode for the lifetime of TextGenerationService
|
||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||
|
||||
async def Info(self, request, context):
|
||||
return self.model.info
|
||||
|
||||
async def ServiceDiscovery(self, request, context):
|
||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||
|
||||
|
|
Loading…
Reference in New Issue