fix(server): fix escape characters in stop sequence (#155)
This commit is contained in:
parent
9122e7bd9c
commit
3f2542bb6a
|
@ -14,6 +14,15 @@ def test_stop_sequence_criteria():
|
|||
assert not criteria("/test; ")
|
||||
|
||||
|
||||
def test_stop_sequence_criteria_escape():
|
||||
criteria = StopSequenceCriteria("<|stop|>")
|
||||
|
||||
assert not criteria("<")
|
||||
assert not criteria("<|stop")
|
||||
assert criteria("<|stop|>")
|
||||
assert not criteria("<|stop|> ")
|
||||
|
||||
|
||||
def test_stopping_criteria():
|
||||
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||
assert criteria(65827, "/test") == (False, None)
|
||||
|
|
|
@ -47,12 +47,12 @@ class FastLayerNorm(nn.LayerNorm):
|
|||
|
||||
class FastLinear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
||||
|
||||
|
@ -67,10 +67,10 @@ class FastLinear(nn.Linear):
|
|||
|
||||
class FlashMQAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
hidden_size,
|
||||
process_group=None,
|
||||
self,
|
||||
num_heads,
|
||||
hidden_size,
|
||||
process_group=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
|
@ -86,13 +86,13 @@ class FlashMQAttention(torch.nn.Module):
|
|||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
self,
|
||||
hidden_states,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
qkv = self.attn(hidden_states)
|
||||
|
||||
|
@ -162,15 +162,17 @@ class FlashMQAttention(torch.nn.Module):
|
|||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self, act, hidden_size, intermediate_size, process_group=None
|
||||
):
|
||||
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
|
||||
super().__init__()
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(x, approximate="tanh" if act in ["gelu_fast",
|
||||
"gelu_pytorch_tanh"] else None)
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate="tanh"
|
||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
if process_group is None:
|
||||
|
@ -188,13 +190,13 @@ class MLP(nn.Module):
|
|||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
act,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
layer_norm_eps,
|
||||
process_group=None,
|
||||
self,
|
||||
num_heads,
|
||||
act,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
layer_norm_eps,
|
||||
process_group=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
@ -212,14 +214,14 @@ class Block(nn.Module):
|
|||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||
|
||||
|
@ -232,9 +234,7 @@ class Block(nn.Module):
|
|||
cu_seqlens_q,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.ln_2(
|
||||
hidden_states, residual
|
||||
)
|
||||
hidden_states, residual = self.ln_2(hidden_states, residual)
|
||||
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
|
||||
|
@ -258,16 +258,16 @@ class FlashSantacoderModel(nn.Module):
|
|||
config.num_attention_heads,
|
||||
config.activation_function,
|
||||
config.hidden_size,
|
||||
config.n_inner if config.n_inner is not None else 4 * config.hidden_size,
|
||||
config.n_inner
|
||||
if config.n_inner is not None
|
||||
else 4 * config.hidden_size,
|
||||
config.layer_norm_epsilon,
|
||||
process_group,
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.ln_f = FastLayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.head_size = self.h[0].attn.head_size
|
||||
self.num_heads = self.h[0].attn.num_heads
|
||||
|
@ -281,12 +281,12 @@ class FlashSantacoderModel(nn.Module):
|
|||
layer.mlp.c_proj.transpose_weight()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
):
|
||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||
|
||||
|
@ -335,21 +335,19 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
|
||||
self.transformer = FlashSantacoderModel(config, process_group)
|
||||
|
||||
self.lm_head = FastLinear(
|
||||
config.hidden_size, config.vocab_size, bias=False
|
||||
)
|
||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def post_load_weights(self):
|
||||
self.transformer.post_load_weights()
|
||||
self.lm_head.transpose_weight()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
):
|
||||
hidden_states, present = self.transformer(
|
||||
input_ids, position_ids, cu_seqlens, max_s, past_key_values
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Optional, List
|
|||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||
FlashSantacoderForCausalLM
|
||||
FlashSantacoderForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
weight_files,
|
||||
|
@ -37,8 +37,9 @@ class FlashSantacoder(FlashCausalLM):
|
|||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision,
|
||||
trust_remote_code=True # Needed as the config is not part of Transformers
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=True, # Needed as the config is not part of Transformers
|
||||
)
|
||||
|
||||
# We do not use from_pretrained as we modified the model internal module layout
|
||||
|
@ -65,8 +66,8 @@ class FlashSantacoder(FlashCausalLM):
|
|||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model: FlashSantacoderForCausalLM,
|
||||
filenames: List[Path],
|
||||
model: FlashSantacoderForCausalLM,
|
||||
filenames: List[Path],
|
||||
):
|
||||
for filename in filenames:
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
|
@ -91,7 +92,12 @@ class FlashSantacoder(FlashCausalLM):
|
|||
current_parameter_tensor = None
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
if "c_fc.weight" in key or "c_proj.weight" in key or "q_attn.weight" in key or "kv_attn.weight" in key:
|
||||
if (
|
||||
"c_fc.weight" in key
|
||||
or "c_proj.weight" in key
|
||||
or "q_attn.weight" in key
|
||||
or "kv_attn.weight" in key
|
||||
):
|
||||
# Tranpose as we use nn.Linear instead of Conv1D
|
||||
value = value.T
|
||||
|
||||
|
@ -99,11 +105,18 @@ class FlashSantacoder(FlashCausalLM):
|
|||
# Init qkv
|
||||
if "attn.weight" in final_key:
|
||||
module._parameters[param_name] = value.new_empty(
|
||||
(model.transformer.head_size * (model.transformer.num_heads + 2), value.shape[1])
|
||||
(
|
||||
model.transformer.head_size
|
||||
* (model.transformer.num_heads + 2),
|
||||
value.shape[1],
|
||||
)
|
||||
)
|
||||
elif "attn.bias" in final_key:
|
||||
module._parameters[param_name] = value.new_empty(
|
||||
(model.transformer.head_size * (model.transformer.num_heads + 2))
|
||||
(
|
||||
model.transformer.head_size
|
||||
* (model.transformer.num_heads + 2)
|
||||
)
|
||||
)
|
||||
|
||||
# Copy to correct slice
|
||||
|
@ -113,11 +126,11 @@ class FlashSantacoder(FlashCausalLM):
|
|||
module._parameters[param_name][: value.shape[0]] = value
|
||||
elif "kv_attn.weight" in key:
|
||||
module._parameters[param_name][
|
||||
model.transformer.head_size * model.transformer.num_heads:
|
||||
model.transformer.head_size * model.transformer.num_heads :
|
||||
] = value
|
||||
elif "kv_attn.bias" in key:
|
||||
module._parameters[param_name][
|
||||
model.transformer.head_size * model.transformer.num_heads:
|
||||
model.transformer.head_size * model.transformer.num_heads :
|
||||
] = value
|
||||
else:
|
||||
if current_parameter_tensor.shape != value.shape:
|
||||
|
|
|
@ -110,6 +110,7 @@ class NextTokenChooser:
|
|||
|
||||
class StopSequenceCriteria:
|
||||
def __init__(self, stop_sequence: str):
|
||||
stop_sequence = re.escape(stop_sequence)
|
||||
self.regex = re.compile(f".*{stop_sequence}$")
|
||||
|
||||
def __call__(self, output: str) -> bool:
|
||||
|
|
Loading…
Reference in New Issue