fix: improve header init and error handling
This commit is contained in:
parent
6e9fa460ba
commit
ac74ec67c5
|
@ -640,9 +640,9 @@ async fn completions(
|
|||
})
|
||||
.collect();
|
||||
|
||||
let mut x_compute_type = "unknown".to_string();
|
||||
let mut x_compute_type = None;
|
||||
let mut x_compute_characters = 0u32;
|
||||
let mut x_accel_buffering = "no".to_string();
|
||||
let mut x_accel_buffering = None;
|
||||
|
||||
if stream {
|
||||
let mut response_streams = FuturesOrdered::new();
|
||||
|
@ -704,29 +704,28 @@ async fn completions(
|
|||
}
|
||||
});
|
||||
|
||||
(index, header_rx, sse_rx)
|
||||
(header_rx, sse_rx)
|
||||
};
|
||||
response_streams.push_back(generate_future);
|
||||
}
|
||||
|
||||
let mut all_rxs = vec![];
|
||||
|
||||
while let Some((index, header_rx, sse_rx)) = response_streams.next().await {
|
||||
while let Some((header_rx, sse_rx)) = response_streams.next().await {
|
||||
all_rxs.push(sse_rx);
|
||||
|
||||
// get the headers from the first response of each stream
|
||||
let headers = header_rx.await.expect("Failed to get headers");
|
||||
if index == 0 {
|
||||
if x_compute_type.is_none() {
|
||||
x_compute_type = headers
|
||||
.get("x-compute-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
.map(|v| v.to_string());
|
||||
|
||||
x_accel_buffering = headers
|
||||
.get("x-accel-buffering")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("no")
|
||||
.to_string();
|
||||
.map(|v| v.to_string());
|
||||
}
|
||||
x_compute_characters += headers
|
||||
.get("x-compute-characters")
|
||||
|
@ -736,9 +735,13 @@ async fn completions(
|
|||
}
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||
if let Some(x_compute_type) = x_compute_type {
|
||||
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||
}
|
||||
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||
if let Some(x_accel_buffering) = x_accel_buffering {
|
||||
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||
}
|
||||
|
||||
// now sink the sse streams into a single stream and remove the ones that are done
|
||||
let stream: AsyncStream<Result<Event, Infallible>, _> = async_stream::stream! {
|
||||
|
@ -805,13 +808,20 @@ async fn completions(
|
|||
let choices = generate_responses
|
||||
.into_iter()
|
||||
.map(|(index, headers, Json(generation))| {
|
||||
let details = generation.details.unwrap_or_default();
|
||||
if index == 0 {
|
||||
let details = generation.details.ok_or((
|
||||
// this should never happen but handle if details are missing unexpectedly
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "No details in generation".to_string(),
|
||||
error_type: "no details".to_string(),
|
||||
}),
|
||||
))?;
|
||||
|
||||
if x_compute_type.is_none() {
|
||||
x_compute_type = headers
|
||||
.get("x-compute-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
.map(|v| v.to_string());
|
||||
}
|
||||
|
||||
// accumulate headers and usage from each response
|
||||
|
@ -856,14 +866,15 @@ async fn completions(
|
|||
completion_tokens += details.generated_tokens;
|
||||
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
|
||||
|
||||
CompletionComplete {
|
||||
Ok(CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: generation.generated_text,
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|(status, Json(err))| (status, Json(err)))?;
|
||||
|
||||
let response = Completion {
|
||||
id: "".to_string(),
|
||||
|
@ -885,7 +896,9 @@ async fn completions(
|
|||
|
||||
// headers similar to `generate` but aggregated
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||
if let Some(x_compute_type) = x_compute_type {
|
||||
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||
}
|
||||
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||
headers.insert("x-total-time", x_total_time.into());
|
||||
headers.insert("x-validation-time", x_validation_time.into());
|
||||
|
@ -894,8 +907,9 @@ async fn completions(
|
|||
headers.insert("x-time-per-token", x_time_per_token.into());
|
||||
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
|
||||
headers.insert("x-generated-tokens", x_generated_tokens.into());
|
||||
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||
|
||||
if let Some(x_accel_buffering) = x_accel_buffering {
|
||||
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||
}
|
||||
Ok((headers, Json(response)).into_response())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue