fix(server): add position ids to neox (#126)

This commit is contained in:
OlivierDehaene 2023-03-15 13:12:49 +01:00 committed by GitHub
parent cbd36aa4d1
commit 8ad60b752f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 12 additions and 32 deletions

View File

@ -1,4 +1,4 @@
transformers_commit := 517563354a3226ecfc3dca6e7a38012668d7156a
transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef
gen-server:
# Compile protos

View File

@ -9,7 +9,7 @@ from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded
__all__ = [
@ -19,7 +19,6 @@ __all__ = [
"CausalLM",
"Galactica",
"GalacticaSharded",
"GPTNeox",
"GPTNeoxSharded",
"Seq2SeqLM",
"SantaCoder",
@ -62,7 +61,7 @@ def get_model(
if sharded:
return GPTNeoxSharded(model_id, revision, quantize=quantize)
else:
return GPTNeox(model_id, revision, quantize=quantize)
return CausalLM(model_id, revision, quantize=quantize)
if config.model_type == "t5":
if sharded:

View File

@ -72,9 +72,7 @@ class CausalLMBatch(Batch):
for r in pb.requests:
inputs.append(r.inputs)
input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device)
)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)

View File

@ -102,9 +102,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device)
)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)

View File

@ -1,7 +1,7 @@
import torch
import torch.distributed
from typing import List, Optional, Tuple
from typing import List, Optional
from accelerate import init_empty_weights
from safetensors import safe_open
@ -30,23 +30,7 @@ except Exception as e:
HAS_BITS_AND_BYTES = False
class GPTNeox(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class GPTNeoxSharded(GPTNeox):
class GPTNeoxSharded(CausalLM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
@ -224,6 +208,7 @@ class GPTNeoxSharded(GPTNeox):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)

View File

@ -82,9 +82,7 @@ class Seq2SeqLMBatch(Batch):
# Decoder sequence only contains the bos_token
decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device)
)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)

View File

@ -43,7 +43,9 @@ class WatermarkLogitsProcessor(LogitsProcessor):
prev_token = input_ids[-1].item()
self.rng.manual_seed(self.hash_key * prev_token)
def _get_greenlist_ids(self, input_ids: torch.LongTensor, max_value: int) -> list[int]:
def _get_greenlist_ids(
self, input_ids: torch.LongTensor, max_value: int
) -> list[int]:
# seed the rng using the previous tokens/prefix
self._seed_rng(input_ids)