improve endpoint support (#1577)
small PR to add a new interface endpoint behind a feature
This commit is contained in:
parent
d19c768cb8
commit
df23062574
|
@ -52,3 +52,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
|||
[features]
|
||||
default = ["ngrok"]
|
||||
ngrok = ["dep:ngrok"]
|
||||
google = []
|
||||
|
|
|
@ -20,6 +20,25 @@ pub(crate) type GenerateStreamResponse = (
|
|||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||
);
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
pub(crate) struct VertexInstance {
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub inputs: String,
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub parameters: Option<GenerateParameters>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, ToSchema)]
|
||||
pub(crate) struct VertexRequest {
|
||||
#[serde(rename = "instances")]
|
||||
pub instances: Vec<VertexInstance>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
pub(crate) struct VertexResponse {
|
||||
pub predictions: Vec<String>,
|
||||
}
|
||||
|
||||
/// Hub type
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct HubModelInfo {
|
||||
|
@ -70,7 +89,7 @@ mod json_object_or_string_to_string {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
#[serde(tag = "type", content = "value")]
|
||||
pub(crate) enum GrammarType {
|
||||
#[serde(
|
||||
|
@ -153,7 +172,7 @@ pub struct Info {
|
|||
pub docker_label: Option<&'static str>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
||||
pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
|
||||
|
|
|
@ -328,6 +328,15 @@ async fn main() -> Result<(), RouterError> {
|
|||
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
|
||||
tracing::info!("Connected");
|
||||
|
||||
// Determine the server port based on the feature and environment variable.
|
||||
let port = if cfg!(feature = "google") {
|
||||
std::env::var("AIP_HTTP_PORT")
|
||||
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
|
||||
.unwrap_or(port)
|
||||
} else {
|
||||
port
|
||||
};
|
||||
|
||||
let addr = match hostname.parse() {
|
||||
Ok(ip) => SocketAddr::new(ip, port),
|
||||
Err(_) => {
|
||||
|
|
|
@ -5,9 +5,9 @@ use crate::validation::ValidationError;
|
|||
use crate::{
|
||||
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
|
||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, TokenizeResponse, Validation,
|
||||
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
|
@ -16,8 +16,10 @@ use axum::response::{IntoResponse, Response};
|
|||
use axum::routing::{get, post};
|
||||
use axum::{http, Json, Router};
|
||||
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::Stream;
|
||||
use futures::TryStreamExt;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
|
@ -693,6 +695,97 @@ async fn chat_completions(
|
|||
}
|
||||
}
|
||||
|
||||
/// Generate tokens from Vertex request
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/vertex",
|
||||
request_body = VertexRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = VertexResponse),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"})),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json ! ({"error": "Model is overloaded"})),
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Input validation error"})),
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json ! ({"error": "Incomplete generation"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
total_time,
|
||||
validation_time,
|
||||
queue_time,
|
||||
inference_time,
|
||||
time_per_token,
|
||||
seed,
|
||||
)
|
||||
)]
|
||||
async fn vertex_compatibility(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Json(req): Json<VertexRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
// check that theres at least one instance
|
||||
if req.instances.is_empty() {
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Input validation error".to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
// Process all instances
|
||||
let predictions = req
|
||||
.instances
|
||||
.iter()
|
||||
.map(|instance| {
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: instance.inputs.clone(),
|
||||
parameters: GenerateParameters {
|
||||
do_sample: true,
|
||||
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
||||
details: true,
|
||||
decoder_input_details: true,
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
||||
async {
|
||||
generate(
|
||||
Extension(infer.clone()),
|
||||
Extension(compute_type.clone()),
|
||||
Json(generate_request),
|
||||
)
|
||||
.await
|
||||
.map(|(_, Json(generation))| generation.generated_text)
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Incomplete generation".into(),
|
||||
error_type: "Incomplete generation".into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect::<FuturesUnordered<_>>()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await?;
|
||||
|
||||
let response = VertexResponse { predictions };
|
||||
Ok((HeaderMap::new(), Json(response)).into_response())
|
||||
}
|
||||
|
||||
/// Tokenize inputs
|
||||
#[utoipa::path(
|
||||
post,
|
||||
|
@ -818,6 +911,7 @@ pub async fn run(
|
|||
StreamResponse,
|
||||
StreamDetails,
|
||||
ErrorResponse,
|
||||
GrammarType,
|
||||
)
|
||||
),
|
||||
tags(
|
||||
|
@ -942,8 +1036,30 @@ pub async fn run(
|
|||
docker_label: option_env!("DOCKER_LABEL"),
|
||||
};
|
||||
|
||||
// Define VertextApiDoc conditionally only if the "google" feature is enabled
|
||||
#[cfg(feature = "google")]
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
paths(vertex_compatibility),
|
||||
components(schemas(VertexInstance, VertexRequest, VertexResponse))
|
||||
)]
|
||||
struct VertextApiDoc;
|
||||
|
||||
let doc = {
|
||||
// avoid `mut` if possible
|
||||
#[cfg(feature = "google")]
|
||||
{
|
||||
// limiting mutability to the smallest scope necessary
|
||||
let mut doc = doc;
|
||||
doc.merge(VertextApiDoc::openapi());
|
||||
doc
|
||||
}
|
||||
#[cfg(not(feature = "google"))]
|
||||
ApiDoc::openapi()
|
||||
};
|
||||
|
||||
// Configure Swagger UI
|
||||
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi());
|
||||
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
|
||||
|
||||
// Define base and health routes
|
||||
let base_routes = Router::new()
|
||||
|
@ -953,6 +1069,7 @@ pub async fn run(
|
|||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/vertex", post(vertex_compatibility))
|
||||
.route("/tokenize", post(tokenize))
|
||||
.route("/health", get(health))
|
||||
.route("/ping", get(health))
|
||||
|
@ -969,10 +1086,27 @@ pub async fn run(
|
|||
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
||||
|
||||
// Combine routes and layers
|
||||
let app = Router::new()
|
||||
let mut app = Router::new()
|
||||
.merge(swagger_ui)
|
||||
.merge(base_routes)
|
||||
.merge(aws_sagemaker_route)
|
||||
.merge(aws_sagemaker_route);
|
||||
|
||||
#[cfg(feature = "google")]
|
||||
{
|
||||
tracing::info!("Built with `google` feature");
|
||||
tracing::info!(
|
||||
"Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
|
||||
);
|
||||
if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
|
||||
app = app.route(&env_predict_route, post(vertex_compatibility));
|
||||
}
|
||||
if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
|
||||
app = app.route(&env_health_route, get(health));
|
||||
}
|
||||
}
|
||||
|
||||
// add layers after routes
|
||||
app = app
|
||||
.layer(Extension(info))
|
||||
.layer(Extension(health_ext.clone()))
|
||||
.layer(Extension(compat_return_full_text))
|
||||
|
|
Loading…
Reference in New Issue