fix(server): fix decode token (#334)

Fixes #333

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
OlivierDehaene 2023-05-16 23:23:27 +02:00 committed by GitHub
parent dbdc587ddd
commit 5a58226130
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 396 additions and 350 deletions

View File

@ -213,12 +213,13 @@ jobs:
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }} sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
- name: Install - name: Install
run: | run: |
pip install pytest-xdist
make install-integration-tests make install-integration-tests
- name: Run tests - name: Run tests
run: | run: |
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv integration-tests pytest -s -vv -n 2 --dist loadfile integration-tests
stop-runner: stop-runner:
name: Stop self-hosted EC2 runner name: Stop self-hosted EC2 runner

View File

@ -66,7 +66,8 @@ jobs:
- name: Run server tests - name: Run server tests
run: | run: |
pip install pytest pip install pytest
make python-server-tests export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv server/tests
- name: Run Rust fmt - name: Run Rust fmt
run: | run: |
cargo fmt --check cargo fmt --check

View File

@ -31,7 +31,7 @@ update-integration-tests: install-integration-tests
pytest -s -vv --snapshot-update integration-tests pytest -s -vv --snapshot-update integration-tests
python-server-tests: python-server-tests:
HF_HUB_ENABLE_HF_TRANSFER=1 pytest server/tests HF_HUB_ENABLE_HF_TRANSFER=1 pytest -s -vv -m "not private" server/tests
python-client-tests: python-client-tests:
pytest clients/python/tests pytest clients/python/tests

View File

@ -1,3 +1,4 @@
import sys
import subprocess import subprocess
import contextlib import contextlib
import pytest import pytest
@ -7,6 +8,7 @@ import docker
import json import json
import math import math
import time import time
import random
from docker.errors import NotFound from docker.errors import NotFound
from typing import Optional, List, Dict from typing import Optional, List, Dict
@ -205,10 +207,12 @@ def launcher(event_loop):
def local_launcher( def local_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
): ):
port = 9999 port = random.randint(8000, 10_000)
master_port = 19999 master_port = random.randint(10_000, 20_000)
shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server" shard_uds_path = (
f"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server"
)
args = [ args = [
"text-generation-launcher", "text-generation-launcher",
@ -236,7 +240,7 @@ def launcher(event_loop):
process.wait(60) process.wait(60)
launcher_output = process.stdout.read().decode("utf-8") launcher_output = process.stdout.read().decode("utf-8")
print(launcher_output) print(launcher_output, file=sys.stderr)
process.stdout.close() process.stdout.close()
process.stderr.close() process.stderr.close()
@ -245,7 +249,7 @@ def launcher(event_loop):
def docker_launcher( def docker_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
): ):
port = 9999 port = random.randint(8000, 10_000)
args = ["--model-id", model_id, "--env"] args = ["--model-id", model_id, "--env"]
@ -298,7 +302,7 @@ def launcher(event_loop):
pass pass
container_output = container.logs().decode("utf-8") container_output = container.logs().decode("utf-8")
print(container_output) print(container_output, file=sys.stderr)
container.remove() container.remove()

View File

@ -1,92 +1,4 @@
[ [
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -8.6875,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.5546875,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 363,
"logprob": -1.5322266,
"special": false,
"text": " for"
},
{
"id": 847,
"logprob": -2.5585938,
"special": false,
"text": " /"
},
{
"id": 2754,
"logprob": -2.265625,
"special": false,
"text": "api"
},
{
"id": 29914,
"logprob": -0.034088135,
"special": false,
"text": "/"
},
{
"id": 29894,
"logprob": -0.96240234,
"special": false,
"text": "v"
},
{
"id": 29896,
"logprob": -0.36816406,
"special": false,
"text": "1"
},
{
"id": 29914,
"logprob": -0.013191223,
"special": false,
"text": "/"
},
{
"id": 16418,
"logprob": -3.15625,
"special": false,
"text": "projects"
},
{
"id": 29914,
"logprob": -0.43774414,
"special": false,
"text": "/"
},
{
"id": 29896,
"logprob": -1.9443359,
"special": false,
"text": "1"
}
]
},
"generated_text": "for /api/v1/projects/1"
},
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
@ -263,6 +175,94 @@
}, },
"generated_text": "for /api/v1/projects/1" "generated_text": "for /api/v1/projects/1"
}, },
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -8.6875,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.5546875,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 363,
"logprob": -1.5322266,
"special": false,
"text": " for"
},
{
"id": 847,
"logprob": -2.5585938,
"special": false,
"text": " /"
},
{
"id": 2754,
"logprob": -2.265625,
"special": false,
"text": "api"
},
{
"id": 29914,
"logprob": -0.034088135,
"special": false,
"text": "/"
},
{
"id": 29894,
"logprob": -0.96240234,
"special": false,
"text": "v"
},
{
"id": 29896,
"logprob": -0.36816406,
"special": false,
"text": "1"
},
{
"id": 29914,
"logprob": -0.013191223,
"special": false,
"text": "/"
},
{
"id": 16418,
"logprob": -3.15625,
"special": false,
"text": "projects"
},
{
"id": 29914,
"logprob": -0.43774414,
"special": false,
"text": "/"
},
{
"id": 29896,
"logprob": -1.9443359,
"special": false,
"text": "1"
}
]
},
"generated_text": "for /api/v1/projects/1"
},
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,

View File

@ -1,58 +1,4 @@
[ [
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 6,
"prefill": [
{
"id": 0,
"logprob": null,
"text": "<pad>"
}
],
"seed": null,
"tokens": [
{
"id": 259,
"logprob": -1.3789062,
"special": false,
"text": ""
},
{
"id": 39261,
"logprob": -0.36279297,
"special": false,
"text": "Because"
},
{
"id": 609,
"logprob": -1.0966797,
"special": false,
"text": " it"
},
{
"id": 339,
"logprob": -0.8276367,
"special": false,
"text": " is"
},
{
"id": 16017,
"logprob": -1.6845703,
"special": false,
"text": " blue"
},
{
"id": 1,
"logprob": -0.72753906,
"special": true,
"text": "</s>"
}
]
},
"generated_text": "Because it is blue"
},
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
@ -161,6 +107,60 @@
}, },
"generated_text": "Because it is blue" "generated_text": "Because it is blue"
}, },
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 6,
"prefill": [
{
"id": 0,
"logprob": null,
"text": "<pad>"
}
],
"seed": null,
"tokens": [
{
"id": 259,
"logprob": -1.3789062,
"special": false,
"text": " "
},
{
"id": 39261,
"logprob": -0.36279297,
"special": false,
"text": "Because"
},
{
"id": 609,
"logprob": -1.0966797,
"special": false,
"text": " it"
},
{
"id": 339,
"logprob": -0.8276367,
"special": false,
"text": " is"
},
{
"id": 16017,
"logprob": -1.6845703,
"special": false,
"text": " blue"
},
{
"id": 1,
"logprob": -0.72753906,
"special": true,
"text": "</s>"
}
]
},
"generated_text": "Because it is blue"
},
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,

View File

@ -146,7 +146,7 @@ fn main() -> Result<(), std::io::Error> {
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
}, },
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({ false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub."); tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None } HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
}), }),

View File

@ -2,7 +2,7 @@ include Makefile-transformers
include Makefile-flash-att include Makefile-flash-att
unit-tests: unit-tests:
python -m pytest tests pytest -s -vv -m "not private" tests
gen-server: gen-server:
# Compile protos # Compile protos

View File

@ -0,0 +1,78 @@
import pytest
import torch
from transformers import AutoTokenizer
from text_generation_server.models import Model
def get_test_model():
class TestModel(Model):
def batch_type(self):
raise NotImplementedError
def generate_token(self, batch):
raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
model = TestModel(
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
)
return model
@pytest.mark.private
def test_decode_streaming_english_spaces():
model = get_test_model()
truth = "Hello here, this is a simple test"
all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
assert (
all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
)
decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth
@pytest.mark.private
def test_decode_streaming_chinese_utf8():
model = get_test_model()
truth = "我很感谢你的热情"
all_input_ids = [
30672,
232,
193,
139,
233,
135,
162,
235,
179,
165,
30919,
30210,
234,
134,
176,
30993,
]
decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth

View File

@ -56,7 +56,7 @@ class BLOOM(CausalLM):
quantize: Optional[str] = None, quantize: Optional[str] = None,
): ):
super(BLOOM, self).__init__( super(BLOOM, self).__init__(
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1 model_id=model_id, revision=revision, quantize=quantize
) )
@property @property
@ -104,14 +104,13 @@ class BLOOMSharded(BLOOM):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=1,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )

View File

@ -35,8 +35,8 @@ class CausalLMBatch(Batch):
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
offsets: List[Optional[int]] prefix_offsets: List[int]
token_offsets: List[Optional[int]] read_offsets: List[int]
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
@ -70,8 +70,8 @@ class CausalLMBatch(Batch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
requests_idx_mapping = {} requests_idx_mapping = {}
# Parse batch # Parse batch
@ -81,8 +81,6 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
offsets.append(None)
token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
@ -102,6 +100,10 @@ class CausalLMBatch(Batch):
truncation=True, truncation=True,
max_length=max_truncation, max_length=max_truncation,
).to(device) ).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(0)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1) input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max() max_input_length = input_lengths.max()
@ -130,8 +132,8 @@ class CausalLMBatch(Batch):
past_key_values=None, past_key_values=None,
all_input_ids=list(all_input_ids), all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(), input_lengths=input_lengths.tolist(),
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
@ -151,8 +153,8 @@ class CausalLMBatch(Batch):
# New values after filtering # New values after filtering
requests_idx_mapping = {} requests_idx_mapping = {}
input_lengths = [] input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
max_input_length = 0 max_input_length = 0
@ -167,8 +169,8 @@ class CausalLMBatch(Batch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
keep_indices.append(idx) keep_indices.append(idx)
offsets.append(self.offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
token_offsets.append(self.token_offsets[idx]) read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx] request_input_length = self.input_lengths[idx]
@ -225,8 +227,8 @@ class CausalLMBatch(Batch):
self.position_ids = position_ids self.position_ids = position_ids
self.all_input_ids = all_input_ids self.all_input_ids = all_input_ids
self.input_lengths = input_lengths self.input_lengths = input_lengths
self.offsets = offsets self.prefix_offsets = prefix_offsets
self.token_offsets = token_offsets self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length self.max_input_length = max_input_length
@ -251,8 +253,8 @@ class CausalLMBatch(Batch):
requests = [] requests = []
requests_idx_mapping = {} requests_idx_mapping = {}
input_lengths = [] input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -270,8 +272,8 @@ class CausalLMBatch(Batch):
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
requests.extend(batch.requests) requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets) prefix_offsets.extend(batch.prefix_offsets)
token_offsets.extend(batch.token_offsets) read_offsets.extend(batch.read_offsets)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
@ -428,8 +430,8 @@ class CausalLMBatch(Batch):
past_key_values=past_key_values, past_key_values=past_key_values,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
input_lengths=input_lengths, input_lengths=input_lengths,
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length, max_input_length=max_input_length,
@ -448,7 +450,6 @@ class CausalLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
decode_buffer: int = 3,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -463,25 +464,25 @@ class CausalLM(Model):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
self.model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
).eval() )
tokenizer.pad_token_id = ( tokenizer.pad_token_id = (
self.model.config.pad_token_id model.config.pad_token_id
if self.model.config.pad_token_id is not None if model.config.pad_token_id is not None
else self.model.config.eos_token_id else model.config.eos_token_id
) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=decode_buffer,
) )
@property @property
@ -528,8 +529,8 @@ class CausalLM(Model):
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.input_lengths, batch.input_lengths,
batch.offsets, batch.prefix_offsets,
batch.token_offsets, batch.read_offsets,
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
@ -540,8 +541,8 @@ class CausalLM(Model):
for i, ( for i, (
request, request,
input_length, input_length,
offset, prefix_offset,
token_offset, read_offset,
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
@ -559,8 +560,8 @@ class CausalLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text, offset, token_offset = self.decode_token( next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[:, 0], offset, token_offset all_input_ids[:, 0], prefix_offset, read_offset
) )
# Evaluate stopping criteria # Evaluate stopping criteria
@ -628,8 +629,8 @@ class CausalLM(Model):
batch.input_ids[i, 0] = next_token_id batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset batch.prefix_offsets[i] = prefix_offset
batch.token_offsets[i] = token_offset batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length) batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch

View File

@ -52,8 +52,8 @@ class FlashCausalLMBatch(Batch):
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
offsets: List[Optional[int]] prefix_offsets: List[Optional[int]]
token_offsets: List[Optional[int]] read_offsets: List[Optional[int]]
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
@ -82,8 +82,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen = 0 max_seqlen = 0
input_lengths = [] input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
@ -108,8 +108,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length) input_lengths.append(input_length)
offsets.append(None) prefix_offsets.append(0)
token_offsets.append(None) read_offsets.append(input_length)
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
@ -151,8 +151,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
past_key_values=None, past_key_values=None,
input_lengths=input_lengths, input_lengths=input_lengths,
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=[], all_input_ids_tensor=[],
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
@ -190,8 +190,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = [] all_input_ids_tensor = []
input_lengths = [] input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -222,8 +222,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor.append(self.all_input_ids_tensor[idx]) all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
input_lengths.append(request_input_length) input_lengths.append(request_input_length)
offsets.append(self.offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
token_offsets.append(self.token_offsets[idx]) read_offsets.append(self.read_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
@ -269,8 +269,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
@ -302,8 +302,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = [] all_input_ids_tensor = []
input_lengths = [] input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -347,8 +347,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor.extend(batch.all_input_ids_tensor) all_input_ids_tensor.extend(batch.all_input_ids_tensor)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets) prefix_offsets.extend(batch.prefix_offsets)
token_offsets.extend(batch.token_offsets) read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
@ -374,8 +374,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
@ -394,7 +394,6 @@ class FlashCausalLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
decode_buffer: int = 3,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -405,23 +404,19 @@ class FlashCausalLM(Model):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
self.model = ( model = model_cls.from_pretrained(
model_cls.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
) ).to(device)
.eval()
.to(device)
)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=decode_buffer,
) )
@property @property
@ -645,8 +640,8 @@ class FlashCausalLM(Model):
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.input_lengths, batch.input_lengths,
batch.offsets, batch.prefix_offsets,
batch.token_offsets, batch.read_offsets,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
@ -659,8 +654,8 @@ class FlashCausalLM(Model):
for i, ( for i, (
request, request,
input_length, input_length,
offset, prefix_offset,
token_offset, read_offset,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
@ -675,10 +670,10 @@ class FlashCausalLM(Model):
all_input_ids.append(next_token_id) all_input_ids.append(next_token_id)
# Generated token # Generated token
next_token_text, offset, token_offset = self.decode_token( next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids, all_input_ids,
offset, prefix_offset,
token_offset, read_offset,
) )
# Evaluate stopping criteria # Evaluate stopping criteria
@ -744,8 +739,8 @@ class FlashCausalLM(Model):
# Update values # Update values
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset batch.prefix_offsets[i] = prefix_offset
batch.token_offsets[i] = token_offset batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.max_seqlen = batch.max_seqlen + 1 batch.max_seqlen = batch.max_seqlen + 1
cumulative_length += input_length cumulative_length += input_length

View File

@ -64,9 +64,9 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(config) model = FlashLlamaForCausalLM(config)
self.load_weights(model, filenames, quantize, device, dtype) self.load_weights(model, filenames, quantize, device, dtype)
self.model = model.eval().to(device)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
@ -189,9 +189,9 @@ class FlashLlamaSharded(FlashLlama):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,

View File

@ -73,9 +73,9 @@ class FlashNeoXSharded(FlashNeoX):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,

View File

@ -67,14 +67,13 @@ class FlashSantacoder(FlashCausalLM):
dtype, dtype,
config.architectures[0].startswith("GPT2"), config.architectures[0].startswith("GPT2"),
) )
self.model = model.eval().to(device)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=1,
) )
@staticmethod @staticmethod
@ -213,16 +212,15 @@ class FlashSantacoderSharded(FlashSantacoder):
world_size=world_size, world_size=world_size,
transpose=config.architectures[0].startswith("GPT2"), transpose=config.architectures[0].startswith("GPT2"),
) )
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
decode_buffer=1,
) )
@staticmethod @staticmethod

View File

@ -94,8 +94,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
requests_idx_mapping = {} requests_idx_mapping = {}
# Parse batch # Parse batch
@ -106,8 +106,6 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
offsets.append(None)
token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
@ -127,6 +125,10 @@ class GalacticaCausalLMBatch(CausalLMBatch):
truncation=True, truncation=True,
max_length=max_truncation, max_length=max_truncation,
).to(device) ).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(0)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1) input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max() max_input_length = input_lengths.max()
@ -155,8 +157,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
past_key_values=None, past_key_values=None,
all_input_ids=list(all_input_ids), all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(), input_lengths=input_lengths.tolist(),
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
@ -231,9 +233,9 @@ class GalacticaSharded(Galactica):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,

View File

@ -70,9 +70,9 @@ class GPTNeoxSharded(CausalLM):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,

View File

@ -13,23 +13,20 @@ B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC):
def __init__( def __init__(
self, self,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
requires_padding: bool, requires_padding: bool,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
decode_buffer: int = 3,
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
): ):
if decode_buffer < 1: self.model = model.eval()
raise ValueError("decode_buffer must be >= 1")
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)
self.requires_padding = requires_padding self.requires_padding = requires_padding
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.decode_buffer = decode_buffer
self.rank = rank self.rank = rank
self.world_size = world_size self.world_size = world_size
self.check_initialized() self.check_initialized()
@ -54,52 +51,29 @@ class Model(ABC):
def decode_token( def decode_token(
self, self,
all_input_ids: List[int], all_input_ids: List[int],
offset: Optional[int] = None, prefix_offset: int = 0,
token_offset: Optional[int] = None, read_offset: int = 0,
) -> Tuple[str, Optional[int], Optional[int]]: ) -> Tuple[str, int, int]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers""" """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
if all_input_ids[-1] in self.all_special_ids:
return ( # The prefix text is necessary only to defeat cleanup algorithms in the decode
self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False), # which decide to add a space or not depending on the surrounding ids.
None, prefix_text = self.tokenizer.decode(
None, all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
)
new_text = self.tokenizer.decode(
all_input_ids[prefix_offset:], skip_special_tokens=False
) )
if token_offset is None: if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
token_offset = len(all_input_ids) - self.decode_buffer # utf-8 char at the end means it's a potential unfinished byte sequence
# left token buffer # from byte fallback tokenization.
if self.decode_buffer > 1: # If it's in the middle, it's probably a real invalid id generated
# Decode token_offset token minus last one and token_offset tokens # by the model
raw_texts = self.tokenizer.batch_decode( new_text = new_text[len(prefix_text) :]
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]], return new_text, read_offset, len(all_input_ids)
skip_special_tokens=False,
)
# default offset is only the last token
offset = len(raw_texts[0])
sequence_text = raw_texts[1]
else: else:
# Only decode the last token without using a token buffer return "", prefix_offset, read_offset
sequence_text = self.tokenizer.decode(
all_input_ids[-1], skip_special_tokens=False
)
# no offset in this case
offset = 0
else:
assert offset is not None
sequence_text = self.tokenizer.decode(
all_input_ids[token_offset:],
skip_special_tokens=False,
)
# get text
token_text = sequence_text[offset:]
# if text is utf-8
if token_text and token_text[-1] != "<EFBFBD>":
return token_text, None, None
else:
return "", offset, token_offset
def check_initialized(self): def check_initialized(self):
uninitialized_parameters = [] uninitialized_parameters = []

View File

@ -86,9 +86,9 @@ class OPTSharded(OPT):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,

View File

@ -46,24 +46,20 @@ class SantaCoder(CausalLM):
} }
) )
self.model = ( model = AutoModelForCausalLM.from_pretrained(
AutoModelForCausalLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=True, # required trust_remote_code=True, # required
) ).to(device)
.to(device)
.eval()
)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=1,
) )
def decode(self, generated_ids: List[int]) -> str: def decode(self, generated_ids: List[int]) -> str:

View File

@ -42,8 +42,8 @@ class Seq2SeqLMBatch(Batch):
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
decoder_input_lengths: List[int] decoder_input_lengths: List[int]
offsets: List[Optional[int]] prefix_offsets: List[int]
token_offsets: List[Optional[int]] read_offsets: List[int]
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
@ -79,8 +79,8 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias = [] stopping_criterias = []
decoder_input_lengths = [] decoder_input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
requests_idx_mapping = {} requests_idx_mapping = {}
# Parse batch # Parse batch
@ -91,8 +91,6 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs) inputs.append(r.inputs)
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
offsets.append(None)
token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
@ -123,6 +121,9 @@ class Seq2SeqLMBatch(Batch):
.repeat(len(pb.requests)) .repeat(len(pb.requests))
.view(-1, 1) .view(-1, 1)
) )
for _ in pb.requests:
prefix_offsets.append(0)
read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens max_tokens = len(inputs) * max_input_length + max_decode_tokens
@ -140,8 +141,8 @@ class Seq2SeqLMBatch(Batch):
past_key_values=None, past_key_values=None,
input_lengths=input_lengths.tolist(), input_lengths=input_lengths.tolist(),
decoder_input_lengths=decoder_input_lengths, decoder_input_lengths=decoder_input_lengths,
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
@ -165,8 +166,8 @@ class Seq2SeqLMBatch(Batch):
requests_idx_mapping = {} requests_idx_mapping = {}
input_lengths = [] input_lengths = []
decoder_input_lengths = [] decoder_input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
all_decoder_input_ids = [] all_decoder_input_ids = []
@ -184,8 +185,8 @@ class Seq2SeqLMBatch(Batch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
keep_indices.append(idx) keep_indices.append(idx)
offsets.append(self.offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
token_offsets.append(self.token_offsets[idx]) read_offsets.append(self.read_offsets[idx])
all_decoder_input_ids.append(self.all_decoder_input_ids[idx]) all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
@ -248,8 +249,8 @@ class Seq2SeqLMBatch(Batch):
self.all_decoder_input_ids = all_decoder_input_ids self.all_decoder_input_ids = all_decoder_input_ids
self.input_lengths = input_lengths self.input_lengths = input_lengths
self.decoder_input_lengths = decoder_input_lengths self.decoder_input_lengths = decoder_input_lengths
self.offsets = offsets self.prefix_offsets = prefix_offsets
self.token_offsets = token_offsets self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length self.max_input_length = max_input_length
@ -283,8 +284,8 @@ class Seq2SeqLMBatch(Batch):
all_decoder_input_ids = [] all_decoder_input_ids = []
input_lengths = [] input_lengths = []
decoder_input_lengths = [] decoder_input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
max_tokens = 0 max_tokens = 0
@ -306,8 +307,8 @@ class Seq2SeqLMBatch(Batch):
all_decoder_input_ids.extend(batch.all_decoder_input_ids) all_decoder_input_ids.extend(batch.all_decoder_input_ids)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths)
offsets.extend(batch.offsets) prefix_offsets.extend(batch.prefix_offsets)
token_offsets.extend(batch.token_offsets) read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
@ -482,8 +483,8 @@ class Seq2SeqLMBatch(Batch):
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths, decoder_input_lengths=decoder_input_lengths,
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length, max_input_length=max_input_length,
@ -502,7 +503,6 @@ class Seq2SeqLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
decode_buffer: int = 3,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -514,24 +514,24 @@ class Seq2SeqLM(Model):
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
self.model = AutoModelForSeq2SeqLM.from_pretrained( model = AutoModelForSeq2SeqLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
).eval() )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
tokenizer.bos_token_id = self.model.config.decoder_start_token_id tokenizer.bos_token_id = model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=decode_buffer,
) )
@property @property
@ -608,8 +608,8 @@ class Seq2SeqLM(Model):
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.input_lengths, batch.input_lengths,
batch.offsets, batch.prefix_offsets,
batch.token_offsets, batch.read_offsets,
batch.decoder_input_lengths, batch.decoder_input_lengths,
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
@ -621,8 +621,8 @@ class Seq2SeqLM(Model):
for i, ( for i, (
request, request,
input_length, input_length,
offset, prefix_offset,
token_offset, read_offset,
decoder_input_length, decoder_input_length,
logits, logits,
next_token_chooser, next_token_chooser,
@ -643,8 +643,8 @@ class Seq2SeqLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text, offset, token_offset = self.decode_token( next_token_text, prefix_offset, read_offset = self.decode_token(
all_decoder_input_ids, offset, token_offset all_decoder_input_ids, prefix_offset, read_offset
) )
# Evaluate stopping criteria # Evaluate stopping criteria
@ -702,8 +702,8 @@ class Seq2SeqLM(Model):
batch.all_decoder_input_ids[i] = all_decoder_input_ids batch.all_decoder_input_ids[i] = all_decoder_input_ids
batch.input_lengths[i] = input_length batch.input_lengths[i] = input_length
batch.decoder_input_lengths[i] = new_decoder_input_length batch.decoder_input_lengths[i] = new_decoder_input_length
batch.offsets[i] = offset batch.prefix_offsets[i] = prefix_offset
batch.token_offsets[i] = token_offset batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, input_length) batch.max_input_length = max(batch.max_input_length, input_length)
batch.max_decoder_input_length = max( batch.max_decoder_input_length = max(
batch.max_decoder_input_length, new_decoder_input_length batch.max_decoder_input_length, new_decoder_input_length

View File

@ -16,9 +16,6 @@ from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
) )
from text_generation_server.utils.layers import (
FastLinear,
)
from transformers.models.t5.parallel_layers import ( from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -73,9 +70,9 @@ class T5Sharded(Seq2SeqLM):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,