feat: Add token streaming using ServerSideEvents support (#36)

Add token streaming using ServerSideEvents (SSE).

The signature of the SSE events is: 

```rust
struct Details {
    finish_reason: String,
    generated_tokens: u32,
    seed: Option<u64>,
}

struct StreamResponse {
    token: Token,
    generated_text: Option<String>,
    details: Option<Details>,
}

struct ErrorResponse {
    error: String,
}
```
This commit is contained in:
OlivierDehaene 2023-01-31 11:49:43 +01:00 committed by GitHub
parent cd298bc5e5
commit 7fbfbb0dc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1032 additions and 717 deletions

2
Cargo.lock generated
View File

@ -1829,6 +1829,7 @@ dependencies = [
name = "text-generation-router" name = "text-generation-router"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"async-stream",
"axum", "axum",
"clap 4.0.22", "clap 4.0.22",
"futures", "futures",
@ -1840,6 +1841,7 @@ dependencies = [
"thiserror", "thiserror",
"tokenizers", "tokenizers",
"tokio", "tokio",
"tokio-stream",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
] ]

View File

@ -16,4 +16,4 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies] [dev-dependencies]
float_eq = "1.0.1" float_eq = "1.0.1"
reqwest = { version = "0.11.13", features = ["blocking", "json"] } reqwest = { version = "0.11.13", features = ["blocking", "json"] }
serde = "1.0.150" serde = { version = "1.0.150", features = ["derive"] }

View File

@ -3,7 +3,7 @@
"details": { "details": {
"finish_reason": "length", "finish_reason": "length",
"generated_tokens": 20, "generated_tokens": 20,
"tokens": [ "prefill": [
[ [
10264, 10264,
"Test", "Test",
@ -13,7 +13,9 @@
8821, 8821,
" request", " request",
-11.895094 -11.895094
]
], ],
"tokens": [
[ [
17, 17,
".", ".",

View File

@ -3,12 +3,14 @@
"details": { "details": {
"finish_reason": "length", "finish_reason": "length",
"generated_tokens": 20, "generated_tokens": 20,
"tokens": [ "prefill": [
[ [
0, 0,
"<pad>", "<pad>",
null null
]
], ],
"tokens": [
[ [
259, 259,
"", "",

View File

@ -7,10 +7,10 @@ service TextGenerationService {
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
/// Empties batch cache /// Empties batch cache
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Generate tokens for a batch /// Prefill batch and decode first token
rpc Generate (GenerateRequest) returns (GenerateResponse); rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Generate tokens for a list of cached batches /// Decode token for a list of prefilled batches
rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse); rpc Decode (DecodeRequest) returns (DecodeResponse);
} }
/// Empty request /// Empty request
@ -70,44 +70,60 @@ message Batch {
} }
message GeneratedText { message GeneratedText {
/// Request
Request request = 1;
/// Output /// Output
string output_text = 2; string text = 1;
/// Number of generated tokens /// Number of generated tokens
uint32 generated_tokens = 3; uint32 generated_tokens = 2;
/// Tokens
repeated string tokens = 4;
/// Token IDs
repeated uint32 token_ids = 5;
/// Logprobs
repeated float logprobs = 6;
/// Finish reason /// Finish reason
string finish_reason = 7; string finish_reason = 3;
/// Seed /// Seed
optional uint64 seed = 8; optional uint64 seed = 4;
} }
message GenerateRequest { message PrefillTokens {
/// Prefill Token IDs
repeated uint32 ids = 1;
/// Prefill Logprobs
repeated float logprobs = 2;
/// Prefill tokens
repeated string texts = 3;
}
message Generation {
/// Request ID
uint64 request_id = 1;
/// Prefill tokens (optional)
PrefillTokens prefill_tokens = 2;
/// Token ID
uint32 token_id = 3;
/// Logprob
float token_logprob = 4;
/// Text
string token_text = 5;
/// Complete generated text
GeneratedText generated_text = 6;
}
message PrefillRequest {
/// Batch /// Batch
Batch batch = 1; Batch batch = 1;
} }
message GenerateResponse { message PrefillResponse {
/// Finished requests /// Generation
repeated GeneratedText generated_texts = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional Batch batch = 2; optional Batch batch = 2;
} }
message GenerateWithCacheRequest { message DecodeRequest {
/// Cached batches /// Cached batches
repeated Batch batches = 1; repeated Batch batches = 1;
} }
message GenerateWithCacheResponse { message DecodeResponse {
/// Finished requests /// Decodes
repeated GeneratedText generated_texts = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional Batch batch = 2; optional Batch batch = 2;
} }

View File

@ -13,6 +13,7 @@ name = "text-generation-router"
path = "src/main.rs" path = "src/main.rs"
[dependencies] [dependencies]
async-stream = "0.3.3"
axum = { version = "0.5.16", features = ["json", "serde_json"] } axum = { version = "0.5.16", features = ["json", "serde_json"] }
text-generation-client = { path = "client" } text-generation-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] } clap = { version = "4.0.15", features = ["derive", "env"] }
@ -24,6 +25,7 @@ serde_json = "1.0.85"
thiserror = "1.0.37" thiserror = "1.0.37"
tokenizers = "0.13.0" tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.11"
tracing = "0.1.36" tracing = "0.1.36"
tracing-subscriber = { version = "0.3.15", features = ["json"] } tracing-subscriber = { version = "0.3.15", features = ["json"] }

View File

@ -70,36 +70,36 @@ impl Client {
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch
/// ///
/// Returns a list of generated texts of request that met their stopping criteria /// Returns Generation for each request in batch
/// and the next cached batch /// and the next cached batch
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> { pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) }); let request = tonic::Request::new(PrefillRequest { batch: Some(batch) });
let response = self let response = self
.stub .stub
.generate(request) .prefill(request)
.instrument(info_span!("generate")) .instrument(info_span!("prefill"))
.await? .await?
.into_inner(); .into_inner();
Ok((response.generated_texts, response.batch)) Ok((response.generations, response.batch))
} }
/// Generate one token for each request in the given cached batch /// Generate one token for each request in the given cached batches
/// ///
/// Returns a list of generated texts of request that met their stopping criteria /// Returns Generation for each request in batches
/// and the next cached batch /// and the next cached batch
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn generate_with_cache( pub async fn decode(
&mut self, &mut self,
batches: Vec<Batch>, batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> { ) -> Result<(Vec<Generation>, Option<Batch>)> {
let request = tonic::Request::new(GenerateWithCacheRequest { batches }); let request = tonic::Request::new(DecodeRequest { batches });
let response = self let response = self
.stub .stub
.generate_with_cache(request) .decode(request)
.instrument(info_span!("generate_with_cache")) .instrument(info_span!("decode"))
.await? .await?
.into_inner(); .into_inner();
Ok((response.generated_texts, response.batch)) Ok((response.generations, response.batch))
} }
} }

View File

@ -7,7 +7,8 @@ mod sharded_client;
pub use client::Client; pub use client::Client;
pub use pb::generate::v1::{ pub use pb::generate::v1::{
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request,
StoppingCriteriaParameters,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;

View File

@ -1,6 +1,6 @@
/// Multi shard Client /// Multi shard Client
use crate::Result; use crate::Result;
use crate::{Batch, Client, GeneratedText}; use crate::{Batch, Client, Generation};
use futures::future::join_all; use futures::future::join_all;
use futures::future::select_all; use futures::future::select_all;
use tonic::transport::Uri; use tonic::transport::Uri;
@ -37,39 +37,6 @@ impl ShardedClient {
Self::from_master_client(master_client).await Self::from_master_client(master_client).await
} }
/// Generate one token for each request in the given batch
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.generate(batch.clone())))
.collect();
// As soon as we receive one response, we can return as all shards will return the same
let (result, _, _) = select_all(futures).await;
result
}
/// Generate one token for each request in the given cached batch
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
pub async fn generate_with_cache(
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.generate_with_cache(batches.clone())))
.collect();
// As soon as we receive one response, we can return as all shards will return the same
let (result, _, _) = select_all(futures).await;
result
}
/// Clear the past generations cache /// Clear the past generations cache
pub async fn clear_cache(&mut self) -> Result<()> { pub async fn clear_cache(&mut self) -> Result<()> {
let futures: Vec<_> = self let futures: Vec<_> = self
@ -79,4 +46,37 @@ impl ShardedClient {
.collect(); .collect();
join_all(futures).await.into_iter().collect() join_all(futures).await.into_iter().collect()
} }
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.collect();
// As soon as we receive one response, we can return as all shards will return the same
let (result, _, _) = select_all(futures).await;
result
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
pub async fn decode(
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<Generation>, Option<Batch>)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
// As soon as we receive one response, we can return as all shards will return the same
let (result, _, _) = select_all(futures).await;
result
}
} }

View File

@ -1,236 +0,0 @@
/// Batching and inference logic
use crate::{Db, Entry};
use crate::{ErrorResponse, GenerateRequest};
use axum::http::StatusCode;
use axum::Json;
use nohash_hasher::IntMap;
use std::future::Future;
use std::sync::Arc;
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
use thiserror::Error;
use tokio::sync::{oneshot, Notify};
use tokio::time::Instant;
use tracing::instrument;
/// Batcher
#[derive(Clone)]
pub struct Batcher {
/// Request database
db: Db,
/// Shared state
shared: Arc<Shared>,
}
/// Batcher shared state
struct Shared {
/// Batching background Tokio task notifier
batching_task: Notify,
}
impl Batcher {
pub(crate) fn new(
client: ShardedClient,
max_batch_size: usize,
max_waiting_tokens: usize,
) -> Self {
// Batcher shared state
let db = Db::new();
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
client,
max_batch_size,
max_waiting_tokens,
db.clone(),
shared.clone(),
));
Self { db, shared }
}
/// Add a new request to the database and return a future that will generate the text
pub(crate) async fn infer(
&self,
input_length: usize,
request: GenerateRequest,
) -> Result<InferResponse, InferError> {
// One shot channel to communicate with the background batching task
let (response_tx, response_rx) = oneshot::channel();
// Try to append the request to the database
self.db.append(Entry {
request,
response_tx,
input_length,
time: Instant::now(),
batch_time: None,
});
// Notify the background task that we have a new entry in the database that needs
// to be batched
self.shared.batching_task.notify_one();
// Await on the response from the background task
// We can safely unwrap as the background task will never drop the sender
response_rx
.await
.unwrap()
.map_err(|err| InferError::GenerationError(err.to_string()))
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[instrument(skip(client, db, shared))]
async fn batching_task(
mut client: ShardedClient,
max_batch_size: usize,
max_waiting_tokens: usize,
db: Db,
shared: Arc<Shared>,
) {
// Minimum batch size after which we try to add more requests
let limit_min_batch_size = (max_batch_size / 2) as u32;
// Infinite loop
loop {
// Wait for a notification from the Batcher struct
shared.batching_task.notified().await;
// Get the next batch from the DB
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB
while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) {
let mut cached_batch = wrap_future(client.generate(batch), &mut entries).await;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
// Get current batch info
let batch_size = batch.size;
let mut batches = vec![batch];
// If the current batch is too small, we try to add more requests to it
if batch_size <= limit_min_batch_size {
let min_size = match waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
_ if waiting_tokens >= max_waiting_tokens => None,
// Minimum size criteria
_ => Some(limit_min_batch_size as usize),
};
// Try to get a new batch
if let Some((mut new_entries, new_batch)) =
db.next_batch(min_size, max_batch_size - batch_size as usize)
{
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
wrap_future(client.generate(new_batch), &mut new_entries).await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
}
}
}
cached_batch = wrap_future(client.generate_with_cache(batches), &mut entries).await;
waiting_tokens += 1;
}
}
}
}
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
async fn wrap_future(
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
match future.await {
Ok((generated_texts, next_batch)) => {
send_generated(generated_texts, entries);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
send_error(err, entries);
None
}
}
}
/// Send errors to the Batcher for all `entries`
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Err(error.clone())).unwrap_or(());
});
}
/// Send `generated_text` to the Batcher for all `finished`
fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>) {
finished.into_iter().for_each(|output| {
// We can `expect` here as the request id should always be in the entries
let entry = entries
.remove(&output.request.unwrap().id)
.expect("ID not found in entries. This is a bug.");
let response = InferResponse {
output_text: output.output_text,
generated_tokens: output.generated_tokens,
token_ids: output.token_ids,
tokens: output.tokens,
logprobs: output.logprobs,
finish_reason: output.finish_reason,
seed: output.seed,
queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid
end: Instant::now(),
};
// unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Ok(response)).unwrap_or(());
});
}
#[derive(Debug)]
pub(crate) struct InferResponse {
pub(crate) output_text: String,
pub(crate) generated_tokens: u32,
pub(crate) token_ids: Vec<u32>,
pub(crate) tokens: Vec<String>,
pub(crate) logprobs: Vec<f32>,
pub(crate) finish_reason: String,
pub(crate) seed: Option<u64>,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) end: Instant,
}
#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
GenerationError(String),
}
/// Convert to Axum supported format
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self {
match err {
InferError::GenerationError(_) => (
StatusCode::FAILED_DEPENDENCY,
Json(ErrorResponse {
error: err.to_string(),
}),
),
}
}
}

View File

@ -1,14 +1,16 @@
/// This code is massively inspired by Tokio mini-redis /// This code is massively inspired by Tokio mini-redis
use crate::InferResponse; use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use parking_lot::Mutex; use parking_lot::Mutex;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ use text_generation_client::{
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
}; };
use tokio::sync::oneshot::Sender; use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::OwnedSemaphorePermit;
use tokio::time::Instant; use tokio::time::Instant;
/// Database entry /// Database entry
@ -16,14 +18,16 @@ use tokio::time::Instant;
pub(crate) struct Entry { pub(crate) struct Entry {
/// Request /// Request
pub request: GenerateRequest, pub request: GenerateRequest,
/// Response sender to communicate between the Batcher and the batching_task /// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: Sender<Result<InferResponse, ClientError>>, pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Number of tokens in the input /// Number of tokens in the input
pub input_length: usize, pub input_length: usize,
/// Instant when this entry was created /// Instant when this entry was created
pub time: Instant, pub time: Instant,
/// Instant when this entry was added to a batch /// Instant when this entry was added to a batch
pub batch_time: Option<Instant>, pub batch_time: Option<Instant>,
/// Permit
pub _permit: OwnedSemaphorePermit,
} }
/// Request Database /// Request Database

354
router/src/infer.rs Normal file
View File

@ -0,0 +1,354 @@
/// Batching and inference logic
use crate::validation::{Validation, ValidationError};
use crate::GenerateRequest;
use crate::{Db, Entry, Token};
use nohash_hasher::IntMap;
use std::future::Future;
use std::sync::Arc;
use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
};
use thiserror::Error;
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::instrument;
/// Inference struct
#[derive(Clone)]
pub struct Infer {
/// Validation
validation: Validation,
/// Request database
db: Db,
/// Shared state
shared: Arc<Shared>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
}
/// Infer shared state
struct Shared {
/// Batching background Tokio task notifier
batching_task: Notify,
}
impl Infer {
pub(crate) fn new(
client: ShardedClient,
validation: Validation,
max_batch_size: usize,
max_waiting_tokens: usize,
max_concurrent_requests: usize,
) -> Self {
// Infer shared state
let db = Db::new();
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
client,
max_batch_size,
max_waiting_tokens,
db.clone(),
shared.clone(),
));
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
Self {
validation,
db,
shared,
limit_concurrent_requests: semaphore,
}
}
/// Add a new request to the database and return a stream of InferStreamResponse
pub(crate) async fn generate_stream(
&self,
request: GenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
// This permit will live as long as Entry
let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?;
// Validate request
let (input_length, validated_request) = self.validation.validate(request).await?;
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();
// Append the request to the database
self.db.append(Entry {
request: validated_request,
response_tx,
input_length,
time: Instant::now(),
batch_time: None,
_permit: permit,
});
// Notify the background task that we have a new entry in the database that needs
// to be batched
self.shared.batching_task.notify_one();
// Return stream
Ok(UnboundedReceiverStream::new(response_rx))
}
/// Add a new request to the database and return a InferResponse
pub(crate) async fn generate(
&self,
request: GenerateRequest,
) -> Result<InferResponse, InferError> {
// Create stream
let mut stream = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new();
let mut result_generated_text = None;
let mut result_start = None;
let mut result_queued = None;
// Iterate on stream
while let Some(response) = stream.next().await {
match response? {
// Add prefill tokens
InferStreamResponse::Prefill(tokens) => {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
result_prefill = tokens
.ids
.into_iter()
.zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| Token(id, text, logprob))
.collect();
}
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token),
// Final message
// Set return values
InferStreamResponse::End {
token,
generated_text,
start,
queued,
} => {
result_tokens.push(token);
result_generated_text = Some(generated_text);
result_start = Some(start);
result_queued = Some(queued)
}
}
}
// Check that we received a `InferStreamResponse::End` message
if let (Some(generated_text), Some(queued), Some(start)) =
(result_generated_text, result_queued, result_start)
{
Ok(InferResponse {
prefill: result_prefill,
tokens: result_tokens,
generated_text,
queued,
start,
})
} else {
Err(InferError::IncompleteGeneration)
}
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[instrument(skip(client, db, shared))]
async fn batching_task(
mut client: ShardedClient,
max_batch_size: usize,
max_waiting_tokens: usize,
db: Db,
shared: Arc<Shared>,
) {
// Minimum batch size after which we try to add more requests
let limit_min_batch_size = (max_batch_size / 2) as u32;
// Infinite loop
loop {
// Wait for a notification from the Infer struct
shared.batching_task.notified().await;
// Get the next batch from the DB
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB
while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) {
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
// Get current batch info
let batch_size = batch.size;
let mut batches = vec![batch];
// If the current batch is too small, we try to add more requests to it
if batch_size <= limit_min_batch_size {
let min_size = match waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
_ if waiting_tokens >= max_waiting_tokens => None,
// Minimum size criteria
_ => Some(limit_min_batch_size as usize),
};
// Try to get a new batch
if let Some((mut new_entries, new_batch)) =
db.next_batch(min_size, max_batch_size - batch_size as usize)
{
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
wrap_future(client.prefill(new_batch), &mut new_entries).await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
}
}
}
cached_batch = wrap_future(client.decode(batches), &mut entries).await;
waiting_tokens += 1;
}
}
}
}
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
async fn wrap_future(
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
match future.await {
Ok((generations, next_batch)) => {
send_generations(generations, entries);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
send_error(err, entries);
None
}
}
}
/// Send errors to Infer for all `entries`
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(InferError::GenerationError(error.to_string())))
.unwrap_or(());
});
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&generation.request_id)
.expect("ID not found in entries. This is a bug.");
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))
.unwrap_or(());
}
// Create last Token
let token = Token(
generation.token_id,
generation.token_text,
generation.token_logprob,
);
if let Some(generated_text) = generation.generated_text {
// Remove entry as this is the last message
// We can `expect` here as the request id should always be in the entries
let entry = entries
.remove(&generation.request_id)
.expect("ID not found in entries. This is a bug.");
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.time,
start: entry.batch_time.unwrap(),
}))
.unwrap_or(());
} else {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Token(token)))
.unwrap_or(());
}
});
}
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(PrefillTokens),
// Intermediate messages
Token(Token),
// Last message
End {
token: Token,
generated_text: GeneratedText,
start: Instant,
queued: Instant,
},
}
#[derive(Debug)]
pub(crate) struct InferResponse {
pub(crate) prefill: Vec<Token>,
pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,
pub(crate) start: Instant,
}
#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
GenerationError(String),
#[error("Model is overloaded")]
Overloaded(#[from] TryAcquireError),
#[error("Input validation error: {0}")]
ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]
IncompleteGeneration,
}

View File

@ -1,11 +1,11 @@
/// Text Generation Inference Webserver /// Text Generation Inference Webserver
mod batcher;
mod db; mod db;
mod infer;
pub mod server; pub mod server;
mod validation; mod validation;
use batcher::{Batcher, InferResponse};
use db::{Db, Entry}; use db::{Db, Entry};
use infer::Infer;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validation::Validation; use validation::Validation;
@ -69,21 +69,34 @@ pub(crate) struct GenerateRequest {
pub parameters: GenerateParameters, pub parameters: GenerateParameters,
} }
#[derive(Debug, Serialize)]
pub struct Token(u32, String, f32);
#[derive(Serialize)] #[derive(Serialize)]
pub(crate) struct Details { pub(crate) struct Details {
pub finish_reason: String, pub finish_reason: String,
pub generated_tokens: u32, pub generated_tokens: u32,
pub seed: Option<u64>, pub seed: Option<u64>,
pub tokens: Vec<(u32, String, f32)>, #[serde(skip_serializing_if = "Option::is_none")]
pub prefill: Option<Vec<Token>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tokens: Option<Vec<Token>>,
} }
#[derive(Serialize)] #[derive(Serialize)]
pub(crate) struct GeneratedText { pub(crate) struct GenerateResponse {
pub generated_text: String, pub generated_text: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Details>, pub details: Option<Details>,
} }
#[derive(Serialize)]
pub(crate) struct StreamResponse {
pub token: Token,
pub generated_text: Option<String>,
pub details: Option<Details>,
}
#[derive(Serialize)] #[derive(Serialize)]
pub(crate) struct ErrorResponse { pub(crate) struct ErrorResponse {
pub error: String, pub error: String,

View File

@ -1,51 +1,35 @@
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
use crate::{ use crate::{
Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation, Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer,
StreamResponse, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode}; use axum::http::{HeaderMap, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use futures::Stream;
use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::sync::Semaphore;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::StreamExt;
use tracing::instrument; use tracing::instrument;
// Server shared state
#[derive(Clone)]
struct ServerState {
validation: Validation,
batcher: Batcher,
limit_concurrent_requests: Arc<Semaphore>,
}
/// Health check method /// Health check method
#[instrument(skip(state), fields(time, time_per_token))] #[instrument(skip(infer))]
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might // TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check. // be a bit too slow for a health check.
// What we should do instead if check if the gRPC channels are still healthy. // What we should do instead if check if the gRPC channels are still healthy.
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
(
StatusCode::TOO_MANY_REQUESTS,
Json(ErrorResponse {
error: "Model is overloaded".to_string(),
}),
)
})?;
// Send a small inference request // Send a small inference request
state infer
.batcher .generate(GenerateRequest {
.infer(
1,
GenerateRequest {
inputs: "liveness".to_string(), inputs: "liveness".to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
temperature: 1.0, temperature: 1.0,
@ -57,15 +41,14 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
details: false, details: false,
seed: None, seed: None,
}, },
}, })
)
.await?; .await?;
Ok(()) Ok(())
} }
/// Generate method /// Generate method
#[instrument( #[instrument(
skip(state), skip(infer),
fields( fields(
total_time, total_time,
validation_time, validation_time,
@ -76,56 +59,28 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
) )
)] )]
async fn generate( async fn generate(
state: Extension<ServerState>, infer: Extension<Infer>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> { ) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
let start_time = Instant::now(); let start_time = Instant::now();
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
tracing::error!("Model is overloaded");
(
StatusCode::TOO_MANY_REQUESTS,
Json(ErrorResponse {
error: "Model is overloaded".to_string(),
}),
)
})?;
// Validate request
let details = req.0.parameters.details;
let (input_length, validated_request) =
state.validation.validate(req.0).await.map_err(|err| {
tracing::error!("{}", err.to_string());
err
})?;
// Inference // Inference
let response = state let details = req.0.parameters.details;
.batcher let response = infer.generate(req.0).await.map_err(|err| {
.infer(input_length, validated_request)
.await
.map_err(|err| {
tracing::error!("{}", err.to_string()); tracing::error!("{}", err.to_string());
err err
})?; })?;
// Token details // Token details
let details = match details { let details = match details {
true => { true => Some(Details {
let tokens = response finish_reason: response.generated_text.finish_reason,
.token_ids generated_tokens: response.generated_text.generated_tokens,
.into_iter() prefill: Some(response.prefill),
.zip(response.tokens.into_iter()) tokens: Some(response.tokens),
.zip(response.logprobs.into_iter()) seed: response.generated_text.seed,
.map(|((id, text), logprob)| (id, text, logprob)) }),
.collect();
Some(Details {
seed: response.seed,
finish_reason: response.finish_reason,
generated_tokens: response.generated_tokens,
tokens,
})
}
false => None, false => None,
}; };
@ -133,8 +88,8 @@ async fn generate(
let total_time = start_time.elapsed(); let total_time = start_time.elapsed();
let validation_time = response.queued - start_time; let validation_time = response.queued - start_time;
let queue_time = response.start - response.queued; let queue_time = response.start - response.queued;
let inference_time = response.end - response.start; let inference_time = Instant::now() - response.start;
let time_per_token = inference_time / response.generated_tokens; let time_per_token = inference_time / response.generated_text.generated_tokens;
// Headers // Headers
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
@ -160,22 +115,143 @@ async fn generate(
); );
// Tracing metadata // Tracing metadata
tracing::Span::current().record("total_time", format!("{:?}", total_time)); span.record("total_time", format!("{:?}", total_time));
tracing::Span::current().record("validation_time", format!("{:?}", validation_time)); span.record("validation_time", format!("{:?}", validation_time));
tracing::Span::current().record("queue_time", format!("{:?}", queue_time)); span.record("queue_time", format!("{:?}", queue_time));
tracing::Span::current().record("inference_time", format!("{:?}", inference_time)); span.record("inference_time", format!("{:?}", inference_time));
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); span.record("time_per_token", format!("{:?}", time_per_token));
tracing::Span::current().record("seed", format!("{:?}", response.seed)); span.record("seed", format!("{:?}", response.generated_text.seed));
tracing::info!("Output: {}", response.output_text); tracing::info!("Output: {}", response.generated_text.text);
// Send response // Send response
let response = vec![GeneratedText { let response = vec![GenerateResponse {
generated_text: response.output_text, generated_text: response.generated_text.text,
details, details,
}]; }];
Ok((headers, Json(response))) Ok((headers, Json(response)))
} }
/// Generate stream method
#[instrument(
skip(infer),
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token
)
)]
async fn generate_stream(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let span = tracing::Span::current();
let start_time = Instant::now();
let stream = async_stream::stream! {
// Inference
let mut end_reached = false;
let mut error = false;
let details = req.0.parameters.details;
match infer.generate_stream(req.0).await {
Ok(mut response_stream) => {
// Server Side Event stream
while let Some(response) = response_stream.next().await {
match response {
Ok(response) => {
match response {
// Prefill is ignored
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
InferStreamResponse::Token(token) => {
// StreamResponse
let stream_token = StreamResponse {
token,
generated_text: None,
details: None,
};
yield Ok(Event::default().json_data(stream_token).unwrap())
}
// Yield event for last token and compute timings
InferStreamResponse::End {
token,
generated_text,
start,
queued,
} => {
// Token details
let details = match details {
true => Some(Details {
finish_reason: generated_text.finish_reason,
generated_tokens: generated_text.generated_tokens,
prefill: None,
tokens: None,
seed: generated_text.seed,
}),
false => None,
};
// Timings
let total_time = start_time.elapsed();
let validation_time = queued - start_time;
let queue_time = start - queued;
let inference_time = Instant::now() - start;
let time_per_token = inference_time / generated_text.generated_tokens;
// Tracing metadata
span.record("total_time", format!("{:?}", total_time));
span
.record("validation_time", format!("{:?}", validation_time));
span.record("queue_time", format!("{:?}", queue_time));
span
.record("inference_time", format!("{:?}", inference_time));
span
.record("time_per_token", format!("{:?}", time_per_token));
tracing::info!(parent: &span, "Output: {}", generated_text.text);
// StreamResponse
end_reached = true;
let stream_token = StreamResponse {
token,
generated_text: Some(generated_text.text),
details
};
yield Ok(Event::default().json_data(stream_token).unwrap())
}
}
}
// Trace and yield error
Err(err) => {
error = true;
tracing::error!("{}", err.to_string());
yield Ok(Event::from(err))
}
}
}
},
// Trace and yield error
Err(err) => {
error = true;
tracing::error!("{}", err.to_string());
yield Ok(Event::from(err))
}
}
// Check if generation reached the end
// Skip if we already sent an error
if !end_reached && !error {
let err = InferError::IncompleteGeneration;
tracing::error!("{}", err.to_string());
yield Ok(Event::from(err))
}
};
Sse::new(stream).keep_alive(KeepAlive::default())
}
/// Serving method /// Serving method
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn run( pub async fn run(
@ -189,21 +265,23 @@ pub async fn run(
addr: SocketAddr, addr: SocketAddr,
) { ) {
// Create state // Create state
let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens);
let validation = Validation::new(validation_workers, tokenizer, max_input_length); let validation = Validation::new(validation_workers, tokenizer, max_input_length);
let shared_state = ServerState { let infer = Infer::new(
client,
validation, validation,
batcher, max_batch_size,
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)), max_waiting_tokens,
}; max_concurrent_requests,
);
// Create router // Create router
let app = Router::new() let app = Router::new()
.route("/", post(generate)) .route("/", post(generate))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/", get(health)) .route("/", get(health))
.route("/health", get(health)) .route("/health", get(health))
.layer(Extension(shared_state.clone())); .layer(Extension(infer));
// Run server // Run server
axum::Server::bind(&addr) axum::Server::bind(&addr)
@ -240,3 +318,32 @@ async fn shutdown_signal() {
tracing::info!("signal received, starting graceful shutdown"); tracing::info!("signal received, starting graceful shutdown");
} }
/// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self {
let status_code = match err {
InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
};
(
status_code,
Json(ErrorResponse {
error: err.to_string(),
}),
)
}
}
impl From<InferError> for Event {
fn from(err: InferError) -> Self {
Event::default()
.json_data(ErrorResponse {
error: err.to_string(),
})
.unwrap()
}
}

View File

@ -1,7 +1,5 @@
/// Payload validation logic /// Payload validation logic
use crate::{ErrorResponse, GenerateRequest}; use crate::GenerateRequest;
use axum::http::StatusCode;
use axum::Json;
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
@ -161,14 +159,3 @@ pub enum ValidationError {
#[error("tokenizer error {0}")] #[error("tokenizer error {0}")]
Tokenizer(String), Tokenizer(String),
} }
impl From<ValidationError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: ValidationError) -> Self {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
}),
)
}
}

View File

@ -91,9 +91,9 @@ def test_causal_lm_batch_type(default_bloom):
def test_causal_lm_generate_token(default_bloom, default_bloom_batch): def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
sequence_length = len(default_bloom_batch.all_input_ids[0]) sequence_length = len(default_bloom_batch.all_input_ids[0])
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch) generations, next_batch = default_bloom.generate_token(default_bloom_batch)
assert generated_texts == [] assert len(generations) == len(default_bloom_batch)
assert isinstance(next_batch, CausalLMBatch) assert isinstance(next_batch, CausalLMBatch)
assert not next_batch.keys_head_dim_last assert not next_batch.keys_head_dim_last
@ -122,24 +122,30 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert all( assert all(
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values] [p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
) )
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 10264 for generation in generations])
assert all([generation.token_text == "Test" for generation in generations])
assert generations[0].request_id == 0
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch): def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
next_batch = default_bloom_batch next_batch = default_bloom_batch
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1): for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(default_bloom_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert ( assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
) )
assert generated_texts[0].request == default_bloom_batch.requests[0] assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens
) )
@ -152,17 +158,19 @@ def test_causal_lm_generate_token_completion_multi(
for i in range( for i in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1 default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
): ):
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(default_multi_requests_bloom_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 2
assert generated_texts[0].output_text == "TestTestTestTestTestTest" assert generations[1].generated_text.text == "TestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
assert ( assert (
generated_texts[0].generated_tokens generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
)
assert (
generations[1].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
) )
@ -171,19 +179,22 @@ def test_causal_lm_generate_token_completion_multi(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 1 - 1
): ):
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert ( assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
) )
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert ( assert (
generated_texts[0].generated_tokens generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
) )
@ -243,17 +254,19 @@ def test_batch_concatenate(
for _ in range( for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
): ):
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 3
assert generated_texts[0].output_text == "TestTestTestTestTestTest" assert generations[2].generated_text.text == "TestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
assert ( assert (
generated_texts[0].generated_tokens generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
)
assert (
generations[2].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
) )
@ -262,19 +275,20 @@ def test_batch_concatenate(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 2 - 2
): ):
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 2
assert ( assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
) )
assert generated_texts[0].request == default_bloom_batch.requests[0] assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens
) )
@ -284,18 +298,21 @@ def test_batch_concatenate(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 4 - 4
): ):
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_bloom.generate_token(next_batch) generations, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert ( assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
) )
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert ( assert (
generated_texts[0].generated_tokens generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
) )

View File

@ -88,11 +88,9 @@ def test_causal_lm_batch_type(default_causal_lm):
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
sequence_length = len(default_causal_lm_batch.all_input_ids[0]) sequence_length = len(default_causal_lm_batch.all_input_ids[0])
generated_texts, next_batch = default_causal_lm.generate_token( generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
default_causal_lm_batch
)
assert generated_texts == [] assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch) assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == next_batch.size assert len(next_batch.all_input_ids) == next_batch.size
@ -121,6 +119,11 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert all( assert all(
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values] [p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
) )
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 13 for generation in generations])
assert all([generation.token_text == "." for generation in generations])
assert generations[0].request_id == 0
def test_causal_lm_generate_token_completion( def test_causal_lm_generate_token_completion(
@ -128,18 +131,17 @@ def test_causal_lm_generate_token_completion(
): ):
next_batch = default_causal_lm_batch next_batch = default_causal_lm_batch
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1): for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generated_texts[0].request == default_causal_lm_batch.requests[0] assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens
) )
@ -152,19 +154,20 @@ def test_causal_lm_generate_token_completion_multi(
for i in range( for i in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
): ):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 2
assert generated_texts[0].output_text == "Test.java:784)" assert generations[1].generated_text.text == "Test.java:784)"
assert ( assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] generations[1].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
) )
assert ( assert (
generated_texts[0].generated_tokens generations[1].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
) )
@ -173,19 +176,20 @@ def test_causal_lm_generate_token_completion_multi(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 1 - 1
): ):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert ( assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
) )
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
) )
@ -244,19 +248,20 @@ def test_batch_concatenate(
for _ in range( for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
): ):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 3
assert generated_texts[0].output_text == "Test.java:784)" assert generations[2].generated_text.text == "Test.java:784)"
assert ( assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] generations[2].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
) )
assert ( assert (
generated_texts[0].generated_tokens generations[2].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
) )
@ -265,17 +270,17 @@ def test_batch_concatenate(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 2 - 2
): ):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 2
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generated_texts[0].request == default_causal_lm_batch.requests[0] assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens
) )
@ -285,18 +290,19 @@ def test_batch_concatenate(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 4 - 4
): ):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert ( assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
) )
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
) )

View File

@ -50,18 +50,17 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
next_batch = batch next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_santacoder.generate_token(next_batch) generations, next_batch = default_santacoder.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_santacoder.generate_token(next_batch) generations, next_batch = default_santacoder.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "def test_get_all_users_with_" assert generations[0].generated_text.text == "def test_get_all_users_with_"
assert generated_texts[0].request == batch.requests[0] assert generations[0].request_id == batch.requests[0].id
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== batch.stopping_criterias[0].max_new_tokens == batch.stopping_criterias[0].max_new_tokens
) )
@ -76,20 +75,19 @@ def test_fim_santacoder_generate_token_completion(
next_batch = batch next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_santacoder.generate_token(next_batch) generations, next_batch = default_santacoder.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_santacoder.generate_token(next_batch) generations, next_batch = default_santacoder.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert ( assert (
generated_texts[0].output_text generations[0].generated_text.text
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value""" == """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
) )
assert generated_texts[0].request == batch.requests[0] assert generations[0].request_id == batch.requests[0].id
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
assert ( assert (
generated_texts[0].generated_tokens generations[0].generated_text.generated_tokens
== batch.stopping_criterias[0].max_new_tokens == batch.stopping_criterias[0].max_new_tokens
) )

View File

@ -99,11 +99,11 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
generated_texts, next_batch = default_seq2seq_lm.generate_token( generations, next_batch = default_seq2seq_lm.generate_token(
default_seq2seq_lm_batch default_seq2seq_lm_batch
) )
assert generated_texts == [] assert len(generations) == len(next_batch)
assert isinstance(next_batch, Seq2SeqLMBatch) assert isinstance(next_batch, Seq2SeqLMBatch)
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids) assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
@ -145,6 +145,11 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
for p in next_batch.past_key_values for p in next_batch.past_key_values
] ]
) )
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 259 for generation in generations])
assert all([generation.token_text == "" for generation in generations])
assert generations[0].request_id == 0
def test_seq2seq_lm_generate_token_completion( def test_seq2seq_lm_generate_token_completion(
@ -152,16 +157,16 @@ def test_seq2seq_lm_generate_token_completion(
): ):
next_batch = default_seq2seq_lm_batch next_batch = default_seq2seq_lm_batch
for _ in range(6): for _ in range(6):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "a few weeks" assert generations[0].generated_text.text == "a few weeks"
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generated_texts[0].generated_tokens == 7 assert generations[0].generated_text.generated_tokens == 7
def test_seq2seq_lm_generate_token_completion_multi( def test_seq2seq_lm_generate_token_completion_multi(
@ -170,33 +175,33 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch = default_multi_requests_seq2seq_lm_batch next_batch = default_multi_requests_seq2seq_lm_batch
for i in range(4): for i in range(4):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 2
assert generated_texts[0].output_text == "a few " assert generations[1].generated_text.text == "a few "
assert ( assert (
generated_texts[0].request generations[1].request_id
== default_multi_requests_seq2seq_lm_batch.requests[1] == default_multi_requests_seq2seq_lm_batch.requests[1].id
) )
assert generated_texts[0].generated_tokens == 5 assert generations[1].generated_text.generated_tokens == 5
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "a few weeks" assert generations[0].generated_text.text == "a few weeks"
assert ( assert (
generated_texts[0].request generations[0].request_id
== default_multi_requests_seq2seq_lm_batch.requests[0] == default_multi_requests_seq2seq_lm_batch.requests[0].id
) )
assert generated_texts[0].generated_tokens == 7 assert generations[0].generated_text.generated_tokens == 7
def test_batch_concatenate( def test_batch_concatenate(
@ -291,35 +296,35 @@ def test_batch_concatenate(
) )
for _ in range(3): for _ in range(3):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == [] assert len(generations) == len(next_batch)
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 3
assert generated_texts[0].output_text == "a few " assert generations[2].generated_text.text == "a few "
assert ( assert (
generated_texts[0].request generations[2].request_id
== default_multi_requests_seq2seq_lm_batch.requests[1] == default_multi_requests_seq2seq_lm_batch.requests[1].id
) )
assert generated_texts[0].generated_tokens == 5 assert generations[2].generated_text.generated_tokens == 5
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generations) == 2
assert generated_texts[0].output_text == "a few weeks" assert generations[0].generated_text.text == "a few weeks"
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generated_texts[0].generated_tokens == 7 assert generations[0].generated_text.generated_tokens == 7
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generations) == 1
assert generated_texts[0].output_text == "a few weeks" assert generations[0].generated_text.text == "a few weeks"
assert ( assert (
generated_texts[0].request generations[0].request_id
== default_multi_requests_seq2seq_lm_batch.requests[0] == default_multi_requests_seq2seq_lm_batch.requests[0].id
) )
assert generated_texts[0].generated_tokens == 7 assert generations[0].generated_text.generated_tokens == 7

View File

@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -23,7 +23,6 @@ class CausalLMBatch(Batch):
# All tokens # All tokens
all_input_ids: List[torch.Tensor] all_input_ids: List[torch.Tensor]
all_logprobs: List[Optional[torch.Tensor]]
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
@ -57,7 +56,6 @@ class CausalLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
input_lengths = [] input_lengths = []
all_logprobs = []
# Parse batch # Parse batch
for r in pb.requests: for r in pb.requests:
@ -67,7 +65,6 @@ class CausalLMBatch(Batch):
stopping_criterias.append( stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
) )
all_logprobs.append(None)
pad_to_multiple_of = 8 if device.type == "cuda" else None pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
@ -89,7 +86,6 @@ class CausalLMBatch(Batch):
position_ids=position_ids, position_ids=position_ids,
past_key_values=None, past_key_values=None,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_logprobs=all_logprobs,
input_lengths=input_lengths, input_lengths=input_lengths,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
@ -107,7 +103,6 @@ class CausalLMBatch(Batch):
requests = [] requests = []
input_lengths = [] input_lengths = []
all_input_ids = [] all_input_ids = []
all_logprobs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -124,7 +119,6 @@ class CausalLMBatch(Batch):
requests.extend(batch.requests) requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
all_logprobs.extend(batch.all_logprobs)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
@ -225,7 +219,6 @@ class CausalLMBatch(Batch):
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_logprobs=all_logprobs,
input_lengths=input_lengths, input_lengths=input_lengths,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
@ -234,6 +227,9 @@ class CausalLMBatch(Batch):
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
) )
def __len__(self):
return len(self.requests)
class CausalLM(Model): class CausalLM(Model):
def __init__(self, model_name: str, quantize=False): def __init__(self, model_name: str, quantize=False):
@ -289,7 +285,7 @@ class CausalLM(Model):
def generate_token( def generate_token(
self, batch: CausalLMBatch self, batch: CausalLMBatch
) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]: ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU # For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = ( context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode torch.no_grad if self.device.type == "cpu" else torch.inference_mode
@ -309,14 +305,13 @@ class CausalLM(Model):
next_batch_input_lengths = [] next_batch_input_lengths = []
next_batch_input_ids = [] next_batch_input_ids = []
next_batch_all_input_ids = [] next_batch_all_input_ids = []
next_batch_all_logprobs = []
# Metadata # Metadata
next_batch_size = 0 next_batch_size = 0
next_batch_max_sequence_length = 0 next_batch_max_sequence_length = 0
# Finished requests # Results
generated_texts: List[GeneratedText] = [] generations: List[Generation] = []
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
@ -326,7 +321,6 @@ class CausalLM(Model):
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.all_logprobs,
) )
# For each member of the batch # For each member of the batch
@ -337,44 +331,36 @@ class CausalLM(Model):
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
all_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
tokens, logprobs = next_token_chooser(all_input_ids, logits) tokens, logprobs = next_token_chooser(all_input_ids, logits)
next_token = tokens[-1].view(1, 1) next_token_id = tokens[-1].view(1, 1)
# Append next token to all tokens # Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token]) all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1 new_input_length = input_length + 1
if all_logprobs is None: # Generated token
# logprobs of all prompt tokens (except the first one) and the generated token next_token_logprob = logprobs[-1, next_token_id]
all_logprobs = logprobs.gather(1, all_input_ids[1:]) next_token_id_squeezed = next_token_id.squeeze()
else: next_token_text = self.tokenizer.decode(
# logprob of the generated token next_token_id_squeezed,
next_token_logprob = logprobs[-1, next_token] clean_up_tokenization_spaces=False,
all_logprobs = torch.cat([all_logprobs, next_token_logprob]) skip_special_tokens=False,
)
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria( stop, reason = stopping_criteria(
next_token.squeeze(), next_token_id_squeezed,
self.tokenizer.decode( next_token_text,
next_token.squeeze(), clean_up_tokenization_spaces=False
),
) )
if stop: if stop:
# Decode generated tokens # Decode generated tokens
generated_text = self.decode( generated_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0] all_input_ids[-stopping_criteria.current_tokens :, 0]
) )
output_text = request.inputs + generated_text output_text = request.inputs + generated_text
# Slice with input_length to remove padding
token_ids = all_input_ids[-new_input_length:]
tokens = self.tokenizer.batch_decode(token_ids)
# Add NaN for the first prompt token
logprobs = [float("nan")] + all_logprobs[-input_length:].squeeze(
1
).tolist()
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
@ -382,39 +368,58 @@ class CausalLM(Model):
else: else:
seed = None seed = None
# Add to the list of finished generations with the original request generated_text = GeneratedText(
generated_texts.append( output_text, stopping_criteria.current_tokens, reason, seed
GeneratedText(
request=request,
output_text=output_text,
generated_tokens=stopping_criteria.current_tokens,
tokens=tokens,
token_ids=token_ids.squeeze(1).tolist(),
logprobs=logprobs,
reason=reason,
seed=seed,
) )
)
# add to the next batch
else: else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i) next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token) next_batch_input_ids.append(next_token_id)
next_batch_all_input_ids.append(all_input_ids) next_batch_all_input_ids.append(all_input_ids)
next_batch_all_logprobs.append(all_logprobs)
next_batch_size += 1 next_batch_size += 1
next_batch_input_lengths.append(new_input_length) next_batch_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max( next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length next_batch_max_sequence_length, new_input_length
) )
# Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids[1:]
).squeeze(1)[-new_input_length:-1].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
generations.append(generation)
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices: if not next_batch_keep_indices:
return generated_texts, None return generations, None
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0) next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
# If we finished at least one generation, we need to evict the indices of the generations that finished # If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch # from the values of the next batch
if generated_texts: if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached # Apply indices to attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
next_batch_position_ids = batch.position_ids[next_batch_keep_indices] next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
@ -461,7 +466,6 @@ class CausalLM(Model):
position_ids=next_batch_position_ids, position_ids=next_batch_position_ids,
past_key_values=next_batch_past_key_values, past_key_values=next_batch_past_key_values,
all_input_ids=next_batch_all_input_ids, all_input_ids=next_batch_all_input_ids,
all_logprobs=next_batch_all_logprobs,
input_lengths=next_batch_input_lengths, input_lengths=next_batch_input_lengths,
next_token_choosers=next_batch_next_token_choosers, next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias, stopping_criterias=next_batch_stopping_criterias,
@ -469,4 +473,4 @@ class CausalLM(Model):
max_sequence_length=next_batch_max_sequence_length, max_sequence_length=next_batch_max_sequence_length,
keys_head_dim_last=batch.keys_head_dim_last, keys_head_dim_last=batch.keys_head_dim_last,
) )
return generated_texts, next_batch return generations, next_batch

