Add load testing
This commit is contained in:
parent
1d986983d5
commit
fa9a088467
|
@ -0,0 +1 @@
|
|||
.idea
|
|
@ -0,0 +1,97 @@
|
|||
import http from 'k6/http';
|
||||
import {check, sleep} from 'k6';
|
||||
|
||||
export const options = {
|
||||
stages: [
|
||||
{duration: '1m', target: 50},
|
||||
{duration: '2m', target: 100},
|
||||
{duration: '1m', target: 0},
|
||||
],
|
||||
hosts: {
|
||||
'text-generation-inference.huggingface.co': '127.0.0.1:3000',
|
||||
},
|
||||
};
|
||||
const SLEEP_DURATION = 1;
|
||||
|
||||
function greedy_example(inputs, max_new_tokens, name) {
|
||||
let body = JSON.stringify({
|
||||
inputs: inputs,
|
||||
parameters: {
|
||||
max_new_tokens: max_new_tokens,
|
||||
do_sample: false,
|
||||
}
|
||||
});
|
||||
let params = {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
tags: {
|
||||
name: name
|
||||
}
|
||||
};
|
||||
return http.post('http://text-generation-inference.huggingface.co/generate', body, params);
|
||||
}
|
||||
|
||||
function sample_example(inputs, max_new_tokens, name) {
|
||||
let body = JSON.stringify({
|
||||
inputs: inputs,
|
||||
parameters: {
|
||||
max_new_tokens: max_new_tokens,
|
||||
do_sample: true,
|
||||
top_p: 0.9
|
||||
}
|
||||
});
|
||||
let params = {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
tags: {
|
||||
name: name
|
||||
}
|
||||
};
|
||||
return http.post('http://text-generation-inference.huggingface.co/generate', body, params);
|
||||
}
|
||||
|
||||
export default function () {
|
||||
const response_1 = sample_example('A "whatpu" is a small, furry animal native to Tanzania. An example of a sentence that uses the word whatpu is: We were traveling in Africa and we saw these very cute whatpus. To do a "farduddle" means to jump up and down really fast. An example of a sentence that uses the word farduddle is:', 32, 'example-1');
|
||||
check(response_1, {
|
||||
'is status 200': (r) => r.status === 200,
|
||||
});
|
||||
sleep(SLEEP_DURATION);
|
||||
|
||||
const response_2 = sample_example("A poem about the beauty of science by Alfred Edgar Brittle\\nTitle: The Magic Craft\\nIn the old times", 50, "example-2");
|
||||
check(response_2, {
|
||||
'is status 200': (r) => r.status === 200,
|
||||
});
|
||||
sleep(SLEEP_DURATION);
|
||||
|
||||
const response_3 = greedy_example("استخراج العدد العاملي في لغة بايثون: ", 30, "example-3");
|
||||
check(response_3, {
|
||||
'is status 200': (r) => r.status === 200,
|
||||
});
|
||||
sleep(SLEEP_DURATION);
|
||||
|
||||
const response_4 = sample_example("Pour déguster un ortolan, il faut tout d'abord", 32, "example-4");
|
||||
check(response_4, {
|
||||
'is status 200': (r) => r.status === 200,
|
||||
});
|
||||
sleep(SLEEP_DURATION);
|
||||
|
||||
const response_5 = sample_example("Traduce español de España a español de Argentina\nEl coche es rojo - el auto es rojo\nEl ordenador es nuevo - la computadora es nueva\nel boligrafo es negro -", 16, "example-5");
|
||||
check(response_5, {
|
||||
'is status 200': (r) => r.status === 200,
|
||||
});
|
||||
sleep(SLEEP_DURATION);
|
||||
|
||||
const response_6 = sample_example("Question: If I put cheese into the fridge, will it melt?\nAnswer:", 32, "example-6");
|
||||
check(response_6, {
|
||||
'is status 200': (r) => r.status === 200,
|
||||
});
|
||||
sleep(SLEEP_DURATION);
|
||||
|
||||
const response_7 = greedy_example("Question: Where does the Greek Goddess Persephone spend half of the year when she is not with her mother?\nAnswer:", 24, "example-7");
|
||||
check(response_7, {
|
||||
'is status 200': (r) => r.status === 200,
|
||||
});
|
||||
sleep(SLEEP_DURATION);
|
||||
}
|
|
@ -1,14 +1,17 @@
|
|||
use crate::{Db, GenerateRequest};
|
||||
use bloom_inference_client::{Batch, BatchCached, CacheEntry, ClientError, FinishedGeneration, ShardedClient};
|
||||
use crate::Db;
|
||||
use bloom_inference_client::{
|
||||
Batch, BatchCached, CacheEntry, ClientError, FinishedGeneration, ShardedClient,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{oneshot, Notify};
|
||||
use tokio::sync::{Notify, oneshot};
|
||||
use crate::server::GenerateRequest;
|
||||
|
||||
const MAX_LENGTH: usize = 128;
|
||||
|
||||
pub struct InferError {}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct Infer {
|
||||
pub(crate) struct Batcher {
|
||||
db: Db,
|
||||
shared: Arc<Shared>,
|
||||
}
|
||||
|
@ -17,7 +20,7 @@ struct Shared {
|
|||
batching_task: Notify,
|
||||
}
|
||||
|
||||
impl Infer {
|
||||
impl Batcher {
|
||||
pub(crate) fn new(client: ShardedClient) -> Self {
|
||||
let db = Db::new();
|
||||
let shared = Arc::new(Shared {
|
||||
|
@ -38,7 +41,7 @@ impl Infer {
|
|||
self.shared.batching_task.notify_waiters();
|
||||
match request_rx.await.unwrap() {
|
||||
Ok(output) => Ok(output),
|
||||
Err(_) => Err(InferError {})
|
||||
Err(_) => Err(InferError {}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -57,19 +60,19 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
|
|||
let mut max_sequence_length = entry.sequence_length;
|
||||
let mut request_ids = entry.request_ids;
|
||||
|
||||
if total_batch_size <= 16 {
|
||||
if let Some(batch) = db.next_batch_minimum_size(16, 48) {
|
||||
let other_cache_entry = infer_batch(batch, &client, &db).await;
|
||||
|
||||
if let Some(entry) = other_cache_entry {
|
||||
batch_cached_ids.push(entry.id);
|
||||
total_batch_size += entry.request_ids.len();
|
||||
max_sequence_length =
|
||||
max_sequence_length.max(entry.sequence_length);
|
||||
request_ids.extend(entry.request_ids.into_iter());
|
||||
}
|
||||
}
|
||||
}
|
||||
// if total_batch_size <= 16 {
|
||||
// if let Some(batch) = db.next_batch_minimum_size(16, 48) {
|
||||
// let other_cache_entry = infer_batch(batch, &client, &db).await;
|
||||
//
|
||||
// if let Some(entry) = other_cache_entry {
|
||||
// batch_cached_ids.push(entry.id);
|
||||
// total_batch_size += entry.request_ids.len();
|
||||
// max_sequence_length =
|
||||
// max_sequence_length.max(entry.sequence_length);
|
||||
// request_ids.extend(entry.request_ids.into_iter());
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
let batch_cached = BatchCached {
|
||||
id: entry.id,
|
||||
|
@ -87,7 +90,11 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
|
|||
}
|
||||
}
|
||||
|
||||
async fn infer_batch_cached(batch: BatchCached, client: &ShardedClient, db: &Db) -> Option<CacheEntry> {
|
||||
async fn infer_batch_cached(
|
||||
batch: BatchCached,
|
||||
client: &ShardedClient,
|
||||
db: &Db,
|
||||
) -> Option<CacheEntry> {
|
||||
match client.generate_with_cache(batch.clone()).await {
|
||||
Ok((finished, cache_entry)) => {
|
||||
send_finished(finished, db);
|
||||
|
@ -109,7 +116,11 @@ async fn infer_batch(batch: Batch, client: &ShardedClient, db: &Db) -> Option<Ca
|
|||
}
|
||||
Err(err) => {
|
||||
println!("{:?}", err);
|
||||
send_error(err, batch.requests.into_iter().map(|req| req.id).collect(), &db);
|
||||
send_error(
|
||||
err,
|
||||
batch.requests.into_iter().map(|req| req.id).collect(),
|
||||
&db,
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/// This code is massively inspired by Tokio mini-redis
|
||||
use crate::GenerateRequest;
|
||||
use crate::server::GenerateRequest;
|
||||
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::BTreeMap;
|
||||
|
@ -44,7 +44,11 @@ impl Db {
|
|||
Self { shared }
|
||||
}
|
||||
|
||||
pub(crate) fn append(&self, request: GenerateRequest, sender: Sender<Result<String, ClientError>>) {
|
||||
pub(crate) fn append(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
sender: Sender<Result<String, ClientError>>,
|
||||
) {
|
||||
let mut state = self.shared.state.write();
|
||||
|
||||
let id = state.next_id;
|
||||
|
@ -65,7 +69,10 @@ impl Db {
|
|||
state.entries.insert(id, (request, sender));
|
||||
}
|
||||
|
||||
pub(crate) fn remove(&self, id: &u64) -> Option<(Request, Sender<Result<String, ClientError>>)> {
|
||||
pub(crate) fn remove(
|
||||
&self,
|
||||
id: &u64,
|
||||
) -> Option<(Request, Sender<Result<String, ClientError>>)> {
|
||||
let mut state = self.shared.state.write();
|
||||
state.entries.remove(id)
|
||||
}
|
||||
|
|
|
@ -1,105 +1,15 @@
|
|||
use tokio::time::Instant;
|
||||
|
||||
use poem;
|
||||
use poem::middleware::AddData;
|
||||
use poem::web::Data;
|
||||
use poem::{handler, listener::TcpListener, post, web::Json, EndpointExt, Result, Route, Server};
|
||||
|
||||
use bloom_inference_client::ShardedClient;
|
||||
use serde::Deserialize;
|
||||
use poem;
|
||||
use poem::listener::TcpListener;
|
||||
use std::time::Duration;
|
||||
use poem::http::StatusCode;
|
||||
use tracing::instrument;
|
||||
|
||||
mod server;
|
||||
|
||||
mod db;
|
||||
|
||||
use db::Db;
|
||||
|
||||
mod infer;
|
||||
|
||||
use infer::Infer;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct GenerateParameters {
|
||||
#[serde(default = "default_temperature")]
|
||||
temperature: f32,
|
||||
#[serde(default = "default_top_k")]
|
||||
top_k: u32,
|
||||
#[serde(default = "default_top_p")]
|
||||
top_p: f32,
|
||||
#[serde(default = "default_do_sample")]
|
||||
do_sample: bool,
|
||||
#[serde(default = "default_max_new_tokens")]
|
||||
max_new_tokens: u32,
|
||||
}
|
||||
|
||||
fn default_temperature() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn default_top_k() -> u32 {
|
||||
0
|
||||
}
|
||||
|
||||
fn default_top_p() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn default_do_sample() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> u32 {
|
||||
20
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct GenerateRequest {
|
||||
inputs: String,
|
||||
#[serde(default = "default_parameters")]
|
||||
parameters: GenerateParameters,
|
||||
}
|
||||
|
||||
fn default_parameters() -> GenerateParameters {
|
||||
GenerateParameters {
|
||||
temperature: default_temperature(),
|
||||
top_k: default_top_k(),
|
||||
top_p: default_top_p(),
|
||||
do_sample: default_do_sample(),
|
||||
max_new_tokens: default_max_new_tokens(),
|
||||
}
|
||||
}
|
||||
|
||||
#[handler]
|
||||
#[instrument(skip(infer), fields(time, time_per_token))]
|
||||
async fn generate(
|
||||
infer: Data<&Infer>,
|
||||
req: Json<GenerateRequest>,
|
||||
) -> Result<Json<serde_json::Value>> {
|
||||
let start = Instant::now();
|
||||
|
||||
let output = infer
|
||||
.infer(GenerateRequest {
|
||||
inputs: req.inputs.clone(),
|
||||
parameters: req.parameters.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
match output {
|
||||
Ok(generated_text) => {
|
||||
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
||||
tracing::Span::current().record("time_per_token", format!("{:?}", start.elapsed() / req.parameters.max_new_tokens));
|
||||
tracing::info!("response: {}", generated_text);
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"generated_text": generated_text,
|
||||
})))
|
||||
}
|
||||
Err(_) => {
|
||||
Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))
|
||||
}
|
||||
}
|
||||
}
|
||||
mod batcher;
|
||||
use batcher::Batcher;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), std::io::Error> {
|
||||
|
@ -114,12 +24,8 @@ async fn main() -> Result<(), std::io::Error> {
|
|||
.expect("Unable to clear cache");
|
||||
tracing::info!("Connected");
|
||||
|
||||
let infer = Infer::new(sharded_client);
|
||||
let addr = "127.0.0.1:3000".to_string();
|
||||
let listener = TcpListener::bind(addr);
|
||||
|
||||
let app = Route::new()
|
||||
.at("/generate", post(generate))
|
||||
.with(AddData::new(infer));
|
||||
Server::new(TcpListener::bind("127.0.0.1:3000"))
|
||||
.run(app)
|
||||
.await
|
||||
server::run(sharded_client, listener).await
|
||||
}
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
use poem::{EndpointExt, handler, post, Route, Server};
|
||||
use poem::http::StatusCode;
|
||||
use poem::listener::TcpListener;
|
||||
use poem::middleware::AddData;
|
||||
use poem::web::{Data, Json};
|
||||
use tokio::time::Instant;
|
||||
use crate::{Batcher, ShardedClient};
|
||||
use tracing::instrument;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub(crate) struct GenerateParameters {
|
||||
#[serde(default = "default_temperature")]
|
||||
pub temperature: f32,
|
||||
#[serde(default = "default_top_k")]
|
||||
pub top_k: u32,
|
||||
#[serde(default = "default_top_p")]
|
||||
pub top_p: f32,
|
||||
#[serde(default = "default_do_sample")]
|
||||
pub do_sample: bool,
|
||||
#[serde(default = "default_max_new_tokens")]
|
||||
pub max_new_tokens: u32,
|
||||
}
|
||||
|
||||
fn default_temperature() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn default_top_k() -> u32 {
|
||||
0
|
||||
}
|
||||
|
||||
fn default_top_p() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn default_do_sample() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> u32 {
|
||||
20
|
||||
}
|
||||
|
||||
fn default_parameters() -> GenerateParameters {
|
||||
GenerateParameters {
|
||||
temperature: default_temperature(),
|
||||
top_k: default_top_k(),
|
||||
top_p: default_top_p(),
|
||||
do_sample: default_do_sample(),
|
||||
max_new_tokens: default_max_new_tokens(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub(crate) struct GenerateRequest {
|
||||
pub inputs: String,
|
||||
#[serde(default = "default_parameters")]
|
||||
pub parameters: GenerateParameters,
|
||||
}
|
||||
|
||||
|
||||
#[handler]
|
||||
#[instrument(skip(infer), fields(time, time_per_token))]
|
||||
async fn generate(
|
||||
infer: Data<&Batcher>,
|
||||
req: Json<GenerateRequest>,
|
||||
) -> poem::Result<Json<serde_json::Value>> {
|
||||
let start = Instant::now();
|
||||
|
||||
let output = infer
|
||||
.infer(GenerateRequest {
|
||||
inputs: req.inputs.clone(),
|
||||
parameters: req.parameters.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
match output {
|
||||
Ok(generated_text) => {
|
||||
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
||||
tracing::Span::current().record(
|
||||
"time_per_token",
|
||||
format!("{:?}", start.elapsed() / req.parameters.max_new_tokens),
|
||||
);
|
||||
tracing::info!("response: {}", generated_text);
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"generated_text": generated_text,
|
||||
})))
|
||||
}
|
||||
Err(_) => Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(client: ShardedClient, listener: TcpListener<String>) -> Result<(), std::io::Error> {
|
||||
client
|
||||
.clear_cache()
|
||||
.await
|
||||
.expect("Unable to clear cache");
|
||||
tracing::info!("Connected");
|
||||
|
||||
let infer = Batcher::new(client);
|
||||
|
||||
let app = Route::new()
|
||||
.at("/generate", post(generate))
|
||||
.with(AddData::new(infer));
|
||||
|
||||
Server::new(listener)
|
||||
.run(app)
|
||||
.await
|
||||
}
|
Loading…
Reference in New Issue