Fixing the issue with `add_special_tokens` not being passed around.

This commit is contained in:
Nicolas Patry 2024-08-27 20:02:35 +02:00
parent e0069a3a26
commit 2cf1f5c00e
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
10 changed files with 42 additions and 28 deletions

View File

@ -153,6 +153,8 @@ impl Client {
}), }),
// We truncate the input on the server side to be sure that it has the correct size // We truncate the input on the server side to be sure that it has the correct size
truncate, truncate,
// Most request will have that
add_special_tokens: true,
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],

View File

@ -221,6 +221,7 @@ impl Health for ShardedClient {
chunks: vec![Chunk::Text("liveness".into()).into()], chunks: vec![Chunk::Text("liveness".into()).into()],
}), }),
truncate: 10, truncate: 10,
add_special_tokens: true,
prefill_logprobs: false, prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 1.0, temperature: 1.0,

View File

@ -149,6 +149,7 @@ impl Client {
requests.push(Request { requests.push(Request {
id: 0, id: 0,
inputs, inputs,
add_special_tokens: true,
input_chunks: Some(Input { input_chunks: Some(Input {
chunks: input_chunks, chunks: input_chunks,
}), }),

View File

@ -222,6 +222,7 @@ impl Health for ShardedClient {
chunks: vec![Chunk::Text("liveness".into()).into()], chunks: vec![Chunk::Text("liveness".into()).into()],
}), }),
truncate: 10, truncate: 10,
add_special_tokens: true,
prefill_logprobs: false, prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 1.0, temperature: 1.0,

View File

@ -387,6 +387,7 @@ impl State {
}), }),
inputs: entry.request.inputs.chunks_to_string(), inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate, truncate: entry.request.truncate,
add_special_tokens: entry.request.add_special_tokens,
parameters: Some(NextTokenChooserParameters::from( parameters: Some(NextTokenChooserParameters::from(
entry.request.parameters.clone(), entry.request.parameters.clone(),
)), )),

View File

@ -148,6 +148,7 @@ async fn prefill(
}), }),
inputs: sequence.clone(), inputs: sequence.clone(),
truncate: sequence_length, truncate: sequence_length,
add_special_tokens: true,
parameters: Some(parameters.clone()), parameters: Some(parameters.clone()),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: decode_length, max_new_tokens: decode_length,

View File

@ -137,6 +137,8 @@ message Request {
optional string adapter_id = 11; optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache. /// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12; uint32 prefix_len = 12;
/// Context truncation
bool add_special_tokens = 13;
} }
message Batch { message Batch {

View File

@ -415,6 +415,7 @@ impl Validation {
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,
input_ids: input_ids.map(Arc::new), input_ids: input_ids.map(Arc::new),
add_special_tokens: request.add_special_tokens,
decoder_input_details, decoder_input_details,
input_length: input_length as u32, input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
@ -738,6 +739,7 @@ pub struct ValidGenerateRequest {
pub input_ids: Option<Arc<Vec<u32>>>, pub input_ids: Option<Arc<Vec<u32>>>,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub add_special_tokens: bool,
pub decoder_input_details: bool, pub decoder_input_details: bool,
pub parameters: ValidParameters, pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters, pub stopping_parameters: ValidStoppingParameters,

View File

@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -104,7 +105,7 @@ class Qwen2Attention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
@ -135,12 +136,10 @@ class Qwen2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
@ -153,7 +152,7 @@ class Qwen2Attention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -225,7 +224,7 @@ class Qwen2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
@ -240,7 +239,7 @@ class Qwen2Layer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
@ -296,7 +295,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
@ -320,7 +319,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
) )
@ -361,7 +360,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -374,7 +373,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
@ -383,7 +382,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,

View File

@ -189,15 +189,22 @@ class FlashCausalLMBatch(Batch):
cls, requests: Iterable[generate_pb2.Request], tokenizer cls, requests: Iterable[generate_pb2.Request], tokenizer
): ):
batch_inputs = [] batch_inputs = []
max_truncation = 0 max_length = 0
all_input_ids = []
batch_size = 0
for r in requests: for r in requests:
batch_size += 1
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer( input_ids = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation batch_inputs,
)["input_ids"] truncation=True,
return batch_tokenized_inputs max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"][0]
max_length = max(max_length, len(input_ids))
all_input_ids.append(input_ids)
return all_input_ids
@classmethod @classmethod
def from_tokenized( def from_tokenized(
@ -256,20 +263,17 @@ class FlashCausalLMBatch(Batch):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :] # tokenized_input = tokenized_input[-r.truncate :]
if ( # if (
tokenized_input[0] == tokenizer.bos_token_id # tokenized_input[0] == tokenizer.bos_token_id
and tokenized_input[1] == tokenizer.bos_token_id # and tokenized_input[1] == tokenizer.bos_token_id
): # ):
tokenized_input = tokenized_input[1:] # tokenized_input = tokenized_input[1:]
orig_input_length = len(tokenized_input) orig_input_length = len(tokenized_input)
prefix_len = r.prefix_len prefix_len = r.prefix_len
assert prefix_len <= orig_input_length assert prefix_len <= orig_input_length
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1
prefix_ids.append(tokenized_input[:prefix_len]) prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:] tokenized_input = tokenized_input[prefix_len:]