Add load testing
This commit is contained in:
parent
1d986983d5
commit
fa9a088467
|
@ -0,0 +1 @@
|
||||||
|
.idea
|
|
@ -0,0 +1,97 @@
|
||||||
|
import http from 'k6/http';
|
||||||
|
import {check, sleep} from 'k6';
|
||||||
|
|
||||||
|
export const options = {
|
||||||
|
stages: [
|
||||||
|
{duration: '1m', target: 50},
|
||||||
|
{duration: '2m', target: 100},
|
||||||
|
{duration: '1m', target: 0},
|
||||||
|
],
|
||||||
|
hosts: {
|
||||||
|
'text-generation-inference.huggingface.co': '127.0.0.1:3000',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
const SLEEP_DURATION = 1;
|
||||||
|
|
||||||
|
function greedy_example(inputs, max_new_tokens, name) {
|
||||||
|
let body = JSON.stringify({
|
||||||
|
inputs: inputs,
|
||||||
|
parameters: {
|
||||||
|
max_new_tokens: max_new_tokens,
|
||||||
|
do_sample: false,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let params = {
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
tags: {
|
||||||
|
name: name
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return http.post('http://text-generation-inference.huggingface.co/generate', body, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
function sample_example(inputs, max_new_tokens, name) {
|
||||||
|
let body = JSON.stringify({
|
||||||
|
inputs: inputs,
|
||||||
|
parameters: {
|
||||||
|
max_new_tokens: max_new_tokens,
|
||||||
|
do_sample: true,
|
||||||
|
top_p: 0.9
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let params = {
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
tags: {
|
||||||
|
name: name
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return http.post('http://text-generation-inference.huggingface.co/generate', body, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function () {
|
||||||
|
const response_1 = sample_example('A "whatpu" is a small, furry animal native to Tanzania. An example of a sentence that uses the word whatpu is: We were traveling in Africa and we saw these very cute whatpus. To do a "farduddle" means to jump up and down really fast. An example of a sentence that uses the word farduddle is:', 32, 'example-1');
|
||||||
|
check(response_1, {
|
||||||
|
'is status 200': (r) => r.status === 200,
|
||||||
|
});
|
||||||
|
sleep(SLEEP_DURATION);
|
||||||
|
|
||||||
|
const response_2 = sample_example("A poem about the beauty of science by Alfred Edgar Brittle\\nTitle: The Magic Craft\\nIn the old times", 50, "example-2");
|
||||||
|
check(response_2, {
|
||||||
|
'is status 200': (r) => r.status === 200,
|
||||||
|
});
|
||||||
|
sleep(SLEEP_DURATION);
|
||||||
|
|
||||||
|
const response_3 = greedy_example("استخراج العدد العاملي في لغة بايثون: ", 30, "example-3");
|
||||||
|
check(response_3, {
|
||||||
|
'is status 200': (r) => r.status === 200,
|
||||||
|
});
|
||||||
|
sleep(SLEEP_DURATION);
|
||||||
|
|
||||||
|
const response_4 = sample_example("Pour déguster un ortolan, il faut tout d'abord", 32, "example-4");
|
||||||
|
check(response_4, {
|
||||||
|
'is status 200': (r) => r.status === 200,
|
||||||
|
});
|
||||||
|
sleep(SLEEP_DURATION);
|
||||||
|
|
||||||
|
const response_5 = sample_example("Traduce español de España a español de Argentina\nEl coche es rojo - el auto es rojo\nEl ordenador es nuevo - la computadora es nueva\nel boligrafo es negro -", 16, "example-5");
|
||||||
|
check(response_5, {
|
||||||
|
'is status 200': (r) => r.status === 200,
|
||||||
|
});
|
||||||
|
sleep(SLEEP_DURATION);
|
||||||
|
|
||||||
|
const response_6 = sample_example("Question: If I put cheese into the fridge, will it melt?\nAnswer:", 32, "example-6");
|
||||||
|
check(response_6, {
|
||||||
|
'is status 200': (r) => r.status === 200,
|
||||||
|
});
|
||||||
|
sleep(SLEEP_DURATION);
|
||||||
|
|
||||||
|
const response_7 = greedy_example("Question: Where does the Greek Goddess Persephone spend half of the year when she is not with her mother?\nAnswer:", 24, "example-7");
|
||||||
|
check(response_7, {
|
||||||
|
'is status 200': (r) => r.status === 200,
|
||||||
|
});
|
||||||
|
sleep(SLEEP_DURATION);
|
||||||
|
}
|
|
@ -1,14 +1,17 @@
|
||||||
use crate::{Db, GenerateRequest};
|
use crate::Db;
|
||||||
use bloom_inference_client::{Batch, BatchCached, CacheEntry, ClientError, FinishedGeneration, ShardedClient};
|
use bloom_inference_client::{
|
||||||
|
Batch, BatchCached, CacheEntry, ClientError, FinishedGeneration, ShardedClient,
|
||||||
|
};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{oneshot, Notify};
|
use tokio::sync::{Notify, oneshot};
|
||||||
|
use crate::server::GenerateRequest;
|
||||||
|
|
||||||
const MAX_LENGTH: usize = 128;
|
const MAX_LENGTH: usize = 128;
|
||||||
|
|
||||||
pub struct InferError {}
|
pub struct InferError {}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct Infer {
|
pub(crate) struct Batcher {
|
||||||
db: Db,
|
db: Db,
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
}
|
}
|
||||||
|
@ -17,7 +20,7 @@ struct Shared {
|
||||||
batching_task: Notify,
|
batching_task: Notify,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Infer {
|
impl Batcher {
|
||||||
pub(crate) fn new(client: ShardedClient) -> Self {
|
pub(crate) fn new(client: ShardedClient) -> Self {
|
||||||
let db = Db::new();
|
let db = Db::new();
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
|
@ -38,7 +41,7 @@ impl Infer {
|
||||||
self.shared.batching_task.notify_waiters();
|
self.shared.batching_task.notify_waiters();
|
||||||
match request_rx.await.unwrap() {
|
match request_rx.await.unwrap() {
|
||||||
Ok(output) => Ok(output),
|
Ok(output) => Ok(output),
|
||||||
Err(_) => Err(InferError {})
|
Err(_) => Err(InferError {}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -57,19 +60,19 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
|
||||||
let mut max_sequence_length = entry.sequence_length;
|
let mut max_sequence_length = entry.sequence_length;
|
||||||
let mut request_ids = entry.request_ids;
|
let mut request_ids = entry.request_ids;
|
||||||
|
|
||||||
if total_batch_size <= 16 {
|
// if total_batch_size <= 16 {
|
||||||
if let Some(batch) = db.next_batch_minimum_size(16, 48) {
|
// if let Some(batch) = db.next_batch_minimum_size(16, 48) {
|
||||||
let other_cache_entry = infer_batch(batch, &client, &db).await;
|
// let other_cache_entry = infer_batch(batch, &client, &db).await;
|
||||||
|
//
|
||||||
if let Some(entry) = other_cache_entry {
|
// if let Some(entry) = other_cache_entry {
|
||||||
batch_cached_ids.push(entry.id);
|
// batch_cached_ids.push(entry.id);
|
||||||
total_batch_size += entry.request_ids.len();
|
// total_batch_size += entry.request_ids.len();
|
||||||
max_sequence_length =
|
// max_sequence_length =
|
||||||
max_sequence_length.max(entry.sequence_length);
|
// max_sequence_length.max(entry.sequence_length);
|
||||||
request_ids.extend(entry.request_ids.into_iter());
|
// request_ids.extend(entry.request_ids.into_iter());
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
let batch_cached = BatchCached {
|
let batch_cached = BatchCached {
|
||||||
id: entry.id,
|
id: entry.id,
|
||||||
|
@ -87,7 +90,11 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn infer_batch_cached(batch: BatchCached, client: &ShardedClient, db: &Db) -> Option<CacheEntry> {
|
async fn infer_batch_cached(
|
||||||
|
batch: BatchCached,
|
||||||
|
client: &ShardedClient,
|
||||||
|
db: &Db,
|
||||||
|
) -> Option<CacheEntry> {
|
||||||
match client.generate_with_cache(batch.clone()).await {
|
match client.generate_with_cache(batch.clone()).await {
|
||||||
Ok((finished, cache_entry)) => {
|
Ok((finished, cache_entry)) => {
|
||||||
send_finished(finished, db);
|
send_finished(finished, db);
|
||||||
|
@ -109,7 +116,11 @@ async fn infer_batch(batch: Batch, client: &ShardedClient, db: &Db) -> Option<Ca
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
println!("{:?}", err);
|
println!("{:?}", err);
|
||||||
send_error(err, batch.requests.into_iter().map(|req| req.id).collect(), &db);
|
send_error(
|
||||||
|
err,
|
||||||
|
batch.requests.into_iter().map(|req| req.id).collect(),
|
||||||
|
&db,
|
||||||
|
);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
/// This code is massively inspired by Tokio mini-redis
|
/// This code is massively inspired by Tokio mini-redis
|
||||||
use crate::GenerateRequest;
|
use crate::server::GenerateRequest;
|
||||||
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
|
@ -44,7 +44,11 @@ impl Db {
|
||||||
Self { shared }
|
Self { shared }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn append(&self, request: GenerateRequest, sender: Sender<Result<String, ClientError>>) {
|
pub(crate) fn append(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
sender: Sender<Result<String, ClientError>>,
|
||||||
|
) {
|
||||||
let mut state = self.shared.state.write();
|
let mut state = self.shared.state.write();
|
||||||
|
|
||||||
let id = state.next_id;
|
let id = state.next_id;
|
||||||
|
@ -65,7 +69,10 @@ impl Db {
|
||||||
state.entries.insert(id, (request, sender));
|
state.entries.insert(id, (request, sender));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn remove(&self, id: &u64) -> Option<(Request, Sender<Result<String, ClientError>>)> {
|
pub(crate) fn remove(
|
||||||
|
&self,
|
||||||
|
id: &u64,
|
||||||
|
) -> Option<(Request, Sender<Result<String, ClientError>>)> {
|
||||||
let mut state = self.shared.state.write();
|
let mut state = self.shared.state.write();
|
||||||
state.entries.remove(id)
|
state.entries.remove(id)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,105 +1,15 @@
|
||||||
use tokio::time::Instant;
|
|
||||||
|
|
||||||
use poem;
|
|
||||||
use poem::middleware::AddData;
|
|
||||||
use poem::web::Data;
|
|
||||||
use poem::{handler, listener::TcpListener, post, web::Json, EndpointExt, Result, Route, Server};
|
|
||||||
|
|
||||||
use bloom_inference_client::ShardedClient;
|
use bloom_inference_client::ShardedClient;
|
||||||
use serde::Deserialize;
|
use poem;
|
||||||
|
use poem::listener::TcpListener;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use poem::http::StatusCode;
|
|
||||||
use tracing::instrument;
|
mod server;
|
||||||
|
|
||||||
mod db;
|
mod db;
|
||||||
|
|
||||||
use db::Db;
|
use db::Db;
|
||||||
|
|
||||||
mod infer;
|
mod batcher;
|
||||||
|
use batcher::Batcher;
|
||||||
use infer::Infer;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
|
||||||
struct GenerateParameters {
|
|
||||||
#[serde(default = "default_temperature")]
|
|
||||||
temperature: f32,
|
|
||||||
#[serde(default = "default_top_k")]
|
|
||||||
top_k: u32,
|
|
||||||
#[serde(default = "default_top_p")]
|
|
||||||
top_p: f32,
|
|
||||||
#[serde(default = "default_do_sample")]
|
|
||||||
do_sample: bool,
|
|
||||||
#[serde(default = "default_max_new_tokens")]
|
|
||||||
max_new_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_temperature() -> f32 {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_top_k() -> u32 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_top_p() -> f32 {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_do_sample() -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_max_new_tokens() -> u32 {
|
|
||||||
20
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
|
||||||
struct GenerateRequest {
|
|
||||||
inputs: String,
|
|
||||||
#[serde(default = "default_parameters")]
|
|
||||||
parameters: GenerateParameters,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_parameters() -> GenerateParameters {
|
|
||||||
GenerateParameters {
|
|
||||||
temperature: default_temperature(),
|
|
||||||
top_k: default_top_k(),
|
|
||||||
top_p: default_top_p(),
|
|
||||||
do_sample: default_do_sample(),
|
|
||||||
max_new_tokens: default_max_new_tokens(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[handler]
|
|
||||||
#[instrument(skip(infer), fields(time, time_per_token))]
|
|
||||||
async fn generate(
|
|
||||||
infer: Data<&Infer>,
|
|
||||||
req: Json<GenerateRequest>,
|
|
||||||
) -> Result<Json<serde_json::Value>> {
|
|
||||||
let start = Instant::now();
|
|
||||||
|
|
||||||
let output = infer
|
|
||||||
.infer(GenerateRequest {
|
|
||||||
inputs: req.inputs.clone(),
|
|
||||||
parameters: req.parameters.clone(),
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match output {
|
|
||||||
Ok(generated_text) => {
|
|
||||||
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
|
||||||
tracing::Span::current().record("time_per_token", format!("{:?}", start.elapsed() / req.parameters.max_new_tokens));
|
|
||||||
tracing::info!("response: {}", generated_text);
|
|
||||||
|
|
||||||
Ok(Json(serde_json::json!({
|
|
||||||
"generated_text": generated_text,
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), std::io::Error> {
|
async fn main() -> Result<(), std::io::Error> {
|
||||||
|
@ -114,12 +24,8 @@ async fn main() -> Result<(), std::io::Error> {
|
||||||
.expect("Unable to clear cache");
|
.expect("Unable to clear cache");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
let infer = Infer::new(sharded_client);
|
let addr = "127.0.0.1:3000".to_string();
|
||||||
|
let listener = TcpListener::bind(addr);
|
||||||
|
|
||||||
let app = Route::new()
|
server::run(sharded_client, listener).await
|
||||||
.at("/generate", post(generate))
|
|
||||||
.with(AddData::new(infer));
|
|
||||||
Server::new(TcpListener::bind("127.0.0.1:3000"))
|
|
||||||
.run(app)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
use poem::{EndpointExt, handler, post, Route, Server};
|
||||||
|
use poem::http::StatusCode;
|
||||||
|
use poem::listener::TcpListener;
|
||||||
|
use poem::middleware::AddData;
|
||||||
|
use poem::web::{Data, Json};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use crate::{Batcher, ShardedClient};
|
||||||
|
use tracing::instrument;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
pub(crate) struct GenerateParameters {
|
||||||
|
#[serde(default = "default_temperature")]
|
||||||
|
pub temperature: f32,
|
||||||
|
#[serde(default = "default_top_k")]
|
||||||
|
pub top_k: u32,
|
||||||
|
#[serde(default = "default_top_p")]
|
||||||
|
pub top_p: f32,
|
||||||
|
#[serde(default = "default_do_sample")]
|
||||||
|
pub do_sample: bool,
|
||||||
|
#[serde(default = "default_max_new_tokens")]
|
||||||
|
pub max_new_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_temperature() -> f32 {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_top_k() -> u32 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_top_p() -> f32 {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_do_sample() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_max_new_tokens() -> u32 {
|
||||||
|
20
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_parameters() -> GenerateParameters {
|
||||||
|
GenerateParameters {
|
||||||
|
temperature: default_temperature(),
|
||||||
|
top_k: default_top_k(),
|
||||||
|
top_p: default_top_p(),
|
||||||
|
do_sample: default_do_sample(),
|
||||||
|
max_new_tokens: default_max_new_tokens(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
pub(crate) struct GenerateRequest {
|
||||||
|
pub inputs: String,
|
||||||
|
#[serde(default = "default_parameters")]
|
||||||
|
pub parameters: GenerateParameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[handler]
|
||||||
|
#[instrument(skip(infer), fields(time, time_per_token))]
|
||||||
|
async fn generate(
|
||||||
|
infer: Data<&Batcher>,
|
||||||
|
req: Json<GenerateRequest>,
|
||||||
|
) -> poem::Result<Json<serde_json::Value>> {
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
|
let output = infer
|
||||||
|
.infer(GenerateRequest {
|
||||||
|
inputs: req.inputs.clone(),
|
||||||
|
parameters: req.parameters.clone(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match output {
|
||||||
|
Ok(generated_text) => {
|
||||||
|
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
||||||
|
tracing::Span::current().record(
|
||||||
|
"time_per_token",
|
||||||
|
format!("{:?}", start.elapsed() / req.parameters.max_new_tokens),
|
||||||
|
);
|
||||||
|
tracing::info!("response: {}", generated_text);
|
||||||
|
|
||||||
|
Ok(Json(serde_json::json!({
|
||||||
|
"generated_text": generated_text,
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
Err(_) => Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(client: ShardedClient, listener: TcpListener<String>) -> Result<(), std::io::Error> {
|
||||||
|
client
|
||||||
|
.clear_cache()
|
||||||
|
.await
|
||||||
|
.expect("Unable to clear cache");
|
||||||
|
tracing::info!("Connected");
|
||||||
|
|
||||||
|
let infer = Batcher::new(client);
|
||||||
|
|
||||||
|
let app = Route::new()
|
||||||
|
.at("/generate", post(generate))
|
||||||
|
.with(AddData::new(infer));
|
||||||
|
|
||||||
|
Server::new(listener)
|
||||||
|
.run(app)
|
||||||
|
.await
|
||||||
|
}
|
Loading…
Reference in New Issue