fix(server): fix escape characters in stop sequence (#155)
This commit is contained in:
parent
9122e7bd9c
commit
3f2542bb6a
|
@ -14,6 +14,15 @@ def test_stop_sequence_criteria():
|
||||||
assert not criteria("/test; ")
|
assert not criteria("/test; ")
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_sequence_criteria_escape():
|
||||||
|
criteria = StopSequenceCriteria("<|stop|>")
|
||||||
|
|
||||||
|
assert not criteria("<")
|
||||||
|
assert not criteria("<|stop")
|
||||||
|
assert criteria("<|stop|>")
|
||||||
|
assert not criteria("<|stop|> ")
|
||||||
|
|
||||||
|
|
||||||
def test_stopping_criteria():
|
def test_stopping_criteria():
|
||||||
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||||
assert criteria(65827, "/test") == (False, None)
|
assert criteria(65827, "/test") == (False, None)
|
||||||
|
|
|
@ -162,15 +162,17 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(
|
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
|
||||||
self, act, hidden_size, intermediate_size, process_group=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.act = (
|
self.act = (
|
||||||
ACT2FN[act]
|
ACT2FN[act]
|
||||||
if "gelu" not in act
|
if "gelu" not in act
|
||||||
else lambda x: torch.nn.functional.gelu(x, approximate="tanh" if act in ["gelu_fast",
|
else lambda x: torch.nn.functional.gelu(
|
||||||
"gelu_pytorch_tanh"] else None)
|
x,
|
||||||
|
approximate="tanh"
|
||||||
|
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else None,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if process_group is None:
|
if process_group is None:
|
||||||
|
@ -232,9 +234,7 @@ class Block(nn.Module):
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.ln_2(
|
hidden_states, residual = self.ln_2(hidden_states, residual)
|
||||||
hidden_states, residual
|
|
||||||
)
|
|
||||||
|
|
||||||
mlp_output = self.mlp(hidden_states)
|
mlp_output = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
@ -258,16 +258,16 @@ class FlashSantacoderModel(nn.Module):
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
config.activation_function,
|
config.activation_function,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.n_inner if config.n_inner is not None else 4 * config.hidden_size,
|
config.n_inner
|
||||||
|
if config.n_inner is not None
|
||||||
|
else 4 * config.hidden_size,
|
||||||
config.layer_norm_epsilon,
|
config.layer_norm_epsilon,
|
||||||
process_group,
|
process_group,
|
||||||
)
|
)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.ln_f = FastLayerNorm(
|
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
config.hidden_size, eps=config.layer_norm_epsilon
|
|
||||||
)
|
|
||||||
|
|
||||||
self.head_size = self.h[0].attn.head_size
|
self.head_size = self.h[0].attn.head_size
|
||||||
self.num_heads = self.h[0].attn.num_heads
|
self.num_heads = self.h[0].attn.num_heads
|
||||||
|
@ -335,9 +335,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
|
|
||||||
self.transformer = FlashSantacoderModel(config, process_group)
|
self.transformer = FlashSantacoderModel(config, process_group)
|
||||||
|
|
||||||
self.lm_head = FastLinear(
|
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
config.hidden_size, config.vocab_size, bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def post_load_weights(self):
|
def post_load_weights(self):
|
||||||
self.transformer.post_load_weights()
|
self.transformer.post_load_weights()
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Optional, List
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||||
FlashSantacoderForCausalLM
|
FlashSantacoderForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
weight_files,
|
weight_files,
|
||||||
|
@ -37,8 +37,9 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision,
|
model_id,
|
||||||
trust_remote_code=True # Needed as the config is not part of Transformers
|
revision=revision,
|
||||||
|
trust_remote_code=True, # Needed as the config is not part of Transformers
|
||||||
)
|
)
|
||||||
|
|
||||||
# We do not use from_pretrained as we modified the model internal module layout
|
# We do not use from_pretrained as we modified the model internal module layout
|
||||||
|
@ -91,7 +92,12 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
current_parameter_tensor = None
|
current_parameter_tensor = None
|
||||||
|
|
||||||
if current_parameter_tensor is not None:
|
if current_parameter_tensor is not None:
|
||||||
if "c_fc.weight" in key or "c_proj.weight" in key or "q_attn.weight" in key or "kv_attn.weight" in key:
|
if (
|
||||||
|
"c_fc.weight" in key
|
||||||
|
or "c_proj.weight" in key
|
||||||
|
or "q_attn.weight" in key
|
||||||
|
or "kv_attn.weight" in key
|
||||||
|
):
|
||||||
# Tranpose as we use nn.Linear instead of Conv1D
|
# Tranpose as we use nn.Linear instead of Conv1D
|
||||||
value = value.T
|
value = value.T
|
||||||
|
|
||||||
|
@ -99,11 +105,18 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
# Init qkv
|
# Init qkv
|
||||||
if "attn.weight" in final_key:
|
if "attn.weight" in final_key:
|
||||||
module._parameters[param_name] = value.new_empty(
|
module._parameters[param_name] = value.new_empty(
|
||||||
(model.transformer.head_size * (model.transformer.num_heads + 2), value.shape[1])
|
(
|
||||||
|
model.transformer.head_size
|
||||||
|
* (model.transformer.num_heads + 2),
|
||||||
|
value.shape[1],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
elif "attn.bias" in final_key:
|
elif "attn.bias" in final_key:
|
||||||
module._parameters[param_name] = value.new_empty(
|
module._parameters[param_name] = value.new_empty(
|
||||||
(model.transformer.head_size * (model.transformer.num_heads + 2))
|
(
|
||||||
|
model.transformer.head_size
|
||||||
|
* (model.transformer.num_heads + 2)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copy to correct slice
|
# Copy to correct slice
|
||||||
|
|
|
@ -110,6 +110,7 @@ class NextTokenChooser:
|
||||||
|
|
||||||
class StopSequenceCriteria:
|
class StopSequenceCriteria:
|
||||||
def __init__(self, stop_sequence: str):
|
def __init__(self, stop_sequence: str):
|
||||||
|
stop_sequence = re.escape(stop_sequence)
|
||||||
self.regex = re.compile(f".*{stop_sequence}$")
|
self.regex = re.compile(f".*{stop_sequence}$")
|
||||||
|
|
||||||
def __call__(self, output: str) -> bool:
|
def __call__(self, output: str) -> bool:
|
||||||
|
|
Loading…
Reference in New Issue