Use destructuring in router arguments to avoid '.0' (#798)
# What does this PR do? This is purely code style - not anything important. Instead of writing `req.0` all over we can use [descructuring](https://doc.rust-lang.org/rust-by-example/flow_control/match/destructuring/destructure_structures.html) to access the contained value that we actually want. (Destructuring in function parameters [here](https://doc.rust-lang.org/reference/items/functions.html#function-parameters)) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @OlivierDehaene
This commit is contained in:
parent
647ae7a7d3
commit
8bdb16ee9a
|
@ -54,15 +54,13 @@ example = json ! ({"error": "Incomplete generation"})),
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(infer, req))]
|
#[instrument(skip(infer, req))]
|
||||||
async fn compat_generate(
|
async fn compat_generate(
|
||||||
default_return_full_text: Extension<bool>,
|
Extension(default_return_full_text): Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
req: Json<CompatGenerateRequest>,
|
Json(mut req): Json<CompatGenerateRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let mut req = req.0;
|
|
||||||
|
|
||||||
// default return_full_text given the pipeline_tag
|
// default return_full_text given the pipeline_tag
|
||||||
if req.parameters.return_full_text.is_none() {
|
if req.parameters.return_full_text.is_none() {
|
||||||
req.parameters.return_full_text = Some(default_return_full_text.0)
|
req.parameters.return_full_text = Some(default_return_full_text)
|
||||||
}
|
}
|
||||||
|
|
||||||
// switch on stream
|
// switch on stream
|
||||||
|
@ -71,9 +69,9 @@ async fn compat_generate(
|
||||||
.await
|
.await
|
||||||
.into_response())
|
.into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, generation) = generate(infer, Json(req.into())).await?;
|
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?;
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
Ok((headers, Json(vec![generation.0])).into_response())
|
Ok((headers, Json(vec![generation])).into_response())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,7 +133,7 @@ example = json ! ({"error": "Incomplete generation"})),
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
parameters = ? req.0.parameters,
|
parameters = ? req.parameters,
|
||||||
total_time,
|
total_time,
|
||||||
validation_time,
|
validation_time,
|
||||||
queue_time,
|
queue_time,
|
||||||
|
@ -146,29 +144,29 @@ seed,
|
||||||
)]
|
)]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
req: Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
tracing::debug!("Input: {}", req.0.inputs);
|
tracing::debug!("Input: {}", req.inputs);
|
||||||
|
|
||||||
let compute_characters = req.0.inputs.chars().count();
|
let compute_characters = req.inputs.chars().count();
|
||||||
let mut add_prompt = None;
|
let mut add_prompt = None;
|
||||||
if req.0.parameters.return_full_text.unwrap_or(false) {
|
if req.parameters.return_full_text.unwrap_or(false) {
|
||||||
add_prompt = Some(req.0.inputs.clone());
|
add_prompt = Some(req.inputs.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
let details = req.0.parameters.details || req.0.parameters.decoder_input_details;
|
let details = req.parameters.details || req.parameters.decoder_input_details;
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
let (response, best_of_responses) = match req.0.parameters.best_of {
|
let (response, best_of_responses) = match req.parameters.best_of {
|
||||||
Some(best_of) if best_of > 1 => {
|
Some(best_of) if best_of > 1 => {
|
||||||
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
|
let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?;
|
||||||
(response, Some(best_of_responses))
|
(response, Some(best_of_responses))
|
||||||
}
|
}
|
||||||
_ => (infer.generate(req.0).await?, None),
|
_ => (infer.generate(req).await?, None),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Token details
|
// Token details
|
||||||
|
@ -321,7 +319,7 @@ content_type = "text/event-stream"),
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
parameters = ? req.0.parameters,
|
parameters = ? req.parameters,
|
||||||
total_time,
|
total_time,
|
||||||
validation_time,
|
validation_time,
|
||||||
queue_time,
|
queue_time,
|
||||||
|
@ -331,8 +329,8 @@ seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
infer: Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
req: Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> (
|
) -> (
|
||||||
HeaderMap,
|
HeaderMap,
|
||||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||||
|
@ -341,9 +339,9 @@ async fn generate_stream(
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
tracing::debug!("Input: {}", req.0.inputs);
|
tracing::debug!("Input: {}", req.inputs);
|
||||||
|
|
||||||
let compute_characters = req.0.inputs.chars().count();
|
let compute_characters = req.inputs.chars().count();
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
||||||
|
@ -359,24 +357,24 @@ async fn generate_stream(
|
||||||
let mut error = false;
|
let mut error = false;
|
||||||
|
|
||||||
let mut add_prompt = None;
|
let mut add_prompt = None;
|
||||||
if req.0.parameters.return_full_text.unwrap_or(false) {
|
if req.parameters.return_full_text.unwrap_or(false) {
|
||||||
add_prompt = Some(req.0.inputs.clone());
|
add_prompt = Some(req.inputs.clone());
|
||||||
}
|
}
|
||||||
let details = req.0.parameters.details;
|
let details = req.parameters.details;
|
||||||
|
|
||||||
let best_of = req.0.parameters.best_of.unwrap_or(1);
|
let best_of = req.parameters.best_of.unwrap_or(1);
|
||||||
if best_of != 1 {
|
if best_of != 1 {
|
||||||
let err = InferError::from(ValidationError::BestOfStream);
|
let err = InferError::from(ValidationError::BestOfStream);
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
} else if req.0.parameters.decoder_input_details {
|
} else if req.parameters.decoder_input_details {
|
||||||
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
} else {
|
} else {
|
||||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
// Keep permit as long as generate_stream lives
|
// Keep permit as long as generate_stream lives
|
||||||
Ok((_permit, mut response_stream)) => {
|
Ok((_permit, mut response_stream)) => {
|
||||||
// Server-Sent Event stream
|
// Server-Sent Event stream
|
||||||
|
|
Loading…
Reference in New Issue