Add load testing

This commit is contained in:
Olivier Dehaene 2022-10-11 10:36:51 +02:00
parent 1d986983d5
commit fa9a088467
6 changed files with 260 additions and 127 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
.idea

97
k6/load_test.js Normal file
View File

@ -0,0 +1,97 @@
import http from 'k6/http';
import {check, sleep} from 'k6';
export const options = {
stages: [
{duration: '1m', target: 50},
{duration: '2m', target: 100},
{duration: '1m', target: 0},
],
hosts: {
'text-generation-inference.huggingface.co': '127.0.0.1:3000',
},
};
const SLEEP_DURATION = 1;
function greedy_example(inputs, max_new_tokens, name) {
let body = JSON.stringify({
inputs: inputs,
parameters: {
max_new_tokens: max_new_tokens,
do_sample: false,
}
});
let params = {
headers: {
'Content-Type': 'application/json',
},
tags: {
name: name
}
};
return http.post('http://text-generation-inference.huggingface.co/generate', body, params);
}
function sample_example(inputs, max_new_tokens, name) {
let body = JSON.stringify({
inputs: inputs,
parameters: {
max_new_tokens: max_new_tokens,
do_sample: true,
top_p: 0.9
}
});
let params = {
headers: {
'Content-Type': 'application/json',
},
tags: {
name: name
}
};
return http.post('http://text-generation-inference.huggingface.co/generate', body, params);
}
export default function () {
const response_1 = sample_example('A "whatpu" is a small, furry animal native to Tanzania. An example of a sentence that uses the word whatpu is: We were traveling in Africa and we saw these very cute whatpus. To do a "farduddle" means to jump up and down really fast. An example of a sentence that uses the word farduddle is:', 32, 'example-1');
check(response_1, {
'is status 200': (r) => r.status === 200,
});
sleep(SLEEP_DURATION);
const response_2 = sample_example("A poem about the beauty of science by Alfred Edgar Brittle\\nTitle: The Magic Craft\\nIn the old times", 50, "example-2");
check(response_2, {
'is status 200': (r) => r.status === 200,
});
sleep(SLEEP_DURATION);
const response_3 = greedy_example("استخراج العدد العاملي في لغة بايثون: ", 30, "example-3");
check(response_3, {
'is status 200': (r) => r.status === 200,
});
sleep(SLEEP_DURATION);
const response_4 = sample_example("Pour déguster un ortolan, il faut tout d'abord", 32, "example-4");
check(response_4, {
'is status 200': (r) => r.status === 200,
});
sleep(SLEEP_DURATION);
const response_5 = sample_example("Traduce español de España a español de Argentina\nEl coche es rojo - el auto es rojo\nEl ordenador es nuevo - la computadora es nueva\nel boligrafo es negro -", 16, "example-5");
check(response_5, {
'is status 200': (r) => r.status === 200,
});
sleep(SLEEP_DURATION);
const response_6 = sample_example("Question: If I put cheese into the fridge, will it melt?\nAnswer:", 32, "example-6");
check(response_6, {
'is status 200': (r) => r.status === 200,
});
sleep(SLEEP_DURATION);
const response_7 = greedy_example("Question: Where does the Greek Goddess Persephone spend half of the year when she is not with her mother?\nAnswer:", 24, "example-7");
check(response_7, {
'is status 200': (r) => r.status === 200,
});
sleep(SLEEP_DURATION);
}

View File