View File

@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokeniz
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -30,7 +30,6 @@ class Seq2SeqLMBatch(Batch):
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
decoder_input_lengths: List[int] decoder_input_lengths: List[int]
decoder_logprobs: List[Optional[torch.Tensor]]
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
@ -64,7 +63,6 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids = [] decoder_input_ids = []
decoder_input_lengths = [] decoder_input_lengths = []
decoder_logprobs = []
# Parse batch # Parse batch
for r in pb.requests: for r in pb.requests:
@ -77,7 +75,6 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias.append( stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
) )
decoder_logprobs.append(None)
# Tokenize batch # Tokenize batch
pad_to_multiple_of = 8 if device.type == "cuda" else None pad_to_multiple_of = 8 if device.type == "cuda" else None
@ -102,7 +99,6 @@ class Seq2SeqLMBatch(Batch):
past_key_values=None, past_key_values=None,
input_lengths=input_lengths, input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths, decoder_input_lengths=decoder_input_lengths,
decoder_logprobs=decoder_logprobs,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=len(pb.requests), size=len(pb.requests),
@ -125,7 +121,6 @@ class Seq2SeqLMBatch(Batch):
requests = [] requests = []
input_lengths = [] input_lengths = []
decoder_input_lengths = [] decoder_input_lengths = []
decoder_logprobs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -146,7 +141,6 @@ class Seq2SeqLMBatch(Batch):
requests.extend(batch.requests) requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths)
decoder_logprobs.extend(batch.decoder_logprobs)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
@ -283,7 +277,6 @@ class Seq2SeqLMBatch(Batch):
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths, decoder_input_lengths=decoder_input_lengths,
decoder_logprobs=decoder_logprobs,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size, size=total_batch_size,
@ -291,6 +284,9 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length=max_decoder_input_length, max_decoder_input_length=max_decoder_input_length,
) )
def __len__(self):
return len(self.requests)
class Seq2SeqLM(Model): class Seq2SeqLM(Model):
def __init__(self, model_name: str, quantize=False): def __init__(self, model_name: str, quantize=False):
@ -364,7 +360,7 @@ class Seq2SeqLM(Model):
def generate_token( def generate_token(
self, batch: Seq2SeqLMBatch self, batch: Seq2SeqLMBatch
) -> Tuple[List[GeneratedText], Optional[Seq2SeqLMBatch]]: ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU # For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = ( context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode torch.no_grad if self.device.type == "cpu" else torch.inference_mode
@ -386,7 +382,6 @@ class Seq2SeqLM(Model):
next_batch_input_lengths = [] next_batch_input_lengths = []
next_batch_decoder_input_ids = [] next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = [] next_batch_decoder_input_lengths = []
next_batch_decoder_logprobs = []
# Metadata # Metadata
next_batch_size = 0 next_batch_size = 0
@ -394,14 +389,13 @@ class Seq2SeqLM(Model):
next_batch_max_decoder_input_length = 0 next_batch_max_decoder_input_length = 0
# Finished requests # Finished requests
generated_texts: List[GeneratedText] = [] generations: List[Generation] = []
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.input_lengths, batch.input_lengths,
batch.decoder_input_lengths, batch.decoder_input_lengths,
batch.decoder_logprobs,
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
@ -414,7 +408,6 @@ class Seq2SeqLM(Model):
request, request,
input_length, input_length,
decoder_input_length, decoder_input_length,
decoder_logprobs,
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
@ -422,35 +415,28 @@ class Seq2SeqLM(Model):
decoder_input_ids, decoder_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token, logprobs = next_token_chooser(decoder_input_ids, logits) next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
# Append next token to decoder tokens # Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token]) decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
new_decoder_input_length = decoder_input_length + 1 new_decoder_input_length = decoder_input_length + 1
next_token_logprob = logprobs[-1, next_token] # Generated token
if decoder_logprobs is None: next_token_logprob = logprobs[-1, next_token_id]
decoder_logprobs = next_token_logprob next_token_id_squeezed = next_token_id.squeeze()
else: next_token_text = self.tokenizer.decode(
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob]) next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria( stop, reason = stopping_criteria(next_token_id, next_token_text)
next_token.squeeze(),
self.tokenizer.decode(
next_token.squeeze(), clean_up_tokenization_spaces=False
),
)
if stop: if stop:
# Slice with decoder_input_length to remove padding # Slice with decoder_input_length to remove padding
# Decode all tokens # Decode all tokens
token_ids = decoder_input_ids[-new_decoder_input_length:] output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
output_text = self.decode(token_ids)
tokens = self.tokenizer.batch_decode(token_ids)
# Add NaN for the bos token
logprobs = [float("nan")] + decoder_logprobs[
-decoder_input_length:
].tolist()
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
@ -458,27 +444,17 @@ class Seq2SeqLM(Model):
else: else:
seed = None seed = None
# Add to the list of finished generations with the original request generated_text = GeneratedText(
generated_texts.append( output_text, stopping_criteria.current_tokens, reason, seed
GeneratedText(
request=request,
output_text=output_text,
generated_tokens=stopping_criteria.current_tokens,
tokens=tokens,
token_ids=token_ids.tolist(),
logprobs=logprobs,
reason=reason,
seed=seed,
) )
)
# add to the next batch
else: else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i) next_batch_keep_indices.append(i)
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1 next_batch_size += 1
next_batch_input_lengths.append(input_length) next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length) next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_decoder_logprobs.append(decoder_logprobs)
next_batch_max_input_length = max( next_batch_max_input_length = max(
next_batch_max_input_length, input_length next_batch_max_input_length, input_length
) )
@ -486,14 +462,39 @@ class Seq2SeqLM(Model):
next_batch_max_decoder_input_length, new_decoder_input_length next_batch_max_decoder_input_length, new_decoder_input_length
) )
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, [float("nan")], prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
generations.append(generation)
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices: if not next_batch_keep_indices:
return generated_texts, None return generations, None
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids) next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
# If we finished at least one generation, we need to evict the indices of the generations that finished # If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch # from the values of the next batch
if generated_texts: if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached # Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids = batch.input_ids[next_batch_keep_indices] next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
@ -551,11 +552,10 @@ class Seq2SeqLM(Model):
past_key_values=next_batch_past_key_values, past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths, input_lengths=next_batch_input_lengths,
decoder_input_lengths=next_batch_decoder_input_lengths, decoder_input_lengths=next_batch_decoder_input_lengths,
decoder_logprobs=next_batch_decoder_logprobs,
next_token_choosers=next_batch_next_token_choosers, next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias, stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size, size=next_batch_size,
max_input_length=next_batch_max_input_length, max_input_length=next_batch_max_input_length,
max_decoder_input_length=next_batch_max_decoder_input_length, max_decoder_input_length=next_batch_max_decoder_input_length,
) )
return generated_texts, next_batch return generations, next_batch

