fix: decrease default batch, refactors and include index in batch
This commit is contained in:
parent
5c14e08e85
commit
0063bf64fc
|
@ -398,7 +398,7 @@ struct Args {
|
|||
env: bool,
|
||||
|
||||
/// Control the maximum number of inputs that a client can send in a single request
|
||||
#[clap(default_value = "32", long, env)]
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
}
|
||||
|
||||
|
|
|
@ -279,6 +279,9 @@ mod prompt_serde {
|
|||
let value = Value::deserialize(deserializer)?;
|
||||
match value {
|
||||
Value::String(s) => Ok(vec![s]),
|
||||
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom(
|
||||
"Empty array detected. Do not use an empty array for the prompt.",
|
||||
)),
|
||||
Value::Array(arr) => arr
|
||||
.iter()
|
||||
.map(|v| match v {
|
||||
|
|
|
@ -77,7 +77,7 @@ struct Args {
|
|||
messages_api_enabled: bool,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "32", long, env)]
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
}
|
||||
|
||||
|
|
|
@ -612,10 +612,10 @@ async fn completions(
|
|||
));
|
||||
}
|
||||
|
||||
let mut generate_requests = Vec::new();
|
||||
for prompt in req.prompt.iter() {
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
let generate_requests: Vec<GenerateRequest> = req
|
||||
.prompt
|
||||
.iter()
|
||||
.map(|prompt| GenerateRequest {
|
||||
inputs: prompt.to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
|
@ -637,9 +637,8 @@ async fn completions(
|
|||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
},
|
||||
};
|
||||
generate_requests.push(generate_request);
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut x_compute_type = "unknown".to_string();
|
||||
let mut x_compute_characters = 0u32;
|
||||
|
@ -766,130 +765,139 @@ async fn completions(
|
|||
};
|
||||
|
||||
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
||||
return Ok((headers, sse).into_response());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let responses = FuturesUnordered::new();
|
||||
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
||||
let infer_clone = infer.clone();
|
||||
let compute_type_clone = compute_type.clone();
|
||||
let response_future = async move {
|
||||
let result = generate(
|
||||
Extension(infer_clone),
|
||||
Extension(compute_type_clone),
|
||||
Json(generate_request),
|
||||
)
|
||||
.await;
|
||||
result.map(|(headers, generation)| (index, headers, generation))
|
||||
};
|
||||
responses.push(response_future);
|
||||
}
|
||||
let generate_responses = responses.try_collect::<Vec<_>>().await?;
|
||||
|
||||
let mut prompt_tokens = 0u32;
|
||||
let mut completion_tokens = 0u32;
|
||||
let mut total_tokens = 0u32;
|
||||
|
||||
let mut x_compute_time = 0u32;
|
||||
let mut x_total_time = 0u32;
|
||||
let mut x_validation_time = 0u32;
|
||||
let mut x_queue_time = 0u32;
|
||||
let mut x_inference_time = 0u32;
|
||||
let mut x_time_per_token = 0u32;
|
||||
let mut x_prompt_tokens = 0u32;
|
||||
let mut x_generated_tokens = 0u32;
|
||||
|
||||
let choices = generate_responses
|
||||
.into_iter()
|
||||
.map(|(index, headers, Json(generation))| {
|
||||
let details = generation.details.unwrap_or_default();
|
||||
if index == 0 {
|
||||
x_compute_type = headers
|
||||
.get("x-compute-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
}
|
||||
|
||||
// accumulate headers and usage from each response
|
||||
x_compute_time += headers
|
||||
.get("x-compute-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_compute_characters += headers
|
||||
.get("x-compute-characters")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_total_time += headers
|
||||
.get("x-total-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_validation_time += headers
|
||||
.get("x-validation-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_queue_time += headers
|
||||
.get("x-queue-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_inference_time += headers
|
||||
.get("x-inference-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_time_per_token += headers
|
||||
.get("x-time-per-token")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_prompt_tokens += headers
|
||||
.get("x-prompt-tokens")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_generated_tokens += headers
|
||||
.get("x-generated-tokens")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
prompt_tokens += details.prefill.len() as u32;
|
||||
completion_tokens += details.generated_tokens;
|
||||
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
|
||||
|
||||
CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: generation.generated_text,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let response = Completion {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
model: info.model_id.clone(),
|
||||
system_fingerprint: format!(
|
||||
"{}-{}",
|
||||
info.version,
|
||||
info.docker_label.unwrap_or("native")
|
||||
),
|
||||
choices,
|
||||
usage: Usage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
},
|
||||
};
|
||||
|
||||
// headers similar to `generate` but aggregated
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||
headers.insert("x-total-time", x_total_time.into());
|
||||
headers.insert("x-validation-time", x_validation_time.into());
|
||||
headers.insert("x-queue-time", x_queue_time.into());
|
||||
headers.insert("x-inference-time", x_inference_time.into());
|
||||
headers.insert("x-time-per-token", x_time_per_token.into());
|
||||
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
|
||||
headers.insert("x-generated-tokens", x_generated_tokens.into());
|
||||
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||
|
||||
Ok((headers, Json(response)).into_response())
|
||||
}
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let responses = FuturesUnordered::new();
|
||||
for generate_request in generate_requests.into_iter() {
|
||||
responses.push(generate(
|
||||
Extension(infer.clone()),
|
||||
Extension(compute_type.clone()),
|
||||
Json(generate_request),
|
||||
));
|
||||
}
|
||||
|
||||
let generate_responses = responses.try_collect::<Vec<_>>().await?;
|
||||
|
||||
let mut prompt_tokens = 0u32;
|
||||
let mut completion_tokens = 0u32;
|
||||
let mut total_tokens = 0u32;
|
||||
|
||||
let mut x_compute_time = 0u32;
|
||||
let mut x_total_time = 0u32;
|
||||
let mut x_validation_time = 0u32;
|
||||
let mut x_queue_time = 0u32;
|
||||
let mut x_inference_time = 0u32;
|
||||
let mut x_time_per_token = 0u32;
|
||||
let mut x_prompt_tokens = 0u32;
|
||||
let mut x_generated_tokens = 0u32;
|
||||
|
||||
let choices = generate_responses
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, (headers, Json(generation)))| {
|
||||
let details = generation.details.unwrap_or_default();
|
||||
if index == 0 {
|
||||
x_compute_type = headers
|
||||
.get("x-compute-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
}
|
||||
|
||||
// accumulate headers and usage from each response
|
||||
x_compute_time += headers
|
||||
.get("x-compute-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_compute_characters += headers
|
||||
.get("x-compute-characters")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_total_time += headers
|
||||
.get("x-total-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_validation_time += headers
|
||||
.get("x-validation-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_queue_time += headers
|
||||
.get("x-queue-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_inference_time += headers
|
||||
.get("x-inference-time")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_time_per_token += headers
|
||||
.get("x-time-per-token")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_prompt_tokens += headers
|
||||
.get("x-prompt-tokens")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
x_generated_tokens += headers
|
||||
.get("x-generated-tokens")
|
||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
prompt_tokens += details.prefill.len() as u32;
|
||||
completion_tokens += details.generated_tokens;
|
||||
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
|
||||
|
||||
CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: generation.generated_text,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let response = Completion {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
model: info.model_id.clone(),
|
||||
system_fingerprint: format!("{}-{}", info.version, info.docker_label.unwrap_or("native")),
|
||||
choices,
|
||||
usage: Usage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
},
|
||||
};
|
||||
|
||||
// headers similar to `generate` but aggregated
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||
headers.insert("x-total-time", x_total_time.into());
|
||||
headers.insert("x-validation-time", x_validation_time.into());
|
||||
headers.insert("x-queue-time", x_queue_time.into());
|
||||
headers.insert("x-inference-time", x_inference_time.into());
|
||||
headers.insert("x-time-per-token", x_time_per_token.into());
|
||||
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
|
||||
headers.insert("x-generated-tokens", x_generated_tokens.into());
|
||||
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||
|
||||
Ok((headers, Json(response)).into_response())
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
|
|
Loading…
Reference in New Issue