fix(server): add position ids to neox (#126)

This commit is contained in:
OlivierDehaene 2023-03-15 13:12:49 +01:00 committed by GitHub
parent cbd36aa4d1
commit 8ad60b752f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 12 additions and 32 deletions

View File

@ -1,4 +1,4 @@
transformers_commit := 517563354a3226ecfc3dca6e7a38012668d7156a transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef
gen-server: gen-server:
# Compile protos # Compile protos

View File

@ -9,7 +9,7 @@ from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
__all__ = [ __all__ = [
@ -19,7 +19,6 @@ __all__ = [
"CausalLM", "CausalLM",
"Galactica", "Galactica",
"GalacticaSharded", "GalacticaSharded",
"GPTNeox",
"GPTNeoxSharded", "GPTNeoxSharded",
"Seq2SeqLM", "Seq2SeqLM",
"SantaCoder", "SantaCoder",
@ -62,7 +61,7 @@ def get_model(
if sharded: if sharded:
return GPTNeoxSharded(model_id, revision, quantize=quantize) return GPTNeoxSharded(model_id, revision, quantize=quantize)
else: else:
return GPTNeox(model_id, revision, quantize=quantize) return CausalLM(model_id, revision, quantize=quantize)
if config.model_type == "t5": if config.model_type == "t5":
if sharded: if sharded:

View File

@ -72,9 +72,7 @@ class CausalLMBatch(Batch):
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
input_lengths.append(r.input_length) input_lengths.append(r.input_length)
next_token_choosers.append( next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
NextTokenChooser.from_pb(r.parameters, device)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )

View File

@ -102,9 +102,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
input_lengths.append(r.input_length) input_lengths.append(r.input_length)
next_token_choosers.append( next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
NextTokenChooser.from_pb(r.parameters, device)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional, Tuple from typing import List, Optional
from accelerate import init_empty_weights from accelerate import init_empty_weights
from safetensors import safe_open from safetensors import safe_open
@ -30,23 +30,7 @@ except Exception as e:
HAS_BITS_AND_BYTES = False HAS_BITS_AND_BYTES = False
class GPTNeox(CausalLM): class GPTNeoxSharded(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class GPTNeoxSharded(GPTNeox):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
@ -224,6 +208,7 @@ class GPTNeoxSharded(GPTNeox):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )

View File

@ -82,9 +82,7 @@ class Seq2SeqLMBatch(Batch):
# Decoder sequence only contains the bos_token # Decoder sequence only contains the bos_token
decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
next_token_choosers.append( next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
NextTokenChooser.from_pb(r.parameters, device)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )

View File

@ -43,7 +43,9 @@ class WatermarkLogitsProcessor(LogitsProcessor):
prev_token = input_ids[-1].item() prev_token = input_ids[-1].item()
self.rng.manual_seed(self.hash_key * prev_token) self.rng.manual_seed(self.hash_key * prev_token)
def _get_greenlist_ids(self, input_ids: torch.LongTensor, max_value: int) -> list[int]: def _get_greenlist_ids(
self, input_ids: torch.LongTensor, max_value: int
) -> list[int]:
# seed the rng using the previous tokens/prefix # seed the rng using the previous tokens/prefix
self._seed_rng(input_ids) self._seed_rng(input_ids)