Fixing the issue with `add_special_tokens` not being passed around.
This commit is contained in:
parent
e0069a3a26
commit
2cf1f5c00e
|
@ -153,6 +153,8 @@ impl Client {
|
||||||
}),
|
}),
|
||||||
// We truncate the input on the server side to be sure that it has the correct size
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
truncate,
|
truncate,
|
||||||
|
// Most request will have that
|
||||||
|
add_special_tokens: true,
|
||||||
// Blocks and slots will be set on the server side if we use paged attention
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
|
|
|
@ -221,6 +221,7 @@ impl Health for ShardedClient {
|
||||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||||
}),
|
}),
|
||||||
truncate: 10,
|
truncate: 10,
|
||||||
|
add_special_tokens: true,
|
||||||
prefill_logprobs: false,
|
prefill_logprobs: false,
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
|
|
|
@ -149,6 +149,7 @@ impl Client {
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
inputs,
|
inputs,
|
||||||
|
add_special_tokens: true,
|
||||||
input_chunks: Some(Input {
|
input_chunks: Some(Input {
|
||||||
chunks: input_chunks,
|
chunks: input_chunks,
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -222,6 +222,7 @@ impl Health for ShardedClient {
|
||||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||||
}),
|
}),
|
||||||
truncate: 10,
|
truncate: 10,
|
||||||
|
add_special_tokens: true,
|
||||||
prefill_logprobs: false,
|
prefill_logprobs: false,
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
|
|
|
@ -387,6 +387,7 @@ impl State {
|
||||||
}),
|
}),
|
||||||
inputs: entry.request.inputs.chunks_to_string(),
|
inputs: entry.request.inputs.chunks_to_string(),
|
||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
||||||
|
add_special_tokens: entry.request.add_special_tokens,
|
||||||
parameters: Some(NextTokenChooserParameters::from(
|
parameters: Some(NextTokenChooserParameters::from(
|
||||||
entry.request.parameters.clone(),
|
entry.request.parameters.clone(),
|
||||||
)),
|
)),
|
||||||
|
|
|
@ -148,6 +148,7 @@ async fn prefill(
|
||||||
}),
|
}),
|
||||||
inputs: sequence.clone(),
|
inputs: sequence.clone(),
|
||||||
truncate: sequence_length,
|
truncate: sequence_length,
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: Some(parameters.clone()),
|
parameters: Some(parameters.clone()),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: decode_length,
|
max_new_tokens: decode_length,
|
||||||
|
|
|
@ -137,6 +137,8 @@ message Request {
|
||||||
optional string adapter_id = 11;
|
optional string adapter_id = 11;
|
||||||
/// Prefix length that can be retrieved from the KV cache.
|
/// Prefix length that can be retrieved from the KV cache.
|
||||||
uint32 prefix_len = 12;
|
uint32 prefix_len = 12;
|
||||||
|
/// Context truncation
|
||||||
|
bool add_special_tokens = 13;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
|
|
@ -415,6 +415,7 @@ impl Validation {
|
||||||
Ok(ValidGenerateRequest {
|
Ok(ValidGenerateRequest {
|
||||||
inputs,
|
inputs,
|
||||||
input_ids: input_ids.map(Arc::new),
|
input_ids: input_ids.map(Arc::new),
|
||||||
|
add_special_tokens: request.add_special_tokens,
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
input_length: input_length as u32,
|
input_length: input_length as u32,
|
||||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||||
|
@ -738,6 +739,7 @@ pub struct ValidGenerateRequest {
|
||||||
pub input_ids: Option<Arc<Vec<u32>>>,
|
pub input_ids: Option<Arc<Vec<u32>>>,
|
||||||
pub input_length: u32,
|
pub input_length: u32,
|
||||||
pub truncate: u32,
|
pub truncate: u32,
|
||||||
|
pub add_special_tokens: bool,
|
||||||
pub decoder_input_details: bool,
|
pub decoder_input_details: bool,
|
||||||
pub parameters: ValidParameters,
|
pub parameters: ValidParameters,
|
||||||
pub stopping_parameters: ValidStoppingParameters,
|
pub stopping_parameters: ValidStoppingParameters,
|
||||||
|
|
|
@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -104,7 +105,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
|
@ -135,12 +136,10 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
|
||||||
torch.select(kv, dim=1, index=1),
|
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
|
@ -153,7 +152,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -225,7 +224,7 @@ class Qwen2Layer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
|
@ -240,7 +239,7 @@ class Qwen2Layer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
@ -296,7 +295,7 @@ class Qwen2Model(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
@ -320,7 +319,7 @@ class Qwen2Model(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
@ -361,7 +360,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -374,7 +373,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -383,7 +382,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
true_max_s,
|
true_max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
|
|
@ -189,15 +189,22 @@ class FlashCausalLMBatch(Batch):
|
||||||
cls, requests: Iterable[generate_pb2.Request], tokenizer
|
cls, requests: Iterable[generate_pb2.Request], tokenizer
|
||||||
):
|
):
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
max_truncation = 0
|
max_length = 0
|
||||||
|
all_input_ids = []
|
||||||
|
batch_size = 0
|
||||||
for r in requests:
|
for r in requests:
|
||||||
|
batch_size += 1
|
||||||
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
input_ids = tokenizer(
|
||||||
batch_inputs, truncation=True, max_length=max_truncation
|
batch_inputs,
|
||||||
)["input_ids"]
|
truncation=True,
|
||||||
return batch_tokenized_inputs
|
max_length=r.truncate,
|
||||||
|
add_special_tokens=r.add_special_tokens,
|
||||||
|
)["input_ids"][0]
|
||||||
|
max_length = max(max_length, len(input_ids))
|
||||||
|
all_input_ids.append(input_ids)
|
||||||
|
return all_input_ids
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tokenized(
|
def from_tokenized(
|
||||||
|
@ -256,20 +263,17 @@ class FlashCausalLMBatch(Batch):
|
||||||
# request id -> idx in list mapping
|
# request id -> idx in list mapping
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
|
|
||||||
tokenized_input = tokenized_input[-r.truncate :]
|
# tokenized_input = tokenized_input[-r.truncate :]
|
||||||
if (
|
# if (
|
||||||
tokenized_input[0] == tokenizer.bos_token_id
|
# tokenized_input[0] == tokenizer.bos_token_id
|
||||||
and tokenized_input[1] == tokenizer.bos_token_id
|
# and tokenized_input[1] == tokenizer.bos_token_id
|
||||||
):
|
# ):
|
||||||
tokenized_input = tokenized_input[1:]
|
# tokenized_input = tokenized_input[1:]
|
||||||
|
|
||||||
orig_input_length = len(tokenized_input)
|
orig_input_length = len(tokenized_input)
|
||||||
|
|
||||||
prefix_len = r.prefix_len
|
prefix_len = r.prefix_len
|
||||||
assert prefix_len <= orig_input_length
|
assert prefix_len <= orig_input_length
|
||||||
if prefix_len == orig_input_length:
|
|
||||||
assert prefix_len > 0
|
|
||||||
prefix_len -= 1
|
|
||||||
|
|
||||||
prefix_ids.append(tokenized_input[:prefix_len])
|
prefix_ids.append(tokenized_input[:prefix_len])
|
||||||
tokenized_input = tokenized_input[prefix_len:]
|
tokenized_input = tokenized_input[prefix_len:]
|
||||||
|
|
Loading…
Reference in New Issue