feat: prefill chunking (#2600)
* wip * rollback * refactor to use prefix/postfix namming + fix all_input_ids_tensor * maybe patching vlms? * fix filter and concat * wip, no filter, no concat * current * add prepare_for_prefill * working * load tested * re-create slots * re-create slots * fix slot_filtering_indices * feedback loop * remove log * fix benchmarker * fix vlm and seq2seq * rename to cache and input lengths * fix prefill logprobs * fix launcher * fix logprobs? * idk at this point * max input length * omfg * remove debugging lines * fix tests * fix mllama * fix cargo tests * remove support chunking for paged * Fixing non blocked attentions * Fixing dtype + AMD, Ipex targets. * lint fix. * rename * Fix prefix_caching variable, remove defaults in server (confusing a lot of the times). * Add simple resolution when user specifies ATTENTION=paged. * Put back non default simple tests. * Fix env name --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
704a58c807
commit
a6a0c97ed9
|
@ -327,7 +327,8 @@ ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||||
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||||
ENV VLLM_MOE_PADDING=0
|
ENV VLLM_MOE_PADDING=0
|
||||||
ENV ATTENTION=paged
|
ENV ATTENTION=paged
|
||||||
ENV USE_PREFIX_CACHING=0
|
ENV PREFIX_CACHING=0
|
||||||
|
ENV PREFILL_CHUNKING=0
|
||||||
ENV ROCM_USE_SKINNY_GEMM=1
|
ENV ROCM_USE_SKINNY_GEMM=1
|
||||||
|
|
||||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
|
|
|
@ -218,7 +218,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
|
||||||
|
|
||||||
FROM ${PLATFORM} AS final
|
FROM ${PLATFORM} AS final
|
||||||
ENV ATTENTION=paged
|
ENV ATTENTION=paged
|
||||||
ENV USE_PREFIX_CACHING=0
|
ENV PREFIX_CACHING=0
|
||||||
|
ENV PREFILL_CHUNKING=0
|
||||||
ENV CUDA_GRAPHS=0
|
ENV CUDA_GRAPHS=0
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
|
|
@ -158,7 +158,8 @@ impl Client {
|
||||||
// Blocks and slots will be set on the server side if we use paged attention
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
|
chunk_len: None,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
|
@ -217,8 +218,13 @@ impl Client {
|
||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
let request = tonic::Request::new(PrefillRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
cached_batch,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
let response = self.stub.prefill(request).await?.into_inner();
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
Ok((
|
Ok((
|
||||||
response.generations,
|
response.generations,
|
||||||
|
|
|
@ -134,11 +134,12 @@ impl ShardedClient {
|
||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||||
|
@ -245,7 +246,8 @@ impl Health for ShardedClient {
|
||||||
// Block 0 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
|
chunk_len: None,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
|
@ -255,7 +257,7 @@ impl Health for ShardedClient {
|
||||||
max_tokens: 2,
|
max_tokens: 2,
|
||||||
max_blocks: 1,
|
max_blocks: 1,
|
||||||
};
|
};
|
||||||
self.clone().prefill(batch).await?;
|
self.clone().prefill(batch, None).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify};
|
use tokio::sync::{mpsc, Notify};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
@ -36,18 +36,14 @@ impl BackendV2 {
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
|
||||||
attention
|
let block_size = match attention.as_str() {
|
||||||
.parse()
|
"flashinfer" => 1,
|
||||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
"flashdecoding" => 256,
|
||||||
} else {
|
"paged" => 16,
|
||||||
Attention::Paged
|
_ => unreachable!(),
|
||||||
};
|
|
||||||
let block_size = if attention == Attention::FlashDecoding {
|
|
||||||
256
|
|
||||||
} else {
|
|
||||||
16
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
|
||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
|
use crate::client::{
|
||||||
|
Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient,
|
||||||
|
};
|
||||||
use crate::queue::{Entry, Queue};
|
use crate::queue::{Entry, Queue};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify};
|
use tokio::sync::{mpsc, Notify};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
@ -31,27 +33,22 @@ impl BackendV3 {
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
requires_padding: bool,
|
shard_info: InfoResponse,
|
||||||
window_size: Option<u32>,
|
|
||||||
speculate: u32,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let prefix_caching =
|
if shard_info.support_chunking {
|
||||||
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
|
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
|
||||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
}
|
||||||
let attention: String = std::env::var("ATTENTION").expect("attention env var");
|
|
||||||
|
|
||||||
let attention: Attention = attention
|
let block_size = shard_info.block_size;
|
||||||
.parse()
|
|
||||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
|
||||||
let block_size = attention.block_size();
|
|
||||||
|
|
||||||
let queue = Queue::new(
|
let queue = Queue::new(
|
||||||
requires_padding,
|
shard_info.requires_padding,
|
||||||
block_size,
|
block_size,
|
||||||
prefix_caching,
|
shard_info.use_prefix_caching,
|
||||||
window_size,
|
shard_info.window_size,
|
||||||
speculate,
|
shard_info.speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
shard_info.support_chunking,
|
||||||
);
|
);
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
|
@ -63,6 +60,7 @@ impl BackendV3 {
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
shard_info.support_chunking,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
batching_task_notifier.clone(),
|
batching_task_notifier.clone(),
|
||||||
));
|
));
|
||||||
|
@ -127,6 +125,7 @@ pub(crate) async fn batching_task(
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
|
support_chunking: bool,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
notifier: Arc<Notify>,
|
notifier: Arc<Notify>,
|
||||||
) {
|
) {
|
||||||
|
@ -147,7 +146,7 @@ pub(crate) async fn batching_task(
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
let mut waiting_tokens = 1;
|
let mut waiting_tokens = 1;
|
||||||
|
@ -158,10 +157,24 @@ pub(crate) async fn batching_task(
|
||||||
// Get current batch info
|
// Get current batch info
|
||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
let batch_max_tokens = batch.max_tokens;
|
let batch_max_tokens = batch.max_tokens;
|
||||||
|
let current_tokens = batch.current_tokens;
|
||||||
let mut batches = vec![batch];
|
let mut batches = vec![batch];
|
||||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||||
|
|
||||||
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
|
||||||
|
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
||||||
|
// Since the next batch will be concatenated with the current batch,
|
||||||
|
// the current batch tokens must be subtracted to the prefill budget
|
||||||
|
let prefill_token_budget =
|
||||||
|
max_batch_prefill_tokens.saturating_sub(current_tokens);
|
||||||
|
// We can ignore min_size and max_size
|
||||||
|
// Models than rely on max_size cannot support chunking
|
||||||
|
// Regarding min_size, chunking allow us to consistently run at the compute
|
||||||
|
// bound, making min_size useless.
|
||||||
|
(None, None, prefill_token_budget)
|
||||||
|
} else {
|
||||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
// to add a new batch even though its size might be small
|
// to add a new batch even though its size might be small
|
||||||
|
@ -173,24 +186,34 @@ pub(crate) async fn batching_task(
|
||||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
|
||||||
let max_size =
|
let max_size =
|
||||||
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||||
|
|
||||||
|
(min_size, max_size, max_batch_prefill_tokens)
|
||||||
|
};
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
if let Some((new_entries, new_batch, span)) = queue
|
||||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
// Tracking metrics
|
// Tracking metrics
|
||||||
if min_size.is_some() {
|
if min_size.is_some() {
|
||||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||||
.increment(1);
|
.increment(1);
|
||||||
|
} else {
|
||||||
|
let counter = if support_chunking {
|
||||||
|
metrics::counter!("tgi_batch_concat", "reason" => "chunking")
|
||||||
} else {
|
} else {
|
||||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||||
.increment(1);
|
};
|
||||||
|
counter.increment(1);
|
||||||
}
|
}
|
||||||
|
let cached_batch = if support_chunking {
|
||||||
|
// Concat current batch to the new one
|
||||||
|
batches.pop()
|
||||||
|
} else {
|
||||||
|
// Request are waiting only if we don't support chunking
|
||||||
entries.iter_mut().for_each(|(_, entry)| {
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
// Create a new span to add the info that this entry is waiting
|
// Create a new span to add the info that this entry is waiting
|
||||||
// because a new batch is being computed
|
// because a new batch is being computed
|
||||||
|
@ -201,17 +224,23 @@ pub(crate) async fn batching_task(
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_waiting_span);
|
entry.temp_span = Some(entry_waiting_span);
|
||||||
});
|
});
|
||||||
|
None
|
||||||
|
};
|
||||||
|
entries.extend(new_entries);
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
let new_cached_batch =
|
||||||
|
prefill(&mut client, new_batch, cached_batch, &mut entries)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
// Reset waiting counter
|
// Reset waiting counter
|
||||||
waiting_tokens = 1;
|
waiting_tokens = 1;
|
||||||
// Extend current batch with the new batch
|
// Extend current batch with the new batch
|
||||||
if let Some(new_cached_batch) = new_cached_batch {
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
entries.extend(new_entries);
|
|
||||||
batches.push(new_cached_batch);
|
batches.push(new_cached_batch);
|
||||||
|
} else if support_chunking {
|
||||||
|
// New cached batch is empty, no work left
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -244,13 +273,14 @@ pub(crate) async fn batching_task(
|
||||||
async fn prefill(
|
async fn prefill(
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch, cached_batch).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
let start_filtering_time = Instant::now();
|
let start_filtering_time = Instant::now();
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
|
@ -259,6 +289,10 @@ async fn prefill(
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
if let Some(concat_duration) = timings.concat {
|
||||||
|
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||||
|
.record(concat_duration.as_secs_f64());
|
||||||
|
}
|
||||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
||||||
.record(timings.forward.as_secs_f64());
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||||
|
|
|
@ -158,7 +158,8 @@ impl Client {
|
||||||
// Blocks and slots will be set on the server side if we use paged attention
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
|
chunk_len: None,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
|
@ -217,13 +218,23 @@ impl Client {
|
||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
let request = tonic::Request::new(PrefillRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
cached_batch,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
let response = self.stub.prefill(request).await?.into_inner();
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
Ok((
|
Ok((
|
||||||
response.generations,
|
response.generations,
|
||||||
response.batch,
|
response.batch,
|
||||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
PrefillTimings::new(
|
||||||
|
response.concat_ns,
|
||||||
|
response.forward_ns,
|
||||||
|
response.decode_ns,
|
||||||
|
response.total_ns,
|
||||||
|
),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -252,14 +263,16 @@ impl Client {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct PrefillTimings {
|
pub struct PrefillTimings {
|
||||||
|
pub concat: Option<Duration>,
|
||||||
pub forward: Duration,
|
pub forward: Duration,
|
||||||
pub decode: Duration,
|
pub decode: Duration,
|
||||||
pub total: Duration,
|
pub total: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PrefillTimings {
|
impl PrefillTimings {
|
||||||
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
concat: concat_ns.map(Duration::from_nanos),
|
||||||
forward: Duration::from_nanos(forward_ns),
|
forward: Duration::from_nanos(forward_ns),
|
||||||
decode: Duration::from_nanos(decode_ns),
|
decode: Duration::from_nanos(decode_ns),
|
||||||
total: Duration::from_nanos(total_ns),
|
total: Duration::from_nanos(total_ns),
|
||||||
|
|
|
@ -29,15 +29,6 @@ pub trait Health {
|
||||||
async fn model_health(&self) -> Result<()>;
|
async fn model_health(&self) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ShardInfo {
|
|
||||||
pub requires_padding: bool,
|
|
||||||
pub dtype: String,
|
|
||||||
pub device_type: String,
|
|
||||||
pub window_size: Option<u32>,
|
|
||||||
pub speculate: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Error, Debug, Clone)]
|
#[derive(Error, Debug, Clone)]
|
||||||
pub enum ClientError {
|
pub enum ClientError {
|
||||||
#[error("Could not connect to Text Generation server: {0}")]
|
#[error("Could not connect to Text Generation server: {0}")]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::client::{ClientError, Result};
|
use crate::client::Health;
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::client::{Health, ShardInfo};
|
use crate::client::{ClientError, Result};
|
||||||
|
|
||||||
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||||
use crate::client::{
|
use crate::client::{
|
||||||
|
@ -49,13 +49,13 @@ impl ShardedClient {
|
||||||
|
|
||||||
/// Get the model info
|
/// Get the model info
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn info(&mut self) -> Result<ShardInfo> {
|
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| client.info())
|
.map(|client| client.info())
|
||||||
.collect();
|
.collect();
|
||||||
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
join_all(futures).await.pop().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GRPC health check
|
/// GRPC health check
|
||||||
|
@ -135,11 +135,12 @@ impl ShardedClient {
|
||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||||
|
@ -194,18 +195,6 @@ impl ShardedClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<InfoResponse> for ShardInfo {
|
|
||||||
fn from(value: InfoResponse) -> Self {
|
|
||||||
Self {
|
|
||||||
requires_padding: value.requires_padding,
|
|
||||||
dtype: value.dtype,
|
|
||||||
device_type: value.device_type,
|
|
||||||
window_size: value.window_size,
|
|
||||||
speculate: value.speculate,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Health for ShardedClient {
|
impl Health for ShardedClient {
|
||||||
async fn device_health(&self) -> Result<()> {
|
async fn device_health(&self) -> Result<()> {
|
||||||
|
@ -246,8 +235,9 @@ impl Health for ShardedClient {
|
||||||
// Block 0 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
|
chunk_len: None,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: u64::MAX,
|
id: u64::MAX,
|
||||||
|
@ -256,7 +246,7 @@ impl Health for ShardedClient {
|
||||||
max_tokens: 2,
|
max_tokens: 2,
|
||||||
max_blocks: 1,
|
max_blocks: 1,
|
||||||
};
|
};
|
||||||
self.clone().prefill(batch).await?;
|
self.clone().prefill(batch, None).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,14 @@ pub struct BackendInfo {
|
||||||
pub max_waiting_tokens: usize,
|
pub max_waiting_tokens: usize,
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
pub max_batch_size: Option<usize>,
|
pub max_batch_size: Option<usize>,
|
||||||
|
#[schema(example = "false")]
|
||||||
|
pub support_chunking: bool,
|
||||||
|
#[schema(example = "false")]
|
||||||
|
pub prefix_caching: bool,
|
||||||
|
#[schema(example = "flashinfer")]
|
||||||
|
pub attention_impl: String,
|
||||||
|
#[schema(example = "1")]
|
||||||
|
pub block_size: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
@ -110,6 +118,10 @@ pub async fn connect_backend(
|
||||||
model_device_type: shard_info.device_type.clone(),
|
model_device_type: shard_info.device_type.clone(),
|
||||||
model_dtype: shard_info.dtype.clone(),
|
model_dtype: shard_info.dtype.clone(),
|
||||||
speculate: shard_info.speculate as usize,
|
speculate: shard_info.speculate as usize,
|
||||||
|
support_chunking: shard_info.support_chunking,
|
||||||
|
prefix_caching: shard_info.use_prefix_caching,
|
||||||
|
attention_impl: shard_info.attention_impl.clone(),
|
||||||
|
block_size: shard_info.block_size,
|
||||||
};
|
};
|
||||||
|
|
||||||
let backend = BackendV3::new(
|
let backend = BackendV3::new(
|
||||||
|
@ -119,9 +131,7 @@ pub async fn connect_backend(
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
shard_info.requires_padding,
|
shard_info,
|
||||||
shard_info.window_size,
|
|
||||||
shard_info.speculate,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
tracing::info!("Using backend V3");
|
tracing::info!("Using backend V3");
|
||||||
|
|
|
@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> {
|
||||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
|
||||||
}
|
|
||||||
|
|
||||||
if validation_workers == 0 {
|
if validation_workers == 0 {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
"`validation_workers` must be > 0".to_string(),
|
"`validation_workers` must be > 0".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
|
||||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(max_batch_size) = max_batch_size {
|
if let Some(max_batch_size) = max_batch_size {
|
||||||
if max_batch_size == 0 {
|
if max_batch_size == 0 {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (backend, _backend_info) = connect_backend(
|
let (backend, backend_info) = connect_backend(
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
|
@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Validate remaining args now that the backend is known
|
||||||
|
let support_chunking = backend_info.support_chunking;
|
||||||
|
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
|
||||||
|
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||||
|
}
|
||||||
|
if max_batch_prefill_tokens > max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
backend,
|
backend,
|
||||||
|
|
|
@ -4,7 +4,7 @@ use crate::client::{
|
||||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::{max, min};
|
use std::cmp::max;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_router::infer::InferError;
|
use text_generation_router::infer::InferError;
|
||||||
use text_generation_router::infer::InferStreamResponse;
|
use text_generation_router::infer::InferStreamResponse;
|
||||||
|
@ -50,6 +50,7 @@ impl Queue {
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
support_chunking: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
|
@ -62,6 +63,7 @@ impl Queue {
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
support_chunking,
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -87,6 +89,10 @@ impl Queue {
|
||||||
prefill_token_budget: u32,
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
) -> Option<NextBatch> {
|
) -> Option<NextBatch> {
|
||||||
|
if prefill_token_budget == 0 || token_budget == 0 {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
// Create response channel
|
// Create response channel
|
||||||
let (response_sender, response_receiver) = oneshot::channel();
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
// Send next batch command to the background task managing the state
|
// Send next batch command to the background task managing the state
|
||||||
|
@ -108,6 +114,7 @@ impl Queue {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Background task responsible of the queue state
|
// Background task responsible of the queue state
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn queue_task(
|
async fn queue_task(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
@ -115,6 +122,7 @@ async fn queue_task(
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
support_chunking: bool,
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(
|
let mut state = State::new(
|
||||||
|
@ -124,6 +132,7 @@ async fn queue_task(
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
support_chunking,
|
||||||
);
|
);
|
||||||
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
|
@ -166,12 +175,14 @@ struct State {
|
||||||
/// Paged Attention block size
|
/// Paged Attention block size
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
|
||||||
/// Sliding window
|
|
||||||
window_size: Option<u32>,
|
|
||||||
|
|
||||||
/// Speculation amount
|
/// Speculation amount
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
|
||||||
|
/// Whether the model allow the prefill chunking
|
||||||
|
/// If it does, the last request in the batch will be split to exactly match the prefill
|
||||||
|
/// token budget
|
||||||
|
support_chunking: bool,
|
||||||
|
|
||||||
/// Paged Attention Block Allocation
|
/// Paged Attention Block Allocation
|
||||||
block_allocator: Option<BlockAllocator>,
|
block_allocator: Option<BlockAllocator>,
|
||||||
}
|
}
|
||||||
|
@ -184,6 +195,7 @@ impl State {
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
support_chunking: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let block_allocator = (!requires_padding).then(|| {
|
let block_allocator = (!requires_padding).then(|| {
|
||||||
BlockAllocator::new(
|
BlockAllocator::new(
|
||||||
|
@ -199,8 +211,8 @@ impl State {
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
|
||||||
speculate,
|
speculate,
|
||||||
|
support_chunking,
|
||||||
block_allocator,
|
block_allocator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -287,32 +299,7 @@ impl State {
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
Some(_block_allocator) => {
|
Some(block_allocator) => {
|
||||||
prefill_tokens += entry.request.input_length;
|
|
||||||
let max_new_tokens = match self.window_size {
|
|
||||||
None => entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
Some(window_size) => min(
|
|
||||||
window_size.saturating_sub(entry.request.input_length),
|
|
||||||
entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
decode_tokens += max_new_tokens;
|
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
|
||||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
|
||||||
{
|
|
||||||
// Entry is over budget
|
|
||||||
// Add it back to the front
|
|
||||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let tokens = entry.request.input_length
|
|
||||||
+ entry.request.stopping_parameters.max_new_tokens
|
|
||||||
+ self.speculate
|
|
||||||
- 1;
|
|
||||||
|
|
||||||
// If users wants the prefill logprobs, we cannot reuse the cache.
|
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||||
// So no input_ids for the radix tree.
|
// So no input_ids for the radix tree.
|
||||||
let input_ids = if entry.request.decoder_input_details {
|
let input_ids = if entry.request.decoder_input_details {
|
||||||
|
@ -321,10 +308,73 @@ impl State {
|
||||||
entry.request.input_ids.clone()
|
entry.request.input_ids.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
Some((tokens, input_ids))
|
let tokens = entry.request.input_length
|
||||||
|
+ entry.request.stopping_parameters.max_new_tokens
|
||||||
|
+ self.speculate
|
||||||
|
- 1;
|
||||||
|
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||||
|
|
||||||
|
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
|
||||||
|
None => {
|
||||||
|
// Entry is over budget
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!("Over budget: not enough free blocks");
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break 'entry_loop;
|
||||||
|
}
|
||||||
|
Some(mut block_allocation) => {
|
||||||
|
tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
|
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||||
|
|
||||||
|
if block_allocation.prefix_len == entry.request.input_length {
|
||||||
|
// The whole request was found in the radix trie
|
||||||
|
// However, for the transformer forward to work, we need to
|
||||||
|
// have at least one token of postfix.
|
||||||
|
block_allocation.prefix_len -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
block_allocation
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
batch.push((id, entry, block_allocation));
|
|
||||||
|
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
|
||||||
|
|
||||||
|
if prefill_tokens + postfix_len > prefill_token_budget {
|
||||||
|
// Entry is over budget
|
||||||
|
if self.support_chunking {
|
||||||
|
// We support chunking, just set postfix_len to exactly match prefill_token_budget
|
||||||
|
let chunk_len = prefill_token_budget.saturating_sub(prefill_tokens);
|
||||||
|
if chunk_len > 0 {
|
||||||
|
// Push this entry inside the batch
|
||||||
|
batch.push((id, entry, Some(block_allocation), Some(chunk_len)));
|
||||||
|
} else {
|
||||||
|
// We cannot prefill even one token for this entry
|
||||||
|
// Add it back to the queue
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
}
|
||||||
|
tracing::debug!(
|
||||||
|
"Matched budget: prefill_tokens={} == {prefill_token_budget}",
|
||||||
|
prefill_tokens + postfix_len
|
||||||
|
);
|
||||||
|
break 'entry_loop;
|
||||||
|
} else {
|
||||||
|
// We don't support chunking, this entry needs to go back to the buffer
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!(
|
||||||
|
"Over budget: prefill_tokens={} > {prefill_token_budget}",
|
||||||
|
prefill_tokens + postfix_len
|
||||||
|
);
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break 'entry_loop;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prefill_tokens += postfix_len;
|
||||||
|
|
||||||
|
Some(block_allocation)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
batch.push((id, entry, block_allocation, None));
|
||||||
if Some(batch.len()) == max_size {
|
if Some(batch.len()) == max_size {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -342,7 +392,7 @@ impl State {
|
||||||
// Batch is too small
|
// Batch is too small
|
||||||
if batch.len() < min_size {
|
if batch.len() < min_size {
|
||||||
// Add back entries to the queue in the correct order
|
// Add back entries to the queue in the correct order
|
||||||
for (id, entry, _) in batch.into_iter().rev() {
|
for (id, entry, _, _) in batch.into_iter().rev() {
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
}
|
}
|
||||||
return None;
|
return None;
|
||||||
|
@ -353,29 +403,7 @@ impl State {
|
||||||
let mut batch_entries =
|
let mut batch_entries =
|
||||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||||
|
|
||||||
for (id, mut entry, block_allocation) in batch {
|
for (id, mut entry, block_allocation, chunk_len) in batch {
|
||||||
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
|
||||||
(block_allocation, &self.block_allocator)
|
|
||||||
{
|
|
||||||
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
|
||||||
match block_allocator.allocate(tokens, input_ids).await {
|
|
||||||
None => {
|
|
||||||
// Entry is over budget
|
|
||||||
// Add it back to the front
|
|
||||||
tracing::debug!("Over budget: not enough free blocks");
|
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
Some(block_allocation) => {
|
|
||||||
tracing::debug!("Allocation: {block_allocation:?}");
|
|
||||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
|
||||||
Some(block_allocation)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
tracing::debug!("Accepting entry");
|
|
||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
// Add relationships
|
// Add relationships
|
||||||
|
@ -427,8 +455,9 @@ impl State {
|
||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
blocks,
|
blocks,
|
||||||
slots,
|
slots,
|
||||||
prefix_len,
|
cache_len: prefix_len,
|
||||||
adapter_id: entry.request.adapter_id.clone(),
|
adapter_id: entry.request.adapter_id.clone(),
|
||||||
|
chunk_len,
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
|
@ -436,12 +465,6 @@ impl State {
|
||||||
batch_entries.insert(id, entry);
|
batch_entries.insert(id, entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Empty batch
|
|
||||||
if batch_requests.is_empty() {
|
|
||||||
tracing::debug!("Filterered out all entries");
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Final batch size
|
// Final batch size
|
||||||
let size = batch_requests.len() as u32;
|
let size = batch_requests.len() as u32;
|
||||||
next_batch_span.record("batch_size", size);
|
next_batch_span.record("batch_size", size);
|
||||||
|
@ -531,7 +554,7 @@ mod tests {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: vec![],
|
inputs: vec![],
|
||||||
input_ids: Some(Arc::new(vec![])),
|
input_ids: Some(Arc::new(vec![])),
|
||||||
input_length: 0,
|
input_length: 1,
|
||||||
add_special_tokens: true,
|
add_special_tokens: true,
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
|
@ -567,7 +590,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_append() {
|
async fn test_append() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
assert_eq!(state.next_id, 0);
|
||||||
|
@ -583,7 +606,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_empty() {
|
async fn test_next_batch_empty() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
|
|
||||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
|
@ -591,7 +614,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_min_size() {
|
async fn test_next_batch_min_size() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
|
@ -623,7 +646,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_max_size() {
|
async fn test_next_batch_max_size() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
|
@ -643,7 +666,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_token_budget() {
|
async fn test_next_batch_token_budget() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 2);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
|
@ -676,14 +699,14 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
|
@ -691,7 +714,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -724,7 +747,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_max_size() {
|
async fn test_queue_next_batch_max_size() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -740,7 +763,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -765,7 +788,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_speculate() {
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
let queue = Queue::new(false, 1, false, None, 2, 16);
|
let queue = Queue::new(true, 1, false, None, 2, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -784,7 +807,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
|
|
|
@ -158,7 +158,8 @@ async fn prefill(
|
||||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
|
chunk_len: None,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
@ -173,7 +174,7 @@ async fn prefill(
|
||||||
|
|
||||||
// Run prefill
|
// Run prefill
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
|
let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?;
|
||||||
|
|
||||||
// Get latency
|
// Get latency
|
||||||
let latency = start_time.elapsed();
|
let latency = start_time.elapsed();
|
||||||
|
|
|
@ -178,6 +178,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
.clear_cache(None)
|
.clear_cache(None)
|
||||||
.await
|
.await
|
||||||
.expect("Unable to clear cache");
|
.expect("Unable to clear cache");
|
||||||
|
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
// Run app
|
// Run app
|
||||||
|
|
|
@ -9,13 +9,16 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import docker
|
import docker
|
||||||
import pytest
|
import pytest
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
||||||
from docker.errors import NotFound
|
from docker.errors import NotFound
|
||||||
from syrupy.extensions.json import JSONSnapshotExtension
|
from syrupy.extensions.json import JSONSnapshotExtension
|
||||||
|
|
||||||
from text_generation import AsyncClient
|
from text_generation import AsyncClient
|
||||||
from text_generation.types import (
|
from text_generation.types import (
|
||||||
BestOfSequence,
|
BestOfSequence,
|
||||||
|
@ -403,6 +406,7 @@ def launcher(event_loop):
|
||||||
print(" ".join(args), file=sys.stderr)
|
print(" ".join(args), file=sys.stderr)
|
||||||
|
|
||||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||||
|
env["PREFILL_CHUNKING"] = "1"
|
||||||
|
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
env["USE_FLASH_ATTENTION"] = "false"
|
env["USE_FLASH_ATTENTION"] = "false"
|
||||||
|
@ -501,6 +505,7 @@ def launcher(event_loop):
|
||||||
|
|
||||||
env = {
|
env = {
|
||||||
"LOG_LEVEL": "info,text_generation_router=debug",
|
"LOG_LEVEL": "info,text_generation_router=debug",
|
||||||
|
"PREFILL_CHUNKING": "1",
|
||||||
}
|
}
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
env["USE_FLASH_ATTENTION"] = "false"
|
env["USE_FLASH_ATTENTION"] = "false"
|
||||||
|
@ -642,3 +647,22 @@ def generate_multi():
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
return generate_load_inner
|
return generate_load_inner
|
||||||
|
|
||||||
|
|
||||||
|
# TODO fix the server parsser to count inline image tokens correctly
|
||||||
|
@pytest.fixture
|
||||||
|
def chicken():
|
||||||
|
path = Path(__file__).parent / "images" / "chicken_on_money.png"
|
||||||
|
|
||||||
|
with open(path, "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cow_beach():
|
||||||
|
path = Path(__file__).parent / "images" / "cow_beach.png"
|
||||||
|
|
||||||
|
with open(path, "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -20,24 +19,11 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
|
||||||
return flash_pali_gemma_handle.client
|
return flash_pali_gemma_handle.client
|
||||||
|
|
||||||
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach):
|
||||||
cow = get_cow_beach()
|
inputs = f"![]({cow_beach})Where is the cow standing?\n"
|
||||||
inputs = f"![]({cow})Where is the cow standing?\n"
|
|
||||||
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
|
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
|
||||||
|
|
||||||
assert response.generated_text == "beach"
|
assert response.generated_text == "beach"
|
||||||
|
@ -47,9 +33,9 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
|
async def test_flash_pali_gemma_two_images(
|
||||||
chicken = get_chicken()
|
flash_pali_gemma, response_snapshot, chicken, cow_beach
|
||||||
cow_beach = get_cow_beach()
|
):
|
||||||
response = await flash_pali_gemma.generate(
|
response = await flash_pali_gemma.generate(
|
||||||
f"caption![]({chicken})![]({cow_beach})\n",
|
f"caption![]({chicken})![]({cow_beach})\n",
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
|
|
|
@ -25,7 +25,6 @@ async def llama_grammar(llama_grammar_handle):
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
||||||
|
|
||||||
class Weather(BaseModel):
|
class Weather(BaseModel):
|
||||||
unit: str
|
unit: str
|
||||||
temperature: List[int]
|
temperature: List[int]
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -16,22 +15,8 @@ async def idefics(idefics_handle):
|
||||||
return idefics_handle.client
|
return idefics_handle.client
|
||||||
|
|
||||||
|
|
||||||
# TODO fix the server parsser to count inline image tokens correctly
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics(idefics, response_snapshot):
|
async def test_idefics(idefics, response_snapshot, chicken):
|
||||||
chicken = get_chicken()
|
|
||||||
response = await idefics.generate(
|
response = await idefics.generate(
|
||||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
@ -48,9 +33,7 @@ async def test_idefics(idefics, response_snapshot):
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_idefics_two_images(idefics, response_snapshot):
|
async def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach):
|
||||||
chicken = get_chicken()
|
|
||||||
cow_beach = get_cow_beach()
|
|
||||||
response = await idefics.generate(
|
response = await idefics.generate(
|
||||||
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
|
@ -63,8 +46,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
|
||||||
|
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
async def test_idefics_load(idefics, generate_load, response_snapshot, chicken):
|
||||||
chicken = get_chicken()
|
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
idefics,
|
idefics,
|
||||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||||
|
|
|
@ -1,18 +1,4 @@
|
||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
|
|
||||||
|
|
||||||
# TODO fix the server parsser to count inline image tokens correctly
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -31,8 +17,9 @@ async def flash_idefics2_next(flash_idefics2_next_handle):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
|
async def test_flash_idefics2_next_simple(
|
||||||
chicken = get_chicken()
|
flash_idefics2_next, response_snapshot, chicken
|
||||||
|
):
|
||||||
response = await flash_idefics2_next.generate(
|
response = await flash_idefics2_next.generate(
|
||||||
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
|
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
@ -46,9 +33,9 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot):
|
async def test_flash_idefics2_two_images(
|
||||||
chicken = get_chicken()
|
flash_idefics2_next, response_snapshot, chicken, cow_beach
|
||||||
cow_beach = get_cow_beach()
|
):
|
||||||
response = await flash_idefics2_next.generate(
|
response = await flash_idefics2_next.generate(
|
||||||
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
|
@ -87,9 +74,8 @@ async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snap
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_idefics2_next_load(
|
async def test_flash_idefics2_next_load(
|
||||||
flash_idefics2_next, generate_load, response_snapshot
|
flash_idefics2_next, generate_load, response_snapshot, chicken
|
||||||
):
|
):
|
||||||
chicken = get_chicken()
|
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_idefics2_next,
|
flash_idefics2_next,
|
||||||
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
|
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
|
||||||
|
|
|
@ -1,12 +1,4 @@
|
||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
|
|
||||||
|
|
||||||
# TODO fix the server parsser to count inline image tokens correctly
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -29,8 +21,7 @@ async def flash_llava_next(flash_llava_next_handle):
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot, chicken):
|
||||||
chicken = get_chicken()
|
|
||||||
response = await flash_llava_next.generate(
|
response = await flash_llava_next.generate(
|
||||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
@ -70,9 +61,8 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llava_next_load(
|
async def test_flash_llava_next_load(
|
||||||
flash_llava_next, generate_load, response_snapshot
|
flash_llava_next, generate_load, response_snapshot, chicken
|
||||||
):
|
):
|
||||||
chicken = get_chicken()
|
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_llava_next,
|
flash_llava_next,
|
||||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,22 +14,8 @@ async def mllama(mllama_handle):
|
||||||
return mllama_handle.client
|
return mllama_handle.client
|
||||||
|
|
||||||
|
|
||||||
# TODO fix the server parsser to count inline image tokens correctly
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mllama_simpl(mllama, response_snapshot):
|
async def test_mllama_simpl(mllama, response_snapshot):
|
||||||
# chicken = get_chicken()
|
|
||||||
response = await mllama.chat(
|
response = await mllama.chat(
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
|
|
|
@ -68,7 +68,7 @@ fn get_config(
|
||||||
|
|
||||||
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||||
let compute_capability = gpu::get_cuda_capability();
|
let compute_capability = gpu::get_cuda_capability();
|
||||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
|
||||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||||
if let Some(config) = config {
|
if let Some(config) = config {
|
||||||
if prefix_caching.is_none() {
|
if prefix_caching.is_none() {
|
||||||
|
@ -124,6 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if attention == Some("paged".to_string()) && prefix_caching.is_none() {
|
||||||
|
tracing::info!("Disabling prefix caching on paged attention");
|
||||||
|
prefix_caching = Some("0".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
let attention = attention.unwrap_or("flashinfer".to_string());
|
let attention = attention.unwrap_or("flashinfer".to_string());
|
||||||
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
||||||
|
@ -1678,7 +1682,7 @@ fn main() -> Result<(), LauncherError> {
|
||||||
};
|
};
|
||||||
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
||||||
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
||||||
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
|
std::env::set_var("PREFIX_CACHING", prefix_caching);
|
||||||
std::env::set_var("ATTENTION", attention);
|
std::env::set_var("ATTENTION", attention);
|
||||||
|
|
||||||
let max_input_tokens = {
|
let max_input_tokens = {
|
||||||
|
@ -1729,12 +1733,6 @@ fn main() -> Result<(), LauncherError> {
|
||||||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
|
||||||
max_batch_prefill_tokens, max_input_tokens
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
||||||
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
||||||
|
@ -1788,12 +1786,6 @@ fn main() -> Result<(), LauncherError> {
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
|
||||||
max_batch_prefill_tokens, max_batch_total_tokens
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||||
|
|
|
@ -34,6 +34,10 @@ message InfoResponse {
|
||||||
string device_type = 3;
|
string device_type = 3;
|
||||||
optional uint32 window_size = 4;
|
optional uint32 window_size = 4;
|
||||||
uint32 speculate = 5;
|
uint32 speculate = 5;
|
||||||
|
bool support_chunking = 6;
|
||||||
|
bool use_prefix_caching = 7;
|
||||||
|
string attention_impl = 8;
|
||||||
|
uint32 block_size = 9;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
|
@ -135,10 +139,14 @@ message Request {
|
||||||
repeated uint32 slots = 10;
|
repeated uint32 slots = 10;
|
||||||
/// LORA adapter index
|
/// LORA adapter index
|
||||||
optional string adapter_id = 11;
|
optional string adapter_id = 11;
|
||||||
/// Prefix length that can be retrieved from the KV cache.
|
/// Tokens that can be retrieved from the KV cache.
|
||||||
uint32 prefix_len = 12;
|
/// This value is set for the first prefill and never reset
|
||||||
|
uint32 cache_len = 12;
|
||||||
/// Context truncation
|
/// Context truncation
|
||||||
bool add_special_tokens = 13;
|
bool add_special_tokens = 13;
|
||||||
|
/// Chunk of tokens that must be computed for the first prefill
|
||||||
|
/// This value is set for the first prefill and never reset
|
||||||
|
optional uint32 chunk_len = 14;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
@ -163,6 +171,8 @@ message CachedBatch {
|
||||||
uint32 size = 3;
|
uint32 size = 3;
|
||||||
/// Maximum number of tokens this batch will grow to
|
/// Maximum number of tokens this batch will grow to
|
||||||
uint32 max_tokens = 4;
|
uint32 max_tokens = 4;
|
||||||
|
/// Number of tokens in the next forward
|
||||||
|
uint32 current_tokens = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum FinishReason {
|
enum FinishReason {
|
||||||
|
@ -220,6 +230,8 @@ message FilterBatchResponse {
|
||||||
message PrefillRequest {
|
message PrefillRequest {
|
||||||
/// Batch
|
/// Batch
|
||||||
Batch batch = 1;
|
Batch batch = 1;
|
||||||
|
/// Optional cached batch
|
||||||
|
CachedBatch cached_batch = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message PrefillResponse {
|
message PrefillResponse {
|
||||||
|
@ -233,6 +245,8 @@ message PrefillResponse {
|
||||||
uint64 decode_ns = 4;
|
uint64 decode_ns = 4;
|
||||||
/// Total elapsed time in nanoseconds
|
/// Total elapsed time in nanoseconds
|
||||||
uint64 total_ns = 5;
|
uint64 total_ns = 5;
|
||||||
|
/// Concatenate elapsed time in nanoseconds
|
||||||
|
optional uint64 concat_ns = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message DecodeRequest {
|
message DecodeRequest {
|
||||||
|
|
|
@ -18,45 +18,6 @@ use tracing::warn;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
#[derive(PartialEq)]
|
|
||||||
pub enum Attention {
|
|
||||||
Paged,
|
|
||||||
FlashDecoding,
|
|
||||||
FlashInfer,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Attention {
|
|
||||||
pub fn block_size(&self) -> u32 {
|
|
||||||
match self {
|
|
||||||
Attention::FlashDecoding => 256,
|
|
||||||
Attention::FlashInfer => 1,
|
|
||||||
Attention::Paged => 16,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ParseError;
|
|
||||||
|
|
||||||
impl std::fmt::Display for ParseError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "Cannot parse attention value")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl std::error::Error for ParseError {}
|
|
||||||
|
|
||||||
impl std::str::FromStr for Attention {
|
|
||||||
type Err = ParseError;
|
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
||||||
match s {
|
|
||||||
"paged" => Ok(Attention::Paged),
|
|
||||||
"flashdecoding" => Ok(Attention::FlashDecoding),
|
|
||||||
"flashinfer" => Ok(Attention::FlashInfer),
|
|
||||||
_ => Err(ParseError),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Hub type
|
/// Hub type
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct HubModelInfo {
|
pub struct HubModelInfo {
|
||||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
||||||
import os
|
import os
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
|
|
||||||
os.environ["USE_PREFIX_CACHING"] = "1"
|
os.environ["PREFIX_CACHING"] = "1"
|
||||||
os.environ["ATTENTION"] = "flashinfer"
|
os.environ["ATTENTION"] = "flashinfer"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,9 @@ from typing import Callable, Any
|
||||||
|
|
||||||
|
|
||||||
class ExceptionInterceptor(AsyncServerInterceptor):
|
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||||
|
def __init__(self, shutdown_callback):
|
||||||
|
self.shutdown_callback = shutdown_callback
|
||||||
|
|
||||||
async def intercept(
|
async def intercept(
|
||||||
self,
|
self,
|
||||||
method: Callable,
|
method: Callable,
|
||||||
|
@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
||||||
|
|
||||||
# Runtime Error cannot be recovered from
|
# Runtime Error cannot be recovered from
|
||||||
if isinstance(err, RuntimeError):
|
if isinstance(err, RuntimeError):
|
||||||
exit(1)
|
self.shutdown_callback()
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
@ -1,16 +1,12 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.models.globals import ATTENTION
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
if ATTENTION in {"flashinfer", "flashdecoding"}:
|
@dataclass
|
||||||
|
class Seqlen:
|
||||||
@dataclass
|
|
||||||
class Seqlen:
|
|
||||||
input_lengths: torch.Tensor
|
input_lengths: torch.Tensor
|
||||||
prefix_lengths: torch.Tensor
|
cache_lengths: torch.Tensor
|
||||||
cu_seqlen_q: Optional[torch.Tensor]
|
cu_seqlen_q: Optional[torch.Tensor]
|
||||||
cu_seqlen_k: Optional[torch.Tensor]
|
cu_seqlen_k: Optional[torch.Tensor]
|
||||||
max_q: int
|
max_q: int
|
||||||
|
@ -19,13 +15,13 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
prefix_lengths,
|
cache_lengths,
|
||||||
cu_seqlen_q=None,
|
cu_seqlen_q=None,
|
||||||
max_q=None,
|
max_q=None,
|
||||||
max_k=None,
|
max_k=None,
|
||||||
):
|
):
|
||||||
self.input_lengths = input_lengths
|
self.input_lengths = input_lengths
|
||||||
self.prefix_lengths = prefix_lengths
|
self.cache_lengths = cache_lengths
|
||||||
device = self.input_lengths.device
|
device = self.input_lengths.device
|
||||||
shape = self.input_lengths.shape
|
shape = self.input_lengths.shape
|
||||||
if cu_seqlen_q is None:
|
if cu_seqlen_q is None:
|
||||||
|
@ -43,7 +39,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||||
# Although FA2 might not want the clamping
|
# Although FA2 might not want the clamping
|
||||||
# cu_seqlen_k[0] = 0
|
# cu_seqlen_k[0] = 0
|
||||||
total = self.input_lengths + self.prefix_lengths
|
total = self.input_lengths + self.cache_lengths
|
||||||
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||||
|
|
||||||
self.cu_seqlen_q = cu_seqlen_q
|
self.cu_seqlen_q = cu_seqlen_q
|
||||||
|
@ -54,19 +50,3 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
# Flash decoding doesn't need to clamp
|
# Flash decoding doesn't need to clamp
|
||||||
return self
|
return self
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Seqlen:
|
|
||||||
input_lengths: torch.Tensor
|
|
||||||
prefix_lengths: torch.Tensor
|
|
||||||
cu_seqlen_q: torch.Tensor
|
|
||||||
max_q: int
|
|
||||||
max_k: int
|
|
||||||
|
|
||||||
def clamp(self, max):
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
return self
|
|
||||||
self.input_lengths = torch.clamp(self.input_lengths, max=max)
|
|
||||||
return self
|
|
||||||
|
|
|
@ -123,7 +123,7 @@ def paged_attention(
|
||||||
else:
|
else:
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
input_lengths = seqlen.input_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
@ -244,7 +244,8 @@ if ATTENTION == "flashinfer":
|
||||||
window_left=window_size_left,
|
window_left=window_size_left,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif V2:
|
elif ATTENTION == "flashdecoding":
|
||||||
|
if V2:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
|
@ -284,7 +285,7 @@ elif V2:
|
||||||
None,
|
None,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
|
@ -302,7 +303,9 @@ else:
|
||||||
"window_size_left is only available with flash attn v2"
|
"window_size_left is only available with flash attn v2"
|
||||||
)
|
)
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise NotImplementedError("softcap is only available with flash attn v2")
|
raise NotImplementedError(
|
||||||
|
"softcap is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
|
||||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
if k.shape[1] != q.shape[1]:
|
if k.shape[1] != q.shape[1]:
|
||||||
|
@ -350,11 +353,123 @@ else:
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
elif ATTENTION == "paged":
|
||||||
|
if V2:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
|
softcap=0.0,
|
||||||
|
):
|
||||||
|
out = torch.empty_like(q)
|
||||||
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
q,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
out,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.cu_seqlen_k,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None, # block_tables,
|
||||||
|
None,
|
||||||
|
seqlen.max_q,
|
||||||
|
seqlen.max_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
window_size_left,
|
||||||
|
0,
|
||||||
|
softcap,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
window_size_left: int = -1,
|
||||||
|
causal: bool = True,
|
||||||
|
softcap=None,
|
||||||
|
):
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"window_size_left is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
if softcap is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"softcap is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
|
if k.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if k.shape[1] == 1:
|
||||||
|
k = k.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = k.shape
|
||||||
|
k = (
|
||||||
|
k.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
if v.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if v.shape[1] == 1:
|
||||||
|
v = v.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = v.shape
|
||||||
|
v = (
|
||||||
|
v.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
out = torch.empty_like(q)
|
||||||
|
flash_attn_cuda.fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.max_q,
|
||||||
|
seqlen.max_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknwon attention {ATTENTION}")
|
||||||
|
|
||||||
|
|
||||||
# Prefill in the cache with every kind of attention, unless we
|
# Prefill in the cache with every kind of attention, unless we
|
||||||
# have a configuration that requires flash-attention v1, which
|
# have a configuration that requires flash-attention v1, which
|
||||||
# does not support block tables.
|
# does not support block tables.
|
||||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PREFILL_IN_KV_CACHE",
|
"PREFILL_IN_KV_CACHE",
|
||||||
|
|
|
@ -699,7 +699,6 @@ def check_args(
|
||||||
|
|
||||||
|
|
||||||
class _attention(torch.autograd.Function):
|
class _attention(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(
|
||||||
ctx,
|
ctx,
|
||||||
|
|
|
@ -66,6 +66,7 @@ def paged_attention(
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
|
@ -74,7 +75,7 @@ def paged_attention(
|
||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
seqlen.input_lengths,
|
input_lengths,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
|
|
|
@ -104,7 +104,7 @@ def paged_attention(
|
||||||
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
||||||
|
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
input_lengths = seqlen.input_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
|
|
|
@ -76,6 +76,7 @@ class CausalLMBatch(Batch):
|
||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
current_tokens=len(self.input_ids),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -493,9 +493,14 @@ class MllamaVisionModel(nn.Module):
|
||||||
aspect_ratio_ids: torch.Tensor,
|
aspect_ratio_ids: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
|
(
|
||||||
pixel_values.shape
|
batch_size,
|
||||||
)
|
num_concurrent_media,
|
||||||
|
num_tiles,
|
||||||
|
num_channels,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
) = pixel_values.shape
|
||||||
|
|
||||||
pixel_values = pixel_values.reshape(
|
pixel_values = pixel_values.reshape(
|
||||||
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
|
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -5,9 +5,14 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
|
ATTENTION = os.environ["ATTENTION"]
|
||||||
|
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||||
|
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
}
|
||||||
|
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
|
||||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
ATTENTION = os.getenv("ATTENTION")
|
|
||||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||||
assert (
|
assert (
|
||||||
ATTENTION in _expected
|
ATTENTION in _expected
|
||||||
|
@ -18,7 +23,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
||||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
|
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
|
||||||
assert TGI_WIGGLE_ROOM > 0
|
assert TGI_WIGGLE_ROOM > 0
|
||||||
assert TGI_WIGGLE_ROOM < 1
|
assert TGI_WIGGLE_ROOM < 1
|
||||||
|
|
||||||
|
|
|
@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
current_tokens=len(self),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -116,6 +116,7 @@ class MambaBatch(Batch):
|
||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
current_tokens=len(self),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -1,14 +1,17 @@
|
||||||
from io import BytesIO
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from typing import Iterable, Optional, Tuple, List, Dict
|
from typing import Iterable, Optional, Tuple, List, Dict
|
||||||
from text_generation_server.pb.generate_pb2 import Request
|
from text_generation_server.pb.generate_pb2 import Request
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import (
|
from transformers import (
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.flash_causal_lm import (
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
|
@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||||
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
|
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
|
||||||
max=config.text_config.vocab_size - 1
|
max=config.text_config.vocab_size - 1
|
||||||
)
|
)
|
||||||
|
if isinstance(batch.input_ids, list):
|
||||||
|
if len(batch) > 1:
|
||||||
|
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
|
||||||
|
else:
|
||||||
|
input_ids = batch.input_ids[0]
|
||||||
|
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
||||||
|
|
||||||
if image_inputs is not None:
|
if image_inputs is not None:
|
||||||
|
@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||||
class MllamaCausalLM(VlmCausalLM):
|
class MllamaCausalLM(VlmCausalLM):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: VlmCausalLMBatch,
|
batch: MllamaCausalLMBatch,
|
||||||
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
|
@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM):
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
speculative_ids = batch.speculative_ids
|
speculative_ids = batch.speculative_ids
|
||||||
|
@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM):
|
||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
prefix_lens_tensor = (
|
cache_lengths_tensor = (
|
||||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
).reshape(-1)
|
).reshape(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
|
@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM):
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
prefix_lens_tensor = batch.prefix_lens_tensor
|
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
|
@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM):
|
||||||
# This makes sure the max_s for the decode pass is correct.
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
max_s = min(self.max_past(), max_s)
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
|
@ -269,26 +278,24 @@ class MllamaCausalLM(VlmCausalLM):
|
||||||
# Only run cuda graphs when there's no images.
|
# Only run cuda graphs when there's no images.
|
||||||
or batch.cross_attention_states is not None
|
or batch.cross_attention_states is not None
|
||||||
):
|
):
|
||||||
input_lengths = input_lengths + prefix_lens_tensor
|
|
||||||
if PREFIX_CACHING:
|
if PREFIX_CACHING:
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
prefix_lens=batch.prefix_lens,
|
cache_lengths=batch.cache_lengths,
|
||||||
)
|
)
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths_tensor=input_lengths,
|
input_lengths_tensor=input_lengths,
|
||||||
prefix_lens_tensor=prefix_lens_tensor,
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
):
|
):
|
||||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
prefix_lengths=prefix_lens_tensor,
|
cache_lengths=cache_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
max_q=max_s,
|
max_q=batch.max_input_length,
|
||||||
max_k=max_k,
|
max_k=batch.max_current_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
|
@ -330,20 +337,32 @@ class MllamaCausalLM(VlmCausalLM):
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
prefix_lens=batch.prefix_lens,
|
cache_lengths=batch.cache_lengths,
|
||||||
)
|
)
|
||||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
else:
|
else:
|
||||||
cuda_graph["block_tables"][
|
cuda_graph["block_tables"][
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
] = block_tables
|
] = block_tables
|
||||||
|
|
||||||
|
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||||
|
# so it doesn't matter if we override it with bogus values.
|
||||||
cuda_graph["slots"].fill_(0)
|
cuda_graph["slots"].fill_(0)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
input_lengths + prefix_lens_tensor
|
cuda_graph["cache_lengths"].zero_()
|
||||||
)
|
cuda_graph["cache_lengths"][
|
||||||
|
: cache_lengths_tensor.shape[0]
|
||||||
|
] = cache_lengths_tensor
|
||||||
|
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=cuda_graph["block_tables"],
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||||
|
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
||||||
|
state=cuda_graph["state"],
|
||||||
|
):
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,17 @@ from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict
|
from typing import List, Tuple, Optional, TypeVar, Type, Dict
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import (
|
||||||
|
ATTENTION,
|
||||||
|
PREFIX_CACHING,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
PREFILL_CHUNKING,
|
||||||
|
)
|
||||||
from text_generation_server.models.types import Batch, Generation
|
from text_generation_server.models.types import Batch, Generation
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
from text_generation_server.utils.prefill_chunking import set_support_chunking
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
from text_generation_server.adapters.weights import LayerAdapterWeights
|
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||||
|
@ -31,6 +40,7 @@ class Model(ABC):
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
adapter_id: str = BASE_MODEL_ADAPTER_ID,
|
adapter_id: str = BASE_MODEL_ADAPTER_ID,
|
||||||
|
support_chunking: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
|
@ -60,6 +70,29 @@ class Model(ABC):
|
||||||
speculate = get_speculate()
|
speculate = get_speculate()
|
||||||
self.speculate = speculate
|
self.speculate = speculate
|
||||||
|
|
||||||
|
support_chunking = support_chunking and PREFILL_CHUNKING
|
||||||
|
|
||||||
|
if speculate != 0 and support_chunking:
|
||||||
|
log_master(
|
||||||
|
logger.warning,
|
||||||
|
"Prefill chunking does not support speculation yet. "
|
||||||
|
"Prefill chunking will be turned off",
|
||||||
|
)
|
||||||
|
support_chunking = False
|
||||||
|
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
|
||||||
|
log_master(
|
||||||
|
logger.warning,
|
||||||
|
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
|
||||||
|
)
|
||||||
|
support_chunking = False
|
||||||
|
|
||||||
|
log_master(
|
||||||
|
logger.info, f"Using experimental prefill chunking = {support_chunking}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.support_chunking = support_chunking
|
||||||
|
set_support_chunking(support_chunking)
|
||||||
|
|
||||||
self.has_position_ids = (
|
self.has_position_ids = (
|
||||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||||
is not None
|
is not None
|
||||||
|
@ -78,6 +111,10 @@ class Model(ABC):
|
||||||
device_type=self.device.type,
|
device_type=self.device.type,
|
||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
speculate=self.speculate,
|
speculate=self.speculate,
|
||||||
|
support_chunking=self.support_chunking,
|
||||||
|
use_prefix_caching=PREFIX_CACHING,
|
||||||
|
attention_impl=ATTENTION,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
current_tokens=len(self.decoder_input_ids),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -74,6 +74,14 @@ class Tokens:
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.token_ids)
|
return len(self.token_ids)
|
||||||
|
|
||||||
|
def __add__(self, other: "Tokens") -> "Tokens":
|
||||||
|
return Tokens(
|
||||||
|
self.token_ids + other.token_ids,
|
||||||
|
self.logprobs + other.logprobs,
|
||||||
|
self.texts + other.texts,
|
||||||
|
self.is_special + other.is_special,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Generation:
|
class Generation:
|
||||||
|
|
|
@ -271,6 +271,8 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
# FIXME: VLM do not work with context chunking yet
|
||||||
|
support_chunking=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -295,7 +297,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
speculative_ids = batch.speculative_ids
|
speculative_ids = batch.speculative_ids
|
||||||
|
@ -314,8 +316,8 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
prefix_lens_tensor = (
|
cache_lengths_tensor = (
|
||||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
).reshape(-1)
|
).reshape(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
|
@ -337,8 +339,8 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
prefix_lens_tensor = batch.prefix_lens_tensor
|
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
|
@ -347,7 +349,6 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
# This makes sure the max_s for the decode pass is correct.
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
max_s = min(self.max_past(), max_s)
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
|
@ -357,26 +358,24 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
else:
|
else:
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
input_lengths = input_lengths + prefix_lens_tensor
|
if ATTENTION == "flashinfer":
|
||||||
if PREFIX_CACHING:
|
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
prefix_lens=batch.prefix_lens,
|
cache_lengths=batch.cache_lengths,
|
||||||
)
|
)
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths_tensor=input_lengths,
|
input_lengths_tensor=input_lengths,
|
||||||
prefix_lens_tensor=prefix_lens_tensor,
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
):
|
):
|
||||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
prefix_lengths=prefix_lens_tensor,
|
cache_lengths=cache_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
max_q=max_s,
|
max_q=batch.max_input_length,
|
||||||
max_k=max_k,
|
max_k=batch.max_current_length,
|
||||||
)
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
@ -411,20 +410,32 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
prefix_lens=batch.prefix_lens,
|
cache_lengths=batch.cache_lengths,
|
||||||
)
|
)
|
||||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
else:
|
else:
|
||||||
cuda_graph["block_tables"][
|
cuda_graph["block_tables"][
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
] = block_tables
|
] = block_tables
|
||||||
cuda_graph["slots"].fill_(-1)
|
|
||||||
|
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||||
|
# so it doesn't matter if we override it with bogus values.
|
||||||
|
cuda_graph["slots"].fill_(0)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
input_lengths + prefix_lens_tensor
|
cuda_graph["cache_lengths"].zero_()
|
||||||
)
|
cuda_graph["cache_lengths"][
|
||||||
|
: cache_lengths_tensor.shape[0]
|
||||||
|
] = cache_lengths_tensor
|
||||||
|
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=cuda_graph["block_tables"],
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||||
|
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
||||||
|
state=cuda_graph["state"],
|
||||||
|
):
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ from text_generation_server.cache import Cache
|
||||||
from text_generation_server.interceptor import ExceptionInterceptor
|
from text_generation_server.interceptor import ExceptionInterceptor
|
||||||
from text_generation_server.models import Model, get_model_with_lora_adapters
|
from text_generation_server.models import Model, get_model_with_lora_adapters
|
||||||
from text_generation_server.utils.adapter import AdapterInfo
|
from text_generation_server.utils.adapter import AdapterInfo
|
||||||
|
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||||
|
@ -46,9 +47,12 @@ class SignalHandler:
|
||||||
signal.signal(signal.SIGINT, self.exit_gracefully)
|
signal.signal(signal.SIGINT, self.exit_gracefully)
|
||||||
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
||||||
|
|
||||||
|
def set_keep_processing(self, value: bool):
|
||||||
|
self.KEEP_PROCESSING = value
|
||||||
|
|
||||||
def exit_gracefully(self, signum, frame):
|
def exit_gracefully(self, signum, frame):
|
||||||
print(f"Exiting gracefully: Signal {signum}")
|
print(f"Exiting gracefully: Signal {signum}")
|
||||||
self.KEEP_PROCESSING = False
|
self.set_keep_processing(False)
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
|
@ -96,6 +100,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
|
set_max_prefill_tokens(request.max_prefill_tokens)
|
||||||
|
|
||||||
if self.quantize in {"exl2", "gptq"}:
|
if self.quantize in {"exl2", "gptq"}:
|
||||||
try:
|
try:
|
||||||
# When using GPTQ, Exllama kernels need some global kernels
|
# When using GPTQ, Exllama kernels need some global kernels
|
||||||
|
@ -150,6 +156,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
concat_ns = None
|
||||||
|
if self.model.support_chunking:
|
||||||
|
if request.HasField("cached_batch"):
|
||||||
|
cached_batch = self.cache.pop(request.cached_batch.id)
|
||||||
|
if cached_batch is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Batch ID {request.cached_batch.id} not found in cache."
|
||||||
|
)
|
||||||
|
start_concat = time.time_ns()
|
||||||
|
batch = self.model.batch_type.concatenate([cached_batch, batch])
|
||||||
|
concat_ns = time.time_ns() - start_concat
|
||||||
|
|
||||||
generations, next_batch, timings = self.model.generate_token(batch)
|
generations, next_batch, timings = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
|
@ -159,6 +177,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
forward_ns=timings[0],
|
forward_ns=timings[0],
|
||||||
decode_ns=timings[1],
|
decode_ns=timings[1],
|
||||||
total_ns=time.time_ns() - start,
|
total_ns=time.time_ns() - start,
|
||||||
|
concat_ns=concat_ns,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def Decode(self, request, context):
|
async def Decode(self, request, context):
|
||||||
|
@ -252,10 +271,12 @@ def serve(
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
signal_handler = SignalHandler()
|
||||||
|
|
||||||
set_adapter_to_index(adapter_to_index)
|
set_adapter_to_index(adapter_to_index)
|
||||||
server = aio.server(
|
server = aio.server(
|
||||||
interceptors=[
|
interceptors=[
|
||||||
ExceptionInterceptor(),
|
ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)),
|
||||||
UDSOpenTelemetryAioServerInterceptor(),
|
UDSOpenTelemetryAioServerInterceptor(),
|
||||||
],
|
],
|
||||||
options=[
|
options=[
|
||||||
|
@ -276,7 +297,6 @@ def serve(
|
||||||
await server.start()
|
await server.start()
|
||||||
|
|
||||||
logger.info("Server started at {}".format(local_url))
|
logger.info("Server started at {}".format(local_url))
|
||||||
signal_handler = SignalHandler()
|
|
||||||
while signal_handler.KEEP_PROCESSING:
|
while signal_handler.KEEP_PROCESSING:
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
|
|
@ -120,8 +120,12 @@ def _load_and_merge(
|
||||||
if adapter.id == BASE_MODEL_ADAPTER_ID:
|
if adapter.id == BASE_MODEL_ADAPTER_ID:
|
||||||
raise ValueError("Base model adapter cannot be merged.")
|
raise ValueError("Base model adapter cannot be merged.")
|
||||||
|
|
||||||
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
(
|
||||||
load_module_map(
|
module_map,
|
||||||
|
adapter_config,
|
||||||
|
adapter_weight_names,
|
||||||
|
adapter_tokenizer,
|
||||||
|
) = load_module_map(
|
||||||
model_id,
|
model_id,
|
||||||
adapter.revision,
|
adapter.revision,
|
||||||
adapter.id,
|
adapter.id,
|
||||||
|
@ -129,7 +133,6 @@ def _load_and_merge(
|
||||||
weight_names,
|
weight_names,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
adapters_to_merge.append((module_map, adapter_config))
|
adapters_to_merge.append((module_map, adapter_config))
|
||||||
merged_weight_names = merged_weight_names.union(adapter_weight_names)
|
merged_weight_names = merged_weight_names.union(adapter_weight_names)
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
SUPPORT_CHUNKING: Optional[bool] = None
|
||||||
|
MAX_PREFILL_TOKENS: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_support_chunking(support_chunking: bool):
|
||||||
|
global SUPPORT_CHUNKING
|
||||||
|
SUPPORT_CHUNKING = support_chunking
|
||||||
|
|
||||||
|
|
||||||
|
def get_support_chunking() -> bool:
|
||||||
|
global SUPPORT_CHUNKING
|
||||||
|
return SUPPORT_CHUNKING
|
||||||
|
|
||||||
|
|
||||||
|
def set_max_prefill_tokens(max_prefill_tokens: int):
|
||||||
|
global MAX_PREFILL_TOKENS
|
||||||
|
MAX_PREFILL_TOKENS = max_prefill_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_prefill_tokens() -> int:
|
||||||
|
global MAX_PREFILL_TOKENS
|
||||||
|
return MAX_PREFILL_TOKENS
|
|
@ -7,6 +7,7 @@ from typing import List, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# FIXME: this should be optimized
|
||||||
def find_segments(
|
def find_segments(
|
||||||
adapter_indices: Union[torch.Tensor, List[int]]
|
adapter_indices: Union[torch.Tensor, List[int]]
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
|
|
Loading…
Reference in New Issue