fix: improve header init and error handling

This commit is contained in:
drbh 2024-04-11 21:18:14 +00:00
parent 6e9fa460ba
commit ac74ec67c5
1 changed files with 35 additions and 21 deletions

View File

@ -640,9 +640,9 @@ async fn completions(
})
.collect();
let mut x_compute_type = "unknown".to_string();
let mut x_compute_type = None;
let mut x_compute_characters = 0u32;
let mut x_accel_buffering = "no".to_string();
let mut x_accel_buffering = None;
if stream {
let mut response_streams = FuturesOrdered::new();
@ -704,29 +704,28 @@ async fn completions(
}
});
(index, header_rx, sse_rx)
(header_rx, sse_rx)
};
response_streams.push_back(generate_future);
}
let mut all_rxs = vec![];
while let Some((index, header_rx, sse_rx)) = response_streams.next().await {
while let Some((header_rx, sse_rx)) = response_streams.next().await {
all_rxs.push(sse_rx);
// get the headers from the first response of each stream
let headers = header_rx.await.expect("Failed to get headers");
if index == 0 {
if x_compute_type.is_none() {
x_compute_type = headers
.get("x-compute-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
.map(|v| v.to_string());
x_accel_buffering = headers
.get("x-accel-buffering")
.and_then(|v| v.to_str().ok())
.unwrap_or("no")
.to_string();
.map(|v| v.to_string());
}
x_compute_characters += headers
.get("x-compute-characters")
@ -736,9 +735,13 @@ async fn completions(
}
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
if let Some(x_compute_type) = x_compute_type {
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
}
headers.insert("x-compute-characters", x_compute_characters.into());
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
if let Some(x_accel_buffering) = x_accel_buffering {
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
}
// now sink the sse streams into a single stream and remove the ones that are done
let stream: AsyncStream<Result<Event, Infallible>, _> = async_stream::stream! {
@ -805,13 +808,20 @@ async fn completions(
let choices = generate_responses
.into_iter()
.map(|(index, headers, Json(generation))| {
let details = generation.details.unwrap_or_default();
if index == 0 {
let details = generation.details.ok_or((
// this should never happen but handle if details are missing unexpectedly
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "No details in generation".to_string(),
error_type: "no details".to_string(),
}),
))?;
if x_compute_type.is_none() {
x_compute_type = headers
.get("x-compute-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
.map(|v| v.to_string());
}
// accumulate headers and usage from each response
@ -856,14 +866,15 @@ async fn completions(
completion_tokens += details.generated_tokens;
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
CompletionComplete {
Ok(CompletionComplete {
finish_reason: details.finish_reason.to_string(),
index: index as u32,
logprobs: None,
text: generation.generated_text,
}
})
})
.collect::<Vec<_>>();
.collect::<Result<Vec<_>, _>>()
.map_err(|(status, Json(err))| (status, Json(err)))?;
let response = Completion {
id: "".to_string(),
@ -885,7 +896,9 @@ async fn completions(
// headers similar to `generate` but aggregated
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
if let Some(x_compute_type) = x_compute_type {
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
}
headers.insert("x-compute-characters", x_compute_characters.into());
headers.insert("x-total-time", x_total_time.into());
headers.insert("x-validation-time", x_validation_time.into());
@ -894,8 +907,9 @@ async fn completions(
headers.insert("x-time-per-token", x_time_per_token.into());
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
headers.insert("x-generated-tokens", x_generated_tokens.into());
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
if let Some(x_accel_buffering) = x_accel_buffering {
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
}
Ok((headers, Json(response)).into_response())
}
}