fix: simplify kserve endpoint and fix imports (#2119)

This commit is contained in:
drbh 2024-06-25 19:30:10 -04:00 committed by GitHub
parent f1f98e369f
commit be2d38032a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 35 deletions

View File

@ -1,15 +1,15 @@
use crate::infer::Infer;
use crate::{ use crate::{
default_parameters, default_parameters,
server::{generate_internal, ComputeType}, server::{generate_internal, ComputeType},
Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Infer, Serialize, ToSchema, Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Serialize, ToSchema,
}; };
use axum::extract::{Extension, Path}; use axum::extract::{Extension, Path};
use axum::response::{IntoResponse, Response}; use axum::http::{HeaderMap, StatusCode};
use axum::response::IntoResponse;
use axum::Json; use axum::Json;
use futures::stream::FuturesUnordered; use futures::stream::FuturesUnordered;
use futures::TryStreamExt; use futures::TryStreamExt;
use reqwest::header::HeaderMap;
use reqwest::StatusCode;
#[derive(Debug, Serialize, Deserialize, ToSchema)] #[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct OutputChunk { pub struct OutputChunk {
@ -64,8 +64,6 @@ pub struct MetadataServerResponse {
pub extensions: Vec<String>, pub extensions: Vec<String>,
} }
// Routes
#[utoipa::path( #[utoipa::path(
post, post,
tag = "Text Generation Inference", tag = "Text Generation Inference",
@ -76,13 +74,13 @@ pub struct MetadataServerResponse {
example = json!({"error": "No response"})) example = json!({"error": "No response"}))
) )
)] )]
pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorResponse>)> { pub async fn kserve_health_live() -> Json<LiveResponse> {
let data = LiveResponse { live: true }; let data = LiveResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response()) Json(data)
} }
#[utoipa::path( #[utoipa::path(
post, get,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/v2/health/ready", path = "/v2/health/ready",
responses( responses(
@ -91,9 +89,9 @@ pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorRes
example = json!({"error": "No response"})) example = json!({"error": "No response"}))
) )
)] )]
pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorResponse>)> { pub async fn kserve_health_ready() -> Json<ReadyResponse> {
let data = ReadyResponse { live: true }; let data = ReadyResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response()) Json(data)
} }
#[utoipa::path( #[utoipa::path(
@ -106,7 +104,7 @@ pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorRe
example = json!({"error": "No response"})) example = json!({"error": "No response"}))
) )
)] )]
pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<ErrorResponse>)> { pub async fn kerve_server_metadata() -> Json<MetadataServerResponse> {
let data = MetadataServerResponse { let data = MetadataServerResponse {
name: "text-generation-inference".to_string(), name: "text-generation-inference".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(), version: env!("CARGO_PKG_VERSION").to_string(),
@ -116,7 +114,7 @@ pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<Error
"metrics".to_string(), "metrics".to_string(),
], ],
}; };
Ok((HeaderMap::new(), Json(data)).into_response()) Json(data)
} }
#[utoipa::path( #[utoipa::path(
@ -131,13 +129,30 @@ pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<Error
)] )]
pub async fn kserve_model_metadata( pub async fn kserve_model_metadata(
Path((model_name, model_version)): Path<(String, String)>, Path((model_name, model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Json<MetadataServerResponse> {
let data = MetadataServerResponse { let data = MetadataServerResponse {
name: model_name, name: model_name,
version: model_version, version: model_version,
extensions: vec!["infer".to_string(), "ready".to_string()], extensions: vec!["infer".to_string(), "ready".to_string()],
}; };
Ok((HeaderMap::new(), Json(data)).into_response()) Json(data)
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/ready",
responses(
(status = 200, description = "Model version is ready", body = ReadyResponse),
(status = 404, description = "Model or version not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_model_metadata_ready(
Path((_model_name, _model_version)): Path<(String, String)>,
) -> Json<ReadyResponse> {
let data = ReadyResponse { live: true };
Json(data)
} }
#[utoipa::path( #[utoipa::path(
@ -155,7 +170,7 @@ pub async fn kserve_model_infer(
infer: Extension<Infer>, infer: Extension<Infer>,
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Json(payload): Json<InferenceRequest>, Json(payload): Json<InferenceRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
let id = payload.id.clone(); let id = payload.id.clone();
let str_inputs = payload let str_inputs = payload
.inputs .inputs
@ -226,22 +241,5 @@ pub async fn kserve_model_infer(
outputs: output_chunks, outputs: output_chunks,
}; };
Ok((HeaderMap::new(), Json(inference_output)).into_response()) Ok((HeaderMap::new(), Json(inference_output)))
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/ready",
responses(
(status = 200, description = "Model version is ready", body = ReadyResponse),
(status = 404, description = "Model or version not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_model_metadata_ready(
Path((_model_name, _model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = ReadyResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response())
} }

View File

@ -1766,12 +1766,12 @@ pub async fn run(
#[derive(OpenApi)] #[derive(OpenApi)]
#[openapi( #[openapi(
paths( paths(
kserve_model_infer,
kserve_health_live, kserve_health_live,
kserve_health_ready, kserve_health_ready,
kerve_server_metadata, kerve_server_metadata,
kserve_model_metadata, kserve_model_metadata,
kserve_model_metadata_ready, kserve_model_metadata_ready,
kserve_model_infer,
), ),
components(schemas( components(schemas(
InferenceOutput, InferenceOutput,