feat: conditionally toggle chat on invocations route (#1454)
This PR adds support for reading the `OAI_ENABLED` env var which will changes the function called when the `/invocations` is called. If `OAI_ENABLED=true` the `chat_completions` method is used otherwise it defaults to `compat_generate`. example running the router ```bash OAI_ENABLED=true \ cargo run -- \ --tokenizer-name mistralai/Mistral-7B-Instruct-v0.2 ``` example request ```bash curl localhost:3000/invocations \ -X POST \ -d '{ "model": "tgi", "messages": [ { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": false, "max_tokens": 20, "logprobs": true, "seed": 0 }' \ -H 'Content-Type: application/json' | jq ``` **please let me know if any naming changes are needed or if any other routes need similar functionality.
This commit is contained in:
parent
becd09978c
commit
98e5faff9d
|
@ -71,6 +71,8 @@ struct Args {
|
|||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_edge: Option<String>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
chat_enabled_api: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
|
@ -102,6 +104,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
chat_enabled_api,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
|
@ -345,6 +348,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
tokenizer_config,
|
||||
chat_enabled_api,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
|
|
@ -708,6 +708,7 @@ pub async fn run(
|
|||
ngrok_authtoken: Option<String>,
|
||||
ngrok_edge: Option<String>,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
chat_enabled_api: bool,
|
||||
) -> Result<(), axum::BoxError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
|
@ -856,25 +857,32 @@ pub async fn run(
|
|||
docker_label: option_env!("DOCKER_LABEL"),
|
||||
};
|
||||
|
||||
// Create router
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||
// Base routes
|
||||
// Configure Swagger UI
|
||||
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi());
|
||||
|
||||
// Define base and health routes
|
||||
let base_routes = Router::new()
|
||||
.route("/", post(compat_generate))
|
||||
.route("/info", get(get_model_info))
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
// AWS Sagemaker route
|
||||
.route("/invocations", post(compat_generate))
|
||||
// Base Health route
|
||||
.route("/health", get(health))
|
||||
// Inference API health route
|
||||
.route("/", get(health))
|
||||
// AWS Sagemaker health route
|
||||
.route("/ping", get(health))
|
||||
// Prometheus metrics route
|
||||
.route("/metrics", get(metrics))
|
||||
.route("/metrics", get(metrics));
|
||||
|
||||
// Conditional AWS Sagemaker route
|
||||
let aws_sagemaker_route = if chat_enabled_api {
|
||||
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
|
||||
} else {
|
||||
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
||||
};
|
||||
|
||||
// Combine routes and layers
|
||||
let app = Router::new()
|
||||
.merge(swagger_ui)
|
||||
.merge(base_routes)
|
||||
.merge(aws_sagemaker_route)
|
||||
.layer(Extension(info))
|
||||
.layer(Extension(health_ext.clone()))
|
||||
.layer(Extension(compat_return_full_text))
|
||||
|
|
Loading…
Reference in New Issue