feat: conditionally toggle chat on invocations route (#1454)

This PR adds support for reading the `OAI_ENABLED` env var which will
changes the function called when the `/invocations` is called.

If `OAI_ENABLED=true` the `chat_completions` method is used otherwise it
defaults to `compat_generate`.

example running the router
```bash
OAI_ENABLED=true \
  cargo run -- \
  --tokenizer-name mistralai/Mistral-7B-Instruct-v0.2
```

example request
```bash
curl localhost:3000/invocations \
    -X POST \
    -d '{ "model": "tgi", "messages": [ { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": false, "max_tokens": 20, "logprobs": true, "seed": 0 }' \
    -H 'Content-Type: application/json' | jq 
```

**please let me know if any naming changes are needed or if any other
routes need similar functionality.
This commit is contained in:
drbh 2024-01-22 10:29:01 -05:00 committed by GitHub
parent becd09978c
commit 98e5faff9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 12 deletions

View File

@ -71,6 +71,8 @@ struct Args {
ngrok_authtoken: Option<String>, ngrok_authtoken: Option<String>,
#[clap(long, env)] #[clap(long, env)]
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
chat_enabled_api: bool,
} }
#[tokio::main] #[tokio::main]
@ -102,6 +104,7 @@ async fn main() -> Result<(), RouterError> {
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
chat_enabled_api,
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
@ -345,6 +348,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
tokenizer_config, tokenizer_config,
chat_enabled_api,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -708,6 +708,7 @@ pub async fn run(
ngrok_authtoken: Option<String>, ngrok_authtoken: Option<String>,
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
chat_enabled_api: bool,
) -> Result<(), axum::BoxError> { ) -> Result<(), axum::BoxError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -856,25 +857,32 @@ pub async fn run(
docker_label: option_env!("DOCKER_LABEL"), docker_label: option_env!("DOCKER_LABEL"),
}; };
// Create router // Configure Swagger UI
let app = Router::new() let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi());
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
// Base routes // Define base and health routes
let base_routes = Router::new()
.route("/", post(compat_generate)) .route("/", post(compat_generate))
.route("/info", get(get_model_info)) .route("/info", get(get_model_info))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions)) .route("/v1/chat/completions", post(chat_completions))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
// Base Health route
.route("/health", get(health)) .route("/health", get(health))
// Inference API health route
.route("/", get(health))
// AWS Sagemaker health route
.route("/ping", get(health)) .route("/ping", get(health))
// Prometheus metrics route .route("/metrics", get(metrics));
.route("/metrics", get(metrics))
// Conditional AWS Sagemaker route
let aws_sagemaker_route = if chat_enabled_api {
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
} else {
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
};
// Combine routes and layers
let app = Router::new()
.merge(swagger_ui)
.merge(base_routes)
.merge(aws_sagemaker_route)
.layer(Extension(info)) .layer(Extension(info))
.layer(Extension(health_ext.clone())) .layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text)) .layer(Extension(compat_return_full_text))