fix(server): add position ids to neox (#126)
This commit is contained in:
parent
cbd36aa4d1
commit
8ad60b752f
|
@ -1,4 +1,4 @@
|
||||||
transformers_commit := 517563354a3226ecfc3dca6e7a38012668d7156a
|
transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef
|
||||||
|
|
||||||
gen-server:
|
gen-server:
|
||||||
# Compile protos
|
# Compile protos
|
||||||
|
|
|
@ -9,7 +9,7 @@ from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
||||||
from text_generation_server.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||||
from text_generation_server.models.t5 import T5Sharded
|
from text_generation_server.models.t5 import T5Sharded
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -19,7 +19,6 @@ __all__ = [
|
||||||
"CausalLM",
|
"CausalLM",
|
||||||
"Galactica",
|
"Galactica",
|
||||||
"GalacticaSharded",
|
"GalacticaSharded",
|
||||||
"GPTNeox",
|
|
||||||
"GPTNeoxSharded",
|
"GPTNeoxSharded",
|
||||||
"Seq2SeqLM",
|
"Seq2SeqLM",
|
||||||
"SantaCoder",
|
"SantaCoder",
|
||||||
|
@ -62,7 +61,7 @@ def get_model(
|
||||||
if sharded:
|
if sharded:
|
||||||
return GPTNeoxSharded(model_id, revision, quantize=quantize)
|
return GPTNeoxSharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return GPTNeox(model_id, revision, quantize=quantize)
|
return CausalLM(model_id, revision, quantize=quantize)
|
||||||
|
|
||||||
if config.model_type == "t5":
|
if config.model_type == "t5":
|
||||||
if sharded:
|
if sharded:
|
||||||
|
|
|
@ -72,9 +72,7 @@ class CausalLMBatch(Batch):
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
input_lengths.append(r.input_length)
|
input_lengths.append(r.input_length)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
NextTokenChooser.from_pb(r.parameters, device)
|
|
||||||
)
|
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
|
|
@ -102,9 +102,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
input_lengths.append(r.input_length)
|
input_lengths.append(r.input_length)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
NextTokenChooser.from_pb(r.parameters, device)
|
|
||||||
)
|
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
@ -30,23 +30,7 @@ except Exception as e:
|
||||||
HAS_BITS_AND_BYTES = False
|
HAS_BITS_AND_BYTES = False
|
||||||
|
|
||||||
|
|
||||||
class GPTNeox(CausalLM):
|
class GPTNeoxSharded(CausalLM):
|
||||||
def forward(
|
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
||||||
"""Overwrite forward to ignore position_ids"""
|
|
||||||
|
|
||||||
# Model Forward
|
|
||||||
outputs = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
return outputs.logits, outputs.past_key_values
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoxSharded(GPTNeox):
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
|
@ -224,6 +208,7 @@ class GPTNeoxSharded(GPTNeox):
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -82,9 +82,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
# Decoder sequence only contains the bos_token
|
# Decoder sequence only contains the bos_token
|
||||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
NextTokenChooser.from_pb(r.parameters, device)
|
|
||||||
)
|
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
|
|
@ -43,7 +43,9 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
||||||
prev_token = input_ids[-1].item()
|
prev_token = input_ids[-1].item()
|
||||||
self.rng.manual_seed(self.hash_key * prev_token)
|
self.rng.manual_seed(self.hash_key * prev_token)
|
||||||
|
|
||||||
def _get_greenlist_ids(self, input_ids: torch.LongTensor, max_value: int) -> list[int]:
|
def _get_greenlist_ids(
|
||||||
|
self, input_ids: torch.LongTensor, max_value: int
|
||||||
|
) -> list[int]:
|
||||||
# seed the rng using the previous tokens/prefix
|
# seed the rng using the previous tokens/prefix
|
||||||
self._seed_rng(input_ids)
|
self._seed_rng(input_ids)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue