fix(server): add position ids to neox (#126)
This commit is contained in:
parent
cbd36aa4d1
commit
8ad60b752f
|
@ -1,4 +1,4 @@
|
|||
transformers_commit := 517563354a3226ecfc3dca6e7a38012668d7156a
|
||||
transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef
|
||||
|
||||
gen-server:
|
||||
# Compile protos
|
||||
|
|
|
@ -9,7 +9,7 @@ from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
|||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
||||
from text_generation_server.models.santacoder import SantaCoder
|
||||
from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded
|
||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
from text_generation_server.models.t5 import T5Sharded
|
||||
|
||||
__all__ = [
|
||||
|
@ -19,7 +19,6 @@ __all__ = [
|
|||
"CausalLM",
|
||||
"Galactica",
|
||||
"GalacticaSharded",
|
||||
"GPTNeox",
|
||||
"GPTNeoxSharded",
|
||||
"Seq2SeqLM",
|
||||
"SantaCoder",
|
||||
|
@ -62,7 +61,7 @@ def get_model(
|
|||
if sharded:
|
||||
return GPTNeoxSharded(model_id, revision, quantize=quantize)
|
||||
else:
|
||||
return GPTNeox(model_id, revision, quantize=quantize)
|
||||
return CausalLM(model_id, revision, quantize=quantize)
|
||||
|
||||
if config.model_type == "t5":
|
||||
if sharded:
|
||||
|
|
|
@ -72,9 +72,7 @@ class CausalLMBatch(Batch):
|
|||
for r in pb.requests:
|
||||
inputs.append(r.inputs)
|
||||
input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device)
|
||||
)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
|
|
@ -102,9 +102,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device)
|
||||
)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
|
@ -30,23 +30,7 @@ except Exception as e:
|
|||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
|
||||
class GPTNeox(CausalLM):
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""Overwrite forward to ignore position_ids"""
|
||||
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
|
||||
class GPTNeoxSharded(GPTNeox):
|
||||
class GPTNeoxSharded(CausalLM):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
|
@ -224,6 +208,7 @@ class GPTNeoxSharded(GPTNeox):
|
|||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
|
|
@ -82,9 +82,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
# Decoder sequence only contains the bos_token
|
||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||
decoder_input_lengths.append(1)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device)
|
||||
)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
|
|
@ -43,7 +43,9 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||
prev_token = input_ids[-1].item()
|
||||
self.rng.manual_seed(self.hash_key * prev_token)
|
||||
|
||||
def _get_greenlist_ids(self, input_ids: torch.LongTensor, max_value: int) -> list[int]:
|
||||
def _get_greenlist_ids(
|
||||
self, input_ids: torch.LongTensor, max_value: int
|
||||
) -> list[int]:
|
||||
# seed the rng using the previous tokens/prefix
|
||||
self._seed_rng(input_ids)
|
||||
|
||||
|
|
Loading…
Reference in New Issue