Fixing the issue with `add_special_tokens` not being passed around.
This commit is contained in:
parent
e0069a3a26
commit
2cf1f5c00e
|
@ -153,6 +153,8 @@ impl Client {
|
|||
}),
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
truncate,
|
||||
// Most request will have that
|
||||
add_special_tokens: true,
|
||||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
|
|
|
@ -221,6 +221,7 @@ impl Health for ShardedClient {
|
|||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||
}),
|
||||
truncate: 10,
|
||||
add_special_tokens: true,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
|
|
|
@ -149,6 +149,7 @@ impl Client {
|
|||
requests.push(Request {
|
||||
id: 0,
|
||||
inputs,
|
||||
add_special_tokens: true,
|
||||
input_chunks: Some(Input {
|
||||
chunks: input_chunks,
|
||||
}),
|
||||
|
|
|
@ -222,6 +222,7 @@ impl Health for ShardedClient {
|
|||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||
}),
|
||||
truncate: 10,
|
||||
add_special_tokens: true,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
|
|
|
@ -387,6 +387,7 @@ impl State {
|
|||
}),
|
||||
inputs: entry.request.inputs.chunks_to_string(),
|
||||
truncate: entry.request.truncate,
|
||||
add_special_tokens: entry.request.add_special_tokens,
|
||||
parameters: Some(NextTokenChooserParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
)),
|
||||
|
|
|
@ -148,6 +148,7 @@ async fn prefill(
|
|||
}),
|
||||
inputs: sequence.clone(),
|
||||
truncate: sequence_length,
|
||||
add_special_tokens: true,
|
||||
parameters: Some(parameters.clone()),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: decode_length,
|
||||
|
|
|
@ -137,6 +137,8 @@ message Request {
|
|||
optional string adapter_id = 11;
|
||||
/// Prefix length that can be retrieved from the KV cache.
|
||||
uint32 prefix_len = 12;
|
||||
/// Context truncation
|
||||
bool add_special_tokens = 13;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
|
|
@ -415,6 +415,7 @@ impl Validation {
|
|||
Ok(ValidGenerateRequest {
|
||||
inputs,
|
||||
input_ids: input_ids.map(Arc::new),
|
||||
add_special_tokens: request.add_special_tokens,
|
||||
decoder_input_details,
|
||||
input_length: input_length as u32,
|
||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||
|
@ -738,6 +739,7 @@ pub struct ValidGenerateRequest {
|
|||
pub input_ids: Option<Arc<Vec<u32>>>,
|
||||
pub input_length: u32,
|
||||
pub truncate: u32,
|
||||
pub add_special_tokens: bool,
|
||||
pub decoder_input_details: bool,
|
||||
pub parameters: ValidParameters,
|
||||
pub stopping_parameters: ValidStoppingParameters,
|
||||
|
|
|
@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
|
|||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
|
@ -104,7 +105,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
|
@ -135,12 +136,10 @@ class Qwen2Attention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
|
@ -153,7 +152,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
|
@ -225,7 +224,7 @@ class Qwen2Layer(nn.Module):
|
|||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
|
@ -240,7 +239,7 @@ class Qwen2Layer(nn.Module):
|
|||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
@ -296,7 +295,7 @@ class Qwen2Model(torch.nn.Module):
|
|||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
|
@ -320,7 +319,7 @@ class Qwen2Model(torch.nn.Module):
|
|||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
@ -361,7 +360,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
|
@ -374,7 +373,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
@ -383,7 +382,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
|
|
|
@ -189,15 +189,22 @@ class FlashCausalLMBatch(Batch):
|
|||
cls, requests: Iterable[generate_pb2.Request], tokenizer
|
||||
):
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
max_length = 0
|
||||
all_input_ids = []
|
||||
batch_size = 0
|
||||
for r in requests:
|
||||
batch_size += 1
|
||||
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
return batch_tokenized_inputs
|
||||
input_ids = tokenizer(
|
||||
batch_inputs,
|
||||
truncation=True,
|
||||
max_length=r.truncate,
|
||||
add_special_tokens=r.add_special_tokens,
|
||||
)["input_ids"][0]
|
||||
max_length = max(max_length, len(input_ids))
|
||||
all_input_ids.append(input_ids)
|
||||
return all_input_ids
|
||||
|
||||
@classmethod
|
||||
def from_tokenized(
|
||||
|
@ -256,20 +263,17 @@ class FlashCausalLMBatch(Batch):
|
|||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
if (
|
||||
tokenized_input[0] == tokenizer.bos_token_id
|
||||
and tokenized_input[1] == tokenizer.bos_token_id
|
||||
):
|
||||
tokenized_input = tokenized_input[1:]
|
||||
# tokenized_input = tokenized_input[-r.truncate :]
|
||||
# if (
|
||||
# tokenized_input[0] == tokenizer.bos_token_id
|
||||
# and tokenized_input[1] == tokenizer.bos_token_id
|
||||
# ):
|
||||
# tokenized_input = tokenized_input[1:]
|
||||
|
||||
orig_input_length = len(tokenized_input)
|
||||
|
||||
prefix_len = r.prefix_len
|
||||
assert prefix_len <= orig_input_length
|
||||
if prefix_len == orig_input_length:
|
||||
assert prefix_len > 0
|
||||
prefix_len -= 1
|
||||
|
||||
prefix_ids.append(tokenized_input[:prefix_len])
|
||||
tokenized_input = tokenized_input[prefix_len:]
|
||||
|
|
Loading…
Reference in New Issue