improve endpoint support (#1577)
small PR to add a new interface endpoint behind a feature
This commit is contained in:
parent
d19c768cb8
commit
df23062574
|
@ -52,3 +52,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||||
[features]
|
[features]
|
||||||
default = ["ngrok"]
|
default = ["ngrok"]
|
||||||
ngrok = ["dep:ngrok"]
|
ngrok = ["dep:ngrok"]
|
||||||
|
google = []
|
||||||
|
|
|
@ -20,6 +20,25 @@ pub(crate) type GenerateStreamResponse = (
|
||||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct VertexInstance {
|
||||||
|
#[schema(example = "What is Deep Learning?")]
|
||||||
|
pub inputs: String,
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub parameters: Option<GenerateParameters>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct VertexRequest {
|
||||||
|
#[serde(rename = "instances")]
|
||||||
|
pub instances: Vec<VertexInstance>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
pub(crate) struct VertexResponse {
|
||||||
|
pub predictions: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Hub type
|
/// Hub type
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct HubModelInfo {
|
pub struct HubModelInfo {
|
||||||
|
@ -70,7 +89,7 @@ mod json_object_or_string_to_string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
#[serde(tag = "type", content = "value")]
|
#[serde(tag = "type", content = "value")]
|
||||||
pub(crate) enum GrammarType {
|
pub(crate) enum GrammarType {
|
||||||
#[serde(
|
#[serde(
|
||||||
|
@ -153,7 +172,7 @@ pub struct Info {
|
||||||
pub docker_label: Option<&'static str>,
|
pub docker_label: Option<&'static str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
||||||
pub(crate) struct GenerateParameters {
|
pub(crate) struct GenerateParameters {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
|
||||||
|
|
|
@ -328,6 +328,15 @@ async fn main() -> Result<(), RouterError> {
|
||||||
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
|
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
|
// Determine the server port based on the feature and environment variable.
|
||||||
|
let port = if cfg!(feature = "google") {
|
||||||
|
std::env::var("AIP_HTTP_PORT")
|
||||||
|
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
|
||||||
|
.unwrap_or(port)
|
||||||
|
} else {
|
||||||
|
port
|
||||||
|
};
|
||||||
|
|
||||||
let addr = match hostname.parse() {
|
let addr = match hostname.parse() {
|
||||||
Ok(ip) => SocketAddr::new(ip, port),
|
Ok(ip) => SocketAddr::new(ip, port),
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
|
|
|
@ -5,9 +5,9 @@ use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||||
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
||||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
|
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||||
StreamResponse, Token, TokenizeResponse, Validation,
|
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
|
@ -16,8 +16,10 @@ use axum::response::{IntoResponse, Response};
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{http, Json, Router};
|
use axum::{http, Json, Router};
|
||||||
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
|
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
|
||||||
|
use futures::stream::FuturesUnordered;
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
|
use futures::TryStreamExt;
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
@ -693,6 +695,97 @@ async fn chat_completions(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate tokens from Vertex request
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/vertex",
|
||||||
|
request_body = VertexRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Text", body = VertexResponse),
|
||||||
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Model is overloaded"})),
|
||||||
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Input validation error"})),
|
||||||
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Incomplete generation"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
total_time,
|
||||||
|
validation_time,
|
||||||
|
queue_time,
|
||||||
|
inference_time,
|
||||||
|
time_per_token,
|
||||||
|
seed,
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn vertex_compatibility(
|
||||||
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
|
Json(req): Json<VertexRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
|
// check that theres at least one instance
|
||||||
|
if req.instances.is_empty() {
|
||||||
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Input validation error".to_string(),
|
||||||
|
error_type: "Input validation error".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process all instances
|
||||||
|
let predictions = req
|
||||||
|
.instances
|
||||||
|
.iter()
|
||||||
|
.map(|instance| {
|
||||||
|
let generate_request = GenerateRequest {
|
||||||
|
inputs: instance.inputs.clone(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
do_sample: true,
|
||||||
|
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||||
|
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
||||||
|
details: true,
|
||||||
|
decoder_input_details: true,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
async {
|
||||||
|
generate(
|
||||||
|
Extension(infer.clone()),
|
||||||
|
Extension(compute_type.clone()),
|
||||||
|
Json(generate_request),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map(|(_, Json(generation))| generation.generated_text)
|
||||||
|
.map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Incomplete generation".into(),
|
||||||
|
error_type: "Incomplete generation".into(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<FuturesUnordered<_>>()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let response = VertexResponse { predictions };
|
||||||
|
Ok((HeaderMap::new(), Json(response)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
/// Tokenize inputs
|
/// Tokenize inputs
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
|
@ -818,6 +911,7 @@ pub async fn run(
|
||||||
StreamResponse,
|
StreamResponse,
|
||||||
StreamDetails,
|
StreamDetails,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
|
GrammarType,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
tags(
|
tags(
|
||||||
|
@ -942,8 +1036,30 @@ pub async fn run(
|
||||||
docker_label: option_env!("DOCKER_LABEL"),
|
docker_label: option_env!("DOCKER_LABEL"),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Define VertextApiDoc conditionally only if the "google" feature is enabled
|
||||||
|
#[cfg(feature = "google")]
|
||||||
|
#[derive(OpenApi)]
|
||||||
|
#[openapi(
|
||||||
|
paths(vertex_compatibility),
|
||||||
|
components(schemas(VertexInstance, VertexRequest, VertexResponse))
|
||||||
|
)]
|
||||||
|
struct VertextApiDoc;
|
||||||
|
|
||||||
|
let doc = {
|
||||||
|
// avoid `mut` if possible
|
||||||
|
#[cfg(feature = "google")]
|
||||||
|
{
|
||||||
|
// limiting mutability to the smallest scope necessary
|
||||||
|
let mut doc = doc;
|
||||||
|
doc.merge(VertextApiDoc::openapi());
|
||||||
|
doc
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "google"))]
|
||||||
|
ApiDoc::openapi()
|
||||||
|
};
|
||||||
|
|
||||||
// Configure Swagger UI
|
// Configure Swagger UI
|
||||||
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi());
|
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
|
||||||
|
|
||||||
// Define base and health routes
|
// Define base and health routes
|
||||||
let base_routes = Router::new()
|
let base_routes = Router::new()
|
||||||
|
@ -953,6 +1069,7 @@ pub async fn run(
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
|
.route("/vertex", post(vertex_compatibility))
|
||||||
.route("/tokenize", post(tokenize))
|
.route("/tokenize", post(tokenize))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health))
|
||||||
|
@ -969,10 +1086,27 @@ pub async fn run(
|
||||||
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
||||||
|
|
||||||
// Combine routes and layers
|
// Combine routes and layers
|
||||||
let app = Router::new()
|
let mut app = Router::new()
|
||||||
.merge(swagger_ui)
|
.merge(swagger_ui)
|
||||||
.merge(base_routes)
|
.merge(base_routes)
|
||||||
.merge(aws_sagemaker_route)
|
.merge(aws_sagemaker_route);
|
||||||
|
|
||||||
|
#[cfg(feature = "google")]
|
||||||
|
{
|
||||||
|
tracing::info!("Built with `google` feature");
|
||||||
|
tracing::info!(
|
||||||
|
"Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
|
||||||
|
);
|
||||||
|
if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
|
||||||
|
app = app.route(&env_predict_route, post(vertex_compatibility));
|
||||||
|
}
|
||||||
|
if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
|
||||||
|
app = app.route(&env_health_route, get(health));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add layers after routes
|
||||||
|
app = app
|
||||||
.layer(Extension(info))
|
.layer(Extension(info))
|
||||||
.layer(Extension(health_ext.clone()))
|
.layer(Extension(health_ext.clone()))
|
||||||
.layer(Extension(compat_return_full_text))
|
.layer(Extension(compat_return_full_text))
|
||||||
|
|
Loading…
Reference in New Issue