implement Open Inference Protocol endpoints (#1942)

* feat: add kserve feature and basic routes

* feat: implement infer endpoint wrapper around generate

* fix: refactor and improve types

* fix: improve infer and simplify

* fix: cleanup and improve api docs

* fix: refactor and encapsulate kserve feat in file

* fix: remove typos after rebase
This commit is contained in:
drbh 2024-06-13 12:51:51 -04:00 committed by GitHub
parent 42aa8ee1bb
commit f433f1f770
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 328 additions and 21 deletions

View File

@ -59,3 +59,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
default = ["ngrok"] default = ["ngrok"]
ngrok = ["dep:ngrok"] ngrok = ["dep:ngrok"]
google = [] google = []
kserve = []

247
router/src/kserve.rs Normal file
View File

@ -0,0 +1,247 @@
use crate::{
default_parameters,
server::{generate_internal, ComputeType},
Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Infer, Serialize, ToSchema,
};
use axum::extract::{Extension, Path};
use axum::response::{IntoResponse, Response};
use axum::Json;
use futures::stream::FuturesUnordered;
use futures::TryStreamExt;
use reqwest::header::HeaderMap;
use reqwest::StatusCode;
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct OutputChunk {
pub name: String,
pub shape: Vec<usize>,
pub datatype: String,
pub data: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct InferenceOutput {
pub id: String,
pub outputs: Vec<OutputChunk>,
}
#[derive(Debug, Deserialize, ToSchema)]
pub(crate) struct InferenceRequest {
pub id: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
pub inputs: Vec<Input>,
pub outputs: Vec<Output>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub(crate) struct Input {
pub name: String,
pub shape: Vec<usize>,
pub datatype: String,
pub data: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub(crate) struct Output {
pub name: String,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct LiveResponse {
pub live: bool,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct ReadyResponse {
pub live: bool,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct MetadataServerResponse {
pub name: String,
pub version: String,
pub extensions: Vec<String>,
}
// Routes
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v2/health/live",
responses(
(status = 200, description = "Service is live", body = LiveReponse),
(status = 404, description = "Service not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = LiveResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response())
}
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v2/health/ready",
responses(
(status = 200, description = "Service is ready", body = ReadyResponse),
(status = 404, description = "Service not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = ReadyResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response())
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2",
responses(
(status = 200, description = "Metadata retrieved", body = MetadataServerResponse),
(status = 404, description = "Service not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = MetadataServerResponse {
name: "text-generation-inference".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
extensions: vec![
"health".to_string(),
"models".to_string(),
"metrics".to_string(),
],
};
Ok((HeaderMap::new(), Json(data)).into_response())
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}",
responses(
(status = 200, description = "Model version metadata retrieved", body = MetadataServerResponse),
(status = 404, description = "Model or version not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_model_metadata(
Path((model_name, model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = MetadataServerResponse {
name: model_name,
version: model_version,
extensions: vec!["infer".to_string(), "ready".to_string()],
};
Ok((HeaderMap::new(), Json(data)).into_response())
}
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/infer",
request_body = Json<InferenceRequest>,
responses(
(status = 200, description = "Inference executed successfully", body = InferenceOutput),
(status = 404, description = "Model or version not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_model_infer(
infer: Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(payload): Json<InferenceRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let id = payload.id.clone();
let str_inputs = payload
.inputs
.iter()
.map(|input| {
std::str::from_utf8(&input.data).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "utf8".to_string(),
}),
)
})
})
.collect::<Result<Vec<_>, _>>()?;
if str_inputs.len() != payload.outputs.len() {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Inputs and outputs length mismatch".to_string(),
error_type: "length mismatch".to_string(),
}),
));
}
let output_chunks = str_inputs
.iter()
.zip(&payload.outputs)
.map(|(str_input, output)| {
let generate_request = GenerateRequest {
inputs: str_input.to_string(),
parameters: payload.parameters.clone(),
};
let infer = infer.clone();
let compute_type = compute_type.clone();
let span = tracing::Span::current();
async move {
generate_internal(infer, compute_type, Json(generate_request), span)
.await
.map(|(_, Json(generation))| {
let generation_as_bytes = generation.generated_text.as_bytes().to_vec();
OutputChunk {
name: output.name.clone(),
shape: vec![1, generation_as_bytes.len()],
datatype: "BYTES".to_string(),
data: generation_as_bytes,
}
})
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
}
})
.collect::<FuturesUnordered<_>>()
.try_collect::<Vec<_>>()
.await?;
let inference_output = InferenceOutput {
id: id.clone(),
outputs: output_chunks,
};
Ok((HeaderMap::new(), Json(inference_output)).into_response())
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/ready",
responses(
(status = 200, description = "Model version is ready", body = ReadyResponse),
(status = 404, description = "Model or version not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_model_metadata_ready(
Path((_model_name, _model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = ReadyResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response())
}

View File

@ -4,6 +4,9 @@ mod infer;
pub mod server; pub mod server;
mod validation; mod validation;
#[cfg(feature = "kserve")]
mod kserve;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::warn; use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;

View File

@ -4,6 +4,11 @@ use crate::infer::v2::SchedulerV2;
use crate::infer::v3::SchedulerV3; use crate::infer::v3::SchedulerV3;
use crate::infer::{HealthCheck, Scheduler}; use crate::infer::{HealthCheck, Scheduler};
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
#[cfg(feature = "kserve")]
use crate::kserve::{
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
@ -172,7 +177,7 @@ async fn generate(
generate_internal(infer, ComputeType(compute_type), Json(req), span).await generate_internal(infer, ComputeType(compute_type), Json(req), span).await
} }
async fn generate_internal( pub(crate) async fn generate_internal(
infer: Extension<Infer>, infer: Extension<Infer>,
ComputeType(compute_type): ComputeType, ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
@ -1727,28 +1732,58 @@ pub async fn run(
docker_label: option_env!("DOCKER_LABEL"), docker_label: option_env!("DOCKER_LABEL"),
}; };
// Define VertextApiDoc conditionally only if the "google" feature is enabled #[allow(unused_mut)] // mut is needed for conditional compilation
let doc = { let mut doc = ApiDoc::openapi();
// avoid `mut` if possible
#[cfg(feature = "google")]
{
use crate::VertexInstance;
#[derive(OpenApi)] #[cfg(feature = "google")]
#[openapi( {
paths(vertex_compatibility), use crate::VertexInstance;
components(schemas(VertexInstance, VertexRequest, VertexResponse))
)]
struct VertextApiDoc;
// limiting mutability to the smallest scope necessary #[derive(OpenApi)]
let mut doc = ApiDoc::openapi(); #[openapi(
doc.merge(VertextApiDoc::openapi()); paths(vertex_compatibility),
doc components(schemas(VertexInstance, VertexRequest, VertexResponse))
} )]
#[cfg(not(feature = "google"))] struct VertexApiDoc;
ApiDoc::openapi()
}; doc.merge(VertexApiDoc::openapi());
}
#[cfg(feature = "kserve")]
{
use crate::kserve::{
InferenceOutput, InferenceRequest, LiveResponse, MetadataServerResponse, OutputChunk,
ReadyResponse,
};
use crate::kserve::{
__path_kerve_server_metadata, __path_kserve_health_live, __path_kserve_health_ready,
__path_kserve_model_infer, __path_kserve_model_metadata,
__path_kserve_model_metadata_ready,
};
#[derive(OpenApi)]
#[openapi(
paths(
kserve_model_infer,
kserve_health_live,
kserve_health_ready,
kerve_server_metadata,
kserve_model_metadata,
kserve_model_metadata_ready,
),
components(schemas(
InferenceOutput,
InferenceRequest,
LiveResponse,
MetadataServerResponse,
OutputChunk,
ReadyResponse,
))
)]
struct KServeApiDoc;
doc.merge(KServeApiDoc::openapi());
}
// Configure Swagger UI // Configure Swagger UI
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
@ -1798,6 +1833,27 @@ pub async fn run(
} }
} }
#[cfg(feature = "kserve")]
{
tracing::info!("Built with `kserve` feature");
app = app
.route(
"/v2/models/:model_name/versions/:model_version/infer",
post(kserve_model_infer),
)
.route(
"/v2/models/:model_name/versions/:model_version",
get(kserve_model_metadata),
)
.route("/v2/health/ready", get(kserve_health_ready))
.route("/v2/health/live", get(kserve_health_live))
.route("/v2", get(kerve_server_metadata))
.route(
"/v2/models/:model_name/versions/:model_version/ready",
get(kserve_model_metadata_ready),
);
}
// add layers after routes // add layers after routes
app = app app = app
.layer(Extension(info)) .layer(Extension(info))