From 0424dabb0179a0b6b76186244d716cf43e034cfd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 29 Jan 2024 11:20:08 +0100 Subject: [PATCH] Sending compute type from the environment instead of hardcoded string (#1504) # What does this PR do? Sending compute type from the environment instead of hardcoded string Using env is slow, therefore getting it from global state instead. Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- router/src/server.rs | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 998d6265..39d1de38 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -57,6 +57,7 @@ example = json ! ({"error": "Incomplete generation"})), async fn compat_generate( Extension(default_return_full_text): Extension, infer: Extension, + compute_type: Extension, Json(mut req): Json, ) -> Result)> { // default return_full_text given the pipeline_tag @@ -66,11 +67,11 @@ async fn compat_generate( // switch on stream if req.stream { - Ok(generate_stream(infer, Json(req.into())) + Ok(generate_stream(infer,compute_type, Json(req.into())) .await .into_response()) } else { - let (headers, Json(generation)) = generate(infer, Json(req.into())).await?; + let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(vec![generation])).into_response()) } @@ -145,6 +146,7 @@ seed, )] async fn generate( infer: Extension, + Extension(ComputeType(compute_type)): Extension, Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); @@ -230,7 +232,7 @@ async fn generate( // Headers let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert("x-compute-type", compute_type.parse().unwrap()); headers.insert( "x-compute-time", total_time.as_millis().to_string().parse().unwrap(), @@ -339,6 +341,7 @@ seed, )] async fn generate_stream( Extension(infer): Extension, + Extension(compute_type): Extension, Json(req): Json, ) -> ( HeaderMap, @@ -349,13 +352,14 @@ async fn generate_stream( event.json_data(stream_token).unwrap() }; let (headers, response_stream) = - generate_stream_internal(infer, Json(req), on_message_callback).await; + generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); (headers, sse) } async fn generate_stream_internal( infer: Infer, + ComputeType(compute_type): ComputeType, Json(req): Json, on_message_callback: impl Fn(StreamResponse) -> Event, ) -> (HeaderMap, impl Stream>) { @@ -368,7 +372,7 @@ async fn generate_stream_internal( let compute_characters = req.inputs.chars().count(); let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert("x-compute-type",compute_type.parse().unwrap()); headers.insert( "x-compute-characters", compute_characters.to_string().parse().unwrap(), @@ -557,6 +561,7 @@ async fn generate_stream_internal( )] async fn chat_completions( Extension(infer): Extension, + Extension(compute_type): Extension, Extension(info): Extension, Json(req): Json, ) -> Result)> { @@ -645,12 +650,12 @@ async fn chat_completions( }; let (headers, response_stream) = - generate_stream_internal(infer, Json(generate_request), on_message_callback).await; + generate_stream_internal(infer, compute_type, Json(generate_request), on_message_callback).await; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { let (headers, Json(generation)) = - generate(Extension(infer), Json(generate_request)).await?; + generate(Extension(infer), Extension(compute_type), Json(generate_request)).await?; let current_time = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -729,6 +734,9 @@ async fn metrics(prom_handle: Extension) -> String { prom_handle.render() } +#[derive(Clone, Debug)] +pub(crate) struct ComputeType(String); + /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( @@ -935,6 +943,8 @@ pub async fn run( Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise }; + let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); + // Combine routes and layers let app = Router::new() .merge(swagger_ui) @@ -944,6 +954,7 @@ pub async fn run( .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text)) .layer(Extension(infer)) + .layer(Extension(compute_type)) .layer(Extension(prom_handle.clone())) .layer(OtelAxumLayer::default()) .layer(cors_layer);