improve endpoint support (#1577)

small PR to add a new interface endpoint behind a feature
This commit is contained in:
drbh 2024-02-20 08:04:51 -05:00 committed by GitHub
parent d19c768cb8
commit df23062574
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 170 additions and 7 deletions

View File

@ -52,3 +52,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
[features] [features]
default = ["ngrok"] default = ["ngrok"]
ngrok = ["dep:ngrok"] ngrok = ["dep:ngrok"]
google = []

View File

@ -20,6 +20,25 @@ pub(crate) type GenerateStreamResponse = (
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
); );
#[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct VertexInstance {
#[schema(example = "What is Deep Learning?")]
pub inputs: String,
#[schema(nullable = true, default = "null", example = "null")]
pub parameters: Option<GenerateParameters>,
}
#[derive(Deserialize, ToSchema)]
pub(crate) struct VertexRequest {
#[serde(rename = "instances")]
pub instances: Vec<VertexInstance>,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct VertexResponse {
pub predictions: Vec<String>,
}
/// Hub type /// Hub type
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo { pub struct HubModelInfo {
@ -70,7 +89,7 @@ mod json_object_or_string_to_string {
} }
} }
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize, ToSchema)]
#[serde(tag = "type", content = "value")] #[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType { pub(crate) enum GrammarType {
#[serde( #[serde(
@ -153,7 +172,7 @@ pub struct Info {
pub docker_label: Option<&'static str>, pub docker_label: Option<&'static str>,
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema, Default)]
pub(crate) struct GenerateParameters { pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]

View File

@ -328,6 +328,15 @@ async fn main() -> Result<(), RouterError> {
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
tracing::info!("Connected"); tracing::info!("Connected");
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
std::env::var("AIP_HTTP_PORT")
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
.unwrap_or(port)
} else {
port
};
let addr = match hostname.parse() { let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port), Ok(ip) => SocketAddr::new(ip, port),
Err(_) => { Err(_) => {

View File

@ -5,9 +5,9 @@ use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta, BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Validation, StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -16,8 +16,10 @@ use axum::response::{IntoResponse, Response};
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{http, Json, Router}; use axum::{http, Json, Router};
use axum_tracing_opentelemetry::middleware::OtelAxumLayer; use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
use futures::stream::FuturesUnordered;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use futures::Stream; use futures::Stream;
use futures::TryStreamExt;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
@ -693,6 +695,97 @@ async fn chat_completions(
} }
} }
/// Generate tokens from Vertex request
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/vertex",
request_body = VertexRequest,
responses(
(status = 200, description = "Generated Text", body = VertexResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(
skip_all,
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
)]
async fn vertex_compatibility(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count");
// check that theres at least one instance
if req.instances.is_empty() {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Input validation error".to_string(),
error_type: "Input validation error".to_string(),
}),
));
}
// Process all instances
let predictions = req
.instances
.iter()
.map(|instance| {
let generate_request = GenerateRequest {
inputs: instance.inputs.clone(),
parameters: GenerateParameters {
do_sample: true,
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
seed: instance.parameters.as_ref().and_then(|p| p.seed),
details: true,
decoder_input_details: true,
..Default::default()
},
};
async {
generate(
Extension(infer.clone()),
Extension(compute_type.clone()),
Json(generate_request),
)
.await
.map(|(_, Json(generation))| generation.generated_text)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
}
})
.collect::<FuturesUnordered<_>>()
.try_collect::<Vec<_>>()
.await?;
let response = VertexResponse { predictions };
Ok((HeaderMap::new(), Json(response)).into_response())
}
/// Tokenize inputs /// Tokenize inputs
#[utoipa::path( #[utoipa::path(
post, post,
@ -818,6 +911,7 @@ pub async fn run(
StreamResponse, StreamResponse,
StreamDetails, StreamDetails,
ErrorResponse, ErrorResponse,
GrammarType,
) )
), ),
tags( tags(
@ -942,8 +1036,30 @@ pub async fn run(
docker_label: option_env!("DOCKER_LABEL"), docker_label: option_env!("DOCKER_LABEL"),
}; };
// Define VertextApiDoc conditionally only if the "google" feature is enabled
#[cfg(feature = "google")]
#[derive(OpenApi)]
#[openapi(
paths(vertex_compatibility),
components(schemas(VertexInstance, VertexRequest, VertexResponse))
)]
struct VertextApiDoc;
let doc = {
// avoid `mut` if possible
#[cfg(feature = "google")]
{
// limiting mutability to the smallest scope necessary
let mut doc = doc;
doc.merge(VertextApiDoc::openapi());
doc
}
#[cfg(not(feature = "google"))]
ApiDoc::openapi()
};
// Configure Swagger UI // Configure Swagger UI
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()); let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
// Define base and health routes // Define base and health routes
let base_routes = Router::new() let base_routes = Router::new()
@ -953,6 +1069,7 @@ pub async fn run(
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions)) .route("/v1/chat/completions", post(chat_completions))
.route("/vertex", post(vertex_compatibility))
.route("/tokenize", post(tokenize)) .route("/tokenize", post(tokenize))
.route("/health", get(health)) .route("/health", get(health))
.route("/ping", get(health)) .route("/ping", get(health))
@ -969,10 +1086,27 @@ pub async fn run(
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
// Combine routes and layers // Combine routes and layers
let app = Router::new() let mut app = Router::new()
.merge(swagger_ui) .merge(swagger_ui)
.merge(base_routes) .merge(base_routes)
.merge(aws_sagemaker_route) .merge(aws_sagemaker_route);
#[cfg(feature = "google")]
{
tracing::info!("Built with `google` feature");
tracing::info!(
"Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
);
if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
app = app.route(&env_predict_route, post(vertex_compatibility));
}
if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
app = app.route(&env_health_route, get(health));
}
}
// add layers after routes
app = app
.layer(Extension(info)) .layer(Extension(info))
.layer(Extension(health_ext.clone())) .layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text)) .layer(Extension(compat_return_full_text))