diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 192cd111..3c9b1d71 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -164,6 +164,7 @@ enum FinishReason { FINISH_REASON_LENGTH = 0; FINISH_REASON_EOS_TOKEN = 1; FINISH_REASON_STOP_SEQUENCE = 2; + FINISH_REASON_TERMINATED = 3; } message GeneratedText { @@ -198,11 +199,11 @@ message Generation { optional GeneratedText generated_text = 4; /// Top tokens repeated Tokens top_tokens = 5; - /// Current length of the request: prompt tokens + number of generated tokens until this point - uint32 current_length = 6; + /// Current length of the cache: prompt tokens + number of generated tokens until this point + uint32 cache_length = 6; } -message UpdatedRequest { +message KeptRequest { /// Request ID uint64 id = 1; /// Paged attention blocks @@ -211,16 +212,23 @@ message UpdatedRequest { repeated uint32 slots = 3; } +/// kept_requests + terminated_request_ids might not cover all requests from the +/// cached batch as some requests can be filtered out without requiring to generate text +/// for example if the client dropped its connection to the router message FilterBatchRequest { /// Batch ID uint64 batch_id = 1; /// Requests to keep - repeated UpdatedRequest updated_requests = 2; + repeated KeptRequest kept_requests = 2; + /// Requests to terminate and generate text for + repeated uint64 terminated_request_ids = 3; } message FilterBatchResponse { /// Filtered Batch (cached) CachedBatch batch = 1; + /// Terminated generations + repeated GeneratedText terminated_generations = 2; } diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 8cefd313..90f43270 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -90,11 +90,13 @@ impl Client { pub async fn filter_batch( &mut self, batch_id: u64, - updated_requests: Vec, + kept_requests: Vec, + terminated_request_ids: Vec, ) -> Result> { let request = tonic::Request::new(FilterBatchRequest { batch_id, - updated_requests, + kept_requests, + terminated_request_ids, }) .inject_context(); let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); diff --git a/router/client/src/v3/mod.rs b/router/client/src/v3/mod.rs index df2bb380..ea7486ee 100644 --- a/router/client/src/v3/mod.rs +++ b/router/client/src/v3/mod.rs @@ -7,7 +7,7 @@ mod sharded_client; pub use client::Client; pub use pb::generate::v3::{ input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, - HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, - StoppingCriteriaParameters, Tokens, UpdatedRequest, + HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, }; pub use sharded_client::ShardedClient; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index a066176c..e1b35a21 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -9,8 +9,8 @@ use tonic::transport::Uri; use tracing::instrument; use v3::client::{DecodeTimings, PrefillTimings}; use v3::{ - Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, - NextTokenChooserParameters, Request, StoppingCriteriaParameters, UpdatedRequest, + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, KeptRequest, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; #[derive(Debug, Clone)] @@ -84,12 +84,19 @@ impl ShardedClient { pub async fn filter_batch( &mut self, batch_id: u64, - updated_requests: Vec, + kept_requests: Vec, + terminated_request_ids: Vec, ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.filter_batch(batch_id, updated_requests.clone()))) + .map(|client| { + Box::pin(client.filter_batch( + batch_id, + kept_requests.clone(), + terminated_request_ids.clone(), + )) + }) .collect(); // all shards return the same message join_all(futures).await.pop().unwrap() diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 3e7cde89..db618034 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -1,4 +1,4 @@ -use std::cmp::{max, min}; +use std::cmp::min; use thiserror::Error; use tokio::sync::{mpsc, oneshot}; @@ -16,8 +16,9 @@ impl BlockAllocation { self.slots.len() } - pub(crate) async fn extend(&mut self, current_length: u32) -> Result<(), AllocationError> { - let remaining_tokens = max(self.prompt_tokens + self.decode_tokens - current_length, 1); + pub(crate) async fn extend(&mut self, cache_length: u32) -> Result<(), AllocationError> { + let remaining_tokens = + (self.prompt_tokens + self.decode_tokens).saturating_sub(cache_length); self.block_allocator .clone() .extend(self, remaining_tokens) @@ -131,6 +132,7 @@ async fn block_allocator_task( let decode_tokens = min(decode_tokens, block_size); let tokens = prompt_tokens + decode_tokens; + // FIXME: window size is not working // Apply window size let (required_blocks, repeats) = { let (tokens, repeats) = match window_size { diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 9a7b1084..cbe1fbd0 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -5,7 +5,7 @@ use crate::validation::{ ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::{max, min}; +use std::cmp::max; use std::collections::VecDeque; use text_generation_client::v3::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, @@ -33,8 +33,8 @@ pub(crate) struct Entry { pub batch_time: Option, /// Block Allocation pub block_allocation: Option, - /// Current length (in tokens) of the request (prompt tokens + generated_tokens) - pub current_length: u32, + /// Cache length (in tokens) of the request (prompt tokens + generated_tokens) + pub cache_length: u32, } /// Request Queue @@ -164,9 +164,6 @@ struct State { /// Paged Attention block size block_size: u32, - /// Sliding window - window_size: Option, - /// Speculation amount speculate: u32, @@ -190,7 +187,6 @@ impl State { next_id: 0, next_batch_id: 0, block_size, - window_size, speculate, block_allocator, } @@ -276,18 +272,7 @@ impl State { } Some(block_allocator) => { prefill_tokens += entry.request.input_length; - let max_new_tokens = match self.window_size { - None => entry.request.stopping_parameters.max_new_tokens, - Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length), - entry.request.stopping_parameters.max_new_tokens, - ), - }; - decode_tokens += max_new_tokens; - - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { + if prefill_tokens > prefill_token_budget { // Entry is over budget // Add it back to the front tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); @@ -296,7 +281,7 @@ impl State { } let decode_tokens = - entry.request.stopping_parameters.max_new_tokens + self.speculate - 1; + entry.request.stopping_parameters.max_new_tokens + self.speculate; match block_allocator .allocate(entry.request.input_length, decode_tokens) .await @@ -500,7 +485,7 @@ mod tests { queue_time: Instant::now(), batch_time: None, block_allocation: None, - current_length: 0, + cache_length: 0, }; (entry, receiver_tx) } diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index b76c5c50..fa1a9ac7 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -10,7 +10,7 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient, UpdatedRequest}; +use text_generation_client::v3::{Batch, CachedBatch, Generation, KeptRequest, ShardedClient}; use text_generation_client::ClientError; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; @@ -88,7 +88,7 @@ impl Scheduler for SchedulerV3 { queue_time: Instant::now(), batch_time: None, block_allocation: None, - current_length: input_length, + cache_length: 0, }); // Notify the background task that we have a new entry in the queue that needs @@ -350,7 +350,7 @@ async fn filter_batch( .map(|alloc| (alloc.blocks.clone(), alloc.slots.clone())) .unwrap_or((Vec::new(), Vec::new())); - UpdatedRequest { + KeptRequest { id: *request_id, blocks, slots, @@ -359,7 +359,10 @@ async fn filter_batch( .collect(); // We unwrap here as we need to panic since we cannot recover if this method fails - client.filter_batch(id, updated_requests).await.unwrap() + client + .filter_batch(id, updated_requests, Vec::new()) + .await + .unwrap() } } @@ -374,7 +377,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap) -> bool { .block_allocation .as_ref() .map(|block_allocation| { - if entry.current_length > block_allocation.len() as u32 { + if entry.cache_length > block_allocation.len() as u32 { // We need to re-allocate Some(*id) } else { @@ -424,8 +427,8 @@ async fn filter_update_allocations(entries: &mut IntMap) -> bool { entry .block_allocation .as_mut() - .unwrap() - .extend(entry.current_length) + .expect("We checked that the block allocation exists above") + .extend(entry.cache_length) .await }; @@ -563,6 +566,7 @@ impl From for GeneratedText { let v3_finish_reason = text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap(); let finish_reason = match v3_finish_reason { + text_generation_client::v3::FinishReason::Terminated => FinishReason::OutOfResources, text_generation_client::v3::FinishReason::Length => FinishReason::Length, text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken, text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence, diff --git a/router/src/lib.rs b/router/src/lib.rs index b6902c49..2f115aba 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1085,6 +1085,8 @@ pub(crate) enum FinishReason { EndOfSequenceToken, #[schema(rename = "stop_sequence")] StopSequence, + #[schema(rename = "out_of_resources")] + OutOfResources, } impl std::fmt::Display for FinishReason { @@ -1093,6 +1095,7 @@ impl std::fmt::Display for FinishReason { FinishReason::Length => write!(f, "length"), FinishReason::EndOfSequenceToken => write!(f, "eos_token"), FinishReason::StopSequence => write!(f, "stop_sequence"), + FinishReason::OutOfResources => write!(f, "out_of_resources"), } } } diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 50a25a50..40a4f100 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -159,7 +159,7 @@ class CausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.UpdatedRequest] + self, updated_requests: List[generate_pb2.KeptRequest] ) -> Optional["CausalLMBatch"]: request_ids = [r.id for r in updated_requests] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index cd7c1f0f..0bd9357f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -398,11 +398,37 @@ class FlashCausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.UpdatedRequest] - ) -> Optional["FlashCausalLMBatch"]: - if len(updated_requests) == 0: + self, + model: "FlashCausalLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.GeneratedText]]: + if len(kept_requests) == 0: raise ValueError("Batch must have at least one request") + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_input_ids = self.all_input_ids[idx] + stopping_criteria = self.stopping_criterias[idx] + do_sample = self.next_token_chooser.do_sample[idx] + seed = self.next_token_chooser.seeds[idx] + + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + generate_pb2.FINISH_REASON_TERMINATED, + seed if do_sample else None, + ) + terminated_generations.append(generated_text) + device = self.input_ids.device # New values after filtering @@ -429,7 +455,7 @@ class FlashCausalLMBatch(Batch): num_blocks = 0 max_blocks = 0 - for i, request in enumerate(updated_requests): + for i, request in enumerate(kept_requests): request_id = request.id idx = self.requests_idx_mapping[request_id] @@ -491,7 +517,7 @@ class FlashCausalLMBatch(Batch): # Move to GPU block_tables_tensor = block_tables_tensor.to(device) - return type(self)( + filtered_batch = type(self)( batch_id=self.batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, @@ -520,6 +546,7 @@ class FlashCausalLMBatch(Batch): max_blocks=max_blocks, speculative_ids=speculative_ids, ) + return filtered_batch, terminated_generations @classmethod @tracer.start_as_current_span("concatenate") diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index fd70ae5d..495a47e5 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -215,7 +215,7 @@ class IdeficsCausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.UpdatedRequest] + self, updated_requests: List[generate_pb2.KeptRequest] ) -> Optional["IdeficsCausalLMBatch"]: request_ids = [r.id for r in updated_requests] diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index c8066aec..0340ca55 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -196,7 +196,7 @@ class MambaBatch(Batch): ) def filter( - self, updated_requests: List[generate_pb2.UpdatedRequest] + self, updated_requests: List[generate_pb2.KeptRequest] ) -> Optional["MambaBatch"]: request_ids = [r.id for r in updated_requests] diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 1e4f7c2e..77407118 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -167,7 +167,7 @@ class Seq2SeqLMBatch(Batch): @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.UpdatedRequest] + self, updated_requests: List[generate_pb2.KeptRequest] ) -> Optional["Seq2SeqLMBatch"]: request_ids = [r.id for r in updated_requests] diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 50c14862..c19f804e 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -3,7 +3,7 @@ import torch from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple from transformers import PreTrainedTokenizerBase @@ -28,7 +28,12 @@ class Batch(ABC): raise NotImplementedError @abstractmethod - def filter(self, updated_requests: List[generate_pb2.UpdatedRequest]) -> "Batch": + def filter( + self, + model, + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple["Batch", List[generate_pb2.GeneratedText]]: raise NotImplementedError @classmethod @@ -84,7 +89,7 @@ class Generation: generated_text: Optional[GeneratedText] # Optional for now, since it's not yet supported for every model. top_tokens: Optional[List[Tokens]] - current_length: int + cache_length: int def to_pb(self) -> generate_pb2.Generation: return generate_pb2.Generation( @@ -101,5 +106,5 @@ class Generation: if self.top_tokens is not None else None ), - current_length=self.current_length, + cache_length=self.cache_length, ) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index da7de2d3..e3d0bee8 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -123,7 +123,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.UpdatedRequest] + self, updated_requests: List[generate_pb2.KeptRequest] ) -> Optional["VlmCausalLMBatch"]: batch = super().filter(updated_requests) batch.pixel_values = None diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index a66c19a0..86df66e7 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -83,10 +83,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = self.cache.pop(request.batch_id) if batch is None: raise ValueError(f"Batch ID {request.batch_id} not found in cache.") - filtered_batch = batch.filter(request.updated_requests) + filtered_batch, terminated_generations = batch.filter( + self.model, request.kept_requests, request.terminated_request_ids + ) self.cache.set(filtered_batch) - return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) + return generate_pb2.FilterBatchResponse( + batch=filtered_batch.to_pb(), terminated_generations=terminated_generations + ) async def Warmup(self, request, context): if self.quantize in {"exl2", "gptq"}: