Use destructuring in router arguments to avoid '.0' (#798)
# What does this PR do? This is purely code style - not anything important. Instead of writing `req.0` all over we can use [descructuring](https://doc.rust-lang.org/rust-by-example/flow_control/match/destructuring/destructure_structures.html) to access the contained value that we actually want. (Destructuring in function parameters [here](https://doc.rust-lang.org/reference/items/functions.html#function-parameters)) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @OlivierDehaene
This commit is contained in:
parent
647ae7a7d3
commit
8bdb16ee9a
|
@ -54,15 +54,13 @@ example = json ! ({"error": "Incomplete generation"})),
|
|||
)]
|
||||
#[instrument(skip(infer, req))]
|
||||
async fn compat_generate(
|
||||
default_return_full_text: Extension<bool>,
|
||||
Extension(default_return_full_text): Extension<bool>,
|
||||
infer: Extension<Infer>,
|
||||
req: Json<CompatGenerateRequest>,
|
||||
Json(mut req): Json<CompatGenerateRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let mut req = req.0;
|
||||
|
||||
// default return_full_text given the pipeline_tag
|
||||
if req.parameters.return_full_text.is_none() {
|
||||
req.parameters.return_full_text = Some(default_return_full_text.0)
|
||||
req.parameters.return_full_text = Some(default_return_full_text)
|
||||
}
|
||||
|
||||
// switch on stream
|
||||
|
@ -71,9 +69,9 @@ async fn compat_generate(
|
|||
.await
|
||||
.into_response())
|
||||
} else {
|
||||
let (headers, generation) = generate(infer, Json(req.into())).await?;
|
||||
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?;
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
Ok((headers, Json(vec![generation.0])).into_response())
|
||||
Ok((headers, Json(vec![generation])).into_response())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -135,7 +133,7 @@ example = json ! ({"error": "Incomplete generation"})),
|
|||
#[instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
parameters = ? req.0.parameters,
|
||||
parameters = ? req.parameters,
|
||||
total_time,
|
||||
validation_time,
|
||||
queue_time,
|
||||
|
@ -146,29 +144,29 @@ seed,
|
|||
)]
|
||||
async fn generate(
|
||||
infer: Extension<Infer>,
|
||||
req: Json<GenerateRequest>,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
let start_time = Instant::now();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
tracing::debug!("Input: {}", req.0.inputs);
|
||||
tracing::debug!("Input: {}", req.inputs);
|
||||
|
||||
let compute_characters = req.0.inputs.chars().count();
|
||||
let compute_characters = req.inputs.chars().count();
|
||||
let mut add_prompt = None;
|
||||
if req.0.parameters.return_full_text.unwrap_or(false) {
|
||||
add_prompt = Some(req.0.inputs.clone());
|
||||
if req.parameters.return_full_text.unwrap_or(false) {
|
||||
add_prompt = Some(req.inputs.clone());
|
||||
}
|
||||
|
||||
let details = req.0.parameters.details || req.0.parameters.decoder_input_details;
|
||||
let details = req.parameters.details || req.parameters.decoder_input_details;
|
||||
|
||||
// Inference
|
||||
let (response, best_of_responses) = match req.0.parameters.best_of {
|
||||
let (response, best_of_responses) = match req.parameters.best_of {
|
||||
Some(best_of) if best_of > 1 => {
|
||||
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
|
||||
let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?;
|
||||
(response, Some(best_of_responses))
|
||||
}
|
||||
_ => (infer.generate(req.0).await?, None),
|
||||
_ => (infer.generate(req).await?, None),
|
||||
};
|
||||
|
||||
// Token details
|
||||
|
@ -321,7 +319,7 @@ content_type = "text/event-stream"),
|
|||
#[instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
parameters = ? req.0.parameters,
|
||||
parameters = ? req.parameters,
|
||||
total_time,
|
||||
validation_time,
|
||||
queue_time,
|
||||
|
@ -331,8 +329,8 @@ seed,
|
|||
)
|
||||
)]
|
||||
async fn generate_stream(
|
||||
infer: Extension<Infer>,
|
||||
req: Json<GenerateRequest>,
|
||||
Extension(infer): Extension<Infer>,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
) -> (
|
||||
HeaderMap,
|
||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||
|
@ -341,9 +339,9 @@ async fn generate_stream(
|
|||
let start_time = Instant::now();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
tracing::debug!("Input: {}", req.0.inputs);
|
||||
tracing::debug!("Input: {}", req.inputs);
|
||||
|
||||
let compute_characters = req.0.inputs.chars().count();
|
||||
let compute_characters = req.inputs.chars().count();
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
||||
|
@ -359,24 +357,24 @@ async fn generate_stream(
|
|||
let mut error = false;
|
||||
|
||||
let mut add_prompt = None;
|
||||
if req.0.parameters.return_full_text.unwrap_or(false) {
|
||||
add_prompt = Some(req.0.inputs.clone());
|
||||
if req.parameters.return_full_text.unwrap_or(false) {
|
||||
add_prompt = Some(req.inputs.clone());
|
||||
}
|
||||
let details = req.0.parameters.details;
|
||||
let details = req.parameters.details;
|
||||
|
||||
let best_of = req.0.parameters.best_of.unwrap_or(1);
|
||||
let best_of = req.parameters.best_of.unwrap_or(1);
|
||||
if best_of != 1 {
|
||||
let err = InferError::from(ValidationError::BestOfStream);
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
} else if req.0.parameters.decoder_input_details {
|
||||
} else if req.parameters.decoder_input_details {
|
||||
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
} else {
|
||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, mut response_stream)) => {
|
||||
// Server-Sent Event stream
|
||||
|
|
Loading…
Reference in New Issue