feat: conditionally toggle chat on invocations route (#1454)
This PR adds support for reading the `OAI_ENABLED` env var which will changes the function called when the `/invocations` is called. If `OAI_ENABLED=true` the `chat_completions` method is used otherwise it defaults to `compat_generate`. example running the router ```bash OAI_ENABLED=true \ cargo run -- \ --tokenizer-name mistralai/Mistral-7B-Instruct-v0.2 ``` example request ```bash curl localhost:3000/invocations \ -X POST \ -d '{ "model": "tgi", "messages": [ { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": false, "max_tokens": 20, "logprobs": true, "seed": 0 }' \ -H 'Content-Type: application/json' | jq ``` **please let me know if any naming changes are needed or if any other routes need similar functionality.
This commit is contained in:
parent
becd09978c
commit
98e5faff9d
|
@ -71,6 +71,8 @@ struct Args {
|
||||||
ngrok_authtoken: Option<String>,
|
ngrok_authtoken: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
chat_enabled_api: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
@ -102,6 +104,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
|
chat_enabled_api,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
|
@ -345,6 +348,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
|
chat_enabled_api,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -708,6 +708,7 @@ pub async fn run(
|
||||||
ngrok_authtoken: Option<String>,
|
ngrok_authtoken: Option<String>,
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
|
chat_enabled_api: bool,
|
||||||
) -> Result<(), axum::BoxError> {
|
) -> Result<(), axum::BoxError> {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
|
@ -856,25 +857,32 @@ pub async fn run(
|
||||||
docker_label: option_env!("DOCKER_LABEL"),
|
docker_label: option_env!("DOCKER_LABEL"),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create router
|
// Configure Swagger UI
|
||||||
let app = Router::new()
|
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi());
|
||||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
|
||||||
// Base routes
|
// Define base and health routes
|
||||||
|
let base_routes = Router::new()
|
||||||
.route("/", post(compat_generate))
|
.route("/", post(compat_generate))
|
||||||
.route("/info", get(get_model_info))
|
.route("/info", get(get_model_info))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
// AWS Sagemaker route
|
|
||||||
.route("/invocations", post(compat_generate))
|
|
||||||
// Base Health route
|
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
// Inference API health route
|
|
||||||
.route("/", get(health))
|
|
||||||
// AWS Sagemaker health route
|
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health))
|
||||||
// Prometheus metrics route
|
.route("/metrics", get(metrics));
|
||||||
.route("/metrics", get(metrics))
|
|
||||||
|
// Conditional AWS Sagemaker route
|
||||||
|
let aws_sagemaker_route = if chat_enabled_api {
|
||||||
|
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
|
||||||
|
} else {
|
||||||
|
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
||||||
|
};
|
||||||
|
|
||||||
|
// Combine routes and layers
|
||||||
|
let app = Router::new()
|
||||||
|
.merge(swagger_ui)
|
||||||
|
.merge(base_routes)
|
||||||
|
.merge(aws_sagemaker_route)
|
||||||
.layer(Extension(info))
|
.layer(Extension(info))
|
||||||
.layer(Extension(health_ext.clone()))
|
.layer(Extension(health_ext.clone()))
|
||||||
.layer(Extension(compat_return_full_text))
|
.layer(Extension(compat_return_full_text))
|
||||||
|
|
Loading…
Reference in New Issue