parent
ce960be0a5
commit
1539d3cbbe
|
@ -1087,6 +1087,12 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nohash-hasher"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nom"
|
name = "nom"
|
||||||
version = "7.1.1"
|
version = "7.1.1"
|
||||||
|
@ -1826,6 +1832,7 @@ dependencies = [
|
||||||
"axum",
|
"axum",
|
||||||
"clap 4.0.22",
|
"clap 4.0.22",
|
||||||
"futures",
|
"futures",
|
||||||
|
"nohash-hasher",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
|
@ -17,6 +17,7 @@ axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||||
futures = "0.3.24"
|
futures = "0.3.24"
|
||||||
|
nohash-hasher = "0.2.0"
|
||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
serde = "1.0.145"
|
serde = "1.0.145"
|
||||||
serde_json = "1.0.85"
|
serde_json = "1.0.85"
|
||||||
|
|
|
@ -3,6 +3,7 @@ use crate::{Db, Entry};
|
||||||
use crate::{ErrorResponse, GenerateRequest};
|
use crate::{ErrorResponse, GenerateRequest};
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use axum::Json;
|
use axum::Json;
|
||||||
|
use nohash_hasher::IntMap;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||||
|
@ -104,8 +105,8 @@ async fn batching_task(
|
||||||
// Get the next batch from the DB
|
// Get the next batch from the DB
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
// waiting in the DB
|
// waiting in the DB
|
||||||
while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
|
while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) {
|
||||||
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
|
let mut cached_batch = wrap_future(client.generate(batch), &mut entries).await;
|
||||||
let mut waiting_tokens = 1;
|
let mut waiting_tokens = 1;
|
||||||
|
|
||||||
// We loop until we do not receive any cached batch from the inference server (== until
|
// We loop until we do not receive any cached batch from the inference server (== until
|
||||||
|
@ -113,7 +114,6 @@ async fn batching_task(
|
||||||
while let Some(batch) = cached_batch {
|
while let Some(batch) = cached_batch {
|
||||||
// Get current batch info
|
// Get current batch info
|
||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
|
|
||||||
let mut batches = vec![batch];
|
let mut batches = vec![batch];
|
||||||
|
|
||||||
// If the current batch is too small, we try to add more requests to it
|
// If the current batch is too small, we try to add more requests to it
|
||||||
|
@ -127,24 +127,23 @@ async fn batching_task(
|
||||||
};
|
};
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((new_request_ids, new_batch)) =
|
if let Some((mut new_entries, new_batch)) =
|
||||||
db.next_batch(min_size, max_batch_size - batch_size as usize)
|
db.next_batch(min_size, max_batch_size - batch_size as usize)
|
||||||
{
|
{
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch =
|
let new_cached_batch =
|
||||||
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
wrap_future(client.generate(new_batch), &mut new_entries).await;
|
||||||
// Reset waiting counter
|
// Reset waiting counter
|
||||||
waiting_tokens = 1;
|
waiting_tokens = 1;
|
||||||
// Extend current batch with the new batch
|
// Extend current batch with the new batch
|
||||||
if let Some(new_cached_batch) = new_cached_batch {
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
entries.extend(new_entries);
|
||||||
batches.push(new_cached_batch);
|
batches.push(new_cached_batch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cached_batch =
|
cached_batch = wrap_future(client.generate_with_cache(batches), &mut entries).await;
|
||||||
wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
|
|
||||||
waiting_tokens += 1;
|
waiting_tokens += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -154,39 +153,36 @@ async fn batching_task(
|
||||||
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
||||||
async fn wrap_future(
|
async fn wrap_future(
|
||||||
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
||||||
request_ids: Vec<u64>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
db: &Db,
|
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
match future.await {
|
match future.await {
|
||||||
Ok((generated_texts, next_batch)) => {
|
Ok((generated_texts, next_batch)) => {
|
||||||
send_generated(generated_texts, db);
|
send_generated(generated_texts, entries);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
send_error(err, request_ids, db);
|
send_error(err, entries);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send errors to the Batcher for all `request_ids`
|
/// Send errors to the Batcher for all `entries`
|
||||||
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
|
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
request_ids.into_iter().for_each(|id| {
|
entries.drain().for_each(|(_, entry)| {
|
||||||
// We can `expect` here as the request id should always be in the DB
|
|
||||||
let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
|
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send `generated_text` to the Batcher for all `finished`
|
/// Send `generated_text` to the Batcher for all `finished`
|
||||||
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>) {
|
||||||
finished.into_iter().for_each(|output| {
|
finished.into_iter().for_each(|output| {
|
||||||
// We can `expect` here as the request id should always be in the DB
|
// We can `expect` here as the request id should always be in the entries
|
||||||
let entry = db
|
let entry = entries
|
||||||
.remove(&output.request.unwrap().id)
|
.remove(&output.request.unwrap().id)
|
||||||
.expect("ID not found in db. This is a bug.");
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
let response = InferResponse {
|
let response = InferResponse {
|
||||||
output_text: output.output_text,
|
output_text: output.output_text,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use crate::InferResponse;
|
|
||||||
/// This code is massively inspired by Tokio mini-redis
|
/// This code is massively inspired by Tokio mini-redis
|
||||||
|
use crate::InferResponse;
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -112,18 +113,12 @@ impl Db {
|
||||||
state.entries.insert(id, entry);
|
state.entries.insert(id, entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Remove an entry from the database if it exists
|
|
||||||
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
|
|
||||||
let mut state = self.shared.state.lock();
|
|
||||||
state.entries.remove(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
pub(crate) fn next_batch(
|
pub(crate) fn next_batch(
|
||||||
&self,
|
&self,
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
max_size: usize,
|
max_size: usize,
|
||||||
) -> Option<(Vec<u64>, Batch)> {
|
) -> Option<(IntMap<u64, Entry>, Batch)> {
|
||||||
// Acquire lock
|
// Acquire lock
|
||||||
let mut state = self.shared.state.lock();
|
let mut state = self.shared.state.lock();
|
||||||
|
|
||||||
|
@ -135,13 +130,19 @@ impl Db {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ids.iter().for_each(|id| {
|
|
||||||
// Set batch_time for each request
|
|
||||||
state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now());
|
|
||||||
});
|
|
||||||
|
|
||||||
// Batch size
|
// Batch size
|
||||||
let size = requests.len();
|
let size = requests.len();
|
||||||
|
|
||||||
|
let mut entries = IntMap::with_capacity_and_hasher(size, BuildNoHashHasher::default());
|
||||||
|
ids.iter().for_each(|id| {
|
||||||
|
// Remove entry from db
|
||||||
|
let mut entry = state.entries.remove(id).unwrap();
|
||||||
|
// Set batch_time
|
||||||
|
entry.batch_time = Some(Instant::now());
|
||||||
|
// Insert in entries IntMap
|
||||||
|
entries.insert(*id, entry);
|
||||||
|
});
|
||||||
|
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: state.next_batch_id,
|
id: state.next_batch_id,
|
||||||
requests,
|
requests,
|
||||||
|
@ -152,7 +153,7 @@ impl Db {
|
||||||
// Increment batch id
|
// Increment batch id
|
||||||
state.next_batch_id += 1;
|
state.next_batch_id += 1;
|
||||||
|
|
||||||
return Some((ids, batch));
|
return Some((entries, batch));
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue