Fixing the issue with `add_special_tokens` not being passed around.

This commit is contained in:
Nicolas Patry 2024-08-27 20:02:35 +02:00
parent e0069a3a26
commit 2cf1f5c00e
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
10 changed files with 42 additions and 28 deletions

View File

@ -153,6 +153,8 @@ impl Client {
}),
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Most request will have that
add_special_tokens: true,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],

View File

@ -221,6 +221,7 @@ impl Health for ShardedClient {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
truncate: 10,
add_special_tokens: true,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,

View File

@ -149,6 +149,7 @@ impl Client {
requests.push(Request {
id: 0,
inputs,
add_special_tokens: true,
input_chunks: Some(Input {
chunks: input_chunks,
}),

View File

@ -222,6 +222,7 @@ impl Health for ShardedClient {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
truncate: 10,
add_special_tokens: true,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,

View File

@ -387,6 +387,7 @@ impl State {
}),
inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate,
add_special_tokens: entry.request.add_special_tokens,
parameters: Some(NextTokenChooserParameters::from(
entry.request.parameters.clone(),
)),

View File

@ -148,6 +148,7 @@ async fn prefill(
}),
inputs: sequence.clone(),
truncate: sequence_length,
add_special_tokens: true,
parameters: Some(parameters.clone()),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: decode_length,

View File

@ -137,6 +137,8 @@ message Request {
optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12;
/// Context truncation
bool add_special_tokens = 13;
}
message Batch {

View File

@ -415,6 +415,7 @@ impl Validation {
Ok(ValidGenerateRequest {
inputs,
input_ids: input_ids.map(Arc::new),
add_special_tokens: request.add_special_tokens,
decoder_input_details,
input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32,
@ -738,6 +739,7 @@ pub struct ValidGenerateRequest {
pub input_ids: Option<Arc<Vec<u32>>>,
pub input_length: u32,
pub truncate: u32,
pub add_special_tokens: bool,
pub decoder_input_details: bool,
pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters,

View File

@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -104,7 +105,7 @@ class Qwen2Attention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
@ -135,12 +136,10 @@ class Qwen2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
@ -153,7 +152,7 @@ class Qwen2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -225,7 +224,7 @@ class Qwen2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
@ -240,7 +239,7 @@ class Qwen2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
@ -296,7 +295,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
@ -320,7 +319,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
@ -361,7 +360,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -374,7 +373,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
@ -383,7 +382,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,

View File

@ -189,15 +189,22 @@ class FlashCausalLMBatch(Batch):
cls, requests: Iterable[generate_pb2.Request], tokenizer
):
batch_inputs = []
max_truncation = 0
max_length = 0
all_input_ids = []
batch_size = 0
for r in requests:
batch_size += 1
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
return batch_tokenized_inputs
input_ids = tokenizer(
batch_inputs,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"][0]
max_length = max(max_length, len(input_ids))
all_input_ids.append(input_ids)
return all_input_ids
@classmethod
def from_tokenized(
@ -256,20 +263,17 @@ class FlashCausalLMBatch(Batch):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :]
if (
tokenized_input[0] == tokenizer.bos_token_id
and tokenized_input[1] == tokenizer.bos_token_id
):
tokenized_input = tokenized_input[1:]
# tokenized_input = tokenized_input[-r.truncate :]
# if (
# tokenized_input[0] == tokenizer.bos_token_id
# and tokenized_input[1] == tokenizer.bos_token_id
# ):
# tokenized_input = tokenized_input[1:]
orig_input_length = len(tokenized_input)
prefix_len = r.prefix_len
assert prefix_len <= orig_input_length
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1
prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:]