feat: handle batch completions requests
This commit is contained in:
parent
b08038cc81
commit
70e8d2aaa9
|
@ -393,6 +393,15 @@ Options:
|
|||
-e, --env
|
||||
Display a lot of information about your runtime environment
|
||||
|
||||
```
|
||||
## MAX_CLIENT_BATCH_SIZE
|
||||
```shell
|
||||
--max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
|
||||
Control the maximum number of inputs that a client can send in a single request
|
||||
|
||||
[env: MAX_CLIENT_BATCH_SIZE=]
|
||||
[default: 32]
|
||||
|
||||
```
|
||||
## HELP
|
||||
```shell
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_completion_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_completion(flash_llama_completion_handle):
|
||||
await flash_llama_completion_handle.health(300)
|
||||
return flash_llama_completion_handle.client
|
||||
|
||||
|
||||
# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience
|
||||
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
||||
|
||||
|
||||
def test_flash_llama_grammar_single_prompt(flash_llama_completion, response_snapshot):
|
||||
response = requests.post(
|
||||
f"{flash_llama_completion.base_url}/v1/completions",
|
||||
json={
|
||||
"model": "tgi",
|
||||
"prompt": "Say this is a test",
|
||||
"max_tokens": 5,
|
||||
"seed": 0,
|
||||
},
|
||||
headers=flash_llama_completion.headers,
|
||||
stream=False,
|
||||
)
|
||||
response = response.json()
|
||||
assert len(response["choices"]) == 1
|
||||
|
||||
|
||||
def test_flash_llama_grammar_many_prompts(flash_llama_completion, response_snapshot):
|
||||
response = requests.post(
|
||||
f"{flash_llama_completion.base_url}/v1/completions",
|
||||
json={
|
||||
"model": "tgi",
|
||||
"prompt": ["Say", "this", "is", "a", "test"],
|
||||
"max_tokens": 5,
|
||||
"seed": 0,
|
||||
},
|
||||
headers=flash_llama_completion.headers,
|
||||
stream=False,
|
||||
)
|
||||
response = response.json()
|
||||
assert len(response["choices"]) == 5
|
||||
|
||||
all_indexes = [choice["index"] for choice in response["choices"]]
|
||||
all_indexes.sort()
|
||||
assert all_indexes == [0, 1, 2, 3, 4]
|
|
@ -396,6 +396,10 @@ struct Args {
|
|||
/// Display a lot of information about your runtime environment
|
||||
#[clap(long, short, action)]
|
||||
env: bool,
|
||||
|
||||
/// Control the maximum number of inputs that a client can send in a single request
|
||||
#[clap(default_value = "32", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -1044,6 +1048,8 @@ fn spawn_webserver(
|
|||
// Start webserver
|
||||
tracing::info!("Starting Webserver");
|
||||
let mut router_args = vec![
|
||||
"--max-client-batch-size".to_string(),
|
||||
args.max_client_batch_size.to_string(),
|
||||
"--max-concurrent-requests".to_string(),
|
||||
args.max_concurrent_requests.to_string(),
|
||||
"--max-best-of".to_string(),
|
||||
|
|
|
@ -141,6 +141,8 @@ pub struct Info {
|
|||
pub max_batch_size: Option<usize>,
|
||||
#[schema(example = "2")]
|
||||
pub validation_workers: usize,
|
||||
#[schema(example = "32")]
|
||||
pub max_client_batch_size: usize,
|
||||
/// Router Info
|
||||
#[schema(example = "0.5.0")]
|
||||
pub version: &'static str,
|
||||
|
@ -270,28 +272,20 @@ mod prompt_serde {
|
|||
use serde::{self, Deserialize, Deserializer};
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
match value {
|
||||
Value::String(s) => Ok(s),
|
||||
Value::Array(arr) => {
|
||||
if arr.len() == 1 {
|
||||
match arr[0].as_str() {
|
||||
Some(s) => Ok(s.to_string()),
|
||||
None => Err(serde::de::Error::custom(
|
||||
"Array contains non-string elements",
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
Err(serde::de::Error::custom(
|
||||
"Array contains non-string element. Expected string. In general arrays should not be used for prompts. Please use a string instead if possible.",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
Value::String(s) => Ok(vec![s]),
|
||||
Value::Array(arr) => arr
|
||||
.iter()
|
||||
.map(|v| match v {
|
||||
Value::String(s) => Ok(s.to_owned()),
|
||||
_ => Err(serde::de::Error::custom("Expected a string")),
|
||||
})
|
||||
.collect(),
|
||||
_ => Err(serde::de::Error::custom(
|
||||
"Expected a string or an array of strings",
|
||||
)),
|
||||
|
@ -309,7 +303,7 @@ pub struct CompletionRequest {
|
|||
/// The prompt to generate completions for.
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
#[serde(deserialize_with = "prompt_serde::deserialize")]
|
||||
pub prompt: String,
|
||||
pub prompt: Vec<String>,
|
||||
|
||||
/// The maximum number of tokens that can be generated in the chat completion.
|
||||
#[serde(default)]
|
||||
|
@ -945,6 +939,20 @@ pub(crate) struct Details {
|
|||
pub top_tokens: Vec<Vec<Token>>,
|
||||
}
|
||||
|
||||
impl Default for Details {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
finish_reason: FinishReason::Length,
|
||||
generated_tokens: 0,
|
||||
seed: None,
|
||||
prefill: Vec::new(),
|
||||
tokens: Vec::new(),
|
||||
best_of_sequences: None,
|
||||
top_tokens: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct GenerateResponse {
|
||||
#[schema(example = "test")]
|
||||
|
|
|
@ -77,6 +77,8 @@ struct Args {
|
|||
messages_api_enabled: bool,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "32", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
|
@ -111,6 +113,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
|
@ -372,6 +375,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
tokenizer_config,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
|
|
@ -595,126 +595,250 @@ async fn completions(
|
|||
));
|
||||
}
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: req.prompt.to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: req.temperature,
|
||||
repetition_penalty: req.repetition_penalty,
|
||||
frequency_penalty: req.frequency_penalty,
|
||||
top_k: None,
|
||||
top_p: req.top_p,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
if req.prompt.len() > info.max_client_batch_size {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: format!(
|
||||
"Number of prompts exceeds the maximum allowed batch size of {}",
|
||||
info.max_client_batch_size
|
||||
),
|
||||
error_type: "batch size exceeded".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
let mut generate_requests = Vec::new();
|
||||
for prompt in req.prompt.iter() {
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: prompt.to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: req.temperature,
|
||||
repetition_penalty: req.repetition_penalty,
|
||||
frequency_penalty: req.frequency_penalty,
|
||||
top_k: None,
|
||||
top_p: req.top_p,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
},
|
||||
};
|
||||
generate_requests.push(generate_request);
|
||||
}
|
||||
|
||||
if stream {
|
||||
let response_streams = FuturesUnordered::new();
|
||||
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
||||
let model_id = info.model_id.clone();
|
||||
let system_fingerprint =
|
||||
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
event
|
||||
.json_data(CompletionCompleteChunk {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: "".to_string(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
})
|
||||
.map_or_else(
|
||||
|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
},
|
||||
|data| data,
|
||||
)
|
||||
};
|
||||
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer.clone(),
|
||||
compute_type.clone(),
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
)
|
||||
.await;
|
||||
|
||||
response_streams.push((headers, Box::pin(response_stream)));
|
||||
}
|
||||
|
||||
let stream = async_stream::stream! {
|
||||
for response_stream in response_streams {
|
||||
let (_headers, mut inner_stream) = response_stream;
|
||||
while let Some(event) = inner_stream.next().await {
|
||||
yield event;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
||||
return Ok((HeaderMap::new(), sse).into_response());
|
||||
}
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let responses = FuturesUnordered::new();
|
||||
for generate_request in generate_requests.into_iter() {
|
||||
responses.push(generate(
|
||||
Extension(infer.clone()),
|
||||
Extension(compute_type.clone()),
|
||||
Json(generate_request),
|
||||
));
|
||||
}
|
||||
|
||||
let generate_responses = responses.try_collect::<Vec<_>>().await?;
|
||||
|
||||
let mut prompt_tokens = 0u32;
|
||||
let mut completion_tokens = 0u32;
|
||||
let mut total_tokens = 0u32;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
let mut x_compute_type: Option<String> = None;
|
||||
let mut x_compute_time = 0u32;
|
||||
let mut x_compute_characters = 0u32;
|
||||
let mut x_total_time = 0u32;
|
||||
let mut x_validation_time = 0u32;
|
||||
let mut x_queue_time = 0u32;
|
||||
let mut x_inference_time = 0u32;
|
||||
let mut x_time_per_token = 0u32;
|
||||
let mut x_prompt_tokens = 0u32;
|
||||
let mut x_generated_tokens = 0u32;
|
||||
|
||||
// helper closure to extract a header value or default to 0
|
||||
let extract_or_zero = |headers: &HeaderMap, key: &str| {
|
||||
headers
|
||||
.get(key)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("0")
|
||||
.parse::<u32>()
|
||||
.unwrap_or(0)
|
||||
};
|
||||
|
||||
let choices = generate_responses
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, (headers, Json(generation)))| {
|
||||
let details = generation.details.unwrap_or_default();
|
||||
if x_compute_type.is_none() {
|
||||
x_compute_type = Some(
|
||||
headers
|
||||
.get("x-compute-type")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
// update headers
|
||||
x_compute_time += extract_or_zero(&headers, "x-compute-time");
|
||||
x_compute_characters += extract_or_zero(&headers, "x-compute-characters");
|
||||
x_total_time += extract_or_zero(&headers, "x-total-time");
|
||||
x_validation_time += extract_or_zero(&headers, "x-validation-time");
|
||||
x_queue_time += extract_or_zero(&headers, "x-queue-time");
|
||||
x_inference_time += extract_or_zero(&headers, "x-inference-time");
|
||||
x_time_per_token += extract_or_zero(&headers, "x-time-per-token");
|
||||
x_prompt_tokens += extract_or_zero(&headers, "x-prompt-tokens");
|
||||
x_generated_tokens += extract_or_zero(&headers, "x-generated-tokens");
|
||||
|
||||
// update usage
|
||||
prompt_tokens += details.prefill.len() as u32;
|
||||
completion_tokens += details.generated_tokens;
|
||||
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
|
||||
|
||||
CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: generation.generated_text,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Headers similar to `generate` but aggregated
|
||||
headers.insert(
|
||||
"x-compute-type",
|
||||
x_compute_type
|
||||
.unwrap_or("unknown".to_string())
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-compute-time",
|
||||
x_compute_time.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-compute-characters",
|
||||
x_compute_characters.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert("x-total-time", x_total_time.to_string().parse().unwrap());
|
||||
headers.insert(
|
||||
"x-validation-time",
|
||||
x_validation_time.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert("x-queue-time", x_queue_time.to_string().parse().unwrap());
|
||||
headers.insert(
|
||||
"x-inference-time",
|
||||
x_inference_time.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-time-per-token",
|
||||
x_time_per_token.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-prompt-tokens",
|
||||
x_prompt_tokens.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-generated-tokens",
|
||||
x_generated_tokens.to_string().parse().unwrap(),
|
||||
);
|
||||
|
||||
let response = Completion {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
model: info.model_id.clone(),
|
||||
system_fingerprint: format!("{}-{}", info.version, info.docker_label.unwrap_or("native")),
|
||||
choices,
|
||||
usage: Usage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
},
|
||||
};
|
||||
|
||||
if stream {
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
event
|
||||
.json_data(CompletionCompleteChunk {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: "".to_string(),
|
||||
index: 0,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
|
||||
model: info.model_id.clone(),
|
||||
system_fingerprint: format!(
|
||||
"{}-{}",
|
||||
info.version,
|
||||
info.docker_label.unwrap_or("native")
|
||||
),
|
||||
})
|
||||
.map_or_else(
|
||||
|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
},
|
||||
|data| data,
|
||||
)
|
||||
};
|
||||
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer,
|
||||
compute_type,
|
||||
Json(generate_request),
|
||||
on_message_callback,
|
||||
)
|
||||
.await;
|
||||
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
let (headers, Json(generation)) = generate(
|
||||
Extension(infer),
|
||||
Extension(compute_type),
|
||||
Json(generate_request),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let details = generation.details.ok_or((
|
||||
// this should never happen but handle if details are missing unexpectedly
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "No details in generation".to_string(),
|
||||
error_type: "no details".to_string(),
|
||||
}),
|
||||
))?;
|
||||
|
||||
let response = Completion {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
created: current_time,
|
||||
model: info.model_id.clone(),
|
||||
system_fingerprint: format!(
|
||||
"{}-{}",
|
||||
info.version,
|
||||
info.docker_label.unwrap_or("native")
|
||||
),
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: 0,
|
||||
logprobs: None,
|
||||
text: generation.generated_text,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: details.prefill.len() as u32,
|
||||
completion_tokens: details.generated_tokens,
|
||||
total_tokens: details.prefill.len() as u32 + details.generated_tokens,
|
||||
},
|
||||
};
|
||||
|
||||
Ok((headers, Json(response)).into_response())
|
||||
}
|
||||
Ok((headers, Json(response)).into_response())
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
|
@ -1163,6 +1287,7 @@ pub async fn run(
|
|||
tokenizer_config: HubTokenizerConfig,
|
||||
messages_api_enabled: bool,
|
||||
grammar_support: bool,
|
||||
max_client_batch_size: usize,
|
||||
) -> Result<(), axum::BoxError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
|
@ -1336,6 +1461,7 @@ pub async fn run(
|
|||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
validation_workers,
|
||||
max_client_batch_size,
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
sha: option_env!("VERGEN_GIT_SHA"),
|
||||
docker_label: option_env!("DOCKER_LABEL"),
|
||||
|
|
Loading…
Reference in New Issue