fix: decrease default batch, refactors and include index in batch

This commit is contained in:
drbh 2024-04-11 20:37:35 +00:00
parent 5c14e08e85
commit 0063bf64fc
4 changed files with 143 additions and 132 deletions

View File

@ -398,7 +398,7 @@ struct Args {
env: bool,
/// Control the maximum number of inputs that a client can send in a single request
#[clap(default_value = "32", long, env)]
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
}

View File

@ -279,6 +279,9 @@ mod prompt_serde {
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(vec![s]),
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom(
"Empty array detected. Do not use an empty array for the prompt.",
)),
Value::Array(arr) => arr
.iter()
.map(|v| match v {

View File

@ -77,7 +77,7 @@ struct Args {
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "32", long, env)]
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
}

View File

@ -612,10 +612,10 @@ async fn completions(
));
}
let mut generate_requests = Vec::new();
for prompt in req.prompt.iter() {
// build the request passing some parameters
let generate_request = GenerateRequest {
let generate_requests: Vec<GenerateRequest> = req
.prompt
.iter()
.map(|prompt| GenerateRequest {
inputs: prompt.to_string(),
parameters: GenerateParameters {
best_of: None,
@ -637,9 +637,8 @@ async fn completions(
top_n_tokens: None,
grammar: None,
},
};
generate_requests.push(generate_request);
}
})
.collect();
let mut x_compute_type = "unknown".to_string();
let mut x_compute_characters = 0u32;
@ -766,130 +765,139 @@ async fn completions(
};
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
return Ok((headers, sse).into_response());
Ok((headers, sse).into_response())
} else {
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let responses = FuturesUnordered::new();
for (index, generate_request) in generate_requests.into_iter().enumerate() {
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let response_future = async move {
let result = generate(
Extension(infer_clone),
Extension(compute_type_clone),
Json(generate_request),
)
.await;
result.map(|(headers, generation)| (index, headers, generation))
};
responses.push(response_future);
}
let generate_responses = responses.try_collect::<Vec<_>>().await?;
let mut prompt_tokens = 0u32;
let mut completion_tokens = 0u32;
let mut total_tokens = 0u32;
let mut x_compute_time = 0u32;
let mut x_total_time = 0u32;
let mut x_validation_time = 0u32;
let mut x_queue_time = 0u32;
let mut x_inference_time = 0u32;
let mut x_time_per_token = 0u32;
let mut x_prompt_tokens = 0u32;
let mut x_generated_tokens = 0u32;
let choices = generate_responses
.into_iter()
.map(|(index, headers, Json(generation))| {
let details = generation.details.unwrap_or_default();
if index == 0 {
x_compute_type = headers
.get("x-compute-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
}
// accumulate headers and usage from each response
x_compute_time += headers
.get("x-compute-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_compute_characters += headers
.get("x-compute-characters")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_total_time += headers
.get("x-total-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_validation_time += headers
.get("x-validation-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_queue_time += headers
.get("x-queue-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_inference_time += headers
.get("x-inference-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_time_per_token += headers
.get("x-time-per-token")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_prompt_tokens += headers
.get("x-prompt-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_generated_tokens += headers
.get("x-generated-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
prompt_tokens += details.prefill.len() as u32;
completion_tokens += details.generated_tokens;
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
CompletionComplete {
finish_reason: details.finish_reason.to_string(),
index: index as u32,
logprobs: None,
text: generation.generated_text,
}
})
.collect::<Vec<_>>();
let response = Completion {
id: "".to_string(),
object: "text_completion".to_string(),
created: current_time,
model: info.model_id.clone(),
system_fingerprint: format!(
"{}-{}",
info.version,
info.docker_label.unwrap_or("native")
),
choices,
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens,
},
};
// headers similar to `generate` but aggregated
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
headers.insert("x-compute-characters", x_compute_characters.into());
headers.insert("x-total-time", x_total_time.into());
headers.insert("x-validation-time", x_validation_time.into());
headers.insert("x-queue-time", x_queue_time.into());
headers.insert("x-inference-time", x_inference_time.into());
headers.insert("x-time-per-token", x_time_per_token.into());
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
headers.insert("x-generated-tokens", x_generated_tokens.into());
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
Ok((headers, Json(response)).into_response())
}
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let responses = FuturesUnordered::new();
for generate_request in generate_requests.into_iter() {
responses.push(generate(
Extension(infer.clone()),
Extension(compute_type.clone()),
Json(generate_request),
));
}
let generate_responses = responses.try_collect::<Vec<_>>().await?;
let mut prompt_tokens = 0u32;
let mut completion_tokens = 0u32;
let mut total_tokens = 0u32;
let mut x_compute_time = 0u32;
let mut x_total_time = 0u32;
let mut x_validation_time = 0u32;
let mut x_queue_time = 0u32;
let mut x_inference_time = 0u32;
let mut x_time_per_token = 0u32;
let mut x_prompt_tokens = 0u32;
let mut x_generated_tokens = 0u32;
let choices = generate_responses
.into_iter()
.enumerate()
.map(|(index, (headers, Json(generation)))| {
let details = generation.details.unwrap_or_default();
if index == 0 {
x_compute_type = headers
.get("x-compute-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
}
// accumulate headers and usage from each response
x_compute_time += headers
.get("x-compute-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_compute_characters += headers
.get("x-compute-characters")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_total_time += headers
.get("x-total-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_validation_time += headers
.get("x-validation-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_queue_time += headers
.get("x-queue-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_inference_time += headers
.get("x-inference-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_time_per_token += headers
.get("x-time-per-token")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_prompt_tokens += headers
.get("x-prompt-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_generated_tokens += headers
.get("x-generated-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
prompt_tokens += details.prefill.len() as u32;
completion_tokens += details.generated_tokens;
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
CompletionComplete {
finish_reason: details.finish_reason.to_string(),
index: index as u32,
logprobs: None,
text: generation.generated_text,
}
})
.collect::<Vec<_>>();
let response = Completion {
id: "".to_string(),
object: "text_completion".to_string(),
created: current_time,
model: info.model_id.clone(),
system_fingerprint: format!("{}-{}", info.version, info.docker_label.unwrap_or("native")),
choices,
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens,
},
};
// headers similar to `generate` but aggregated
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
headers.insert("x-compute-characters", x_compute_characters.into());
headers.insert("x-total-time", x_total_time.into());
headers.insert("x-validation-time", x_validation_time.into());
headers.insert("x-queue-time", x_queue_time.into());
headers.insert("x-inference-time", x_inference_time.into());
headers.insert("x-time-per-token", x_time_per_token.into());
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
headers.insert("x-generated-tokens", x_generated_tokens.into());
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
Ok((headers, Json(response)).into_response())
}
/// Generate tokens