2023-05-16 15:23:27 -06:00
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
from text_generation_server.models import Model
|
|
|
|
|
|
|
|
|
|
|
|
def get_test_model():
|
|
|
|
class TestModel(Model):
|
|
|
|
def batch_type(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def generate_token(self, batch):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
|
|
|
|
|
|
|
|
model = TestModel(
|
2024-06-25 12:46:27 -06:00
|
|
|
"test_model_id",
|
|
|
|
torch.nn.Linear(1, 1),
|
|
|
|
tokenizer,
|
|
|
|
False,
|
|
|
|
torch.float32,
|
|
|
|
torch.device("cpu"),
|
2023-05-16 15:23:27 -06:00
|
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.private
|
|
|
|
def test_decode_streaming_english_spaces():
|
|
|
|
model = get_test_model()
|
|
|
|
truth = "Hello here, this is a simple test"
|
|
|
|
all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
|
|
|
|
assert (
|
|
|
|
all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
|
|
|
|
)
|
|
|
|
|
|
|
|
decoded_text = ""
|
|
|
|
offset = 0
|
|
|
|
token_offset = 0
|
|
|
|
for i in range(len(all_input_ids)):
|
|
|
|
text, offset, token_offset = model.decode_token(
|
|
|
|
all_input_ids[: i + 1], offset, token_offset
|
|
|
|
)
|
|
|
|
decoded_text += text
|
|
|
|
|
|
|
|
assert decoded_text == truth
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.private
|
|
|
|
def test_decode_streaming_chinese_utf8():
|
|
|
|
model = get_test_model()
|
|
|
|
truth = "我很感谢你的热情"
|
|
|
|
all_input_ids = [
|
|
|
|
30672,
|
|
|
|
232,
|
|
|
|
193,
|
|
|
|
139,
|
|
|
|
233,
|
|
|
|
135,
|
|
|
|
162,
|
|
|
|
235,
|
|
|
|
179,
|
|
|
|
165,
|
|
|
|
30919,
|
|
|
|
30210,
|
|
|
|
234,
|
|
|
|
134,
|
|
|
|
176,
|
|
|
|
30993,
|
|
|
|
]
|
|
|
|
|
|
|
|
decoded_text = ""
|
|
|
|
offset = 0
|
|
|
|
token_offset = 0
|
|
|
|
for i in range(len(all_input_ids)):
|
|
|
|
text, offset, token_offset = model.decode_token(
|
|
|
|
all_input_ids[: i + 1], offset, token_offset
|
|
|
|
)
|
|
|
|
decoded_text += text
|
|
|
|
|
|
|
|
assert decoded_text == truth
|