@ -1,14 +1,17 @@
use crate::{Db, GenerateRequest};
use bloom_inference_client::{Batch, BatchCached, CacheEntry, ClientError, FinishedGeneration, ShardedClient};
use crate::Db;
use bloom_inference_client::{
Batch, BatchCached, CacheEntry, ClientError, FinishedGeneration, ShardedClient,
};
use std::sync::Arc;
use tokio::sync::{oneshot, Notify};
use tokio::sync::{Notify, oneshot};
use crate::server::GenerateRequest;
const MAX_LENGTH: usize = 128;
pub struct InferError {}
#[derive(Clone)]
pub(crate) struct Infer {
pub(crate) struct Batcher {
db: Db,
shared: Arc<Shared>,
}
@ -17,7 +20,7 @@ struct Shared {
batching_task: Notify,
}
impl Infer {
impl Batcher {
pub(crate) fn new(client: ShardedClient) -> Self {
let db = Db::new();
let shared = Arc::new(Shared {
@ -38,7 +41,7 @@ impl Infer {
self.shared.batching_task.notify_waiters();
match request_rx.await.unwrap() {
Ok(output) => Ok(output),
Err(_) => Err(InferError {})
Err(_) => Err(InferError {}),
}
}
}
@ -57,19 +60,19 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
let mut max_sequence_length = entry.sequence_length;
let mut request_ids = entry.request_ids;
if total_batch_size <= 16 {
if let Some(batch) = db.next_batch_minimum_size(16, 48) {
let other_cache_entry = infer_batch(batch, &client, &db).await;
if let Some(entry) = other_cache_entry {
batch_cached_ids.push(entry.id);
total_batch_size += entry.request_ids.len();
max_sequence_length =
max_sequence_length.max(entry.sequence_length);
request_ids.extend(entry.request_ids.into_iter());
}
}
}
// if total_batch_size <= 16 {
// if let Some(batch) = db.next_batch_minimum_size(16, 48) {
// let other_cache_entry = infer_batch(batch, &client, &db).await;
//
// if let Some(entry) = other_cache_entry {
// batch_cached_ids.push(entry.id);
// total_batch_size += entry.request_ids.len();
// max_sequence_length =
// max_sequence_length.max(entry.sequence_length);
// request_ids.extend(entry.request_ids.into_iter());
// }
// }
// }
let batch_cached = BatchCached {
id: entry.id,
@ -87,7 +90,11 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
}
}
async fn infer_batch_cached(batch: BatchCached, client: &ShardedClient, db: &Db) -> Option<CacheEntry> {
async fn infer_batch_cached(
batch: BatchCached,
client: &ShardedClient,
db: &Db,
) -> Option<CacheEntry> {
match client.generate_with_cache(batch.clone()).await {
Ok((finished, cache_entry)) => {
send_finished(finished, db);
@ -109,7 +116,11 @@ async fn infer_batch(batch: Batch, client: &ShardedClient, db: &Db) -> Option<Ca
}
Err(err) => {
println!("{:?}", err);
send_error(err, batch.requests.into_iter().map(|req| req.id).collect(), &db);
send_error(
err,
batch.requests.into_iter().map(|req| req.id).collect(),
&db,
);
None
}
}

View File

@ -1,5 +1,5 @@
/// This code is massively inspired by Tokio mini-redis
use crate::GenerateRequest;
use crate::server::GenerateRequest;
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
use parking_lot::RwLock;
use std::collections::BTreeMap;
@ -44,7 +44,11 @@ impl Db {
Self { shared }
}
pub(crate) fn append(&self, request: GenerateRequest, sender: Sender<Result<String, ClientError>>) {
pub(crate) fn append(
&self,
request: GenerateRequest,
sender: Sender<Result<String, ClientError>>,
) {
let mut state = self.shared.state.write();
let id = state.next_id;
@ -65,7 +69,10 @@ impl Db {
state.entries.insert(id, (request, sender));
}
pub(crate) fn remove(&self, id: &u64) -> Option<(Request, Sender<Result<String, ClientError>>)> {
pub(crate) fn remove(
&self,
id: &u64,
) -> Option<(Request, Sender<Result<String, ClientError>>)> {
let mut state = self.shared.state.write();
state.entries.remove(id)
}

View File

@ -1,105 +1,15 @@
use tokio::time::Instant;
use poem;
use poem::middleware::AddData;
use poem::web::Data;
use poem::{handler, listener::TcpListener, post, web::Json, EndpointExt, Result, Route, Server};
use bloom_inference_client::ShardedClient;
use serde::Deserialize;
use poem;
use poem::listener::TcpListener;
use std::time::Duration;
use poem::http::StatusCode;
use tracing::instrument;
mod server;
mod db;
use db::Db;
mod infer;
use infer::Infer;
#[derive(Clone, Debug, Deserialize)]
struct GenerateParameters {
#[serde(default = "default_temperature")]
temperature: f32,
#[serde(default = "default_top_k")]
top_k: u32,
#[serde(default = "default_top_p")]
top_p: f32,
#[serde(default = "default_do_sample")]
do_sample: bool,
#[serde(default = "default_max_new_tokens")]
max_new_tokens: u32,
}
fn default_temperature() -> f32 {
1.0
}
fn default_top_k() -> u32 {
0
}
fn default_top_p() -> f32 {
1.0
}
fn default_do_sample() -> bool {
false
}
fn default_max_new_tokens() -> u32 {
20
}
#[derive(Clone, Debug, Deserialize)]
struct GenerateRequest {
inputs: String,
#[serde(default = "default_parameters")]
parameters: GenerateParameters,
}
fn default_parameters() -> GenerateParameters {
GenerateParameters {
temperature: default_temperature(),
top_k: default_top_k(),
top_p: default_top_p(),
do_sample: default_do_sample(),
max_new_tokens: default_max_new_tokens(),
}
}
#[handler]
#[instrument(skip(infer), fields(time, time_per_token))]
async fn generate(
infer: Data<&Infer>,
req: Json<GenerateRequest>,
) -> Result<Json<serde_json::Value>> {
let start = Instant::now();
let output = infer
.infer(GenerateRequest {
inputs: req.inputs.clone(),
parameters: req.parameters.clone(),
})
.await;
match output {
Ok(generated_text) => {
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
tracing::Span::current().record("time_per_token", format!("{:?}", start.elapsed() / req.parameters.max_new_tokens));
tracing::info!("response: {}", generated_text);
Ok(Json(serde_json::json!({
"generated_text": generated_text,
})))
}
Err(_) => {
Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))
}
}
}
mod batcher;
use batcher::Batcher;
#[tokio::main]
async fn main() -> Result<(), std::io::Error> {
@ -114,12 +24,8 @@ async fn main() -> Result<(), std::io::Error> {
.expect("Unable to clear cache");
tracing::info!("Connected");
let infer = Infer::new(sharded_client);
let addr = "127.0.0.1:3000".to_string();
let listener = TcpListener::bind(addr);
let app = Route::new()
.at("/generate", post(generate))
.with(AddData::new(infer));
Server::new(TcpListener::bind("127.0.0.1:3000"))
.run(app)
.await
server::run(sharded_client, listener).await
}

111
router/src/server.rs Normal file
View File

@ -0,0 +1,111 @@
use poem::{EndpointExt, handler, post, Route, Server};
use poem::http::StatusCode;
use poem::listener::TcpListener;
use poem::middleware::AddData;
use poem::web::{Data, Json};
use tokio::time::Instant;
use crate::{Batcher, ShardedClient};
use tracing::instrument;
use serde::Deserialize;
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateParameters {
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_top_k")]
pub top_k: u32,
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default = "default_do_sample")]
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
pub max_new_tokens: u32,
}
fn default_temperature() -> f32 {
1.0
}
fn default_top_k() -> u32 {
0
}
fn default_top_p() -> f32 {
1.0
}
fn default_do_sample() -> bool {
false
}
fn default_max_new_tokens() -> u32 {
20
}
fn default_parameters() -> GenerateParameters {
GenerateParameters {
temperature: default_temperature(),
top_k: default_top_k(),
top_p: default_top_p(),
do_sample: default_do_sample(),
max_new_tokens: default_max_new_tokens(),
}
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateRequest {
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
}
#[handler]
#[instrument(skip(infer), fields(time, time_per_token))]
async fn generate(
infer: Data<&Batcher>,
req: Json<GenerateRequest>,
) -> poem::Result<Json<serde_json::Value>> {
let start = Instant::now();
let output = infer
.infer(GenerateRequest {
inputs: req.inputs.clone(),
parameters: req.parameters.clone(),
})
.await;
match output {
Ok(generated_text) => {
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
tracing::Span::current().record(
"time_per_token",
format!("{:?}", start.elapsed() / req.parameters.max_new_tokens),
);
tracing::info!("response: {}", generated_text);
Ok(Json(serde_json::json!({
"generated_text": generated_text,
})))
}
Err(_) => Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)),
}
}
pub async fn run(client: ShardedClient, listener: TcpListener<String>) -> Result<(), std::io::Error> {
client
.clear_cache()
.await
.expect("Unable to clear cache");
tracing::info!("Connected");
let infer = Batcher::new(client);
let app = Route::new()
.at("/generate", post(generate))
.with(AddData::new(infer));
Server::new(listener)
.run(app)
.await
}