feat: prefill chunking (#2600)
* wip * rollback * refactor to use prefix/postfix namming + fix all_input_ids_tensor * maybe patching vlms? * fix filter and concat * wip, no filter, no concat * current * add prepare_for_prefill * working * load tested * re-create slots * re-create slots * fix slot_filtering_indices * feedback loop * remove log * fix benchmarker * fix vlm and seq2seq * rename to cache and input lengths * fix prefill logprobs * fix launcher * fix logprobs? * idk at this point * max input length * omfg * remove debugging lines * fix tests * fix mllama * fix cargo tests * remove support chunking for paged * Fixing non blocked attentions * Fixing dtype + AMD, Ipex targets. * lint fix. * rename * Fix prefix_caching variable, remove defaults in server (confusing a lot of the times). * Add simple resolution when user specifies ATTENTION=paged. * Put back non default simple tests. * Fix env name --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
704a58c807
commit
a6a0c97ed9
|
@ -327,7 +327,8 @@ ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
|||
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||
ENV VLLM_MOE_PADDING=0
|
||||
ENV ATTENTION=paged
|
||||
ENV USE_PREFIX_CACHING=0
|
||||
ENV PREFIX_CACHING=0
|
||||
ENV PREFILL_CHUNKING=0
|
||||
ENV ROCM_USE_SKINNY_GEMM=1
|
||||
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
|
|
|
@ -218,7 +218,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
|
|||
|
||||
FROM ${PLATFORM} AS final
|
||||
ENV ATTENTION=paged
|
||||
ENV USE_PREFIX_CACHING=0
|
||||
ENV PREFIX_CACHING=0
|
||||
ENV PREFILL_CHUNKING=0
|
||||
ENV CUDA_GRAPHS=0
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
||||
|
|
|
@ -158,7 +158,8 @@ impl Client {
|
|||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
cache_len: 0,
|
||||
chunk_len: None,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
@ -217,8 +218,13 @@ impl Client {
|
|||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
cached_batch: Option<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let request = tonic::Request::new(PrefillRequest {
|
||||
batch: Some(batch),
|
||||
cached_batch,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
|
|
|
@ -134,11 +134,12 @@ impl ShardedClient {
|
|||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
cached_batch: Option<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||
|
@ -245,7 +246,8 @@ impl Health for ShardedClient {
|
|||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
cache_len: 0,
|
||||
chunk_len: None,
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
|
@ -255,7 +257,7 @@ impl Health for ShardedClient {
|
|||
max_tokens: 2,
|
||||
max_blocks: 1,
|
||||
};
|
||||
self.clone().prefill(batch).await?;
|
||||
self.clone().prefill(batch, None).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
|
|||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
|
@ -36,18 +36,14 @@ impl BackendV2 {
|
|||
speculate: u32,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||
attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||
} else {
|
||||
Attention::Paged
|
||||
};
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else {
|
||||
16
|
||||
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
|
||||
let block_size = match attention.as_str() {
|
||||
"flashinfer" => 1,
|
||||
"flashdecoding" => 256,
|
||||
"paged" => 16,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
||||
/// Batching and inference logic
|
||||
use crate::client::{
|
||||
Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient,
|
||||
};
|
||||
use crate::queue::{Entry, Queue};
|
||||
use async_trait::async_trait;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
|
@ -31,27 +33,22 @@ impl BackendV3 {
|
|||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
shard_info: InfoResponse,
|
||||
) -> Self {
|
||||
let prefix_caching =
|
||||
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
|
||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
||||
let attention: String = std::env::var("ATTENTION").expect("attention env var");
|
||||
if shard_info.support_chunking {
|
||||
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
|
||||
}
|
||||
|
||||
let attention: Attention = attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
||||
let block_size = attention.block_size();
|
||||
let block_size = shard_info.block_size;
|
||||
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
shard_info.requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
shard_info.use_prefix_caching,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
max_batch_total_tokens,
|
||||
shard_info.support_chunking,
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
|
@ -63,6 +60,7 @@ impl BackendV3 {
|
|||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.support_chunking,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
));
|
||||
|
@ -127,6 +125,7 @@ pub(crate) async fn batching_task(
|
|||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
support_chunking: bool,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
) {
|
||||
|
@ -147,7 +146,7 @@ pub(crate) async fn batching_task(
|
|||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
@ -158,10 +157,24 @@ pub(crate) async fn batching_task(
|
|||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let current_tokens = batch.current_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
|
||||
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
||||
// Since the next batch will be concatenated with the current batch,
|
||||
// the current batch tokens must be subtracted to the prefill budget
|
||||
let prefill_token_budget =
|
||||
max_batch_prefill_tokens.saturating_sub(current_tokens);
|
||||
// We can ignore min_size and max_size
|
||||
// Models than rely on max_size cannot support chunking
|
||||
// Regarding min_size, chunking allow us to consistently run at the compute
|
||||
// bound, making min_size useless.
|
||||
(None, None, prefill_token_budget)
|
||||
} else {
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
|
@ -173,24 +186,34 @@ pub(crate) async fn batching_task(
|
|||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size =
|
||||
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||
|
||||
(min_size, max_size, max_batch_prefill_tokens)
|
||||
};
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
if let Some((new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||
.await
|
||||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
.increment(1);
|
||||
} else {
|
||||
let counter = if support_chunking {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "chunking")
|
||||
} else {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||
.increment(1);
|
||||
};
|
||||
counter.increment(1);
|
||||
}
|
||||
|
||||
let cached_batch = if support_chunking {
|
||||
// Concat current batch to the new one
|
||||
batches.pop()
|
||||
} else {
|
||||
// Request are waiting only if we don't support chunking
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
|
@ -201,17 +224,23 @@ pub(crate) async fn batching_task(
|
|||
// Update entry
|
||||
entry.temp_span = Some(entry_waiting_span);
|
||||
});
|
||||
None
|
||||
};
|
||||
entries.extend(new_entries);
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||
let new_cached_batch =
|
||||
prefill(&mut client, new_batch, cached_batch, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
entries.extend(new_entries);
|
||||
batches.push(new_cached_batch);
|
||||
} else if support_chunking {
|
||||
// New cached batch is empty, no work left
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -244,13 +273,14 @@ pub(crate) async fn batching_task(
|
|||
async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
cached_batch: Option<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||
|
||||
match client.prefill(batch).await {
|
||||
match client.prefill(batch, cached_batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
|
@ -259,6 +289,10 @@ async fn prefill(
|
|||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
|
|
|
@ -158,7 +158,8 @@ impl Client {
|
|||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
cache_len: 0,
|
||||
chunk_len: None,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
@ -217,13 +218,23 @@ impl Client {
|
|||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
cached_batch: Option<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let request = tonic::Request::new(PrefillRequest {
|
||||
batch: Some(batch),
|
||||
cached_batch,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||
PrefillTimings::new(
|
||||
response.concat_ns,
|
||||
response.forward_ns,
|
||||
response.decode_ns,
|
||||
response.total_ns,
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -252,14 +263,16 @@ impl Client {
|
|||
}
|
||||
|
||||
pub struct PrefillTimings {
|
||||
pub concat: Option<Duration>,
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl PrefillTimings {
|
||||
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
concat: concat_ns.map(Duration::from_nanos),
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
|
|
|
@ -29,15 +29,6 @@ pub trait Health {
|
|||
async fn model_health(&self) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ShardInfo {
|
||||
pub requires_padding: bool,
|
||||
pub dtype: String,
|
||||
pub device_type: String,
|
||||
pub window_size: Option<u32>,
|
||||
pub speculate: u32,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientError {
|
||||
#[error("Could not connect to Text Generation server: {0}")]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::client::{ClientError, Result};
|
||||
use crate::client::Health;
|
||||
/// Multi shard Client
|
||||
use crate::client::{Health, ShardInfo};
|
||||
use crate::client::{ClientError, Result};
|
||||
|
||||
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||
use crate::client::{
|
||||
|
@ -49,13 +49,13 @@ impl ShardedClient {
|
|||
|
||||
/// Get the model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.info())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// GRPC health check
|
||||
|
@ -135,11 +135,12 @@ impl ShardedClient {
|
|||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
cached_batch: Option<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||
|
@ -194,18 +195,6 @@ impl ShardedClient {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<InfoResponse> for ShardInfo {
|
||||
fn from(value: InfoResponse) -> Self {
|
||||
Self {
|
||||
requires_padding: value.requires_padding,
|
||||
dtype: value.dtype,
|
||||
device_type: value.device_type,
|
||||
window_size: value.window_size,
|
||||
speculate: value.speculate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Health for ShardedClient {
|
||||
async fn device_health(&self) -> Result<()> {
|
||||
|
@ -246,8 +235,9 @@ impl Health for ShardedClient {
|
|||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
cache_len: 0,
|
||||
adapter_id: None,
|
||||
chunk_len: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
|
@ -256,7 +246,7 @@ impl Health for ShardedClient {
|
|||
max_tokens: 2,
|
||||
max_blocks: 1,
|
||||
};
|
||||
self.clone().prefill(batch).await?;
|
||||
self.clone().prefill(batch, None).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,14 @@ pub struct BackendInfo {
|
|||
pub max_waiting_tokens: usize,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
#[schema(example = "false")]
|
||||
pub support_chunking: bool,
|
||||
#[schema(example = "false")]
|
||||
pub prefix_caching: bool,
|
||||
#[schema(example = "flashinfer")]
|
||||
pub attention_impl: String,
|
||||
#[schema(example = "1")]
|
||||
pub block_size: u32,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
|
@ -110,6 +118,10 @@ pub async fn connect_backend(
|
|||
model_device_type: shard_info.device_type.clone(),
|
||||
model_dtype: shard_info.dtype.clone(),
|
||||
speculate: shard_info.speculate as usize,
|
||||
support_chunking: shard_info.support_chunking,
|
||||
prefix_caching: shard_info.use_prefix_caching,
|
||||
attention_impl: shard_info.attention_impl.clone(),
|
||||
block_size: shard_info.block_size,
|
||||
};
|
||||
|
||||
let backend = BackendV3::new(
|
||||
|
@ -119,9 +131,7 @@ pub async fn connect_backend(
|
|||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
shard_info,
|
||||
);
|
||||
|
||||
tracing::info!("Using backend V3");
|
||||
|
|
|
@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> {
|
|||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_batch_size) = max_batch_size {
|
||||
if max_batch_size == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
|
@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
}
|
||||
}
|
||||
|
||||
let (backend, _backend_info) = connect_backend(
|
||||
let (backend, backend_info) = connect_backend(
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
master_shard_uds_path,
|
||||
|
@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> {
|
|||
)
|
||||
.await?;
|
||||
|
||||
// Validate remaining args now that the backend is known
|
||||
let support_chunking = backend_info.support_chunking;
|
||||
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
if max_batch_prefill_tokens > max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
backend,
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::client::{
|
|||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::{max, min};
|
||||
use std::cmp::max;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
|
@ -50,6 +50,7 @@ impl Queue {
|
|||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
support_chunking: bool,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
|
@ -62,6 +63,7 @@ impl Queue {
|
|||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
support_chunking,
|
||||
queue_receiver,
|
||||
));
|
||||
|
||||
|
@ -87,6 +89,10 @@ impl Queue {
|
|||
prefill_token_budget: u32,
|
||||
token_budget: u32,
|
||||
) -> Option<NextBatch> {
|
||||
if prefill_token_budget == 0 || token_budget == 0 {
|
||||
return None;
|
||||
};
|
||||
|
||||
// Create response channel
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
// Send next batch command to the background task managing the state
|
||||
|
@ -108,6 +114,7 @@ impl Queue {
|
|||
}
|
||||
|
||||
// Background task responsible of the queue state
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn queue_task(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
|
@ -115,6 +122,7 @@ async fn queue_task(
|
|||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
support_chunking: bool,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(
|
||||
|
@ -124,6 +132,7 @@ async fn queue_task(
|
|||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
support_chunking,
|
||||
);
|
||||
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
|
@ -166,12 +175,14 @@ struct State {
|
|||
/// Paged Attention block size
|
||||
block_size: u32,
|
||||
|
||||
/// Sliding window
|
||||
window_size: Option<u32>,
|
||||
|
||||
/// Speculation amount
|
||||
speculate: u32,
|
||||
|
||||
/// Whether the model allow the prefill chunking
|
||||
/// If it does, the last request in the batch will be split to exactly match the prefill
|
||||
/// token budget
|
||||
support_chunking: bool,
|
||||
|
||||
/// Paged Attention Block Allocation
|
||||
block_allocator: Option<BlockAllocator>,
|
||||
}
|
||||
|
@ -184,6 +195,7 @@ impl State {
|
|||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
support_chunking: bool,
|
||||
) -> Self {
|
||||
let block_allocator = (!requires_padding).then(|| {
|
||||
BlockAllocator::new(
|
||||
|
@ -199,8 +211,8 @@ impl State {
|
|||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
support_chunking,
|
||||
block_allocator,
|
||||
}
|
||||
}
|
||||
|
@ -287,32 +299,7 @@ impl State {
|
|||
}
|
||||
None
|
||||
}
|
||||
Some(_block_allocator) => {
|
||||
prefill_tokens += entry.request.input_length;
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
Some(window_size) => min(
|
||||
window_size.saturating_sub(entry.request.input_length),
|
||||
entry.request.stopping_parameters.max_new_tokens,
|
||||
),
|
||||
};
|
||||
decode_tokens += max_new_tokens;
|
||||
|
||||
if prefill_tokens > prefill_token_budget
|
||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||
{
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
self.entries.push_front((id, entry));
|
||||
break;
|
||||
}
|
||||
|
||||
let tokens = entry.request.input_length
|
||||
+ entry.request.stopping_parameters.max_new_tokens
|
||||
+ self.speculate
|
||||
- 1;
|
||||
|
||||
Some(block_allocator) => {
|
||||
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||
// So no input_ids for the radix tree.
|
||||
let input_ids = if entry.request.decoder_input_details {
|
||||
|
@ -321,10 +308,73 @@ impl State {
|
|||
entry.request.input_ids.clone()
|
||||
};
|
||||
|
||||
Some((tokens, input_ids))
|
||||
let tokens = entry.request.input_length
|
||||
+ entry.request.stopping_parameters.max_new_tokens
|
||||
+ self.speculate
|
||||
- 1;
|
||||
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||
|
||||
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
|
||||
None => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: not enough free blocks");
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
Some(mut block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||
|
||||
if block_allocation.prefix_len == entry.request.input_length {
|
||||
// The whole request was found in the radix trie
|
||||
// However, for the transformer forward to work, we need to
|
||||
// have at least one token of postfix.
|
||||
block_allocation.prefix_len -= 1;
|
||||
}
|
||||
|
||||
block_allocation
|
||||
}
|
||||
};
|
||||
batch.push((id, entry, block_allocation));
|
||||
|
||||
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
|
||||
|
||||
if prefill_tokens + postfix_len > prefill_token_budget {
|
||||
// Entry is over budget
|
||||
if self.support_chunking {
|
||||
// We support chunking, just set postfix_len to exactly match prefill_token_budget
|
||||
let chunk_len = prefill_token_budget.saturating_sub(prefill_tokens);
|
||||
if chunk_len > 0 {
|
||||
// Push this entry inside the batch
|
||||
batch.push((id, entry, Some(block_allocation), Some(chunk_len)));
|
||||
} else {
|
||||
// We cannot prefill even one token for this entry
|
||||
// Add it back to the queue
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
tracing::debug!(
|
||||
"Matched budget: prefill_tokens={} == {prefill_token_budget}",
|
||||
prefill_tokens + postfix_len
|
||||
);
|
||||
break 'entry_loop;
|
||||
} else {
|
||||
// We don't support chunking, this entry needs to go back to the buffer
|
||||
// Add it back to the front
|
||||
tracing::debug!(
|
||||
"Over budget: prefill_tokens={} > {prefill_token_budget}",
|
||||
prefill_tokens + postfix_len
|
||||
);
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
}
|
||||
|
||||
prefill_tokens += postfix_len;
|
||||
|
||||
Some(block_allocation)
|
||||
}
|
||||
};
|
||||
batch.push((id, entry, block_allocation, None));
|
||||
if Some(batch.len()) == max_size {
|
||||
break;
|
||||
}
|
||||
|
@ -342,7 +392,7 @@ impl State {
|
|||
// Batch is too small
|
||||
if batch.len() < min_size {
|
||||
// Add back entries to the queue in the correct order
|
||||
for (id, entry, _) in batch.into_iter().rev() {
|
||||
for (id, entry, _, _) in batch.into_iter().rev() {
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
return None;
|
||||
|
@ -353,29 +403,7 @@ impl State {
|
|||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
for (id, mut entry, block_allocation) in batch {
|
||||
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
||||
(block_allocation, &self.block_allocator)
|
||||
{
|
||||
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||
match block_allocator.allocate(tokens, input_ids).await {
|
||||
None => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: not enough free blocks");
|
||||
self.entries.push_front((id, entry));
|
||||
continue;
|
||||
}
|
||||
Some(block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||
Some(block_allocation)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
tracing::debug!("Accepting entry");
|
||||
for (id, mut entry, block_allocation, chunk_len) in batch {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
|
@ -427,8 +455,9 @@ impl State {
|
|||
top_n_tokens: entry.request.top_n_tokens,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len,
|
||||
cache_len: prefix_len,
|
||||
adapter_id: entry.request.adapter_id.clone(),
|
||||
chunk_len,
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
|
@ -436,12 +465,6 @@ impl State {
|
|||
batch_entries.insert(id, entry);
|
||||
}
|
||||
|
||||
// Empty batch
|
||||
if batch_requests.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Final batch size
|
||||
let size = batch_requests.len() as u32;
|
||||
next_batch_span.record("batch_size", size);
|
||||
|
@ -531,7 +554,7 @@ mod tests {
|
|||
request: ValidGenerateRequest {
|
||||
inputs: vec![],
|
||||
input_ids: Some(Arc::new(vec![])),
|
||||
input_length: 0,
|
||||
input_length: 1,
|
||||
add_special_tokens: true,
|
||||
truncate: 0,
|
||||
decoder_input_details: false,
|
||||
|
@ -567,7 +590,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_append() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
|
@ -583,7 +606,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_empty() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||
|
||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
|
@ -591,7 +614,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
|
@ -623,7 +646,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_max_size() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
|
@ -643,7 +666,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 2);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
|
@ -676,14 +699,14 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
|
@ -691,7 +714,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -724,7 +747,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -740,7 +763,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -765,7 +788,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_speculate() {
|
||||
let queue = Queue::new(false, 1, false, None, 2, 16);
|
||||
let queue = Queue::new(true, 1, false, None, 2, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -784,7 +807,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
|
|
|
@ -158,7 +158,8 @@ async fn prefill(
|
|||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
cache_len: 0,
|
||||
chunk_len: None,
|
||||
adapter_id: None,
|
||||
})
|
||||
.collect();
|
||||
|
@ -173,7 +174,7 @@ async fn prefill(
|
|||
|
||||
// Run prefill
|
||||
let start_time = Instant::now();
|
||||
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
|
||||
let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?;
|
||||
|
||||
// Get latency
|
||||
let latency = start_time.elapsed();
|
||||
|
|
|
@ -178,6 +178,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
.clear_cache(None)
|
||||
.await
|
||||
.expect("Unable to clear cache");
|
||||
|
||||
tracing::info!("Connected");
|
||||
|
||||
// Run app
|
||||
|
|
|
@ -9,13 +9,16 @@ import subprocess
|
|||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import docker
|
||||
import pytest
|
||||
import base64
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
||||
from docker.errors import NotFound
|
||||
from syrupy.extensions.json import JSONSnapshotExtension
|
||||
|
||||
from text_generation import AsyncClient
|
||||
from text_generation.types import (
|
||||
BestOfSequence,
|
||||
|
@ -403,6 +406,7 @@ def launcher(event_loop):
|
|||
print(" ".join(args), file=sys.stderr)
|
||||
|
||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||
env["PREFILL_CHUNKING"] = "1"
|
||||
|
||||
if not use_flash_attention:
|
||||
env["USE_FLASH_ATTENTION"] = "false"
|
||||
|
@ -501,6 +505,7 @@ def launcher(event_loop):
|
|||
|
||||
env = {
|
||||
"LOG_LEVEL": "info,text_generation_router=debug",
|
||||
"PREFILL_CHUNKING": "1",
|
||||
}
|
||||
if not use_flash_attention:
|
||||
env["USE_FLASH_ATTENTION"] = "false"
|
||||
|
@ -642,3 +647,22 @@ def generate_multi():
|
|||
return responses
|
||||
|
||||
return generate_load_inner
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
@pytest.fixture
|
||||
def chicken():
|
||||
path = Path(__file__).parent / "images" / "chicken_on_money.png"
|
||||
|
||||
with open(path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cow_beach():
|
||||
path = Path(__file__).parent / "images" / "cow_beach.png"
|
||||
|
||||
with open(path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
import base64
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -20,24 +19,11 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
|
|||
return flash_pali_gemma_handle.client
|
||||
|
||||
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
def get_cow_beach():
|
||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
||||
cow = get_cow_beach()
|
||||
inputs = f"![]({cow})Where is the cow standing?\n"
|
||||
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach):
|
||||
inputs = f"![]({cow_beach})Where is the cow standing?\n"
|
||||
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
|
||||
|
||||
assert response.generated_text == "beach"
|
||||
|
@ -47,9 +33,9 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
|||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
cow_beach = get_cow_beach()
|
||||
async def test_flash_pali_gemma_two_images(
|
||||
flash_pali_gemma, response_snapshot, chicken, cow_beach
|
||||
):
|
||||
response = await flash_pali_gemma.generate(
|
||||
f"caption![]({chicken})![]({cow_beach})\n",
|
||||
max_new_tokens=20,
|
||||
|
|
|
@ -25,7 +25,6 @@ async def llama_grammar(llama_grammar_handle):
|
|||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
||||
|
||||
class Weather(BaseModel):
|
||||
unit: str
|
||||
temperature: List[int]
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
import base64
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -16,22 +15,8 @@ async def idefics(idefics_handle):
|
|||
return idefics_handle.client
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
def get_cow_beach():
|
||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idefics(idefics, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
async def test_idefics(idefics, response_snapshot, chicken):
|
||||
response = await idefics.generate(
|
||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||
max_new_tokens=10,
|
||||
|
@ -48,9 +33,7 @@ async def test_idefics(idefics, response_snapshot):
|
|||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_idefics_two_images(idefics, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
cow_beach = get_cow_beach()
|
||||
async def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach):
|
||||
response = await idefics.generate(
|
||||
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||
max_new_tokens=20,
|
||||
|
@ -63,8 +46,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
|
|||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
async def test_idefics_load(idefics, generate_load, response_snapshot, chicken):
|
||||
responses = await generate_load(
|
||||
idefics,
|
||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||
|
|
|
@ -1,18 +1,4 @@
|
|||
import pytest
|
||||
import base64
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
def get_cow_beach():
|
||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -31,8 +17,9 @@ async def flash_idefics2_next(flash_idefics2_next_handle):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
async def test_flash_idefics2_next_simple(
|
||||
flash_idefics2_next, response_snapshot, chicken
|
||||
):
|
||||
response = await flash_idefics2_next.generate(
|
||||
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
|
||||
max_new_tokens=10,
|
||||
|
@ -46,9 +33,9 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
cow_beach = get_cow_beach()
|
||||
async def test_flash_idefics2_two_images(
|
||||
flash_idefics2_next, response_snapshot, chicken, cow_beach
|
||||
):
|
||||
response = await flash_idefics2_next.generate(
|
||||
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||
max_new_tokens=20,
|
||||
|
@ -87,9 +74,8 @@ async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snap
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_idefics2_next_load(
|
||||
flash_idefics2_next, generate_load, response_snapshot
|
||||
flash_idefics2_next, generate_load, response_snapshot, chicken
|
||||
):
|
||||
chicken = get_chicken()
|
||||
responses = await generate_load(
|
||||
flash_idefics2_next,
|
||||
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
|
||||
|
|
|
@ -1,12 +1,4 @@
|
|||
import pytest
|
||||
import base64
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -29,8 +21,7 @@ async def flash_llava_next(flash_llava_next_handle):
|
|||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot, chicken):
|
||||
response = await flash_llava_next.generate(
|
||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||
max_new_tokens=10,
|
||||
|
@ -70,9 +61,8 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_load(
|
||||
flash_llava_next, generate_load, response_snapshot
|
||||
flash_llava_next, generate_load, response_snapshot, chicken
|
||||
):
|
||||
chicken = get_chicken()
|
||||
responses = await generate_load(
|
||||
flash_llava_next,
|
||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
import base64
|
||||
import asyncio
|
||||
|
||||
|
||||
|
@ -15,22 +14,8 @@ async def mllama(mllama_handle):
|
|||
return mllama_handle.client
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
def get_cow_beach():
|
||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mllama_simpl(mllama, response_snapshot):
|
||||
# chicken = get_chicken()
|
||||
response = await mllama.chat(
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
|
|
|
@ -68,7 +68,7 @@ fn get_config(
|
|||
|
||||
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||
let compute_capability = gpu::get_cuda_capability();
|
||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||
let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
|
||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||
if let Some(config) = config {
|
||||
if prefix_caching.is_none() {
|
||||
|
@ -124,6 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||
}
|
||||
}
|
||||
}
|
||||
if attention == Some("paged".to_string()) && prefix_caching.is_none() {
|
||||
tracing::info!("Disabling prefix caching on paged attention");
|
||||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
|
||||
let attention = attention.unwrap_or("flashinfer".to_string());
|
||||
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
||||
|
@ -1678,7 +1682,7 @@ fn main() -> Result<(), LauncherError> {
|
|||
};
|
||||
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
||||
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
||||
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
|
||||
std::env::set_var("PREFIX_CACHING", prefix_caching);
|
||||
std::env::set_var("ATTENTION", attention);
|
||||
|
||||
let max_input_tokens = {
|
||||
|
@ -1729,12 +1733,6 @@ fn main() -> Result<(), LauncherError> {
|
|||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
||||
max_batch_prefill_tokens, max_input_tokens
|
||||
)));
|
||||
}
|
||||
|
||||
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
||||
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
||||
|
@ -1788,12 +1786,6 @@ fn main() -> Result<(), LauncherError> {
|
|||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
max_batch_prefill_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
|
|
|
@ -34,6 +34,10 @@ message InfoResponse {
|
|||
string device_type = 3;
|
||||
optional uint32 window_size = 4;
|
||||
uint32 speculate = 5;
|
||||
bool support_chunking = 6;
|
||||
bool use_prefix_caching = 7;
|
||||
string attention_impl = 8;
|
||||
uint32 block_size = 9;
|
||||
}
|
||||
|
||||
/// Empty request
|
||||
|
@ -135,10 +139,14 @@ message Request {
|
|||
repeated uint32 slots = 10;
|
||||
/// LORA adapter index
|
||||
optional string adapter_id = 11;
|
||||
/// Prefix length that can be retrieved from the KV cache.
|
||||
uint32 prefix_len = 12;
|
||||
/// Tokens that can be retrieved from the KV cache.
|
||||
/// This value is set for the first prefill and never reset
|
||||
uint32 cache_len = 12;
|
||||
/// Context truncation
|
||||
bool add_special_tokens = 13;
|
||||
/// Chunk of tokens that must be computed for the first prefill
|
||||
/// This value is set for the first prefill and never reset
|
||||
optional uint32 chunk_len = 14;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
@ -163,6 +171,8 @@ message CachedBatch {
|
|||
uint32 size = 3;
|
||||
/// Maximum number of tokens this batch will grow to
|
||||
uint32 max_tokens = 4;
|
||||
/// Number of tokens in the next forward
|
||||
uint32 current_tokens = 5;
|
||||
}
|
||||
|
||||
enum FinishReason {
|
||||
|
@ -220,6 +230,8 @@ message FilterBatchResponse {
|
|||
message PrefillRequest {
|
||||
/// Batch
|
||||
Batch batch = 1;
|
||||
/// Optional cached batch
|
||||
CachedBatch cached_batch = 2;
|
||||
}
|
||||
|
||||
message PrefillResponse {
|
||||
|
@ -233,6 +245,8 @@ message PrefillResponse {
|
|||
uint64 decode_ns = 4;
|
||||
/// Total elapsed time in nanoseconds
|
||||
uint64 total_ns = 5;
|
||||
/// Concatenate elapsed time in nanoseconds
|
||||
optional uint64 concat_ns = 6;
|
||||
}
|
||||
|
||||
message DecodeRequest {
|
||||
|
|
|
@ -18,45 +18,6 @@ use tracing::warn;
|
|||
use utoipa::ToSchema;
|
||||
use validation::Validation;
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum Attention {
|
||||
Paged,
|
||||
FlashDecoding,
|
||||
FlashInfer,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
pub fn block_size(&self) -> u32 {
|
||||
match self {
|
||||
Attention::FlashDecoding => 256,
|
||||
Attention::FlashInfer => 1,
|
||||
Attention::Paged => 16,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ParseError;
|
||||
|
||||
impl std::fmt::Display for ParseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Cannot parse attention value")
|
||||
}
|
||||
}
|
||||
impl std::error::Error for ParseError {}
|
||||
|
||||
impl std::str::FromStr for Attention {
|
||||
type Err = ParseError;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"paged" => Ok(Attention::Paged),
|
||||
"flashdecoding" => Ok(Attention::FlashDecoding),
|
||||
"flashinfer" => Ok(Attention::FlashInfer),
|
||||
_ => Err(ParseError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Hub type
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct HubModelInfo {
|
||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
import os
|
||||
from text_generation_server.pb import generate_pb2
|
||||
|
||||
os.environ["USE_PREFIX_CACHING"] = "1"
|
||||
os.environ["PREFIX_CACHING"] = "1"
|
||||
os.environ["ATTENTION"] = "flashinfer"
|
||||
|
||||
|
||||
|
|
|
@ -9,6 +9,9 @@ from typing import Callable, Any
|
|||
|
||||
|
||||
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
def __init__(self, shutdown_callback):
|
||||
self.shutdown_callback = shutdown_callback
|
||||
|
||||
async def intercept(
|
||||
self,
|
||||
method: Callable,
|
||||
|
@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
|||
|
||||
# Runtime Error cannot be recovered from
|
||||
if isinstance(err, RuntimeError):
|
||||
exit(1)
|
||||
self.shutdown_callback()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -1,16 +1,12 @@
|
|||
from dataclasses import dataclass
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
|
||||
if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
prefix_lengths: torch.Tensor
|
||||
cache_lengths: torch.Tensor
|
||||
cu_seqlen_q: Optional[torch.Tensor]
|
||||
cu_seqlen_k: Optional[torch.Tensor]
|
||||
max_q: int
|
||||
|
@ -19,13 +15,13 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
|||
def __init__(
|
||||
self,
|
||||
input_lengths,
|
||||
prefix_lengths,
|
||||
cache_lengths,
|
||||
cu_seqlen_q=None,
|
||||
max_q=None,
|
||||
max_k=None,
|
||||
):
|
||||
self.input_lengths = input_lengths
|
||||
self.prefix_lengths = prefix_lengths
|
||||
self.cache_lengths = cache_lengths
|
||||
device = self.input_lengths.device
|
||||
shape = self.input_lengths.shape
|
||||
if cu_seqlen_q is None:
|
||||
|
@ -43,7 +39,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
|||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||
# Although FA2 might not want the clamping
|
||||
# cu_seqlen_k[0] = 0
|
||||
total = self.input_lengths + self.prefix_lengths
|
||||
total = self.input_lengths + self.cache_lengths
|
||||
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||
|
||||
self.cu_seqlen_q = cu_seqlen_q
|
||||
|
@ -54,19 +50,3 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
|||
def clamp(self, max):
|
||||
# Flash decoding doesn't need to clamp
|
||||
return self
|
||||
|
||||
else:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
prefix_lengths: torch.Tensor
|
||||
cu_seqlen_q: torch.Tensor
|
||||
max_q: int
|
||||
max_k: int
|
||||
|
||||
def clamp(self, max):
|
||||
if SYSTEM == "rocm":
|
||||
return self
|
||||
self.input_lengths = torch.clamp(self.input_lengths, max=max)
|
||||
return self
|
||||
|
|
|
@ -123,7 +123,7 @@ def paged_attention(
|
|||
else:
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
input_lengths = seqlen.input_lengths
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
from vllm._C import ops
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
@ -244,7 +244,8 @@ if ATTENTION == "flashinfer":
|
|||
window_left=window_size_left,
|
||||
)
|
||||
|
||||
elif V2:
|
||||
elif ATTENTION == "flashdecoding":
|
||||
if V2:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
|
@ -284,7 +285,7 @@ elif V2:
|
|||
None,
|
||||
)[0]
|
||||
|
||||
else:
|
||||
else:
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
|
@ -302,7 +303,9 @@ else:
|
|||
"window_size_left is only available with flash attn v2"
|
||||
)
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is only available with flash attn v2")
|
||||
raise NotImplementedError(
|
||||
"softcap is only available with flash attn v2"
|
||||
)
|
||||
|
||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||
if k.shape[1] != q.shape[1]:
|
||||
|
@ -350,11 +353,123 @@ else:
|
|||
)
|
||||
return out
|
||||
|
||||
elif ATTENTION == "paged":
|
||||
if V2:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap=0.0,
|
||||
):
|
||||
out = torch.empty_like(q)
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
None,
|
||||
None,
|
||||
None, # block_tables,
|
||||
None,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
window_size_left,
|
||||
0,
|
||||
softcap,
|
||||
False,
|
||||
None,
|
||||
)[0]
|
||||
|
||||
else:
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap=None,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
raise NotImplementedError(
|
||||
"window_size_left is only available with flash attn v2"
|
||||
)
|
||||
if softcap is not None:
|
||||
raise NotImplementedError(
|
||||
"softcap is only available with flash attn v2"
|
||||
)
|
||||
|
||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||
if k.shape[1] != q.shape[1]:
|
||||
# MQA expand
|
||||
if k.shape[1] == 1:
|
||||
k = k.expand(-1, q.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = k.shape
|
||||
k = (
|
||||
k.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
if v.shape[1] != q.shape[1]:
|
||||
# MQA expand
|
||||
if v.shape[1] == 1:
|
||||
v = v.expand(-1, q.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = v.shape
|
||||
v = (
|
||||
v.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
|
||||
out = torch.empty_like(q)
|
||||
flash_attn_cuda.fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
return out
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknwon attention {ATTENTION}")
|
||||
|
||||
|
||||
# Prefill in the cache with every kind of attention, unless we
|
||||
# have a configuration that requires flash-attention v1, which
|
||||
# does not support block tables.
|
||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
||||
PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
|
|
|
@ -699,7 +699,6 @@ def check_args(
|
|||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
|
|
|
@ -66,6 +66,7 @@ def paged_attention(
|
|||
softcap: Optional[float] = None,
|
||||
):
|
||||
out = torch.empty_like(query)
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
|
@ -74,7 +75,7 @@ def paged_attention(
|
|||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
seqlen.input_lengths,
|
||||
input_lengths,
|
||||
BLOCK_SIZE,
|
||||
max_s,
|
||||
None,
|
||||
|
|
|
@ -104,7 +104,7 @@ def paged_attention(
|
|||
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
||||
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
input_lengths = seqlen.input_lengths
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
|
|
|
@ -76,6 +76,7 @@ class CausalLMBatch(Batch):
|
|||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
current_tokens=len(self.input_ids),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -493,9 +493,14 @@ class MllamaVisionModel(nn.Module):
|
|||
aspect_ratio_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
|
||||
pixel_values.shape
|
||||
)
|
||||
(
|
||||
batch_size,
|
||||
num_concurrent_media,
|
||||
num_tiles,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
) = pixel_values.shape
|
||||
|
||||
pixel_values = pixel_values.reshape(
|
||||
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -5,9 +5,14 @@ from typing import Dict, Optional
|
|||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
|
||||
ATTENTION = os.environ["ATTENTION"]
|
||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
||||
"1",
|
||||
"true",
|
||||
}
|
||||
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
ATTENTION = os.getenv("ATTENTION")
|
||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
|
@ -18,7 +23,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
|||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
|
||||
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
|
||||
assert TGI_WIGGLE_ROOM > 0
|
||||
assert TGI_WIGGLE_ROOM < 1
|
||||
|
||||
|
|
|
@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
current_tokens=len(self),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -116,6 +116,7 @@ class MambaBatch(Batch):
|
|||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
current_tokens=len(self),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from typing import Iterable, Optional, Tuple, List, Dict
|
||||
from text_generation_server.pb.generate_pb2 import Request
|
||||
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
|
@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
|
||||
max=config.text_config.vocab_size - 1
|
||||
)
|
||||
if isinstance(batch.input_ids, list):
|
||||
if len(batch) > 1:
|
||||
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
|
||||
else:
|
||||
input_ids = batch.input_ids[0]
|
||||
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
|
||||
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
||||
|
||||
if image_inputs is not None:
|
||||
|
@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||
class MllamaCausalLM(VlmCausalLM):
|
||||
def forward(
|
||||
self,
|
||||
batch: VlmCausalLMBatch,
|
||||
batch: MllamaCausalLMBatch,
|
||||
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
|
@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM):
|
|||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
speculative_ids = batch.speculative_ids
|
||||
|
@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM):
|
|||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
prefix_lens_tensor = (
|
||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
cache_lengths_tensor = (
|
||||
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
).reshape(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
|
@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM):
|
|||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||
max_s = batch.max_seqlen
|
||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
|
@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM):
|
|||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
# Try to find an associated cuda graph
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
|
@ -269,26 +278,24 @@ class MllamaCausalLM(VlmCausalLM):
|
|||
# Only run cuda graphs when there's no images.
|
||||
or batch.cross_attention_states is not None
|
||||
):
|
||||
input_lengths = input_lengths + prefix_lens_tensor
|
||||
if PREFIX_CACHING:
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths_tensor=input_lengths,
|
||||
prefix_lens_tensor=prefix_lens_tensor,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
prefix_lengths=prefix_lens_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=max_s,
|
||||
max_k=max_k,
|
||||
max_q=batch.max_input_length,
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
|
||||
if batch.pixel_values is not None:
|
||||
|
@ -330,20 +337,32 @@ class MllamaCausalLM(VlmCausalLM):
|
|||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
)
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
else:
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
|
||||
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||
# so it doesn't matter if we override it with bogus values.
|
||||
cuda_graph["slots"].fill_(0)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||
input_lengths + prefix_lens_tensor
|
||||
)
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["cache_lengths"].zero_()
|
||||
cuda_graph["cache_lengths"][
|
||||
: cache_lengths_tensor.shape[0]
|
||||
] = cache_lengths_tensor
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
||||
state=cuda_graph["state"],
|
||||
):
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
|
|
|
@ -5,8 +5,17 @@ from abc import ABC, abstractmethod
|
|||
from typing import List, Tuple, Optional, TypeVar, Type, Dict
|
||||
from collections import defaultdict
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.models.globals import (
|
||||
ATTENTION,
|
||||
PREFIX_CACHING,
|
||||
BLOCK_SIZE,
|
||||
PREFILL_CHUNKING,
|
||||
)
|
||||
from text_generation_server.models.types import Batch, Generation
|
||||
from text_generation_server.utils.log import log_master
|
||||
from text_generation_server.utils.prefill_chunking import set_support_chunking
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||
|
@ -31,6 +40,7 @@ class Model(ABC):
|
|||
sliding_window: Optional[int] = None,
|
||||
speculate: Optional[int] = None,
|
||||
adapter_id: str = BASE_MODEL_ADAPTER_ID,
|
||||
support_chunking: bool = False,
|
||||
):
|
||||
self.model_id = model_id
|
||||
self.model = model.eval()
|
||||
|
@ -60,6 +70,29 @@ class Model(ABC):
|
|||
speculate = get_speculate()
|
||||
self.speculate = speculate
|
||||
|
||||
support_chunking = support_chunking and PREFILL_CHUNKING
|
||||
|
||||
if speculate != 0 and support_chunking:
|
||||
log_master(
|
||||
logger.warning,
|
||||
"Prefill chunking does not support speculation yet. "
|
||||
"Prefill chunking will be turned off",
|
||||
)
|
||||
support_chunking = False
|
||||
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
|
||||
log_master(
|
||||
logger.warning,
|
||||
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
|
||||
)
|
||||
support_chunking = False
|
||||
|
||||
log_master(
|
||||
logger.info, f"Using experimental prefill chunking = {support_chunking}"
|
||||
)
|
||||
|
||||
self.support_chunking = support_chunking
|
||||
set_support_chunking(support_chunking)
|
||||
|
||||
self.has_position_ids = (
|
||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||
is not None
|
||||
|
@ -78,6 +111,10 @@ class Model(ABC):
|
|||
device_type=self.device.type,
|
||||
window_size=self.sliding_window,
|
||||
speculate=self.speculate,
|
||||
support_chunking=self.support_chunking,
|
||||
use_prefix_caching=PREFIX_CACHING,
|
||||
attention_impl=ATTENTION,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
current_tokens=len(self.decoder_input_ids),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -74,6 +74,14 @@ class Tokens:
|
|||
def __len__(self):
|
||||
return len(self.token_ids)
|
||||
|
||||
def __add__(self, other: "Tokens") -> "Tokens":
|
||||
return Tokens(
|
||||
self.token_ids + other.token_ids,
|
||||
self.logprobs + other.logprobs,
|
||||
self.texts + other.texts,
|
||||
self.is_special + other.is_special,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Generation:
|
||||
|
|
|
@ -271,6 +271,8 @@ class VlmCausalLM(FlashCausalLM):
|
|||
model_id=model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
# FIXME: VLM do not work with context chunking yet
|
||||
support_chunking=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -295,7 +297,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
speculative_ids = batch.speculative_ids
|
||||
|
@ -314,8 +316,8 @@ class VlmCausalLM(FlashCausalLM):
|
|||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
prefix_lens_tensor = (
|
||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
cache_lengths_tensor = (
|
||||
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
).reshape(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
|
@ -337,8 +339,8 @@ class VlmCausalLM(FlashCausalLM):
|
|||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||
max_s = batch.max_seqlen
|
||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
|
@ -347,7 +349,6 @@ class VlmCausalLM(FlashCausalLM):
|
|||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
# Try to find an associated cuda graph
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
|
@ -357,26 +358,24 @@ class VlmCausalLM(FlashCausalLM):
|
|||
else:
|
||||
cuda_graph = None
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
input_lengths = input_lengths + prefix_lens_tensor
|
||||
if PREFIX_CACHING:
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths_tensor=input_lengths,
|
||||
prefix_lens_tensor=prefix_lens_tensor,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
prefix_lengths=prefix_lens_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=max_s,
|
||||
max_k=max_k,
|
||||
max_q=batch.max_input_length,
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
|
@ -411,20 +410,32 @@ class VlmCausalLM(FlashCausalLM):
|
|||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
)
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
else:
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
|
||||
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||
# so it doesn't matter if we override it with bogus values.
|
||||
cuda_graph["slots"].fill_(0)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||
input_lengths + prefix_lens_tensor
|
||||
)
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["cache_lengths"].zero_()
|
||||
cuda_graph["cache_lengths"][
|
||||
: cache_lengths_tensor.shape[0]
|
||||
] = cache_lengths_tensor
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
||||
state=cuda_graph["state"],
|
||||
):
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from text_generation_server.cache import Cache
|
|||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model_with_lora_adapters
|
||||
from text_generation_server.utils.adapter import AdapterInfo
|
||||
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||
|
||||
try:
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
|
@ -46,9 +47,12 @@ class SignalHandler:
|
|||
signal.signal(signal.SIGINT, self.exit_gracefully)
|
||||
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
||||
|
||||
def set_keep_processing(self, value: bool):
|
||||
self.KEEP_PROCESSING = value
|
||||
|
||||
def exit_gracefully(self, signum, frame):
|
||||
print(f"Exiting gracefully: Signal {signum}")
|
||||
self.KEEP_PROCESSING = False
|
||||
self.set_keep_processing(False)
|
||||
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
@ -96,6 +100,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
set_max_prefill_tokens(request.max_prefill_tokens)
|
||||
|
||||
if self.quantize in {"exl2", "gptq"}:
|
||||
try:
|
||||
# When using GPTQ, Exllama kernels need some global kernels
|
||||
|
@ -150,6 +156,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
concat_ns = None
|
||||
if self.model.support_chunking:
|
||||
if request.HasField("cached_batch"):
|
||||
cached_batch = self.cache.pop(request.cached_batch.id)
|
||||
if cached_batch is None:
|
||||
raise ValueError(
|
||||
f"Batch ID {request.cached_batch.id} not found in cache."
|
||||
)
|
||||
start_concat = time.time_ns()
|
||||
batch = self.model.batch_type.concatenate([cached_batch, batch])
|
||||
concat_ns = time.time_ns() - start_concat
|
||||
|
||||
generations, next_batch, timings = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
|
@ -159,6 +177,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
forward_ns=timings[0],
|
||||
decode_ns=timings[1],
|
||||
total_ns=time.time_ns() - start,
|
||||
concat_ns=concat_ns,
|
||||
)
|
||||
|
||||
async def Decode(self, request, context):
|
||||
|
@ -252,10 +271,12 @@ def serve(
|
|||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
||||
signal_handler = SignalHandler()
|
||||
|
||||
set_adapter_to_index(adapter_to_index)
|
||||
server = aio.server(
|
||||
interceptors=[
|
||||
ExceptionInterceptor(),
|
||||
ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)),
|
||||
UDSOpenTelemetryAioServerInterceptor(),
|
||||
],
|
||||
options=[
|
||||
|
@ -276,7 +297,6 @@ def serve(
|
|||
await server.start()
|
||||
|
||||
logger.info("Server started at {}".format(local_url))
|
||||
signal_handler = SignalHandler()
|
||||
while signal_handler.KEEP_PROCESSING:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
|
|
@ -120,8 +120,12 @@ def _load_and_merge(
|
|||
if adapter.id == BASE_MODEL_ADAPTER_ID:
|
||||
raise ValueError("Base model adapter cannot be merged.")
|
||||
|
||||
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
||||
load_module_map(
|
||||
(
|
||||
module_map,
|
||||
adapter_config,
|
||||
adapter_weight_names,
|
||||
adapter_tokenizer,
|
||||
) = load_module_map(
|
||||
model_id,
|
||||
adapter.revision,
|
||||
adapter.id,
|
||||
|
@ -129,7 +133,6 @@ def _load_and_merge(
|
|||
weight_names,
|
||||
trust_remote_code,
|
||||
)
|
||||
)
|
||||
|
||||
adapters_to_merge.append((module_map, adapter_config))
|
||||
merged_weight_names = merged_weight_names.union(adapter_weight_names)
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
from typing import Optional
|
||||
|
||||
SUPPORT_CHUNKING: Optional[bool] = None
|
||||
MAX_PREFILL_TOKENS: Optional[int] = None
|
||||
|
||||
|
||||
def set_support_chunking(support_chunking: bool):
|
||||
global SUPPORT_CHUNKING
|
||||
SUPPORT_CHUNKING = support_chunking
|
||||
|
||||
|
||||
def get_support_chunking() -> bool:
|
||||
global SUPPORT_CHUNKING
|
||||
return SUPPORT_CHUNKING
|
||||
|
||||
|
||||
def set_max_prefill_tokens(max_prefill_tokens: int):
|
||||
global MAX_PREFILL_TOKENS
|
||||
MAX_PREFILL_TOKENS = max_prefill_tokens
|
||||
|
||||
|
||||
def get_max_prefill_tokens() -> int:
|
||||
global MAX_PREFILL_TOKENS
|
||||
return MAX_PREFILL_TOKENS
|
|
@ -7,6 +7,7 @@ from typing import List, Tuple, Union
|
|||
import torch
|
||||
|
||||
|
||||
# FIXME: this should be optimized
|
||||
def find_segments(
|
||||
adapter_indices: Union[torch.Tensor, List[int]]
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
|
|
Loading…
Reference in New Issue