Use destructuring in router arguments to avoid '.0' (#798)

# What does this PR do?

This is purely code style - not anything important.
Instead of writing `req.0` all over we can use
[descructuring](https://doc.rust-lang.org/rust-by-example/flow_control/match/destructuring/destructure_structures.html)
to access the contained value that we actually want.

(Destructuring in function parameters
[here](https://doc.rust-lang.org/reference/items/functions.html#function-parameters))

## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

@OlivierDehaene
This commit is contained in:
ivarflakstad 2023-08-10 10:52:50 +02:00 committed by GitHub
parent 647ae7a7d3
commit 8bdb16ee9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 28 deletions

View File

@ -54,15 +54,13 @@ example = json ! ({"error": "Incomplete generation"})),
)]
#[instrument(skip(infer, req))]
async fn compat_generate(
default_return_full_text: Extension<bool>,
Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>,
req: Json<CompatGenerateRequest>,
Json(mut req): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let mut req = req.0;
// default return_full_text given the pipeline_tag
if req.parameters.return_full_text.is_none() {
req.parameters.return_full_text = Some(default_return_full_text.0)
req.parameters.return_full_text = Some(default_return_full_text)
}
// switch on stream
@ -71,9 +69,9 @@ async fn compat_generate(
.await
.into_response())
} else {
let (headers, generation) = generate(infer, Json(req.into())).await?;
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation.0])).into_response())
Ok((headers, Json(vec![generation])).into_response())
}
}
@ -135,7 +133,7 @@ example = json ! ({"error": "Incomplete generation"})),
#[instrument(
skip_all,
fields(
parameters = ? req.0.parameters,
parameters = ? req.parameters,
total_time,
validation_time,
queue_time,
@ -146,29 +144,29 @@ seed,
)]
async fn generate(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count");
tracing::debug!("Input: {}", req.0.inputs);
tracing::debug!("Input: {}", req.inputs);
let compute_characters = req.0.inputs.chars().count();
let compute_characters = req.inputs.chars().count();
let mut add_prompt = None;
if req.0.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.0.inputs.clone());
if req.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.inputs.clone());
}
let details = req.0.parameters.details || req.0.parameters.decoder_input_details;
let details = req.parameters.details || req.parameters.decoder_input_details;
// Inference
let (response, best_of_responses) = match req.0.parameters.best_of {
let (response, best_of_responses) = match req.parameters.best_of {
Some(best_of) if best_of > 1 => {
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?;
(response, Some(best_of_responses))
}
_ => (infer.generate(req.0).await?, None),
_ => (infer.generate(req).await?, None),
};
// Token details
@ -321,7 +319,7 @@ content_type = "text/event-stream"),
#[instrument(
skip_all,
fields(
parameters = ? req.0.parameters,
parameters = ? req.parameters,
total_time,
validation_time,
queue_time,
@ -331,8 +329,8 @@ seed,
)
)]
async fn generate_stream(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
Extension(infer): Extension<Infer>,
Json(req): Json<GenerateRequest>,
) -> (
HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>,
@ -341,9 +339,9 @@ async fn generate_stream(
let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count");
tracing::debug!("Input: {}", req.0.inputs);
tracing::debug!("Input: {}", req.inputs);
let compute_characters = req.0.inputs.chars().count();
let compute_characters = req.inputs.chars().count();
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
@ -359,24 +357,24 @@ async fn generate_stream(
let mut error = false;
let mut add_prompt = None;
if req.0.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.0.inputs.clone());
if req.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.inputs.clone());
}
let details = req.0.parameters.details;
let details = req.parameters.details;
let best_of = req.0.parameters.best_of.unwrap_or(1);
let best_of = req.parameters.best_of.unwrap_or(1);
if best_of != 1 {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else if req.0.parameters.decoder_input_details {
} else if req.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
// Server-Sent Event stream