feat: Add token streaming using ServerSideEvents support (#36)
Add token streaming using ServerSideEvents (SSE). The signature of the SSE events is: ```rust struct Details { finish_reason: String, generated_tokens: u32, seed: Option<u64>, } struct StreamResponse { token: Token, generated_text: Option<String>, details: Option<Details>, } struct ErrorResponse { error: String, } ```
This commit is contained in:
parent
cd298bc5e5
commit
7fbfbb0dc5
|
@ -1829,6 +1829,7 @@ dependencies = [
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"async-stream",
|
||||||
"axum",
|
"axum",
|
||||||
"clap 4.0.22",
|
"clap 4.0.22",
|
||||||
"futures",
|
"futures",
|
||||||
|
@ -1840,6 +1841,7 @@ dependencies = [
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-stream",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
]
|
]
|
||||||
|
|
|
@ -16,4 +16,4 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] }
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
float_eq = "1.0.1"
|
float_eq = "1.0.1"
|
||||||
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
|
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
|
||||||
serde = "1.0.150"
|
serde = { version = "1.0.150", features = ["derive"] }
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
"details": {
|
"details": {
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
"generated_tokens": 20,
|
"generated_tokens": 20,
|
||||||
"tokens": [
|
"prefill": [
|
||||||
[
|
[
|
||||||
10264,
|
10264,
|
||||||
"Test",
|
"Test",
|
||||||
|
@ -13,7 +13,9 @@
|
||||||
8821,
|
8821,
|
||||||
" request",
|
" request",
|
||||||
-11.895094
|
-11.895094
|
||||||
],
|
]
|
||||||
|
],
|
||||||
|
"tokens": [
|
||||||
[
|
[
|
||||||
17,
|
17,
|
||||||
".",
|
".",
|
||||||
|
|
|
@ -3,12 +3,14 @@
|
||||||
"details": {
|
"details": {
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
"generated_tokens": 20,
|
"generated_tokens": 20,
|
||||||
"tokens": [
|
"prefill": [
|
||||||
[
|
[
|
||||||
0,
|
0,
|
||||||
"<pad>",
|
"<pad>",
|
||||||
null
|
null
|
||||||
],
|
]
|
||||||
|
],
|
||||||
|
"tokens": [
|
||||||
[
|
[
|
||||||
259,
|
259,
|
||||||
"",
|
"",
|
||||||
|
|
|
@ -7,10 +7,10 @@ service TextGenerationService {
|
||||||
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
|
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
|
||||||
/// Empties batch cache
|
/// Empties batch cache
|
||||||
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
||||||
/// Generate tokens for a batch
|
/// Prefill batch and decode first token
|
||||||
rpc Generate (GenerateRequest) returns (GenerateResponse);
|
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||||
/// Generate tokens for a list of cached batches
|
/// Decode token for a list of prefilled batches
|
||||||
rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse);
|
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
|
@ -70,44 +70,60 @@ message Batch {
|
||||||
}
|
}
|
||||||
|
|
||||||
message GeneratedText {
|
message GeneratedText {
|
||||||
/// Request
|
|
||||||
Request request = 1;
|
|
||||||
/// Output
|
/// Output
|
||||||
string output_text = 2;
|
string text = 1;
|
||||||
/// Number of generated tokens
|
/// Number of generated tokens
|
||||||
uint32 generated_tokens = 3;
|
uint32 generated_tokens = 2;
|
||||||
/// Tokens
|
|
||||||
repeated string tokens = 4;
|
|
||||||
/// Token IDs
|
|
||||||
repeated uint32 token_ids = 5;
|
|
||||||
/// Logprobs
|
|
||||||
repeated float logprobs = 6;
|
|
||||||
/// Finish reason
|
/// Finish reason
|
||||||
string finish_reason = 7;
|
string finish_reason = 3;
|
||||||
/// Seed
|
/// Seed
|
||||||
optional uint64 seed = 8;
|
optional uint64 seed = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateRequest {
|
message PrefillTokens {
|
||||||
|
/// Prefill Token IDs
|
||||||
|
repeated uint32 ids = 1;
|
||||||
|
/// Prefill Logprobs
|
||||||
|
repeated float logprobs = 2;
|
||||||
|
/// Prefill tokens
|
||||||
|
repeated string texts = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Generation {
|
||||||
|
/// Request ID
|
||||||
|
uint64 request_id = 1;
|
||||||
|
/// Prefill tokens (optional)
|
||||||
|
PrefillTokens prefill_tokens = 2;
|
||||||
|
/// Token ID
|
||||||
|
uint32 token_id = 3;
|
||||||
|
/// Logprob
|
||||||
|
float token_logprob = 4;
|
||||||
|
/// Text
|
||||||
|
string token_text = 5;
|
||||||
|
/// Complete generated text
|
||||||
|
GeneratedText generated_text = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PrefillRequest {
|
||||||
/// Batch
|
/// Batch
|
||||||
Batch batch = 1;
|
Batch batch = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateResponse {
|
message PrefillResponse {
|
||||||
/// Finished requests
|
/// Generation
|
||||||
repeated GeneratedText generated_texts = 1;
|
repeated Generation generations = 1;
|
||||||
/// Next batch (cached)
|
/// Next batch (cached)
|
||||||
optional Batch batch = 2;
|
optional Batch batch = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateWithCacheRequest {
|
message DecodeRequest {
|
||||||
/// Cached batches
|
/// Cached batches
|
||||||
repeated Batch batches = 1;
|
repeated Batch batches = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateWithCacheResponse {
|
message DecodeResponse {
|
||||||
/// Finished requests
|
/// Decodes
|
||||||
repeated GeneratedText generated_texts = 1;
|
repeated Generation generations = 1;
|
||||||
/// Next batch (cached)
|
/// Next batch (cached)
|
||||||
optional Batch batch = 2;
|
optional Batch batch = 2;
|
||||||
}
|
}
|
|
@ -13,6 +13,7 @@ name = "text-generation-router"
|
||||||
path = "src/main.rs"
|
path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
async-stream = "0.3.3"
|
||||||
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||||
|
@ -24,6 +25,7 @@ serde_json = "1.0.85"
|
||||||
thiserror = "1.0.37"
|
thiserror = "1.0.37"
|
||||||
tokenizers = "0.13.0"
|
tokenizers = "0.13.0"
|
||||||
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
|
tokio-stream = "0.1.11"
|
||||||
tracing = "0.1.36"
|
tracing = "0.1.36"
|
||||||
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
||||||
|
|
||||||
|
|
|
@ -70,36 +70,36 @@ impl Client {
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
///
|
///
|
||||||
/// Returns a list of generated texts of request that met their stopping criteria
|
/// Returns Generation for each request in batch
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||||
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) });
|
||||||
let response = self
|
let response = self
|
||||||
.stub
|
.stub
|
||||||
.generate(request)
|
.prefill(request)
|
||||||
.instrument(info_span!("generate"))
|
.instrument(info_span!("prefill"))
|
||||||
.await?
|
.await?
|
||||||
.into_inner();
|
.into_inner();
|
||||||
Ok((response.generated_texts, response.batch))
|
Ok((response.generations, response.batch))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given cached batch
|
/// Generate one token for each request in the given cached batches
|
||||||
///
|
///
|
||||||
/// Returns a list of generated texts of request that met their stopping criteria
|
/// Returns Generation for each request in batches
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn generate_with_cache(
|
pub async fn decode(
|
||||||
&mut self,
|
&mut self,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<Batch>,
|
||||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||||
let request = tonic::Request::new(GenerateWithCacheRequest { batches });
|
let request = tonic::Request::new(DecodeRequest { batches });
|
||||||
let response = self
|
let response = self
|
||||||
.stub
|
.stub
|
||||||
.generate_with_cache(request)
|
.decode(request)
|
||||||
.instrument(info_span!("generate_with_cache"))
|
.instrument(info_span!("decode"))
|
||||||
.await?
|
.await?
|
||||||
.into_inner();
|
.into_inner();
|
||||||
Ok((response.generated_texts, response.batch))
|
Ok((response.generations, response.batch))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,8 @@ mod sharded_client;
|
||||||
|
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
pub use pb::generate::v1::{
|
pub use pb::generate::v1::{
|
||||||
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request,
|
||||||
|
StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use crate::{Batch, Client, GeneratedText};
|
use crate::{Batch, Client, Generation};
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use futures::future::select_all;
|
use futures::future::select_all;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
|
@ -37,39 +37,6 @@ impl ShardedClient {
|
||||||
Self::from_master_client(master_client).await
|
Self::from_master_client(master_client).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
|
||||||
///
|
|
||||||
/// Returns a list of generated texts of request that met their stopping criteria
|
|
||||||
/// and the next cached batch
|
|
||||||
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
|
||||||
let futures: Vec<_> = self
|
|
||||||
.clients
|
|
||||||
.iter_mut()
|
|
||||||
.map(|client| Box::pin(client.generate(batch.clone())))
|
|
||||||
.collect();
|
|
||||||
// As soon as we receive one response, we can return as all shards will return the same
|
|
||||||
let (result, _, _) = select_all(futures).await;
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate one token for each request in the given cached batch
|
|
||||||
///
|
|
||||||
/// Returns a list of generated texts of request that met their stopping criteria
|
|
||||||
/// and the next cached batch
|
|
||||||
pub async fn generate_with_cache(
|
|
||||||
&mut self,
|
|
||||||
batches: Vec<Batch>,
|
|
||||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
|
||||||
let futures: Vec<_> = self
|
|
||||||
.clients
|
|
||||||
.iter_mut()
|
|
||||||
.map(|client| Box::pin(client.generate_with_cache(batches.clone())))
|
|
||||||
.collect();
|
|
||||||
// As soon as we receive one response, we can return as all shards will return the same
|
|
||||||
let (result, _, _) = select_all(futures).await;
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Clear the past generations cache
|
/// Clear the past generations cache
|
||||||
pub async fn clear_cache(&mut self) -> Result<()> {
|
pub async fn clear_cache(&mut self) -> Result<()> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
|
@ -79,4 +46,37 @@ impl ShardedClient {
|
||||||
.collect();
|
.collect();
|
||||||
join_all(futures).await.into_iter().collect()
|
join_all(futures).await.into_iter().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batch
|
||||||
|
/// and the next cached batch
|
||||||
|
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||||
|
.collect();
|
||||||
|
// As soon as we receive one response, we can return as all shards will return the same
|
||||||
|
let (result, _, _) = select_all(futures).await;
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batches
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batches
|
||||||
|
/// and the next cached batch
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
batches: Vec<Batch>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||||
|
.collect();
|
||||||
|
// As soon as we receive one response, we can return as all shards will return the same
|
||||||
|
let (result, _, _) = select_all(futures).await;
|
||||||
|
result
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,236 +0,0 @@
|
||||||
/// Batching and inference logic
|
|
||||||
use crate::{Db, Entry};
|
|
||||||
use crate::{ErrorResponse, GenerateRequest};
|
|
||||||
use axum::http::StatusCode;
|
|
||||||
use axum::Json;
|
|
||||||
use nohash_hasher::IntMap;
|
|
||||||
use std::future::Future;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
|
||||||
use thiserror::Error;
|
|
||||||
use tokio::sync::{oneshot, Notify};
|
|
||||||
use tokio::time::Instant;
|
|
||||||
use tracing::instrument;
|
|
||||||
|
|
||||||
/// Batcher
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct Batcher {
|
|
||||||
/// Request database
|
|
||||||
db: Db,
|
|
||||||
/// Shared state
|
|
||||||
shared: Arc<Shared>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Batcher shared state
|
|
||||||
struct Shared {
|
|
||||||
/// Batching background Tokio task notifier
|
|
||||||
batching_task: Notify,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Batcher {
|
|
||||||
pub(crate) fn new(
|
|
||||||
client: ShardedClient,
|
|
||||||
max_batch_size: usize,
|
|
||||||
max_waiting_tokens: usize,
|
|
||||||
) -> Self {
|
|
||||||
// Batcher shared state
|
|
||||||
let db = Db::new();
|
|
||||||
let shared = Arc::new(Shared {
|
|
||||||
batching_task: Notify::new(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Spawn batching background task that contains all the inference logic
|
|
||||||
tokio::spawn(batching_task(
|
|
||||||
client,
|
|
||||||
max_batch_size,
|
|
||||||
max_waiting_tokens,
|
|
||||||
db.clone(),
|
|
||||||
shared.clone(),
|
|
||||||
));
|
|
||||||
|
|
||||||
Self { db, shared }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a new request to the database and return a future that will generate the text
|
|
||||||
pub(crate) async fn infer(
|
|
||||||
&self,
|
|
||||||
input_length: usize,
|
|
||||||
request: GenerateRequest,
|
|
||||||
) -> Result<InferResponse, InferError> {
|
|
||||||
// One shot channel to communicate with the background batching task
|
|
||||||
let (response_tx, response_rx) = oneshot::channel();
|
|
||||||
|
|
||||||
// Try to append the request to the database
|
|
||||||
self.db.append(Entry {
|
|
||||||
request,
|
|
||||||
response_tx,
|
|
||||||
input_length,
|
|
||||||
time: Instant::now(),
|
|
||||||
batch_time: None,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Notify the background task that we have a new entry in the database that needs
|
|
||||||
// to be batched
|
|
||||||
self.shared.batching_task.notify_one();
|
|
||||||
|
|
||||||
// Await on the response from the background task
|
|
||||||
// We can safely unwrap as the background task will never drop the sender
|
|
||||||
response_rx
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.map_err(|err| InferError::GenerationError(err.to_string()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Batching logic
|
|
||||||
/// Will be launched in a background Tokio task
|
|
||||||
///
|
|
||||||
/// Batches requests and sends them to the inference server
|
|
||||||
#[instrument(skip(client, db, shared))]
|
|
||||||
async fn batching_task(
|
|
||||||
mut client: ShardedClient,
|
|
||||||
max_batch_size: usize,
|
|
||||||
max_waiting_tokens: usize,
|
|
||||||
db: Db,
|
|
||||||
shared: Arc<Shared>,
|
|
||||||
) {
|
|
||||||
// Minimum batch size after which we try to add more requests
|
|
||||||
let limit_min_batch_size = (max_batch_size / 2) as u32;
|
|
||||||
|
|
||||||
// Infinite loop
|
|
||||||
loop {
|
|
||||||
// Wait for a notification from the Batcher struct
|
|
||||||
shared.batching_task.notified().await;
|
|
||||||
|
|
||||||
// Get the next batch from the DB
|
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
|
||||||
// waiting in the DB
|
|
||||||
while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) {
|
|
||||||
let mut cached_batch = wrap_future(client.generate(batch), &mut entries).await;
|
|
||||||
let mut waiting_tokens = 1;
|
|
||||||
|
|
||||||
// We loop until we do not receive any cached batch from the inference server (== until
|
|
||||||
// all requests have met their stopping criteria)
|
|
||||||
while let Some(batch) = cached_batch {
|
|
||||||
// Get current batch info
|
|
||||||
let batch_size = batch.size;
|
|
||||||
let mut batches = vec![batch];
|
|
||||||
|
|
||||||
// If the current batch is too small, we try to add more requests to it
|
|
||||||
if batch_size <= limit_min_batch_size {
|
|
||||||
let min_size = match waiting_tokens {
|
|
||||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
|
||||||
// to add a new batch even though its size might be small
|
|
||||||
_ if waiting_tokens >= max_waiting_tokens => None,
|
|
||||||
// Minimum size criteria
|
|
||||||
_ => Some(limit_min_batch_size as usize),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Try to get a new batch
|
|
||||||
if let Some((mut new_entries, new_batch)) =
|
|
||||||
db.next_batch(min_size, max_batch_size - batch_size as usize)
|
|
||||||
{
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
|
||||||
let new_cached_batch =
|
|
||||||
wrap_future(client.generate(new_batch), &mut new_entries).await;
|
|
||||||
// Reset waiting counter
|
|
||||||
waiting_tokens = 1;
|
|
||||||
// Extend current batch with the new batch
|
|
||||||
if let Some(new_cached_batch) = new_cached_batch {
|
|
||||||
entries.extend(new_entries);
|
|
||||||
batches.push(new_cached_batch);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cached_batch = wrap_future(client.generate_with_cache(batches), &mut entries).await;
|
|
||||||
waiting_tokens += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
|
||||||
async fn wrap_future(
|
|
||||||
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
|
||||||
entries: &mut IntMap<u64, Entry>,
|
|
||||||
) -> Option<Batch> {
|
|
||||||
match future.await {
|
|
||||||
Ok((generated_texts, next_batch)) => {
|
|
||||||
send_generated(generated_texts, entries);
|
|
||||||
next_batch
|
|
||||||
}
|
|
||||||
// If we have an error, we discard the whole batch
|
|
||||||
Err(err) => {
|
|
||||||
send_error(err, entries);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send errors to the Batcher for all `entries`
|
|
||||||
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|
||||||
entries.drain().for_each(|(_, entry)| {
|
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
|
||||||
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send `generated_text` to the Batcher for all `finished`
|
|
||||||
fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>) {
|
|
||||||
finished.into_iter().for_each(|output| {
|
|
||||||
// We can `expect` here as the request id should always be in the entries
|
|
||||||
let entry = entries
|
|
||||||
.remove(&output.request.unwrap().id)
|
|
||||||
.expect("ID not found in entries. This is a bug.");
|
|
||||||
|
|
||||||
let response = InferResponse {
|
|
||||||
output_text: output.output_text,
|
|
||||||
generated_tokens: output.generated_tokens,
|
|
||||||
token_ids: output.token_ids,
|
|
||||||
tokens: output.tokens,
|
|
||||||
logprobs: output.logprobs,
|
|
||||||
finish_reason: output.finish_reason,
|
|
||||||
seed: output.seed,
|
|
||||||
queued: entry.time,
|
|
||||||
start: entry.batch_time.unwrap(), // unwrap is always valid
|
|
||||||
end: Instant::now(),
|
|
||||||
};
|
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
|
||||||
entry.response_tx.send(Ok(response)).unwrap_or(());
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(crate) struct InferResponse {
|
|
||||||
pub(crate) output_text: String,
|
|
||||||
pub(crate) generated_tokens: u32,
|
|
||||||
pub(crate) token_ids: Vec<u32>,
|
|
||||||
pub(crate) tokens: Vec<String>,
|
|
||||||
pub(crate) logprobs: Vec<f32>,
|
|
||||||
pub(crate) finish_reason: String,
|
|
||||||
pub(crate) seed: Option<u64>,
|
|
||||||
pub(crate) queued: Instant,
|
|
||||||
pub(crate) start: Instant,
|
|
||||||
pub(crate) end: Instant,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum InferError {
|
|
||||||
#[error("Request failed during generation: {0}")]
|
|
||||||
GenerationError(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert to Axum supported format
|
|
||||||
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|
||||||
fn from(err: InferError) -> Self {
|
|
||||||
match err {
|
|
||||||
InferError::GenerationError(_) => (
|
|
||||||
StatusCode::FAILED_DEPENDENCY,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: err.to_string(),
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,14 +1,16 @@
|
||||||
/// This code is massively inspired by Tokio mini-redis
|
/// This code is massively inspired by Tokio mini-redis
|
||||||
use crate::InferResponse;
|
use crate::infer::InferError;
|
||||||
|
use crate::infer::InferStreamResponse;
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use tokio::sync::oneshot::Sender;
|
use tokio::sync::mpsc::UnboundedSender;
|
||||||
|
use tokio::sync::OwnedSemaphorePermit;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
|
||||||
/// Database entry
|
/// Database entry
|
||||||
|
@ -16,14 +18,16 @@ use tokio::time::Instant;
|
||||||
pub(crate) struct Entry {
|
pub(crate) struct Entry {
|
||||||
/// Request
|
/// Request
|
||||||
pub request: GenerateRequest,
|
pub request: GenerateRequest,
|
||||||
/// Response sender to communicate between the Batcher and the batching_task
|
/// Response sender to communicate between the Infer struct and the batching_task
|
||||||
pub response_tx: Sender<Result<InferResponse, ClientError>>,
|
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||||
/// Number of tokens in the input
|
/// Number of tokens in the input
|
||||||
pub input_length: usize,
|
pub input_length: usize,
|
||||||
/// Instant when this entry was created
|
/// Instant when this entry was created
|
||||||
pub time: Instant,
|
pub time: Instant,
|
||||||
/// Instant when this entry was added to a batch
|
/// Instant when this entry was added to a batch
|
||||||
pub batch_time: Option<Instant>,
|
pub batch_time: Option<Instant>,
|
||||||
|
/// Permit
|
||||||
|
pub _permit: OwnedSemaphorePermit,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Request Database
|
/// Request Database
|
||||||
|
|
|
@ -0,0 +1,354 @@
|
||||||
|
/// Batching and inference logic
|
||||||
|
use crate::validation::{Validation, ValidationError};
|
||||||
|
use crate::GenerateRequest;
|
||||||
|
use crate::{Db, Entry, Token};
|
||||||
|
use nohash_hasher::IntMap;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use text_generation_client::{
|
||||||
|
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||||
|
};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tokio_stream::StreamExt;
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
/// Inference struct
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Infer {
|
||||||
|
/// Validation
|
||||||
|
validation: Validation,
|
||||||
|
/// Request database
|
||||||
|
db: Db,
|
||||||
|
/// Shared state
|
||||||
|
shared: Arc<Shared>,
|
||||||
|
/// Inference limit
|
||||||
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Infer shared state
|
||||||
|
struct Shared {
|
||||||
|
/// Batching background Tokio task notifier
|
||||||
|
batching_task: Notify,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Infer {
|
||||||
|
pub(crate) fn new(
|
||||||
|
client: ShardedClient,
|
||||||
|
validation: Validation,
|
||||||
|
max_batch_size: usize,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
) -> Self {
|
||||||
|
// Infer shared state
|
||||||
|
let db = Db::new();
|
||||||
|
let shared = Arc::new(Shared {
|
||||||
|
batching_task: Notify::new(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Spawn batching background task that contains all the inference logic
|
||||||
|
tokio::spawn(batching_task(
|
||||||
|
client,
|
||||||
|
max_batch_size,
|
||||||
|
max_waiting_tokens,
|
||||||
|
db.clone(),
|
||||||
|
shared.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
// Inference limit with a semaphore
|
||||||
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
validation,
|
||||||
|
db,
|
||||||
|
shared,
|
||||||
|
limit_concurrent_requests: semaphore,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a new request to the database and return a stream of InferStreamResponse
|
||||||
|
pub(crate) async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
|
// This permit will live as long as Entry
|
||||||
|
let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?;
|
||||||
|
|
||||||
|
// Validate request
|
||||||
|
let (input_length, validated_request) = self.validation.validate(request).await?;
|
||||||
|
|
||||||
|
// MPSC channel to communicate with the background batching task
|
||||||
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
// Append the request to the database
|
||||||
|
self.db.append(Entry {
|
||||||
|
request: validated_request,
|
||||||
|
response_tx,
|
||||||
|
input_length,
|
||||||
|
time: Instant::now(),
|
||||||
|
batch_time: None,
|
||||||
|
_permit: permit,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Notify the background task that we have a new entry in the database that needs
|
||||||
|
// to be batched
|
||||||
|
self.shared.batching_task.notify_one();
|
||||||
|
|
||||||
|
// Return stream
|
||||||
|
Ok(UnboundedReceiverStream::new(response_rx))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a new request to the database and return a InferResponse
|
||||||
|
pub(crate) async fn generate(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> Result<InferResponse, InferError> {
|
||||||
|
// Create stream
|
||||||
|
let mut stream = self.generate_stream(request).await?;
|
||||||
|
|
||||||
|
// Return values
|
||||||
|
let mut result_prefill = Vec::new();
|
||||||
|
let mut result_tokens = Vec::new();
|
||||||
|
let mut result_generated_text = None;
|
||||||
|
let mut result_start = None;
|
||||||
|
let mut result_queued = None;
|
||||||
|
|
||||||
|
// Iterate on stream
|
||||||
|
while let Some(response) = stream.next().await {
|
||||||
|
match response? {
|
||||||
|
// Add prefill tokens
|
||||||
|
InferStreamResponse::Prefill(tokens) => {
|
||||||
|
// Create Token objects
|
||||||
|
// We do that here instead of in the Python code as Rust for loops are faster
|
||||||
|
result_prefill = tokens
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(tokens.logprobs.into_iter())
|
||||||
|
.zip(tokens.texts.into_iter())
|
||||||
|
.map(|((id, logprob), text)| Token(id, text, logprob))
|
||||||
|
.collect();
|
||||||
|
}
|
||||||
|
// Push last token
|
||||||
|
InferStreamResponse::Token(token) => result_tokens.push(token),
|
||||||
|
// Final message
|
||||||
|
// Set return values
|
||||||
|
InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
generated_text,
|
||||||
|
start,
|
||||||
|
queued,
|
||||||
|
} => {
|
||||||
|
result_tokens.push(token);
|
||||||
|
result_generated_text = Some(generated_text);
|
||||||
|
result_start = Some(start);
|
||||||
|
result_queued = Some(queued)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that we received a `InferStreamResponse::End` message
|
||||||
|
if let (Some(generated_text), Some(queued), Some(start)) =
|
||||||
|
(result_generated_text, result_queued, result_start)
|
||||||
|
{
|
||||||
|
Ok(InferResponse {
|
||||||
|
prefill: result_prefill,
|
||||||
|
tokens: result_tokens,
|
||||||
|
generated_text,
|
||||||
|
queued,
|
||||||
|
start,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Err(InferError::IncompleteGeneration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Batching logic
|
||||||
|
/// Will be launched in a background Tokio task
|
||||||
|
///
|
||||||
|
/// Batches requests and sends them to the inference server
|
||||||
|
#[instrument(skip(client, db, shared))]
|
||||||
|
async fn batching_task(
|
||||||
|
mut client: ShardedClient,
|
||||||
|
max_batch_size: usize,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
db: Db,
|
||||||
|
shared: Arc<Shared>,
|
||||||
|
) {
|
||||||
|
// Minimum batch size after which we try to add more requests
|
||||||
|
let limit_min_batch_size = (max_batch_size / 2) as u32;
|
||||||
|
|
||||||
|
// Infinite loop
|
||||||
|
loop {
|
||||||
|
// Wait for a notification from the Infer struct
|
||||||
|
shared.batching_task.notified().await;
|
||||||
|
|
||||||
|
// Get the next batch from the DB
|
||||||
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
|
// waiting in the DB
|
||||||
|
while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) {
|
||||||
|
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await;
|
||||||
|
let mut waiting_tokens = 1;
|
||||||
|
|
||||||
|
// We loop until we do not receive any cached batch from the inference server (== until
|
||||||
|
// all requests have met their stopping criteria)
|
||||||
|
while let Some(batch) = cached_batch {
|
||||||
|
// Get current batch info
|
||||||
|
let batch_size = batch.size;
|
||||||
|
let mut batches = vec![batch];
|
||||||
|
|
||||||
|
// If the current batch is too small, we try to add more requests to it
|
||||||
|
if batch_size <= limit_min_batch_size {
|
||||||
|
let min_size = match waiting_tokens {
|
||||||
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
|
// to add a new batch even though its size might be small
|
||||||
|
_ if waiting_tokens >= max_waiting_tokens => None,
|
||||||
|
// Minimum size criteria
|
||||||
|
_ => Some(limit_min_batch_size as usize),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Try to get a new batch
|
||||||
|
if let Some((mut new_entries, new_batch)) =
|
||||||
|
db.next_batch(min_size, max_batch_size - batch_size as usize)
|
||||||
|
{
|
||||||
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
|
let new_cached_batch =
|
||||||
|
wrap_future(client.prefill(new_batch), &mut new_entries).await;
|
||||||
|
// Reset waiting counter
|
||||||
|
waiting_tokens = 1;
|
||||||
|
// Extend current batch with the new batch
|
||||||
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
|
entries.extend(new_entries);
|
||||||
|
batches.push(new_cached_batch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cached_batch = wrap_future(client.decode(batches), &mut entries).await;
|
||||||
|
waiting_tokens += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
|
||||||
|
async fn wrap_future(
|
||||||
|
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<Batch> {
|
||||||
|
match future.await {
|
||||||
|
Ok((generations, next_batch)) => {
|
||||||
|
send_generations(generations, entries);
|
||||||
|
next_batch
|
||||||
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
|
Err(err) => {
|
||||||
|
send_error(err, entries);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send errors to Infer for all `entries`
|
||||||
|
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
entries.drain().for_each(|(_, entry)| {
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Err(InferError::GenerationError(error.to_string())))
|
||||||
|
.unwrap_or(());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||||
|
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
generations.into_iter().for_each(|generation| {
|
||||||
|
// Get entry
|
||||||
|
// We can `expect` here as the request id should always be in the entries
|
||||||
|
let entry = entries
|
||||||
|
.get(&generation.request_id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
|
// Send message
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))
|
||||||
|
.unwrap_or(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create last Token
|
||||||
|
let token = Token(
|
||||||
|
generation.token_id,
|
||||||
|
generation.token_text,
|
||||||
|
generation.token_logprob,
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(generated_text) = generation.generated_text {
|
||||||
|
// Remove entry as this is the last message
|
||||||
|
// We can `expect` here as the request id should always be in the entries
|
||||||
|
let entry = entries
|
||||||
|
.remove(&generation.request_id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
|
// Send message
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
generated_text,
|
||||||
|
queued: entry.time,
|
||||||
|
start: entry.batch_time.unwrap(),
|
||||||
|
}))
|
||||||
|
.unwrap_or(());
|
||||||
|
} else {
|
||||||
|
// Send message
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Token(token)))
|
||||||
|
.unwrap_or(());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) enum InferStreamResponse {
|
||||||
|
// Optional first message
|
||||||
|
Prefill(PrefillTokens),
|
||||||
|
// Intermediate messages
|
||||||
|
Token(Token),
|
||||||
|
// Last message
|
||||||
|
End {
|
||||||
|
token: Token,
|
||||||
|
generated_text: GeneratedText,
|
||||||
|
start: Instant,
|
||||||
|
queued: Instant,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct InferResponse {
|
||||||
|
pub(crate) prefill: Vec<Token>,
|
||||||
|
pub(crate) tokens: Vec<Token>,
|
||||||
|
pub(crate) generated_text: GeneratedText,
|
||||||
|
pub(crate) queued: Instant,
|
||||||
|
pub(crate) start: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum InferError {
|
||||||
|
#[error("Request failed during generation: {0}")]
|
||||||
|
GenerationError(String),
|
||||||
|
#[error("Model is overloaded")]
|
||||||
|
Overloaded(#[from] TryAcquireError),
|
||||||
|
#[error("Input validation error: {0}")]
|
||||||
|
ValidationError(#[from] ValidationError),
|
||||||
|
#[error("Incomplete generation")]
|
||||||
|
IncompleteGeneration,
|
||||||
|
}
|
|
@ -1,11 +1,11 @@
|
||||||
/// Text Generation Inference Webserver
|
/// Text Generation Inference Webserver
|
||||||
mod batcher;
|
|
||||||
mod db;
|
mod db;
|
||||||
|
mod infer;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
mod validation;
|
mod validation;
|
||||||
|
|
||||||
use batcher::{Batcher, InferResponse};
|
|
||||||
use db::{Db, Entry};
|
use db::{Db, Entry};
|
||||||
|
use infer::Infer;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
|
@ -69,21 +69,34 @@ pub(crate) struct GenerateRequest {
|
||||||
pub parameters: GenerateParameters,
|
pub parameters: GenerateParameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct Token(u32, String, f32);
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub(crate) struct Details {
|
pub(crate) struct Details {
|
||||||
pub finish_reason: String,
|
pub finish_reason: String,
|
||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
pub tokens: Vec<(u32, String, f32)>,
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub prefill: Option<Vec<Token>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tokens: Option<Vec<Token>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub(crate) struct GeneratedText {
|
pub(crate) struct GenerateResponse {
|
||||||
pub generated_text: String,
|
pub generated_text: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub details: Option<Details>,
|
pub details: Option<Details>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub(crate) struct StreamResponse {
|
||||||
|
pub token: Token,
|
||||||
|
pub generated_text: Option<String>,
|
||||||
|
pub details: Option<Details>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub(crate) struct ErrorResponse {
|
pub(crate) struct ErrorResponse {
|
||||||
pub error: String,
|
pub error: String,
|
||||||
|
|
|
@ -1,71 +1,54 @@
|
||||||
|
/// HTTP Server logic
|
||||||
|
use crate::infer::{InferError, InferStreamResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation,
|
Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer,
|
||||||
|
StreamResponse, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, StatusCode};
|
use axum::http::{HeaderMap, StatusCode};
|
||||||
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
use axum::response::IntoResponse;
|
use axum::response::IntoResponse;
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
|
use futures::Stream;
|
||||||
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::sync::Semaphore;
|
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
use tokio_stream::StreamExt;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
// Server shared state
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct ServerState {
|
|
||||||
validation: Validation,
|
|
||||||
batcher: Batcher,
|
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Health check method
|
/// Health check method
|
||||||
#[instrument(skip(state), fields(time, time_per_token))]
|
#[instrument(skip(infer))]
|
||||||
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||||
// be a bit too slow for a health check.
|
// be a bit too slow for a health check.
|
||||||
// What we should do instead if check if the gRPC channels are still healthy.
|
// What we should do instead if check if the gRPC channels are still healthy.
|
||||||
|
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
|
||||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
|
||||||
(
|
|
||||||
StatusCode::TOO_MANY_REQUESTS,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: "Model is overloaded".to_string(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Send a small inference request
|
// Send a small inference request
|
||||||
state
|
infer
|
||||||
.batcher
|
.generate(GenerateRequest {
|
||||||
.infer(
|
inputs: "liveness".to_string(),
|
||||||
1,
|
parameters: GenerateParameters {
|
||||||
GenerateRequest {
|
temperature: 1.0,
|
||||||
inputs: "liveness".to_string(),
|
top_k: 0,
|
||||||
parameters: GenerateParameters {
|
top_p: 1.0,
|
||||||
temperature: 1.0,
|
do_sample: false,
|
||||||
top_k: 0,
|
max_new_tokens: 1,
|
||||||
top_p: 1.0,
|
stop: vec![],
|
||||||
do_sample: false,
|
details: false,
|
||||||
max_new_tokens: 1,
|
seed: None,
|
||||||
stop: vec![],
|
|
||||||
details: false,
|
|
||||||
seed: None,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
)
|
})
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate method
|
/// Generate method
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(state),
|
skip(infer),
|
||||||
fields(
|
fields(
|
||||||
total_time,
|
total_time,
|
||||||
validation_time,
|
validation_time,
|
||||||
|
@ -76,56 +59,28 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
state: Extension<ServerState>,
|
infer: Extension<Infer>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let span = tracing::Span::current();
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
|
||||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
|
||||||
tracing::error!("Model is overloaded");
|
|
||||||
(
|
|
||||||
StatusCode::TOO_MANY_REQUESTS,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: "Model is overloaded".to_string(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Validate request
|
|
||||||
let details = req.0.parameters.details;
|
|
||||||
let (input_length, validated_request) =
|
|
||||||
state.validation.validate(req.0).await.map_err(|err| {
|
|
||||||
tracing::error!("{}", err.to_string());
|
|
||||||
err
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
let response = state
|
let details = req.0.parameters.details;
|
||||||
.batcher
|
let response = infer.generate(req.0).await.map_err(|err| {
|
||||||
.infer(input_length, validated_request)
|
tracing::error!("{}", err.to_string());
|
||||||
.await
|
err
|
||||||
.map_err(|err| {
|
})?;
|
||||||
tracing::error!("{}", err.to_string());
|
|
||||||
err
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
let details = match details {
|
||||||
true => {
|
true => Some(Details {
|
||||||
let tokens = response
|
finish_reason: response.generated_text.finish_reason,
|
||||||
.token_ids
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
.into_iter()
|
prefill: Some(response.prefill),
|
||||||
.zip(response.tokens.into_iter())
|
tokens: Some(response.tokens),
|
||||||
.zip(response.logprobs.into_iter())
|
seed: response.generated_text.seed,
|
||||||
.map(|((id, text), logprob)| (id, text, logprob))
|
}),
|
||||||
.collect();
|
|
||||||
Some(Details {
|
|
||||||
seed: response.seed,
|
|
||||||
finish_reason: response.finish_reason,
|
|
||||||
generated_tokens: response.generated_tokens,
|
|
||||||
tokens,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
false => None,
|
false => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -133,8 +88,8 @@ async fn generate(
|
||||||
let total_time = start_time.elapsed();
|
let total_time = start_time.elapsed();
|
||||||
let validation_time = response.queued - start_time;
|
let validation_time = response.queued - start_time;
|
||||||
let queue_time = response.start - response.queued;
|
let queue_time = response.start - response.queued;
|
||||||
let inference_time = response.end - response.start;
|
let inference_time = Instant::now() - response.start;
|
||||||
let time_per_token = inference_time / response.generated_tokens;
|
let time_per_token = inference_time / response.generated_text.generated_tokens;
|
||||||
|
|
||||||
// Headers
|
// Headers
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
|
@ -160,22 +115,143 @@ async fn generate(
|
||||||
);
|
);
|
||||||
|
|
||||||
// Tracing metadata
|
// Tracing metadata
|
||||||
tracing::Span::current().record("total_time", format!("{:?}", total_time));
|
span.record("total_time", format!("{:?}", total_time));
|
||||||
tracing::Span::current().record("validation_time", format!("{:?}", validation_time));
|
span.record("validation_time", format!("{:?}", validation_time));
|
||||||
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
|
span.record("queue_time", format!("{:?}", queue_time));
|
||||||
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
|
span.record("inference_time", format!("{:?}", inference_time));
|
||||||
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
span.record("time_per_token", format!("{:?}", time_per_token));
|
||||||
tracing::Span::current().record("seed", format!("{:?}", response.seed));
|
span.record("seed", format!("{:?}", response.generated_text.seed));
|
||||||
tracing::info!("Output: {}", response.output_text);
|
tracing::info!("Output: {}", response.generated_text.text);
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
let response = vec![GeneratedText {
|
let response = vec![GenerateResponse {
|
||||||
generated_text: response.output_text,
|
generated_text: response.generated_text.text,
|
||||||
details,
|
details,
|
||||||
}];
|
}];
|
||||||
Ok((headers, Json(response)))
|
Ok((headers, Json(response)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate stream method
|
||||||
|
#[instrument(
|
||||||
|
skip(infer),
|
||||||
|
fields(
|
||||||
|
total_time,
|
||||||
|
validation_time,
|
||||||
|
queue_time,
|
||||||
|
inference_time,
|
||||||
|
time_per_token
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn generate_stream(
|
||||||
|
infer: Extension<Infer>,
|
||||||
|
req: Json<GenerateRequest>,
|
||||||
|
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
||||||
|
let span = tracing::Span::current();
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
let stream = async_stream::stream! {
|
||||||
|
// Inference
|
||||||
|
let mut end_reached = false;
|
||||||
|
let mut error = false;
|
||||||
|
let details = req.0.parameters.details;
|
||||||
|
|
||||||
|
match infer.generate_stream(req.0).await {
|
||||||
|
Ok(mut response_stream) => {
|
||||||
|
// Server Side Event stream
|
||||||
|
while let Some(response) = response_stream.next().await {
|
||||||
|
match response {
|
||||||
|
Ok(response) => {
|
||||||
|
match response {
|
||||||
|
// Prefill is ignored
|
||||||
|
InferStreamResponse::Prefill(_) => {}
|
||||||
|
// Yield event for every new token
|
||||||
|
InferStreamResponse::Token(token) => {
|
||||||
|
// StreamResponse
|
||||||
|
let stream_token = StreamResponse {
|
||||||
|
token,
|
||||||
|
generated_text: None,
|
||||||
|
details: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
yield Ok(Event::default().json_data(stream_token).unwrap())
|
||||||
|
}
|
||||||
|
// Yield event for last token and compute timings
|
||||||
|
InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
generated_text,
|
||||||
|
start,
|
||||||
|
queued,
|
||||||
|
} => {
|
||||||
|
// Token details
|
||||||
|
let details = match details {
|
||||||
|
true => Some(Details {
|
||||||
|
finish_reason: generated_text.finish_reason,
|
||||||
|
generated_tokens: generated_text.generated_tokens,
|
||||||
|
prefill: None,
|
||||||
|
tokens: None,
|
||||||
|
seed: generated_text.seed,
|
||||||
|
}),
|
||||||
|
false => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Timings
|
||||||
|
let total_time = start_time.elapsed();
|
||||||
|
let validation_time = queued - start_time;
|
||||||
|
let queue_time = start - queued;
|
||||||
|
let inference_time = Instant::now() - start;
|
||||||
|
let time_per_token = inference_time / generated_text.generated_tokens;
|
||||||
|
|
||||||
|
// Tracing metadata
|
||||||
|
span.record("total_time", format!("{:?}", total_time));
|
||||||
|
span
|
||||||
|
.record("validation_time", format!("{:?}", validation_time));
|
||||||
|
span.record("queue_time", format!("{:?}", queue_time));
|
||||||
|
span
|
||||||
|
.record("inference_time", format!("{:?}", inference_time));
|
||||||
|
span
|
||||||
|
.record("time_per_token", format!("{:?}", time_per_token));
|
||||||
|
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||||
|
|
||||||
|
// StreamResponse
|
||||||
|
end_reached = true;
|
||||||
|
let stream_token = StreamResponse {
|
||||||
|
token,
|
||||||
|
generated_text: Some(generated_text.text),
|
||||||
|
details
|
||||||
|
};
|
||||||
|
|
||||||
|
yield Ok(Event::default().json_data(stream_token).unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Trace and yield error
|
||||||
|
Err(err) => {
|
||||||
|
error = true;
|
||||||
|
tracing::error!("{}", err.to_string());
|
||||||
|
yield Ok(Event::from(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
// Trace and yield error
|
||||||
|
Err(err) => {
|
||||||
|
error = true;
|
||||||
|
tracing::error!("{}", err.to_string());
|
||||||
|
yield Ok(Event::from(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check if generation reached the end
|
||||||
|
// Skip if we already sent an error
|
||||||
|
if !end_reached && !error {
|
||||||
|
let err = InferError::IncompleteGeneration;
|
||||||
|
tracing::error!("{}", err.to_string());
|
||||||
|
yield Ok(Event::from(err))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||||
|
}
|
||||||
|
|
||||||
/// Serving method
|
/// Serving method
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
|
@ -189,21 +265,23 @@ pub async fn run(
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
) {
|
) {
|
||||||
// Create state
|
// Create state
|
||||||
let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens);
|
|
||||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||||
let shared_state = ServerState {
|
let infer = Infer::new(
|
||||||
|
client,
|
||||||
validation,
|
validation,
|
||||||
batcher,
|
max_batch_size,
|
||||||
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
|
max_waiting_tokens,
|
||||||
};
|
max_concurrent_requests,
|
||||||
|
);
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/", post(generate))
|
.route("/", post(generate))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/", get(health))
|
.route("/", get(health))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.layer(Extension(shared_state.clone()));
|
.layer(Extension(infer));
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
axum::Server::bind(&addr)
|
axum::Server::bind(&addr)
|
||||||
|
@ -240,3 +318,32 @@ async fn shutdown_signal() {
|
||||||
|
|
||||||
tracing::info!("signal received, starting graceful shutdown");
|
tracing::info!("signal received, starting graceful shutdown");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert to Axum supported formats
|
||||||
|
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
|
fn from(err: InferError) -> Self {
|
||||||
|
let status_code = match err {
|
||||||
|
InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
|
||||||
|
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
};
|
||||||
|
|
||||||
|
(
|
||||||
|
status_code,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: err.to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<InferError> for Event {
|
||||||
|
fn from(err: InferError) -> Self {
|
||||||
|
Event::default()
|
||||||
|
.json_data(ErrorResponse {
|
||||||
|
error: err.to_string(),
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::{ErrorResponse, GenerateRequest};
|
use crate::GenerateRequest;
|
||||||
use axum::http::StatusCode;
|
|
||||||
use axum::Json;
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
@ -161,14 +159,3 @@ pub enum ValidationError {
|
||||||
#[error("tokenizer error {0}")]
|
#[error("tokenizer error {0}")]
|
||||||
Tokenizer(String),
|
Tokenizer(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<ValidationError> for (StatusCode, Json<ErrorResponse>) {
|
|
||||||
fn from(err: ValidationError) -> Self {
|
|
||||||
(
|
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: err.to_string(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -91,9 +91,9 @@ def test_causal_lm_batch_type(default_bloom):
|
||||||
|
|
||||||
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||||
sequence_length = len(default_bloom_batch.all_input_ids[0])
|
sequence_length = len(default_bloom_batch.all_input_ids[0])
|
||||||
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch)
|
generations, next_batch = default_bloom.generate_token(default_bloom_batch)
|
||||||
|
|
||||||
assert generated_texts == []
|
assert len(generations) == len(default_bloom_batch)
|
||||||
assert isinstance(next_batch, CausalLMBatch)
|
assert isinstance(next_batch, CausalLMBatch)
|
||||||
assert not next_batch.keys_head_dim_last
|
assert not next_batch.keys_head_dim_last
|
||||||
|
|
||||||
|
@ -122,24 +122,30 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||||
assert all(
|
assert all(
|
||||||
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
|
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
|
||||||
)
|
)
|
||||||
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
|
assert all([generation.token_id.item() == 10264 for generation in generations])
|
||||||
|
assert all([generation.token_text == "Test" for generation in generations])
|
||||||
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
|
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
|
||||||
next_batch = default_bloom_batch
|
next_batch = default_bloom_batch
|
||||||
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
|
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(default_bloom_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
generations[0].generated_text.text
|
||||||
|
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
)
|
)
|
||||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -152,17 +158,19 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
for i in range(
|
for i in range(
|
||||||
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
|
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(default_multi_requests_bloom_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 2
|
||||||
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
|
assert generations[1].generated_text.text == "TestTestTestTestTestTest"
|
||||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
generations[1].generated_text.generated_tokens
|
||||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -171,19 +179,22 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 1
|
- 1
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
generations[0].generated_text.text
|
||||||
|
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
)
|
)
|
||||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
generations[0].generated_text.generated_tokens
|
||||||
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -243,17 +254,19 @@ def test_batch_concatenate(
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
|
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 3
|
||||||
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
|
assert generations[2].generated_text.text == "TestTestTestTestTestTest"
|
||||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
generations[2].generated_text.generated_tokens
|
||||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -262,19 +275,20 @@ def test_batch_concatenate(
|
||||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 2
|
- 2
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 2
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
generations[0].generated_text.text
|
||||||
|
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
)
|
)
|
||||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -284,18 +298,21 @@ def test_batch_concatenate(
|
||||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 4
|
- 4
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
|
generations[0].generated_text.text
|
||||||
|
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||||
)
|
)
|
||||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
generations[0].generated_text.generated_tokens
|
||||||
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
|
@ -88,11 +88,9 @@ def test_causal_lm_batch_type(default_causal_lm):
|
||||||
|
|
||||||
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||||
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
|
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(
|
generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
|
||||||
default_causal_lm_batch
|
|
||||||
)
|
|
||||||
|
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
assert isinstance(next_batch, CausalLMBatch)
|
assert isinstance(next_batch, CausalLMBatch)
|
||||||
|
|
||||||
assert len(next_batch.all_input_ids) == next_batch.size
|
assert len(next_batch.all_input_ids) == next_batch.size
|
||||||
|
@ -121,6 +119,11 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||||
assert all(
|
assert all(
|
||||||
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
|
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
|
||||||
)
|
)
|
||||||
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
|
assert all([generation.token_id.item() == 13 for generation in generations])
|
||||||
|
assert all([generation.token_text == "." for generation in generations])
|
||||||
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
def test_causal_lm_generate_token_completion(
|
def test_causal_lm_generate_token_completion(
|
||||||
|
@ -128,18 +131,17 @@ def test_causal_lm_generate_token_completion(
|
||||||
):
|
):
|
||||||
next_batch = default_causal_lm_batch
|
next_batch = default_causal_lm_batch
|
||||||
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
|
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
||||||
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -152,19 +154,20 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
for i in range(
|
for i in range(
|
||||||
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
|
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 2
|
||||||
assert generated_texts[0].output_text == "Test.java:784)"
|
assert generations[1].generated_text.text == "Test.java:784)"
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
generations[1].request_id
|
||||||
|
== default_multi_requests_causal_lm_batch.requests[1].id
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[1].generated_text.generated_tokens
|
||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -173,19 +176,20 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 1
|
- 1
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
generations[0].request_id
|
||||||
|
== default_multi_requests_causal_lm_batch.requests[0].id
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -244,19 +248,20 @@ def test_batch_concatenate(
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
|
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 3
|
||||||
assert generated_texts[0].output_text == "Test.java:784)"
|
assert generations[2].generated_text.text == "Test.java:784)"
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
generations[2].request_id
|
||||||
|
== default_multi_requests_causal_lm_batch.requests[1].id
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[2].generated_text.generated_tokens
|
||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -265,17 +270,17 @@ def test_batch_concatenate(
|
||||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 2
|
- 2
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 2
|
||||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -285,18 +290,19 @@ def test_batch_concatenate(
|
||||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 4
|
- 4
|
||||||
):
|
):
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft."
|
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
generations[0].request_id
|
||||||
|
== default_multi_requests_causal_lm_batch.requests[0].id
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
|
@ -50,18 +50,17 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
|
||||||
next_batch = batch
|
next_batch = batch
|
||||||
|
|
||||||
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "def test_get_all_users_with_"
|
assert generations[0].generated_text.text == "def test_get_all_users_with_"
|
||||||
assert generated_texts[0].request == batch.requests[0]
|
assert generations[0].request_id == batch.requests[0].id
|
||||||
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== batch.stopping_criterias[0].max_new_tokens
|
== batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -76,20 +75,19 @@ def test_fim_santacoder_generate_token_completion(
|
||||||
next_batch = batch
|
next_batch = batch
|
||||||
|
|
||||||
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].output_text
|
generations[0].generated_text.text
|
||||||
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
|
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
|
||||||
)
|
)
|
||||||
assert generated_texts[0].request == batch.requests[0]
|
assert generations[0].request_id == batch.requests[0].id
|
||||||
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
|
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
== batch.stopping_criterias[0].max_new_tokens
|
== batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
|
@ -99,11 +99,11 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
|
||||||
|
|
||||||
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
||||||
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(
|
generations, next_batch = default_seq2seq_lm.generate_token(
|
||||||
default_seq2seq_lm_batch
|
default_seq2seq_lm_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
assert isinstance(next_batch, Seq2SeqLMBatch)
|
assert isinstance(next_batch, Seq2SeqLMBatch)
|
||||||
|
|
||||||
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
|
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
|
||||||
|
@ -145,6 +145,11 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
||||||
for p in next_batch.past_key_values
|
for p in next_batch.past_key_values
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
|
assert all([generation.token_id.item() == 259 for generation in generations])
|
||||||
|
assert all([generation.token_text == "" for generation in generations])
|
||||||
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
def test_seq2seq_lm_generate_token_completion(
|
def test_seq2seq_lm_generate_token_completion(
|
||||||
|
@ -152,16 +157,16 @@ def test_seq2seq_lm_generate_token_completion(
|
||||||
):
|
):
|
||||||
next_batch = default_seq2seq_lm_batch
|
next_batch = default_seq2seq_lm_batch
|
||||||
for _ in range(6):
|
for _ in range(6):
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "a few weeks"
|
assert generations[0].generated_text.text == "a few weeks"
|
||||||
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
|
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||||
assert generated_texts[0].generated_tokens == 7
|
assert generations[0].generated_text.generated_tokens == 7
|
||||||
|
|
||||||
|
|
||||||
def test_seq2seq_lm_generate_token_completion_multi(
|
def test_seq2seq_lm_generate_token_completion_multi(
|
||||||
|
@ -170,33 +175,33 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
||||||
next_batch = default_multi_requests_seq2seq_lm_batch
|
next_batch = default_multi_requests_seq2seq_lm_batch
|
||||||
|
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 2
|
||||||
assert generated_texts[0].output_text == "a few "
|
assert generations[1].generated_text.text == "a few "
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request
|
generations[1].request_id
|
||||||
== default_multi_requests_seq2seq_lm_batch.requests[1]
|
== default_multi_requests_seq2seq_lm_batch.requests[1].id
|
||||||
)
|
)
|
||||||
assert generated_texts[0].generated_tokens == 5
|
assert generations[1].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "a few weeks"
|
assert generations[0].generated_text.text == "a few weeks"
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request
|
generations[0].request_id
|
||||||
== default_multi_requests_seq2seq_lm_batch.requests[0]
|
== default_multi_requests_seq2seq_lm_batch.requests[0].id
|
||||||
)
|
)
|
||||||
assert generated_texts[0].generated_tokens == 7
|
assert generations[0].generated_text.generated_tokens == 7
|
||||||
|
|
||||||
|
|
||||||
def test_batch_concatenate(
|
def test_batch_concatenate(
|
||||||
|
@ -291,35 +296,35 @@ def test_batch_concatenate(
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert generated_texts == []
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 3
|
||||||
assert generated_texts[0].output_text == "a few "
|
assert generations[2].generated_text.text == "a few "
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request
|
generations[2].request_id
|
||||||
== default_multi_requests_seq2seq_lm_batch.requests[1]
|
== default_multi_requests_seq2seq_lm_batch.requests[1].id
|
||||||
)
|
)
|
||||||
assert generated_texts[0].generated_tokens == 5
|
assert generations[2].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 2
|
||||||
assert generated_texts[0].output_text == "a few weeks"
|
assert generations[0].generated_text.text == "a few weeks"
|
||||||
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
|
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||||
assert generated_texts[0].generated_tokens == 7
|
assert generations[0].generated_text.generated_tokens == 7
|
||||||
|
|
||||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generated_texts) == 1
|
assert len(generations) == 1
|
||||||
assert generated_texts[0].output_text == "a few weeks"
|
assert generations[0].generated_text.text == "a few weeks"
|
||||||
assert (
|
assert (
|
||||||
generated_texts[0].request
|
generations[0].request_id
|
||||||
== default_multi_requests_seq2seq_lm_batch.requests[0]
|
== default_multi_requests_seq2seq_lm_batch.requests[0].id
|
||||||
)
|
)
|
||||||
assert generated_texts[0].generated_tokens == 7
|
assert generations[0].generated_text.generated_tokens == 7
|
||||||
|
|
|
@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
|
||||||
from typing import Optional, Tuple, List, Type
|
from typing import Optional, Tuple, List, Type
|
||||||
|
|
||||||
from text_generation.models import Model
|
from text_generation.models import Model
|
||||||
from text_generation.models.types import GeneratedText, Batch
|
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ class CausalLMBatch(Batch):
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
all_input_ids: List[torch.Tensor]
|
all_input_ids: List[torch.Tensor]
|
||||||
all_logprobs: List[Optional[torch.Tensor]]
|
|
||||||
|
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
|
@ -57,7 +56,6 @@ class CausalLMBatch(Batch):
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
all_logprobs = []
|
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
|
@ -67,7 +65,6 @@ class CausalLMBatch(Batch):
|
||||||
stopping_criterias.append(
|
stopping_criterias.append(
|
||||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
)
|
)
|
||||||
all_logprobs.append(None)
|
|
||||||
|
|
||||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
|
@ -89,7 +86,6 @@ class CausalLMBatch(Batch):
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_logprobs=all_logprobs,
|
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
@ -107,7 +103,6 @@ class CausalLMBatch(Batch):
|
||||||
requests = []
|
requests = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
all_logprobs = []
|
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
|
@ -124,7 +119,6 @@ class CausalLMBatch(Batch):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
all_logprobs.extend(batch.all_logprobs)
|
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
|
@ -225,7 +219,6 @@ class CausalLMBatch(Batch):
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_logprobs=all_logprobs,
|
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
@ -234,6 +227,9 @@ class CausalLMBatch(Batch):
|
||||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
class CausalLM(Model):
|
class CausalLM(Model):
|
||||||
def __init__(self, model_name: str, quantize=False):
|
def __init__(self, model_name: str, quantize=False):
|
||||||
|
@ -289,7 +285,7 @@ class CausalLM(Model):
|
||||||
|
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: CausalLMBatch
|
self, batch: CausalLMBatch
|
||||||
) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
context_manager = (
|
context_manager = (
|
||||||
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
||||||
|
@ -309,14 +305,13 @@ class CausalLM(Model):
|
||||||
next_batch_input_lengths = []
|
next_batch_input_lengths = []
|
||||||
next_batch_input_ids = []
|
next_batch_input_ids = []
|
||||||
next_batch_all_input_ids = []
|
next_batch_all_input_ids = []
|
||||||
next_batch_all_logprobs = []
|
|
||||||
|
|
||||||
# Metadata
|
# Metadata
|
||||||
next_batch_size = 0
|
next_batch_size = 0
|
||||||
next_batch_max_sequence_length = 0
|
next_batch_max_sequence_length = 0
|
||||||
|
|
||||||
# Finished requests
|
# Results
|
||||||
generated_texts: List[GeneratedText] = []
|
generations: List[Generation] = []
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
|
@ -326,7 +321,6 @@ class CausalLM(Model):
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
batch.all_logprobs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
|
@ -337,44 +331,36 @@ class CausalLM(Model):
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
all_logprobs,
|
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
tokens, logprobs = next_token_chooser(all_input_ids, logits)
|
tokens, logprobs = next_token_chooser(all_input_ids, logits)
|
||||||
next_token = tokens[-1].view(1, 1)
|
next_token_id = tokens[-1].view(1, 1)
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids = torch.cat([all_input_ids, next_token])
|
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||||
new_input_length = input_length + 1
|
new_input_length = input_length + 1
|
||||||
|
|
||||||
if all_logprobs is None:
|
# Generated token
|
||||||
# logprobs of all prompt tokens (except the first one) and the generated token
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
all_logprobs = logprobs.gather(1, all_input_ids[1:])
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
else:
|
next_token_text = self.tokenizer.decode(
|
||||||
# logprob of the generated token
|
next_token_id_squeezed,
|
||||||
next_token_logprob = logprobs[-1, next_token]
|
clean_up_tokenization_spaces=False,
|
||||||
all_logprobs = torch.cat([all_logprobs, next_token_logprob])
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(
|
||||||
next_token.squeeze(),
|
next_token_id_squeezed,
|
||||||
self.tokenizer.decode(
|
next_token_text,
|
||||||
next_token.squeeze(), clean_up_tokenization_spaces=False
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
generated_text = self.decode(
|
generated_text = self.decode(
|
||||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||||
)
|
)
|
||||||
output_text = request.inputs + generated_text
|
output_text = request.inputs + generated_text
|
||||||
# Slice with input_length to remove padding
|
|
||||||
token_ids = all_input_ids[-new_input_length:]
|
|
||||||
tokens = self.tokenizer.batch_decode(token_ids)
|
|
||||||
# Add NaN for the first prompt token
|
|
||||||
logprobs = [float("nan")] + all_logprobs[-input_length:].squeeze(
|
|
||||||
1
|
|
||||||
).tolist()
|
|
||||||
|
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
@ -382,39 +368,58 @@ class CausalLM(Model):
|
||||||
else:
|
else:
|
||||||
seed = None
|
seed = None
|
||||||
|
|
||||||
# Add to the list of finished generations with the original request
|
generated_text = GeneratedText(
|
||||||
generated_texts.append(
|
output_text, stopping_criteria.current_tokens, reason, seed
|
||||||
GeneratedText(
|
|
||||||
request=request,
|
|
||||||
output_text=output_text,
|
|
||||||
generated_tokens=stopping_criteria.current_tokens,
|
|
||||||
tokens=tokens,
|
|
||||||
token_ids=token_ids.squeeze(1).tolist(),
|
|
||||||
logprobs=logprobs,
|
|
||||||
reason=reason,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# add to the next batch
|
|
||||||
else:
|
else:
|
||||||
|
# Keep request in the batch
|
||||||
|
generated_text = None
|
||||||
next_batch_keep_indices.append(i)
|
next_batch_keep_indices.append(i)
|
||||||
next_batch_input_ids.append(next_token)
|
next_batch_input_ids.append(next_token_id)
|
||||||
next_batch_all_input_ids.append(all_input_ids)
|
next_batch_all_input_ids.append(all_input_ids)
|
||||||
next_batch_all_logprobs.append(all_logprobs)
|
|
||||||
next_batch_size += 1
|
next_batch_size += 1
|
||||||
next_batch_input_lengths.append(new_input_length)
|
next_batch_input_lengths.append(new_input_length)
|
||||||
next_batch_max_sequence_length = max(
|
next_batch_max_sequence_length = max(
|
||||||
next_batch_max_sequence_length, new_input_length
|
next_batch_max_sequence_length, new_input_length
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if stopping_criteria.current_tokens == 1:
|
||||||
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
|
prefill_logprobs = [float("nan")] + logprobs.gather(
|
||||||
|
1, all_input_ids[1:]
|
||||||
|
).squeeze(1)[-new_input_length:-1].tolist()
|
||||||
|
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||||
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
|
prefill_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
prefill_tokens = PrefillTokens(
|
||||||
|
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefill_tokens = None
|
||||||
|
|
||||||
|
generation = Generation(
|
||||||
|
request.id,
|
||||||
|
prefill_tokens,
|
||||||
|
next_token_id_squeezed,
|
||||||
|
next_token_logprob,
|
||||||
|
next_token_text,
|
||||||
|
generated_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
generations.append(generation)
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if not next_batch_keep_indices:
|
if not next_batch_keep_indices:
|
||||||
return generated_texts, None
|
return generations, None
|
||||||
|
|
||||||
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
|
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
|
||||||
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
||||||
# from the values of the next batch
|
# from the values of the next batch
|
||||||
if generated_texts:
|
if len(next_batch_keep_indices) != len(batch):
|
||||||
# Apply indices to attention mask, past key values and other items that need to be cached
|
# Apply indices to attention mask, past key values and other items that need to be cached
|
||||||
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
||||||
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
|
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
|
||||||
|
@ -461,7 +466,6 @@ class CausalLM(Model):
|
||||||
position_ids=next_batch_position_ids,
|
position_ids=next_batch_position_ids,
|
||||||
past_key_values=next_batch_past_key_values,
|
past_key_values=next_batch_past_key_values,
|
||||||
all_input_ids=next_batch_all_input_ids,
|
all_input_ids=next_batch_all_input_ids,
|
||||||
all_logprobs=next_batch_all_logprobs,
|
|
||||||
input_lengths=next_batch_input_lengths,
|
input_lengths=next_batch_input_lengths,
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
stopping_criterias=next_batch_stopping_criterias,
|
stopping_criterias=next_batch_stopping_criterias,
|
||||||
|
@ -469,4 +473,4 @@ class CausalLM(Model):
|
||||||
max_sequence_length=next_batch_max_sequence_length,
|
max_sequence_length=next_batch_max_sequence_length,
|
||||||
keys_head_dim_last=batch.keys_head_dim_last,
|
keys_head_dim_last=batch.keys_head_dim_last,
|
||||||
)
|
)
|
||||||
return generated_texts, next_batch
|
return generations, next_batch
|
||||||
|
|
|
@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokeniz
|
||||||
from typing import Optional, Tuple, List, Type
|
from typing import Optional, Tuple, List, Type
|
||||||
|
|
||||||
from text_generation.models import Model
|
from text_generation.models import Model
|
||||||
from text_generation.models.types import GeneratedText, Batch
|
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
|
@ -30,7 +30,6 @@ class Seq2SeqLMBatch(Batch):
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
decoder_input_lengths: List[int]
|
decoder_input_lengths: List[int]
|
||||||
decoder_logprobs: List[Optional[torch.Tensor]]
|
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
|
@ -64,7 +63,6 @@ class Seq2SeqLMBatch(Batch):
|
||||||
|
|
||||||
decoder_input_ids = []
|
decoder_input_ids = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
decoder_logprobs = []
|
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
|
@ -77,7 +75,6 @@ class Seq2SeqLMBatch(Batch):
|
||||||
stopping_criterias.append(
|
stopping_criterias.append(
|
||||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
)
|
)
|
||||||
decoder_logprobs.append(None)
|
|
||||||
|
|
||||||
# Tokenize batch
|
# Tokenize batch
|
||||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||||
|
@ -102,7 +99,6 @@ class Seq2SeqLMBatch(Batch):
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
decoder_logprobs=decoder_logprobs,
|
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=len(pb.requests),
|
size=len(pb.requests),
|
||||||
|
@ -125,7 +121,6 @@ class Seq2SeqLMBatch(Batch):
|
||||||
requests = []
|
requests = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
decoder_logprobs = []
|
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
|
@ -146,7 +141,6 @@ class Seq2SeqLMBatch(Batch):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
||||||
decoder_logprobs.extend(batch.decoder_logprobs)
|
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
|
@ -283,7 +277,6 @@ class Seq2SeqLMBatch(Batch):
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
decoder_logprobs=decoder_logprobs,
|
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=total_batch_size,
|
size=total_batch_size,
|
||||||
|
@ -291,6 +284,9 @@ class Seq2SeqLMBatch(Batch):
|
||||||
max_decoder_input_length=max_decoder_input_length,
|
max_decoder_input_length=max_decoder_input_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqLM(Model):
|
class Seq2SeqLM(Model):
|
||||||
def __init__(self, model_name: str, quantize=False):
|
def __init__(self, model_name: str, quantize=False):
|
||||||
|
@ -364,7 +360,7 @@ class Seq2SeqLM(Model):
|
||||||
|
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: Seq2SeqLMBatch
|
self, batch: Seq2SeqLMBatch
|
||||||
) -> Tuple[List[GeneratedText], Optional[Seq2SeqLMBatch]]:
|
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
context_manager = (
|
context_manager = (
|
||||||
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
||||||
|
@ -386,7 +382,6 @@ class Seq2SeqLM(Model):
|
||||||
next_batch_input_lengths = []
|
next_batch_input_lengths = []
|
||||||
next_batch_decoder_input_ids = []
|
next_batch_decoder_input_ids = []
|
||||||
next_batch_decoder_input_lengths = []
|
next_batch_decoder_input_lengths = []
|
||||||
next_batch_decoder_logprobs = []
|
|
||||||
|
|
||||||
# Metadata
|
# Metadata
|
||||||
next_batch_size = 0
|
next_batch_size = 0
|
||||||
|
@ -394,14 +389,13 @@ class Seq2SeqLM(Model):
|
||||||
next_batch_max_decoder_input_length = 0
|
next_batch_max_decoder_input_length = 0
|
||||||
|
|
||||||
# Finished requests
|
# Finished requests
|
||||||
generated_texts: List[GeneratedText] = []
|
generations: List[Generation] = []
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.decoder_input_lengths,
|
batch.decoder_input_lengths,
|
||||||
batch.decoder_logprobs,
|
|
||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
|
@ -414,7 +408,6 @@ class Seq2SeqLM(Model):
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
decoder_input_length,
|
decoder_input_length,
|
||||||
decoder_logprobs,
|
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
|
@ -422,35 +415,28 @@ class Seq2SeqLM(Model):
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token, logprobs = next_token_chooser(decoder_input_ids, logits)
|
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
|
||||||
|
|
||||||
# Append next token to decoder tokens
|
# Append next token to decoder tokens
|
||||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token])
|
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
|
||||||
new_decoder_input_length = decoder_input_length + 1
|
new_decoder_input_length = decoder_input_length + 1
|
||||||
|
|
||||||
next_token_logprob = logprobs[-1, next_token]
|
# Generated token
|
||||||
if decoder_logprobs is None:
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
decoder_logprobs = next_token_logprob
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
else:
|
next_token_text = self.tokenizer.decode(
|
||||||
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob])
|
next_token_id_squeezed,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
||||||
next_token.squeeze(),
|
|
||||||
self.tokenizer.decode(
|
|
||||||
next_token.squeeze(), clean_up_tokenization_spaces=False
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if stop:
|
if stop:
|
||||||
# Slice with decoder_input_length to remove padding
|
# Slice with decoder_input_length to remove padding
|
||||||
# Decode all tokens
|
# Decode all tokens
|
||||||
token_ids = decoder_input_ids[-new_decoder_input_length:]
|
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
|
||||||
output_text = self.decode(token_ids)
|
|
||||||
tokens = self.tokenizer.batch_decode(token_ids)
|
|
||||||
# Add NaN for the bos token
|
|
||||||
logprobs = [float("nan")] + decoder_logprobs[
|
|
||||||
-decoder_input_length:
|
|
||||||
].tolist()
|
|
||||||
|
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
@ -458,27 +444,17 @@ class Seq2SeqLM(Model):
|
||||||
else:
|
else:
|
||||||
seed = None
|
seed = None
|
||||||
|
|
||||||
# Add to the list of finished generations with the original request
|
generated_text = GeneratedText(
|
||||||
generated_texts.append(
|
output_text, stopping_criteria.current_tokens, reason, seed
|
||||||
GeneratedText(
|
|
||||||
request=request,
|
|
||||||
output_text=output_text,
|
|
||||||
generated_tokens=stopping_criteria.current_tokens,
|
|
||||||
tokens=tokens,
|
|
||||||
token_ids=token_ids.tolist(),
|
|
||||||
logprobs=logprobs,
|
|
||||||
reason=reason,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# add to the next batch
|
|
||||||
else:
|
else:
|
||||||
|
# Keep request in the batch
|
||||||
|
generated_text = None
|
||||||
next_batch_keep_indices.append(i)
|
next_batch_keep_indices.append(i)
|
||||||
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
||||||
next_batch_size += 1
|
next_batch_size += 1
|
||||||
next_batch_input_lengths.append(input_length)
|
next_batch_input_lengths.append(input_length)
|
||||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
||||||
next_batch_decoder_logprobs.append(decoder_logprobs)
|
|
||||||
next_batch_max_input_length = max(
|
next_batch_max_input_length = max(
|
||||||
next_batch_max_input_length, input_length
|
next_batch_max_input_length, input_length
|
||||||
)
|
)
|
||||||
|
@ -486,14 +462,39 @@ class Seq2SeqLM(Model):
|
||||||
next_batch_max_decoder_input_length, new_decoder_input_length
|
next_batch_max_decoder_input_length, new_decoder_input_length
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if stopping_criteria.current_tokens == 1:
|
||||||
|
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
|
||||||
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
|
prefill_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
prefill_tokens = PrefillTokens(
|
||||||
|
prefill_token_ids, [float("nan")], prefill_texts
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefill_tokens = None
|
||||||
|
|
||||||
|
generation = Generation(
|
||||||
|
request.id,
|
||||||
|
prefill_tokens,
|
||||||
|
next_token_id_squeezed,
|
||||||
|
next_token_logprob,
|
||||||
|
next_token_text,
|
||||||
|
generated_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
generations.append(generation)
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if not next_batch_keep_indices:
|
if not next_batch_keep_indices:
|
||||||
return generated_texts, None
|
return generations, None
|
||||||
|
|
||||||
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
|
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
|
||||||
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
||||||
# from the values of the next batch
|
# from the values of the next batch
|
||||||
if generated_texts:
|
if len(next_batch_keep_indices) != len(batch):
|
||||||
# Apply indices to attention mask, past key values and other items that need to be cached
|
# Apply indices to attention mask, past key values and other items that need to be cached
|
||||||
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
|
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
|
||||||
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
||||||
|
@ -551,11 +552,10 @@ class Seq2SeqLM(Model):
|
||||||
past_key_values=next_batch_past_key_values,
|
past_key_values=next_batch_past_key_values,
|
||||||
input_lengths=next_batch_input_lengths,
|
input_lengths=next_batch_input_lengths,
|
||||||
decoder_input_lengths=next_batch_decoder_input_lengths,
|
decoder_input_lengths=next_batch_decoder_input_lengths,
|
||||||
decoder_logprobs=next_batch_decoder_logprobs,
|
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
stopping_criterias=next_batch_stopping_criterias,
|
stopping_criterias=next_batch_stopping_criterias,
|
||||||
size=next_batch_size,
|
size=next_batch_size,
|
||||||
max_input_length=next_batch_max_input_length,
|
max_input_length=next_batch_max_input_length,
|
||||||
max_decoder_input_length=next_batch_max_decoder_input_length,
|
max_decoder_input_length=next_batch_max_decoder_input_length,
|
||||||
)
|
)
|
||||||
return generated_texts, next_batch
|
return generations, next_batch
|
||||||
|
|
|
@ -29,26 +29,61 @@ class Batch(ABC):
|
||||||
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __len__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GeneratedText:
|
class GeneratedText:
|
||||||
request: generate_pb2.Request
|
text: str
|
||||||
output_text: str
|
|
||||||
generated_tokens: int
|
generated_tokens: int
|
||||||
tokens: List[str]
|
finish_reason: str
|
||||||
token_ids: List[int]
|
|
||||||
logprobs: List[float]
|
|
||||||
reason: str
|
|
||||||
seed: Optional[int]
|
seed: Optional[int]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||||
return generate_pb2.GeneratedText(
|
return generate_pb2.GeneratedText(
|
||||||
request=self.request,
|
text=self.text,
|
||||||
output_text=self.output_text,
|
|
||||||
generated_tokens=self.generated_tokens,
|
generated_tokens=self.generated_tokens,
|
||||||
tokens=self.tokens,
|
finish_reason=self.finish_reason,
|
||||||
token_ids=self.token_ids,
|
|
||||||
logprobs=self.logprobs,
|
|
||||||
finish_reason=self.reason,
|
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PrefillTokens:
|
||||||
|
token_ids: List[int]
|
||||||
|
logprobs: List[float]
|
||||||
|
texts: List[str]
|
||||||
|
|
||||||
|
def to_pb(self) -> generate_pb2.PrefillTokens:
|
||||||
|
return generate_pb2.PrefillTokens(
|
||||||
|
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Generation:
|
||||||
|
request_id: int
|
||||||
|
prefill_tokens: Optional[PrefillTokens]
|
||||||
|
token_id: int
|
||||||
|
token_logprob: float
|
||||||
|
token_text: str
|
||||||
|
generated_text: Optional[GeneratedText]
|
||||||
|
|
||||||
|
def to_pb(self) -> generate_pb2.Generation:
|
||||||
|
return generate_pb2.Generation(
|
||||||
|
request_id=self.request_id,
|
||||||
|
prefill_tokens=self.prefill_tokens.to_pb()
|
||||||
|
if self.prefill_tokens is not None
|
||||||
|
else None,
|
||||||
|
token_id=self.token_id,
|
||||||
|
token_logprob=self.token_logprob,
|
||||||
|
token_text=self.token_text,
|
||||||
|
generated_text=self.generated_text.to_pb()
|
||||||
|
if self.generated_text is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
|
@ -27,22 +27,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
return generate_pb2.ClearCacheResponse()
|
return generate_pb2.ClearCacheResponse()
|
||||||
|
|
||||||
async def Generate(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.device
|
request.batch, self.model.tokenizer, self.model.device
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_texts, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.GenerateResponse(
|
return generate_pb2.PrefillResponse(
|
||||||
generated_texts=[
|
generations=[generation.to_pb() for generation in generations],
|
||||||
generated_text.to_pb() for generated_text in generated_texts
|
|
||||||
],
|
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def GenerateWithCache(self, request, context):
|
async def Decode(self, request, context):
|
||||||
if len(request.batches) == 0:
|
if len(request.batches) == 0:
|
||||||
raise ValueError("Must provide at least one batch")
|
raise ValueError("Must provide at least one batch")
|
||||||
|
|
||||||
|
@ -58,13 +56,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
generated_texts, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.GenerateWithCacheResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
generated_texts=[
|
generations=[generation.to_pb() for generation in generations],
|
||||||
generated_text.to_pb() for generated_text in generated_texts
|
|
||||||
],
|
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue