feat(server): Support SantaCoder (#26)

This commit is contained in:
OlivierDehaene 2023-01-20 12:24:39 +01:00 committed by GitHub
parent f7ac394935
commit 15511edc01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 320 additions and 78 deletions

View File

@ -25,6 +25,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
- [BLOOMZ](https://huggingface.co/bigscience/bloomz) - [BLOOMZ](https://huggingface.co/bigscience/bloomz)
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated) - ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
Other models are supported on a best effort basis using: Other models are supported on a best effort basis using:

View File

@ -1,4 +1,5 @@
use clap::Parser; use clap::Parser;
use serde_json::Value;
use std::env; use std::env;
use std::io::{BufRead, BufReader, Read}; use std::io::{BufRead, BufReader, Read};
use std::path::Path; use std::path::Path;
@ -11,7 +12,6 @@ use std::thread;
use std::thread::sleep; use std::thread::sleep;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::{fs, io}; use std::{fs, io};
use serde_json::Value;
use subprocess::{Popen, PopenConfig, PopenError, Redirection}; use subprocess::{Popen, PopenConfig, PopenError, Redirection};
/// App Configuration /// App Configuration
@ -299,16 +299,12 @@ fn shard_manager(
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard // If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") { if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
env.push(( env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
"HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into(),
));
}; };
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard // If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") { if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
env.push(( env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into()));
"CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into(),
));
}; };
// Start process // Start process

View File

@ -74,9 +74,10 @@ impl Batcher {
// Await on the response from the background task // Await on the response from the background task
// We can safely unwrap as the background task will never drop the sender // We can safely unwrap as the background task will never drop the sender
response_rx.await.unwrap().map_err( response_rx
|err| InferError::GenerationError(err.to_string()) .await
) .unwrap()
.map_err(|err| InferError::GenerationError(err.to_string()))
} }
} }

View File

@ -94,7 +94,9 @@ fn validation_worker(
) { ) {
// Loop over requests // Loop over requests
while let Some((request, response_tx)) = receiver.blocking_recv() { while let Some((request, response_tx)) = receiver.blocking_recv() {
response_tx.send(validate(request, &tokenizer, max_input_length)).unwrap_or(()) response_tx
.send(validate(request, &tokenizer, max_input_length))
.unwrap_or(())
} }
} }
@ -117,8 +119,9 @@ fn validate(
} }
if request.parameters.stop.len() > MAX_STOP_SEQUENCES { if request.parameters.stop.len() > MAX_STOP_SEQUENCES {
return Err(ValidationError::StopSequence( return Err(ValidationError::StopSequence(
MAX_STOP_SEQUENCES, request.parameters.stop.len(), MAX_STOP_SEQUENCES,
)) request.parameters.stop.len(),
));
} }
// Get the number of tokens in the input // Get the number of tokens in the input
@ -127,14 +130,11 @@ fn validate(
let input_length = inputs.len(); let input_length = inputs.len();
if input_length > max_input_length { if input_length > max_input_length {
Err(ValidationError::InputLength( Err(ValidationError::InputLength(input_length, max_input_length))
input_length,
max_input_length,
))
} else { } else {
Ok((input_length, request)) Ok((input_length, request))
} }
}, }
Err(err) => Err(ValidationError::Tokenizer(err.to_string())), Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
} }
} }

View File

@ -1,7 +1,5 @@
import pytest import pytest
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
@ -18,24 +16,3 @@ def default_pb_parameters():
@pytest.fixture @pytest.fixture
def default_pb_stop_parameters(): def default_pb_stop_parameters():
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10) return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)
@pytest.fixture(scope="session")
def bloom_560m_tokenizer():
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
@pytest.fixture(scope="session")
def gpt2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
tokenizer.pad_token_id = 50256
return tokenizer
@pytest.fixture(scope="session")
def mt0_small_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(
"bigscience/mt0-small", padding_side="left"
)
tokenizer.bos_token_id = 0
return tokenizer

View File

@ -2,12 +2,23 @@ import pytest
import torch import torch
from copy import copy from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch from text_generation.models.causal_lm import CausalLMBatch
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
@pytest.fixture(scope="session")
def default_bloom():
return BLOOM("bigscience/bloom-560m")
@pytest.fixture(scope="session")
def bloom_560m_tokenizer():
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
@pytest.fixture @pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters): def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
@ -44,11 +55,6 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
) )
@pytest.fixture(scope="session")
def default_bloom():
return BLOOM("bigscience/bloom-560m")
def test_batch_from_pb(default_pb_batch, default_bloom_batch): def test_batch_from_pb(default_pb_batch, default_bloom_batch):
batch = default_bloom_batch batch = default_bloom_batch

View File

@ -2,11 +2,24 @@ import pytest
import torch import torch
from copy import copy from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLM, CausalLMBatch from text_generation.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session")
def default_causal_lm():
return CausalLM("gpt2")
@pytest.fixture(scope="session")
def gpt2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
tokenizer.pad_token_id = 50256
return tokenizer
@pytest.fixture @pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters): def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
@ -39,11 +52,6 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu")) return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
@pytest.fixture(scope="session")
def default_causal_lm():
return CausalLM("gpt2")
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
batch = default_causal_lm_batch batch = default_causal_lm_batch

View File

@ -0,0 +1,93 @@
import pytest
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.models.santacoder import SantaCoder
@pytest.fixture(scope="session")
def default_santacoder():
return SantaCoder("bigcode/santacoder")
@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="def",
input_length=1,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
input_length=5,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_fim_pb_batch(default_fim_pb_request):
return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1)
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
batch = CausalLMBatch.from_pb(
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
)
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output_text == "def test_get_all_users_with_"
assert generated_texts[0].request == batch.requests[0]
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
assert (
generated_texts[0].generated_tokens
== batch.stopping_criterias[0].max_new_tokens
)
def test_fim_santacoder_generate_token_completion(
default_santacoder, default_fim_pb_batch
):
batch = CausalLMBatch.from_pb(
default_fim_pb_batch, default_santacoder.tokenizer, default_santacoder.device
)
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert (
generated_texts[0].output_text
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
)
assert generated_texts[0].request == batch.requests[0]
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
assert (
generated_texts[0].generated_tokens
== batch.stopping_criterias[0].max_new_tokens
)

View File

@ -3,10 +3,26 @@ import torch
from copy import copy from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture(scope="session")
def mt0_small_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(
"bigscience/mt0-small", padding_side="left"
)
tokenizer.bos_token_id = 0
return tokenizer
@pytest.fixture(scope="session")
def default_seq2seq_lm():
return Seq2SeqLM("bigscience/mt0-small")
@pytest.fixture @pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters): def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
@ -41,11 +57,6 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu")) return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
@pytest.fixture(scope="session")
def default_seq2seq_lm():
return Seq2SeqLM("bigscience/mt0-small")
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch): def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
batch = default_seq2seq_lm_batch batch = default_seq2seq_lm_batch
sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) sequence_length = len(default_seq2seq_lm_batch.input_ids[0])

View File

@ -1,10 +1,28 @@
import torch
from text_generation.models.model import Model from text_generation.models.model import Model
from text_generation.models.causal_lm import CausalLM from text_generation.models.causal_lm import CausalLM
from text_generation.models.bloom import BLOOM, BLOOMSharded from text_generation.models.bloom import BLOOM, BLOOMSharded
from text_generation.models.seq2seq_lm import Seq2SeqLM from text_generation.models.seq2seq_lm import Seq2SeqLM
from text_generation.models.galactica import Galactica, GalacticaSharded from text_generation.models.galactica import Galactica, GalacticaSharded
from text_generation.models.santacoder import SantaCoder
__all__ = ["Model", "BLOOM", "BLOOMSharded", "CausalLM", "Seq2SeqLM", "get_model"] __all__ = [
"Model",
"BLOOM",
"BLOOMSharded",
"CausalLM",
"Seq2SeqLM",
"SantaCoder",
"get_model",
]
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
@ -18,6 +36,8 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
return GalacticaSharded(model_name, quantize=quantize) return GalacticaSharded(model_name, quantize=quantize)
else: else:
return Galactica(model_name, quantize=quantize) return Galactica(model_name, quantize=quantize)
elif "santacoder" in model_name:
return SantaCoder(model_name, quantize)
else: else:
if sharded: if sharded:
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")

View File

@ -5,7 +5,12 @@ from typing import List, Optional, Type
from accelerate import init_empty_weights from accelerate import init_empty_weights
from safetensors import safe_open from safetensors import safe_open
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
PreTrainedTokenizerBase,
)
from transformers.models.bloom.parallel_layers import ( from transformers.models.bloom.parallel_layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
@ -34,7 +39,10 @@ torch.manual_seed(0)
class BloomCausalLMBatch(CausalLMBatch): class BloomCausalLMBatch(CausalLMBatch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
batch = super(BloomCausalLMBatch, cls).from_pb( batch = super(BloomCausalLMBatch, cls).from_pb(
pb=pb, tokenizer=tokenizer, device=device pb=pb, tokenizer=tokenizer, device=device
@ -70,13 +78,6 @@ class BLOOMSharded(BLOOM):
) )
config.pad_token_id = 3 config.pad_token_id = 3
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
# Only download weights for small models # Only download weights for small models
if self.master and model_name == "bigscience/bloom-560m": if self.master and model_name == "bigscience/bloom-560m":
download_weights(model_name, extension=".safetensors") download_weights(model_name, extension=".safetensors")

View File

@ -47,7 +47,10 @@ class CausalLMBatch(Batch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
@ -71,6 +74,7 @@ class CausalLMBatch(Batch):
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
pad_to_multiple_of=pad_to_multiple_of, pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False,
).to(device) ).to(device)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
@ -253,6 +257,11 @@ class CausalLM(Model):
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch return CausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
)
def forward( def forward(
self, input_ids, attention_mask, past_key_values: Optional = None self, input_ids, attention_mask, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
@ -338,11 +347,11 @@ class CausalLM(Model):
), ),
) )
if stop: if stop:
# Decode all tokens # Decode generated tokens
output_text = self.tokenizer.decode( generated_text = self.decode(
all_input_ids.squeeze(-1), skip_special_tokens=True, all_input_ids[-stopping_criteria.current_tokens :, 0]
cleanup_tokenization_spaces=False
) )
output_text = request.inputs + generated_text
# Slice with input_length to remove padding # Slice with input_length to remove padding
token_ids = all_input_ids[-new_input_length:] token_ids = all_input_ids[-new_input_length:]
tokens = self.tokenizer.batch_decode(token_ids) tokens = self.tokenizer.batch_decode(token_ids)

View File

@ -6,7 +6,12 @@ from typing import List, Optional, Type
from accelerate import init_empty_weights from accelerate import init_empty_weights
from safetensors import safe_open from safetensors import safe_open
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
PreTrainedTokenizerBase,
)
from transformers.models.opt.parallel_layers import ( from transformers.models.opt.parallel_layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
@ -82,7 +87,10 @@ def escape_custom_split_sequence(text):
class GalacticaCausalLMBatch(CausalLMBatch): class GalacticaCausalLMBatch(CausalLMBatch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "GalacticaCausalLMBatch": ) -> "GalacticaCausalLMBatch":
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
@ -99,8 +107,14 @@ class GalacticaCausalLMBatch(CausalLMBatch):
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
) )
# Tokenize batch
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 inputs,
return_tensors="pt",
padding=True,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False,
).to(device) ).to(device)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
@ -124,6 +138,12 @@ class Galactica(CausalLM):
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return GalacticaCausalLMBatch return GalacticaCausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
)
class GalacticaSharded(Galactica): class GalacticaSharded(Galactica):
def __init__(self, model_name: str, quantize: bool = False): def __init__(self, model_name: str, quantize: bool = False):

View File

@ -0,0 +1,87 @@
import torch
import torch.distributed
from typing import Optional, List, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation.models import CausalLM
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
EOD = "<|endoftext|>"
class SantaCoder(CausalLM):
def __init__(self, model_name: str, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
EOD,
FIM_PREFIX,
FIM_MIDDLE,
FIM_SUFFIX,
FIM_PAD,
],
"pad_token": EOD,
}
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
trust_remote_code=True, # required
).eval()
super(CausalLM, self).__init__(
tokenizer=tokenizer,
device=device,
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# FIXME: current forward with past is bugged for bigcode/santacoder because past_key_values does not have
# the correct shape ([batch_size, D, seq_length] instead of [batch_size, seq_length D]
# this leads to position_ids being wrong
input_length = input_ids.shape[-1]
past_key_values_length = (
0 if past_key_values is None else past_key_values[0][0].shape[-1]
)
position_ids = torch.arange(
past_key_values_length,
input_length + past_key_values_length,
dtype=torch.long,
device=input_ids.device,
).view(1, input_length)
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
use_cache=True,
)
return outputs.logits, outputs.past_key_values

View File

@ -51,7 +51,10 @@ class Seq2SeqLMBatch(Batch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "Seq2SeqLMBatch": ) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" """Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = [] inputs = []
@ -83,6 +86,7 @@ class Seq2SeqLMBatch(Batch):
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
pad_to_multiple_of=pad_to_multiple_of, pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False,
).to(device) ).to(device)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1] # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1) decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
@ -318,6 +322,9 @@ class Seq2SeqLM(Model):
def batch_type(self) -> Type[Seq2SeqLMBatch]: def batch_type(self) -> Type[Seq2SeqLMBatch]:
return Seq2SeqLMBatch return Seq2SeqLMBatch
def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
def forward( def forward(
self, self,
input_ids, input_ids,
@ -438,7 +445,7 @@ class Seq2SeqLM(Model):
# Slice with decoder_input_length to remove padding # Slice with decoder_input_length to remove padding
# Decode all tokens # Decode all tokens
token_ids = decoder_input_ids[-new_decoder_input_length:] token_ids = decoder_input_ids[-new_decoder_input_length:]
output_text = self.tokenizer.decode(token_ids, skip_special_tokens=True) output_text = self.decode(token_ids)
tokens = self.tokenizer.batch_decode(token_ids) tokens = self.tokenizer.batch_decode(token_ids)
# Add NaN for the bos token # Add NaN for the bos token
logprobs = [float("nan")] + decoder_logprobs[ logprobs = [float("nan")] + decoder_logprobs[

View File

@ -17,7 +17,10 @@ class Batch(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "Batch": ) -> "Batch":
raise NotImplementedError raise NotImplementedError

View File

@ -114,7 +114,9 @@ class StoppingCriteria:
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: PreTrainedTokenizerBase cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria": ) -> "StoppingCriteria":
stop_sequence_criterias = [ stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences