From ee47973a2f2d2152a6c32b69a93573040271c10b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 25 Apr 2024 19:41:50 +0200 Subject: [PATCH] Use the generation config. (#1808) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- router/src/lib.rs | 43 ++++++++++----- router/src/server.rs | 8 ++- .../custom_modeling/flash_llama_modeling.py | 52 ------------------- .../models/flash_llama.py | 14 +++-- server/text_generation_server/models/model.py | 7 +++ server/text_generation_server/utils/tokens.py | 25 +++++++-- 6 files changed, 74 insertions(+), 75 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index ecd8e2e0..9b9097f6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -589,7 +589,9 @@ pub(crate) struct ChatCompletionChoice { #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletionDelta { #[schema(example = "user")] - pub role: String, + // TODO Modify this to a true enum. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub role: Option, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "What is Deep Learning?")] pub content: Option, @@ -623,6 +625,31 @@ impl ChatCompletionChunk { logprobs: Option, finish_reason: Option, ) -> Self { + let delta = match (delta, tool_calls) { + (Some(delta), _) => ChatCompletionDelta { + role: Some("assistant".to_string()), + content: Some(delta), + tool_calls: None, + }, + (None, Some(tool_calls)) => ChatCompletionDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: Some(DeltaToolCall { + index: 0, + id: String::new(), + r#type: "function".to_string(), + function: Function { + name: None, + arguments: tool_calls[0].to_string(), + }, + }), + }, + (None, None) => ChatCompletionDelta { + role: None, + content: None, + tool_calls: None, + }, + }; Self { id: String::new(), object: "text_completion".to_string(), @@ -631,19 +658,7 @@ impl ChatCompletionChunk { system_fingerprint, choices: vec![ChatCompletionChoice { index: 0, - delta: ChatCompletionDelta { - role: "assistant".to_string(), - content: delta, - tool_calls: tool_calls.map(|tc| DeltaToolCall { - index: 0, - id: String::new(), - r#type: "function".to_string(), - function: Function { - name: None, - arguments: tc[0].to_string(), - }, - }), - }, + delta, logprobs, finish_reason, }], diff --git a/router/src/server.rs b/router/src/server.rs index 03d184c3..8657b779 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1103,7 +1103,13 @@ async fn chat_completions( let (content, tool_calls) = if tool_grammar.is_some() { (None, Some(vec![stream_token.token.text])) } else { - (Some(stream_token.token.text), None) + let content = if !stream_token.token.special { + Some(stream_token.token.text) + } else { + None + }; + + (content, None) }; event diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6d796ac3..6fa85d4e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -38,58 +38,6 @@ from text_generation_server.utils.layers import ( ) -class LlamaConfig(PretrainedConfig): - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_scaling=None, - rope_theta=10000.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_scaling = rope_scaling - self.rope_theta = rope_theta - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 56768942..f3578f88 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -2,14 +2,13 @@ import torch import torch.distributed from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoTokenizer, GenerationConfig from transformers.models.llama import LlamaTokenizer from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, - LlamaConfig, ) from text_generation_server.utils import ( initialize_torch_distributed, @@ -53,8 +52,17 @@ class FlashLlama(FlashCausalLM): truncation_side="left", trust_remote_code=trust_remote_code, ) + try: + generation_config = GenerationConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + if isinstance(generation_config.eos_token_id, (list, set)): + # TODO Huge hack + tokenizer._eos_token_ids = set(generation_config.eos_token_id) + except Exception: + pass - config = LlamaConfig.from_pretrained( + config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index cec9eafa..4f35b0aa 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -27,7 +27,14 @@ class Model(ABC): ): self.model = model.eval() self.tokenizer = tokenizer + + # all_special_ids is not set correctly if the rust tokenizer is unpacked + # TODO report this to transformers. + other_special_ids = { + id for id, token in tokenizer.added_tokens_decoder.items() if token.special + } self.all_special_ids = set(tokenizer.all_special_ids) + self.all_special_ids.update(other_special_ids) self.requires_padding = requires_padding self.dtype = dtype self.device = device diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 8ef1ca0d..22f86b60 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,5 +1,5 @@ import re -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Set, Union import math import torch @@ -143,12 +143,22 @@ class StopSequenceCriteria: class StoppingCriteria: def __init__( self, - eos_token_id: int, + eos_token_ids: Optional[Union[Set[int], int]], stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens: int = 20, ignore_eos_token: bool = False, ): - self.eos_token_id = eos_token_id + if eos_token_ids is None: + eos_token_ids = set() + elif isinstance(eos_token_ids, int): + eos_token_ids = set([eos_token_ids]) + elif isinstance(eos_token_ids, set): + eos_token_ids = eos_token_ids + else: + raise RuntimeError( + f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]" + ) + self.eos_token_ids = eos_token_ids self.stop_sequence_criterias = stop_sequence_criterias self.max_new_tokens = max_new_tokens self.current_tokens = 0 @@ -160,7 +170,10 @@ class StoppingCriteria: if self.current_tokens >= self.max_new_tokens: return True, FinishReason.FINISH_REASON_LENGTH - if not self.ignore_eos_token and last_token == self.eos_token_id: + if isinstance(last_token, torch.Tensor): + last_token = last_token.item() + + if not self.ignore_eos_token and last_token in self.eos_token_ids: return True, FinishReason.FINISH_REASON_EOS_TOKEN if self.stop_sequence_criterias: @@ -184,8 +197,10 @@ class StoppingCriteria: stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences ] + # TODO Hack because eos_token_id cannot be what we want. + eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id) return StoppingCriteria( - tokenizer.eos_token_id, + eos_token_id, stop_sequence_criterias, pb.max_new_tokens, pb.ignore_eos_token,