Use destructuring in router arguments to avoid '.0' (#798)

# What does this PR do?

This is purely code style - not anything important.
Instead of writing `req.0` all over we can use
[descructuring](https://doc.rust-lang.org/rust-by-example/flow_control/match/destructuring/destructure_structures.html)
to access the contained value that we actually want.

(Destructuring in function parameters
[here](https://doc.rust-lang.org/reference/items/functions.html#function-parameters))

## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

@OlivierDehaene
This commit is contained in:
ivarflakstad 2023-08-10 10:52:50 +02:00 committed by GitHub
parent 647ae7a7d3
commit 8bdb16ee9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 28 deletions

View File

@ -54,15 +54,13 @@ example = json ! ({"error": "Incomplete generation"})),
)] )]
#[instrument(skip(infer, req))] #[instrument(skip(infer, req))]
async fn compat_generate( async fn compat_generate(
default_return_full_text: Extension<bool>, Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>, infer: Extension<Infer>,
req: Json<CompatGenerateRequest>, Json(mut req): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let mut req = req.0;
// default return_full_text given the pipeline_tag // default return_full_text given the pipeline_tag
if req.parameters.return_full_text.is_none() { if req.parameters.return_full_text.is_none() {
req.parameters.return_full_text = Some(default_return_full_text.0) req.parameters.return_full_text = Some(default_return_full_text)
} }
// switch on stream // switch on stream
@ -71,9 +69,9 @@ async fn compat_generate(
.await .await
.into_response()) .into_response())
} else { } else {
let (headers, generation) = generate(infer, Json(req.into())).await?; let (headers, Json(generation)) = generate(infer, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation.0])).into_response()) Ok((headers, Json(vec![generation])).into_response())
} }
} }
@ -135,7 +133,7 @@ example = json ! ({"error": "Incomplete generation"})),
#[instrument( #[instrument(
skip_all, skip_all,
fields( fields(
parameters = ? req.0.parameters, parameters = ? req.parameters,
total_time, total_time,
validation_time, validation_time,
queue_time, queue_time,
@ -146,29 +144,29 @@ seed,
)] )]
async fn generate( async fn generate(
infer: Extension<Infer>, infer: Extension<Infer>,
req: Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
tracing::debug!("Input: {}", req.0.inputs); tracing::debug!("Input: {}", req.inputs);
let compute_characters = req.0.inputs.chars().count(); let compute_characters = req.inputs.chars().count();
let mut add_prompt = None; let mut add_prompt = None;
if req.0.parameters.return_full_text.unwrap_or(false) { if req.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.0.inputs.clone()); add_prompt = Some(req.inputs.clone());
} }
let details = req.0.parameters.details || req.0.parameters.decoder_input_details; let details = req.parameters.details || req.parameters.decoder_input_details;
// Inference // Inference
let (response, best_of_responses) = match req.0.parameters.best_of { let (response, best_of_responses) = match req.parameters.best_of {
Some(best_of) if best_of > 1 => { Some(best_of) if best_of > 1 => {
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?; let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?;
(response, Some(best_of_responses)) (response, Some(best_of_responses))
} }
_ => (infer.generate(req.0).await?, None), _ => (infer.generate(req).await?, None),
}; };
// Token details // Token details
@ -321,7 +319,7 @@ content_type = "text/event-stream"),
#[instrument( #[instrument(
skip_all, skip_all,
fields( fields(
parameters = ? req.0.parameters, parameters = ? req.parameters,
total_time, total_time,
validation_time, validation_time,
queue_time, queue_time,
@ -331,8 +329,8 @@ seed,
) )
)] )]
async fn generate_stream( async fn generate_stream(
infer: Extension<Infer>, Extension(infer): Extension<Infer>,
req: Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> ( ) -> (
HeaderMap, HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>, Sse<impl Stream<Item = Result<Event, Infallible>>>,
@ -341,9 +339,9 @@ async fn generate_stream(
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
tracing::debug!("Input: {}", req.0.inputs); tracing::debug!("Input: {}", req.inputs);
let compute_characters = req.0.inputs.chars().count(); let compute_characters = req.inputs.chars().count();
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
@ -359,24 +357,24 @@ async fn generate_stream(
let mut error = false; let mut error = false;
let mut add_prompt = None; let mut add_prompt = None;
if req.0.parameters.return_full_text.unwrap_or(false) { if req.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.0.inputs.clone()); add_prompt = Some(req.inputs.clone());
} }
let details = req.0.parameters.details; let details = req.parameters.details;
let best_of = req.0.parameters.best_of.unwrap_or(1); let best_of = req.parameters.best_of.unwrap_or(1);
if best_of != 1 { if best_of != 1 {
let err = InferError::from(ValidationError::BestOfStream); let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} else if req.0.parameters.decoder_input_details { } else if req.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream); let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} else { } else {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => { Ok((_permit, mut response_stream)) => {
// Server-Sent Event stream // Server-Sent Event stream