View File

@ -29,26 +29,61 @@ class Batch(ABC):
def concatenate(cls, batches: List["Batch"]) -> "Batch": def concatenate(cls, batches: List["Batch"]) -> "Batch":
raise NotImplementedError raise NotImplementedError
@abstractmethod
def __len__(self):
raise NotImplementedError
@dataclass @dataclass
class GeneratedText: class GeneratedText:
request: generate_pb2.Request text: str
output_text: str
generated_tokens: int generated_tokens: int
tokens: List[str] finish_reason: str
token_ids: List[int]
logprobs: List[float]
reason: str
seed: Optional[int] seed: Optional[int]
def to_pb(self) -> generate_pb2.GeneratedText: def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText( return generate_pb2.GeneratedText(
request=self.request, text=self.text,
output_text=self.output_text,
generated_tokens=self.generated_tokens, generated_tokens=self.generated_tokens,
tokens=self.tokens, finish_reason=self.finish_reason,
token_ids=self.token_ids,
logprobs=self.logprobs,
finish_reason=self.reason,
seed=self.seed, seed=self.seed,
) )
@dataclass
class PrefillTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
def to_pb(self) -> generate_pb2.PrefillTokens:
return generate_pb2.PrefillTokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
)
def __len__(self):
return len(self.token_ids)
@dataclass
class Generation:
request_id: int
prefill_tokens: Optional[PrefillTokens]
token_id: int
token_logprob: float
token_text: str
generated_text: Optional[GeneratedText]
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
request_id=self.request_id,
prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None
else None,
token_id=self.token_id,
token_logprob=self.token_logprob,
token_text=self.token_text,
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
)

View File

@ -27,22 +27,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.cache.clear() self.cache.clear()
return generate_pb2.ClearCacheResponse() return generate_pb2.ClearCacheResponse()
async def Generate(self, request, context): async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device request.batch, self.model.tokenizer, self.model.device
) )
generated_texts, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.GenerateResponse( return generate_pb2.PrefillResponse(
generated_texts=[ generations=[generation.to_pb() for generation in generations],
generated_text.to_pb() for generated_text in generated_texts
],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
) )
async def GenerateWithCache(self, request, context): async def Decode(self, request, context):
if len(request.batches) == 0: if len(request.batches) == 0:
raise ValueError("Must provide at least one batch") raise ValueError("Must provide at least one batch")
@ -58,13 +56,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
else: else:
batch = batches[0] batch = batches[0]
generated_texts, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.GenerateWithCacheResponse( return generate_pb2.DecodeResponse(
generated_texts=[ generations=[generation.to_pb() for generation in generations],
generated_text.to_pb() for generated_text in generated_texts
],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
) )