diff --git a/router/Cargo.toml b/router/Cargo.toml index 1a7ceb70..7d6dc017 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -52,3 +52,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } [features] default = ["ngrok"] ngrok = ["dep:ngrok"] +google = [] diff --git a/router/src/lib.rs b/router/src/lib.rs index 8c7ca74b..b7285e65 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -20,6 +20,25 @@ pub(crate) type GenerateStreamResponse = ( UnboundedReceiverStream>, ); +#[derive(Clone, Deserialize, ToSchema)] +pub(crate) struct VertexInstance { + #[schema(example = "What is Deep Learning?")] + pub inputs: String, + #[schema(nullable = true, default = "null", example = "null")] + pub parameters: Option, +} + +#[derive(Deserialize, ToSchema)] +pub(crate) struct VertexRequest { + #[serde(rename = "instances")] + pub instances: Vec, +} + +#[derive(Clone, Deserialize, ToSchema, Serialize)] +pub(crate) struct VertexResponse { + pub predictions: Vec, +} + /// Hub type #[derive(Clone, Debug, Deserialize)] pub struct HubModelInfo { @@ -70,7 +89,7 @@ mod json_object_or_string_to_string { } } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, ToSchema)] #[serde(tag = "type", content = "value")] pub(crate) enum GrammarType { #[serde( @@ -153,7 +172,7 @@ pub struct Info { pub docker_label: Option<&'static str>, } -#[derive(Clone, Debug, Deserialize, ToSchema)] +#[derive(Clone, Debug, Deserialize, ToSchema, Default)] pub(crate) struct GenerateParameters { #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] diff --git a/router/src/main.rs b/router/src/main.rs index 457bca8e..60a66a41 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -328,6 +328,15 @@ async fn main() -> Result<(), RouterError> { tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); tracing::info!("Connected"); + // Determine the server port based on the feature and environment variable. + let port = if cfg!(feature = "google") { + std::env::var("AIP_HTTP_PORT") + .map(|aip_http_port| aip_http_port.parse::().unwrap_or(port)) + .unwrap_or(port) + } else { + port + }; + let addr = match hostname.parse() { Ok(ip) => SocketAddr::new(ip, port), Err(_) => { diff --git a/router/src/server.rs b/router/src/server.rs index 0fc76916..140fb014 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -5,9 +5,9 @@ use crate::validation::ValidationError; use crate::{ BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta, ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, - FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, + FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, - StreamResponse, Token, TokenizeResponse, Validation, + StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -16,8 +16,10 @@ use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::middleware::OtelAxumLayer; +use futures::stream::FuturesUnordered; use futures::stream::StreamExt; use futures::Stream; +use futures::TryStreamExt; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use std::convert::Infallible; use std::net::SocketAddr; @@ -693,6 +695,97 @@ async fn chat_completions( } } +/// Generate tokens from Vertex request +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/vertex", + request_body = VertexRequest, + responses( + (status = 200, description = "Generated Text", body = VertexResponse), + (status = 424, description = "Generation Error", body = ErrorResponse, + example = json ! ({"error": "Request failed during generation"})), + (status = 429, description = "Model is overloaded", body = ErrorResponse, + example = json ! ({"error": "Model is overloaded"})), + (status = 422, description = "Input validation error", body = ErrorResponse, + example = json ! ({"error": "Input validation error"})), + (status = 500, description = "Incomplete generation", body = ErrorResponse, + example = json ! ({"error": "Incomplete generation"})), + ) + )] +#[instrument( + skip_all, + fields( + total_time, + validation_time, + queue_time, + inference_time, + time_per_token, + seed, + ) +)] +async fn vertex_compatibility( + Extension(infer): Extension, + Extension(compute_type): Extension, + Json(req): Json, +) -> Result)> { + metrics::increment_counter!("tgi_request_count"); + + // check that theres at least one instance + if req.instances.is_empty() { + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Input validation error".to_string(), + error_type: "Input validation error".to_string(), + }), + )); + } + + // Process all instances + let predictions = req + .instances + .iter() + .map(|instance| { + let generate_request = GenerateRequest { + inputs: instance.inputs.clone(), + parameters: GenerateParameters { + do_sample: true, + max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), + seed: instance.parameters.as_ref().and_then(|p| p.seed), + details: true, + decoder_input_details: true, + ..Default::default() + }, + }; + + async { + generate( + Extension(infer.clone()), + Extension(compute_type.clone()), + Json(generate_request), + ) + .await + .map(|(_, Json(generation))| generation.generated_text) + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Incomplete generation".into(), + error_type: "Incomplete generation".into(), + }), + ) + }) + } + }) + .collect::>() + .try_collect::>() + .await?; + + let response = VertexResponse { predictions }; + Ok((HeaderMap::new(), Json(response)).into_response()) +} + /// Tokenize inputs #[utoipa::path( post, @@ -818,6 +911,7 @@ pub async fn run( StreamResponse, StreamDetails, ErrorResponse, + GrammarType, ) ), tags( @@ -942,8 +1036,30 @@ pub async fn run( docker_label: option_env!("DOCKER_LABEL"), }; + // Define VertextApiDoc conditionally only if the "google" feature is enabled + #[cfg(feature = "google")] + #[derive(OpenApi)] + #[openapi( + paths(vertex_compatibility), + components(schemas(VertexInstance, VertexRequest, VertexResponse)) + )] + struct VertextApiDoc; + + let doc = { + // avoid `mut` if possible + #[cfg(feature = "google")] + { + // limiting mutability to the smallest scope necessary + let mut doc = doc; + doc.merge(VertextApiDoc::openapi()); + doc + } + #[cfg(not(feature = "google"))] + ApiDoc::openapi() + }; + // Configure Swagger UI - let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()); + let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); // Define base and health routes let base_routes = Router::new() @@ -953,6 +1069,7 @@ pub async fn run( .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) + .route("/vertex", post(vertex_compatibility)) .route("/tokenize", post(tokenize)) .route("/health", get(health)) .route("/ping", get(health)) @@ -969,10 +1086,27 @@ pub async fn run( ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); // Combine routes and layers - let app = Router::new() + let mut app = Router::new() .merge(swagger_ui) .merge(base_routes) - .merge(aws_sagemaker_route) + .merge(aws_sagemaker_route); + + #[cfg(feature = "google")] + { + tracing::info!("Built with `google` feature"); + tracing::info!( + "Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected." + ); + if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") { + app = app.route(&env_predict_route, post(vertex_compatibility)); + } + if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") { + app = app.route(&env_health_route, get(health)); + } + } + + // add layers after routes + app = app .layer(Extension(info)) .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text))