Use the generation config. (#1808)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Nicolas Patry 2024-04-25 19:41:50 +02:00 committed by GitHub
parent eb08b9faef
commit ee47973a2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 74 additions and 75 deletions

View File

@ -589,7 +589,9 @@ pub(crate) struct ChatCompletionChoice {
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")]
pub role: String,
// TODO Modify this to a true enum.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")]
pub content: Option<String>,
@ -623,6 +625,31 @@ impl ChatCompletionChunk {
logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>,
) -> Self {
let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: Some(delta),
tool_calls: None,
},
(None, Some(tool_calls)) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tool_calls[0].to_string(),
},
}),
},
(None, None) => ChatCompletionDelta {
role: None,
content: None,
tool_calls: None,
},
};
Self {
id: String::new(),
object: "text_completion".to_string(),
@ -631,19 +658,7 @@ impl ChatCompletionChunk {
system_fingerprint,
choices: vec![ChatCompletionChoice {
index: 0,
delta: ChatCompletionDelta {
role: "assistant".to_string(),
content: delta,
tool_calls: tool_calls.map(|tc| DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tc[0].to_string(),
},
}),
},
delta,
logprobs,
finish_reason,
}],

View File

@ -1103,7 +1103,13 @@ async fn chat_completions(
let (content, tool_calls) = if tool_grammar.is_some() {
(None, Some(vec![stream_token.token.text]))
} else {
(Some(stream_token.token.text), None)
let content = if !stream_token.token.special {
Some(stream_token.token.text)
} else {
None
};
(content, None)
};
event

View File

@ -38,58 +38,6 @@ from text_generation_server.utils.layers import (
)
class LlamaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)

View File

@ -2,14 +2,13 @@ import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from transformers.models.llama import LlamaTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
LlamaConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
@ -53,8 +52,17 @@ class FlashLlama(FlashCausalLM):
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = LlamaConfig.from_pretrained(
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize

View File

@ -27,7 +27,14 @@ class Model(ABC):
):
self.model = model.eval()
self.tokenizer = tokenizer
# all_special_ids is not set correctly if the rust tokenizer is unpacked
# TODO report this to transformers.
other_special_ids = {
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
}
self.all_special_ids = set(tokenizer.all_special_ids)
self.all_special_ids.update(other_special_ids)
self.requires_padding = requires_padding
self.dtype = dtype
self.device = device

View File

@ -1,5 +1,5 @@
import re
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Set, Union
import math
import torch
@ -143,12 +143,22 @@ class StopSequenceCriteria:
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
eos_token_ids: Optional[Union[Set[int], int]],
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
):
self.eos_token_id = eos_token_id
if eos_token_ids is None:
eos_token_ids = set()
elif isinstance(eos_token_ids, int):
eos_token_ids = set([eos_token_ids])
elif isinstance(eos_token_ids, set):
eos_token_ids = eos_token_ids
else:
raise RuntimeError(
f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
)
self.eos_token_ids = eos_token_ids
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
@ -160,7 +170,10 @@ class StoppingCriteria:
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if not self.ignore_eos_token and last_token == self.eos_token_id:
if isinstance(last_token, torch.Tensor):
last_token = last_token.item()
if not self.ignore_eos_token and last_token in self.eos_token_ids:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
if self.stop_sequence_criterias:
@ -184,8 +197,10 @@ class StoppingCriteria:
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
# TODO Hack because eos_token_id cannot be what we want.
eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
return StoppingCriteria(
tokenizer.eos_token_id,
eos_token_id,
stop_sequence_criterias,
pb.max_new_tokens,
pb.ignore_eos_token,