From 15511edc01a0725d374840f0e77d085eb5821483 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 20 Jan 2023 12:24:39 +0100 Subject: [PATCH] feat(server): Support SantaCoder (#26) --- README.md | 1 + launcher/src/main.rs | 10 +-- router/src/batcher.rs | 7 +- router/src/validation.rs | 16 ++-- server/tests/conftest.py | 23 ----- server/tests/models/test_bloom.py | 16 ++-- server/tests/models/test_causal_lm.py | 18 ++-- server/tests/models/test_santacoder.py | 93 +++++++++++++++++++++ server/tests/models/test_seq2seq_lm.py | 21 +++-- server/text_generation/models/__init__.py | 22 ++++- server/text_generation/models/bloom.py | 19 +++-- server/text_generation/models/causal_lm.py | 19 +++-- server/text_generation/models/galactica.py | 26 +++++- server/text_generation/models/santacoder.py | 87 +++++++++++++++++++ server/text_generation/models/seq2seq_lm.py | 11 ++- server/text_generation/models/types.py | 5 +- server/text_generation/utils.py | 4 +- 17 files changed, 320 insertions(+), 78 deletions(-) create mode 100644 server/tests/models/test_santacoder.py create mode 100644 server/text_generation/models/santacoder.py diff --git a/README.md b/README.md index 8496ce8..2635f64 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets. - [BLOOMZ](https://huggingface.co/bigscience/bloomz) - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) - ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated) +- [SantaCoder](https://huggingface.co/bigcode/santacoder) Other models are supported on a best effort basis using: diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 51ddad2..f897d2c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1,4 +1,5 @@ use clap::Parser; +use serde_json::Value; use std::env; use std::io::{BufRead, BufReader, Read}; use std::path::Path; @@ -11,7 +12,6 @@ use std::thread; use std::thread::sleep; use std::time::{Duration, Instant}; use std::{fs, io}; -use serde_json::Value; use subprocess::{Popen, PopenConfig, PopenError, Redirection}; /// App Configuration @@ -299,16 +299,12 @@ fn shard_manager( // If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard // Useful when running inside a docker container if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") { - env.push(( - "HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into(), - )); + env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); }; // If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") { - env.push(( - "CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into(), - )); + env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into())); }; // Start process diff --git a/router/src/batcher.rs b/router/src/batcher.rs index 90ee409..ee83d89 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -74,9 +74,10 @@ impl Batcher { // Await on the response from the background task // We can safely unwrap as the background task will never drop the sender - response_rx.await.unwrap().map_err( - |err| InferError::GenerationError(err.to_string()) - ) + response_rx + .await + .unwrap() + .map_err(|err| InferError::GenerationError(err.to_string())) } } diff --git a/router/src/validation.rs b/router/src/validation.rs index 4a9d0c2..aabc82a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -94,7 +94,9 @@ fn validation_worker( ) { // Loop over requests while let Some((request, response_tx)) = receiver.blocking_recv() { - response_tx.send(validate(request, &tokenizer, max_input_length)).unwrap_or(()) + response_tx + .send(validate(request, &tokenizer, max_input_length)) + .unwrap_or(()) } } @@ -117,8 +119,9 @@ fn validate( } if request.parameters.stop.len() > MAX_STOP_SEQUENCES { return Err(ValidationError::StopSequence( - MAX_STOP_SEQUENCES, request.parameters.stop.len(), - )) + MAX_STOP_SEQUENCES, + request.parameters.stop.len(), + )); } // Get the number of tokens in the input @@ -127,14 +130,11 @@ fn validate( let input_length = inputs.len(); if input_length > max_input_length { - Err(ValidationError::InputLength( - input_length, - max_input_length, - )) + Err(ValidationError::InputLength(input_length, max_input_length)) } else { Ok((input_length, request)) } - }, + } Err(err) => Err(ValidationError::Tokenizer(err.to_string())), } } diff --git a/server/tests/conftest.py b/server/tests/conftest.py index eb72b8a..e0ed76b 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,7 +1,5 @@ import pytest -from transformers import AutoTokenizer - from text_generation.pb import generate_pb2 @@ -18,24 +16,3 @@ def default_pb_parameters(): @pytest.fixture def default_pb_stop_parameters(): return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10) - - -@pytest.fixture(scope="session") -def bloom_560m_tokenizer(): - return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left") - - -@pytest.fixture(scope="session") -def gpt2_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left") - tokenizer.pad_token_id = 50256 - return tokenizer - - -@pytest.fixture(scope="session") -def mt0_small_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained( - "bigscience/mt0-small", padding_side="left" - ) - tokenizer.bos_token_id = 0 - return tokenizer diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 2a6e670..1a788ce 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -2,12 +2,23 @@ import pytest import torch from copy import copy +from transformers import AutoTokenizer from text_generation.pb import generate_pb2 from text_generation.models.causal_lm import CausalLMBatch from text_generation.models.bloom import BloomCausalLMBatch, BLOOM +@pytest.fixture(scope="session") +def default_bloom(): + return BLOOM("bigscience/bloom-560m") + + +@pytest.fixture(scope="session") +def bloom_560m_tokenizer(): + return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left") + + @pytest.fixture def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( @@ -44,11 +55,6 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer) ) -@pytest.fixture(scope="session") -def default_bloom(): - return BLOOM("bigscience/bloom-560m") - - def test_batch_from_pb(default_pb_batch, default_bloom_batch): batch = default_bloom_batch diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 683d9fd..bedb65b 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -2,11 +2,24 @@ import pytest import torch from copy import copy +from transformers import AutoTokenizer from text_generation.pb import generate_pb2 from text_generation.models.causal_lm import CausalLM, CausalLMBatch +@pytest.fixture(scope="session") +def default_causal_lm(): + return CausalLM("gpt2") + + +@pytest.fixture(scope="session") +def gpt2_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left") + tokenizer.pad_token_id = 50256 + return tokenizer + + @pytest.fixture def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( @@ -39,11 +52,6 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu")) -@pytest.fixture(scope="session") -def default_causal_lm(): - return CausalLM("gpt2") - - def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): batch = default_causal_lm_batch diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py new file mode 100644 index 0000000..c3a8375 --- /dev/null +++ b/server/tests/models/test_santacoder.py @@ -0,0 +1,93 @@ +import pytest + +from text_generation.pb import generate_pb2 +from text_generation.models.causal_lm import CausalLMBatch +from text_generation.models.santacoder import SantaCoder + + +@pytest.fixture(scope="session") +def default_santacoder(): + return SantaCoder("bigcode/santacoder") + + +@pytest.fixture +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="def", + input_length=1, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="defworld", + input_length=5, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_fim_pb_batch(default_fim_pb_request): + return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1) + + +def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch): + batch = CausalLMBatch.from_pb( + default_pb_batch, default_santacoder.tokenizer, default_santacoder.device + ) + next_batch = batch + + for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): + generated_texts, next_batch = default_santacoder.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_santacoder.generate_token(next_batch) + assert next_batch is None + + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "def test_get_all_users_with_" + assert generated_texts[0].request == batch.requests[0] + assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs) + assert ( + generated_texts[0].generated_tokens + == batch.stopping_criterias[0].max_new_tokens + ) + + +def test_fim_santacoder_generate_token_completion( + default_santacoder, default_fim_pb_batch +): + batch = CausalLMBatch.from_pb( + default_fim_pb_batch, default_santacoder.tokenizer, default_santacoder.device + ) + next_batch = batch + + for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): + generated_texts, next_batch = default_santacoder.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_santacoder.generate_token(next_batch) + assert next_batch is None + + assert len(generated_texts) == 1 + assert ( + generated_texts[0].output_text + == """defworldineProperty(exports, "__esModule", { value""" + ) + assert generated_texts[0].request == batch.requests[0] + assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs) + assert ( + generated_texts[0].generated_tokens + == batch.stopping_criterias[0].max_new_tokens + ) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index f1b11bc..de1a482 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -3,10 +3,26 @@ import torch from copy import copy +from transformers import AutoTokenizer + from text_generation.pb import generate_pb2 from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch +@pytest.fixture(scope="session") +def mt0_small_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained( + "bigscience/mt0-small", padding_side="left" + ) + tokenizer.bos_token_id = 0 + return tokenizer + + +@pytest.fixture(scope="session") +def default_seq2seq_lm(): + return Seq2SeqLM("bigscience/mt0-small") + + @pytest.fixture def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( @@ -41,11 +57,6 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu")) -@pytest.fixture(scope="session") -def default_seq2seq_lm(): - return Seq2SeqLM("bigscience/mt0-small") - - def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch): batch = default_seq2seq_lm_batch sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index b615eb7..41d7381 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -1,10 +1,28 @@ +import torch + from text_generation.models.model import Model from text_generation.models.causal_lm import CausalLM from text_generation.models.bloom import BLOOM, BLOOMSharded from text_generation.models.seq2seq_lm import Seq2SeqLM from text_generation.models.galactica import Galactica, GalacticaSharded +from text_generation.models.santacoder import SantaCoder -__all__ = ["Model", "BLOOM", "BLOOMSharded", "CausalLM", "Seq2SeqLM", "get_model"] +__all__ = [ + "Model", + "BLOOM", + "BLOOMSharded", + "CausalLM", + "Seq2SeqLM", + "SantaCoder", + "get_model", +] + +# The flag below controls whether to allow TF32 on matmul. This flag defaults to False +# in PyTorch 1.12 and later. +torch.backends.cuda.matmul.allow_tf32 = True + +# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. +torch.backends.cudnn.allow_tf32 = True def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: @@ -18,6 +36,8 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: return GalacticaSharded(model_name, quantize=quantize) else: return Galactica(model_name, quantize=quantize) + elif "santacoder" in model_name: + return SantaCoder(model_name, quantize) else: if sharded: raise ValueError("sharded is not supported for AutoModel") diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 1135e56..375fff4 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -5,7 +5,12 @@ from typing import List, Optional, Type from accelerate import init_empty_weights from safetensors import safe_open -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + AutoConfig, + PreTrainedTokenizerBase, +) from transformers.models.bloom.parallel_layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -34,7 +39,10 @@ torch.manual_seed(0) class BloomCausalLMBatch(CausalLMBatch): @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "CausalLMBatch": batch = super(BloomCausalLMBatch, cls).from_pb( pb=pb, tokenizer=tokenizer, device=device @@ -70,13 +78,6 @@ class BLOOMSharded(BLOOM): ) config.pad_token_id = 3 - # The flag below controls whether to allow TF32 on matmul. This flag defaults to False - # in PyTorch 1.12 and later. - torch.backends.cuda.matmul.allow_tf32 = True - - # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. - torch.backends.cudnn.allow_tf32 = True - # Only download weights for small models if self.master and model_name == "bigscience/bloom-560m": download_weights(model_name, extension=".safetensors") diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 6bebcc3..93f3520 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -47,7 +47,10 @@ class CausalLMBatch(Batch): @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "CausalLMBatch": inputs = [] next_token_choosers = [] @@ -71,6 +74,7 @@ class CausalLMBatch(Batch): return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of, + return_token_type_ids=False, ).to(device) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) @@ -253,6 +257,11 @@ class CausalLM(Model): def batch_type(self) -> Type[CausalLMBatch]: return CausalLMBatch + def decode(self, generated_ids: List[int]) -> str: + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False + ) + def forward( self, input_ids, attention_mask, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: @@ -338,11 +347,11 @@ class CausalLM(Model): ), ) if stop: - # Decode all tokens - output_text = self.tokenizer.decode( - all_input_ids.squeeze(-1), skip_special_tokens=True, - cleanup_tokenization_spaces=False + # Decode generated tokens + generated_text = self.decode( + all_input_ids[-stopping_criteria.current_tokens :, 0] ) + output_text = request.inputs + generated_text # Slice with input_length to remove padding token_ids = all_input_ids[-new_input_length:] tokens = self.tokenizer.batch_decode(token_ids) diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 76a1b1a..26fb3dd 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -6,7 +6,12 @@ from typing import List, Optional, Type from accelerate import init_empty_weights from safetensors import safe_open -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + AutoConfig, + PreTrainedTokenizerBase, +) from transformers.models.opt.parallel_layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -82,7 +87,10 @@ def escape_custom_split_sequence(text): class GalacticaCausalLMBatch(CausalLMBatch): @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "GalacticaCausalLMBatch": inputs = [] next_token_choosers = [] @@ -99,8 +107,14 @@ class GalacticaCausalLMBatch(CausalLMBatch): StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) + # Tokenize batch + pad_to_multiple_of = 8 if device.type == "cuda" else None tokenized_inputs = tokenizer( - inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 + inputs, + return_tensors="pt", + padding=True, + pad_to_multiple_of=pad_to_multiple_of, + return_token_type_ids=False, ).to(device) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) @@ -124,6 +138,12 @@ class Galactica(CausalLM): def batch_type(self) -> Type[CausalLMBatch]: return GalacticaCausalLMBatch + def decode(self, generated_ids: List[int]) -> str: + # Do not skip special tokens as they are used for custom parsing rules of the generated text + return self.tokenizer.decode( + generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False + ) + class GalacticaSharded(Galactica): def __init__(self, model_name: str, quantize: bool = False): diff --git a/server/text_generation/models/santacoder.py b/server/text_generation/models/santacoder.py new file mode 100644 index 0000000..e1d8e6a --- /dev/null +++ b/server/text_generation/models/santacoder.py @@ -0,0 +1,87 @@ +import torch +import torch.distributed + +from typing import Optional, List, Tuple +from transformers import AutoTokenizer, AutoModelForCausalLM + +from text_generation.models import CausalLM + +FIM_PREFIX = "" +FIM_MIDDLE = "" +FIM_SUFFIX = "" +FIM_PAD = "" +EOD = "<|endoftext|>" + + +class SantaCoder(CausalLM): + def __init__(self, model_name: str, quantize=False): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 + + tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + tokenizer.add_special_tokens( + { + "additional_special_tokens": [ + EOD, + FIM_PREFIX, + FIM_MIDDLE, + FIM_SUFFIX, + FIM_PAD, + ], + "pad_token": EOD, + } + ) + + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=dtype, + device_map="auto" if torch.cuda.is_available() else None, + load_in_8bit=quantize, + trust_remote_code=True, # required + ).eval() + + super(CausalLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + def decode(self, generated_ids: List[int]) -> str: + # Do not skip special tokens as they are used for custom parsing rules of the generated text + return self.tokenizer.decode( + generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False + ) + + def forward( + self, input_ids, attention_mask, past_key_values: Optional = None + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # FIXME: current forward with past is bugged for bigcode/santacoder because past_key_values does not have + # the correct shape ([batch_size, D, seq_length] instead of [batch_size, seq_length D] + # this leads to position_ids being wrong + + input_length = input_ids.shape[-1] + past_key_values_length = ( + 0 if past_key_values is None else past_key_values[0][0].shape[-1] + ) + position_ids = torch.arange( + past_key_values_length, + input_length + past_key_values_length, + dtype=torch.long, + device=input_ids.device, + ).view(1, input_length) + + # Model Forward + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + use_cache=True, + ) + return outputs.logits, outputs.past_key_values diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index c561aeb..2980c74 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -51,7 +51,10 @@ class Seq2SeqLMBatch(Batch): @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "Seq2SeqLMBatch": """Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" inputs = [] @@ -83,6 +86,7 @@ class Seq2SeqLMBatch(Batch): return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of, + return_token_type_ids=False, ).to(device) # Convert decoder_input_ids to torch tensor of size [batch_size, 1] decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1) @@ -318,6 +322,9 @@ class Seq2SeqLM(Model): def batch_type(self) -> Type[Seq2SeqLMBatch]: return Seq2SeqLMBatch + def decode(self, decoder_ids: List[int]) -> str: + return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) + def forward( self, input_ids, @@ -438,7 +445,7 @@ class Seq2SeqLM(Model): # Slice with decoder_input_length to remove padding # Decode all tokens token_ids = decoder_input_ids[-new_decoder_input_length:] - output_text = self.tokenizer.decode(token_ids, skip_special_tokens=True) + output_text = self.decode(token_ids) tokens = self.tokenizer.batch_decode(token_ids) # Add NaN for the bos token logprobs = [float("nan")] + decoder_logprobs[ diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index fa0dc9a..6bf64e0 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -17,7 +17,10 @@ class Batch(ABC): @classmethod @abstractmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "Batch": raise NotImplementedError diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index b0b5b07..1ddeed6 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -114,7 +114,9 @@ class StoppingCriteria: @classmethod def from_pb( - cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: PreTrainedTokenizerBase + cls, + pb: generate_pb2.StoppingCriteriaParameters, + tokenizer: PreTrainedTokenizerBase, ) -> "StoppingCriteria": stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences