fix(server): fix decode token (#334)
Fixes #333 --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
dbdc587ddd
commit
5a58226130
|
@ -213,12 +213,13 @@ jobs:
|
||||||
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
|
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
|
||||||
- name: Install
|
- name: Install
|
||||||
run: |
|
run: |
|
||||||
|
pip install pytest-xdist
|
||||||
make install-integration-tests
|
make install-integration-tests
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||||
pytest -s -vv integration-tests
|
pytest -s -vv -n 2 --dist loadfile integration-tests
|
||||||
|
|
||||||
stop-runner:
|
stop-runner:
|
||||||
name: Stop self-hosted EC2 runner
|
name: Stop self-hosted EC2 runner
|
||||||
|
|
|
@ -66,7 +66,8 @@ jobs:
|
||||||
- name: Run server tests
|
- name: Run server tests
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
pip install pytest
|
||||||
make python-server-tests
|
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||||
|
pytest -s -vv server/tests
|
||||||
- name: Run Rust fmt
|
- name: Run Rust fmt
|
||||||
run: |
|
run: |
|
||||||
cargo fmt --check
|
cargo fmt --check
|
||||||
|
|
2
Makefile
2
Makefile
|
@ -31,7 +31,7 @@ update-integration-tests: install-integration-tests
|
||||||
pytest -s -vv --snapshot-update integration-tests
|
pytest -s -vv --snapshot-update integration-tests
|
||||||
|
|
||||||
python-server-tests:
|
python-server-tests:
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 pytest server/tests
|
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -s -vv -m "not private" server/tests
|
||||||
|
|
||||||
python-client-tests:
|
python-client-tests:
|
||||||
pytest clients/python/tests
|
pytest clients/python/tests
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import sys
|
||||||
import subprocess
|
import subprocess
|
||||||
import contextlib
|
import contextlib
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -7,6 +8,7 @@ import docker
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
import random
|
||||||
|
|
||||||
from docker.errors import NotFound
|
from docker.errors import NotFound
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
|
@ -205,10 +207,12 @@ def launcher(event_loop):
|
||||||
def local_launcher(
|
def local_launcher(
|
||||||
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
||||||
):
|
):
|
||||||
port = 9999
|
port = random.randint(8000, 10_000)
|
||||||
master_port = 19999
|
master_port = random.randint(10_000, 20_000)
|
||||||
|
|
||||||
shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server"
|
shard_uds_path = (
|
||||||
|
f"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server"
|
||||||
|
)
|
||||||
|
|
||||||
args = [
|
args = [
|
||||||
"text-generation-launcher",
|
"text-generation-launcher",
|
||||||
|
@ -236,7 +240,7 @@ def launcher(event_loop):
|
||||||
process.wait(60)
|
process.wait(60)
|
||||||
|
|
||||||
launcher_output = process.stdout.read().decode("utf-8")
|
launcher_output = process.stdout.read().decode("utf-8")
|
||||||
print(launcher_output)
|
print(launcher_output, file=sys.stderr)
|
||||||
|
|
||||||
process.stdout.close()
|
process.stdout.close()
|
||||||
process.stderr.close()
|
process.stderr.close()
|
||||||
|
@ -245,7 +249,7 @@ def launcher(event_loop):
|
||||||
def docker_launcher(
|
def docker_launcher(
|
||||||
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
||||||
):
|
):
|
||||||
port = 9999
|
port = random.randint(8000, 10_000)
|
||||||
|
|
||||||
args = ["--model-id", model_id, "--env"]
|
args = ["--model-id", model_id, "--env"]
|
||||||
|
|
||||||
|
@ -298,7 +302,7 @@ def launcher(event_loop):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
container_output = container.logs().decode("utf-8")
|
container_output = container.logs().decode("utf-8")
|
||||||
print(container_output)
|
print(container_output, file=sys.stderr)
|
||||||
|
|
||||||
container.remove()
|
container.remove()
|
||||||
|
|
||||||
|
|
|
@ -1,92 +1,4 @@
|
||||||
[
|
[
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "<s>"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 4321,
|
|
||||||
"logprob": -8.6875,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2009,
|
|
||||||
"logprob": -11.5546875,
|
|
||||||
"text": "request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 363,
|
|
||||||
"logprob": -1.5322266,
|
|
||||||
"special": false,
|
|
||||||
"text": " for"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 847,
|
|
||||||
"logprob": -2.5585938,
|
|
||||||
"special": false,
|
|
||||||
"text": " /"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2754,
|
|
||||||
"logprob": -2.265625,
|
|
||||||
"special": false,
|
|
||||||
"text": "api"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29914,
|
|
||||||
"logprob": -0.034088135,
|
|
||||||
"special": false,
|
|
||||||
"text": "/"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29894,
|
|
||||||
"logprob": -0.96240234,
|
|
||||||
"special": false,
|
|
||||||
"text": "v"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29896,
|
|
||||||
"logprob": -0.36816406,
|
|
||||||
"special": false,
|
|
||||||
"text": "1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29914,
|
|
||||||
"logprob": -0.013191223,
|
|
||||||
"special": false,
|
|
||||||
"text": "/"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 16418,
|
|
||||||
"logprob": -3.15625,
|
|
||||||
"special": false,
|
|
||||||
"text": "projects"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29914,
|
|
||||||
"logprob": -0.43774414,
|
|
||||||
"special": false,
|
|
||||||
"text": "/"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29896,
|
|
||||||
"logprob": -1.9443359,
|
|
||||||
"special": false,
|
|
||||||
"text": "1"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"generated_text": "for /api/v1/projects/1"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
|
@ -263,6 +175,94 @@
|
||||||
},
|
},
|
||||||
"generated_text": "for /api/v1/projects/1"
|
"generated_text": "for /api/v1/projects/1"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -8.6875,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -11.5546875,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 363,
|
||||||
|
"logprob": -1.5322266,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 847,
|
||||||
|
"logprob": -2.5585938,
|
||||||
|
"special": false,
|
||||||
|
"text": " /"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2754,
|
||||||
|
"logprob": -2.265625,
|
||||||
|
"special": false,
|
||||||
|
"text": "api"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29914,
|
||||||
|
"logprob": -0.034088135,
|
||||||
|
"special": false,
|
||||||
|
"text": "/"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29894,
|
||||||
|
"logprob": -0.96240234,
|
||||||
|
"special": false,
|
||||||
|
"text": "v"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -0.36816406,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29914,
|
||||||
|
"logprob": -0.013191223,
|
||||||
|
"special": false,
|
||||||
|
"text": "/"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16418,
|
||||||
|
"logprob": -3.15625,
|
||||||
|
"special": false,
|
||||||
|
"text": "projects"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29914,
|
||||||
|
"logprob": -0.43774414,
|
||||||
|
"special": false,
|
||||||
|
"text": "/"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -1.9443359,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "for /api/v1/projects/1"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
"id": 926,
|
"id": 926,
|
||||||
"logprob": -4.3554688,
|
"logprob": -4.3554688,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "To"
|
"text": " To"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 18295,
|
"id": 18295,
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
"id": 16017,
|
"id": 16017,
|
||||||
"logprob": -1.3505859,
|
"logprob": -1.3505859,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "blue"
|
"text": " blue"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 20495,
|
"id": 20495,
|
||||||
|
|
|
@ -1,58 +1,4 @@
|
||||||
[
|
[
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "eos_token",
|
|
||||||
"generated_tokens": 6,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 0,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "<pad>"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 259,
|
|
||||||
"logprob": -1.3789062,
|
|
||||||
"special": false,
|
|
||||||
"text": ""
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 39261,
|
|
||||||
"logprob": -0.36279297,
|
|
||||||
"special": false,
|
|
||||||
"text": "Because"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 609,
|
|
||||||
"logprob": -1.0966797,
|
|
||||||
"special": false,
|
|
||||||
"text": " it"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 339,
|
|
||||||
"logprob": -0.8276367,
|
|
||||||
"special": false,
|
|
||||||
"text": " is"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 16017,
|
|
||||||
"logprob": -1.6845703,
|
|
||||||
"special": false,
|
|
||||||
"text": " blue"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"logprob": -0.72753906,
|
|
||||||
"special": true,
|
|
||||||
"text": "</s>"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"generated_text": "Because it is blue"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
|
@ -71,7 +17,7 @@
|
||||||
"id": 259,
|
"id": 259,
|
||||||
"logprob": -1.3798828,
|
"logprob": -1.3798828,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ""
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 39261,
|
"id": 39261,
|
||||||
|
@ -125,7 +71,7 @@
|
||||||
"id": 259,
|
"id": 259,
|
||||||
"logprob": -1.3789062,
|
"logprob": -1.3789062,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ""
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 39261,
|
"id": 39261,
|
||||||
|
@ -179,7 +125,61 @@
|
||||||
"id": 259,
|
"id": 259,
|
||||||
"logprob": -1.3789062,
|
"logprob": -1.3789062,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ""
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 39261,
|
||||||
|
"logprob": -0.36279297,
|
||||||
|
"special": false,
|
||||||
|
"text": "Because"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 609,
|
||||||
|
"logprob": -1.0966797,
|
||||||
|
"special": false,
|
||||||
|
"text": " it"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 339,
|
||||||
|
"logprob": -0.8276367,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16017,
|
||||||
|
"logprob": -1.6845703,
|
||||||
|
"special": false,
|
||||||
|
"text": " blue"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": -0.72753906,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "Because it is blue"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 6,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<pad>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 259,
|
||||||
|
"logprob": -1.3789062,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 39261,
|
"id": 39261,
|
||||||
|
|
|
@ -146,7 +146,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
sha: None,
|
sha: None,
|
||||||
pipeline_tag: None,
|
pipeline_tag: None,
|
||||||
},
|
},
|
||||||
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({
|
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| {
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
|
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -2,7 +2,7 @@ include Makefile-transformers
|
||||||
include Makefile-flash-att
|
include Makefile-flash-att
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
python -m pytest tests
|
pytest -s -vv -m "not private" tests
|
||||||
|
|
||||||
gen-server:
|
gen-server:
|
||||||
# Compile protos
|
# Compile protos
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from text_generation_server.models import Model
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_model():
|
||||||
|
class TestModel(Model):
|
||||||
|
def batch_type(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def generate_token(self, batch):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
|
||||||
|
|
||||||
|
model = TestModel(
|
||||||
|
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
def test_decode_streaming_english_spaces():
|
||||||
|
model = get_test_model()
|
||||||
|
truth = "Hello here, this is a simple test"
|
||||||
|
all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
|
||||||
|
assert (
|
||||||
|
all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded_text = ""
|
||||||
|
offset = 0
|
||||||
|
token_offset = 0
|
||||||
|
for i in range(len(all_input_ids)):
|
||||||
|
text, offset, token_offset = model.decode_token(
|
||||||
|
all_input_ids[: i + 1], offset, token_offset
|
||||||
|
)
|
||||||
|
decoded_text += text
|
||||||
|
|
||||||
|
assert decoded_text == truth
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
def test_decode_streaming_chinese_utf8():
|
||||||
|
model = get_test_model()
|
||||||
|
truth = "我很感谢你的热情"
|
||||||
|
all_input_ids = [
|
||||||
|
30672,
|
||||||
|
232,
|
||||||
|
193,
|
||||||
|
139,
|
||||||
|
233,
|
||||||
|
135,
|
||||||
|
162,
|
||||||
|
235,
|
||||||
|
179,
|
||||||
|
165,
|
||||||
|
30919,
|
||||||
|
30210,
|
||||||
|
234,
|
||||||
|
134,
|
||||||
|
176,
|
||||||
|
30993,
|
||||||
|
]
|
||||||
|
|
||||||
|
decoded_text = ""
|
||||||
|
offset = 0
|
||||||
|
token_offset = 0
|
||||||
|
for i in range(len(all_input_ids)):
|
||||||
|
text, offset, token_offset = model.decode_token(
|
||||||
|
all_input_ids[: i + 1], offset, token_offset
|
||||||
|
)
|
||||||
|
decoded_text += text
|
||||||
|
|
||||||
|
assert decoded_text == truth
|
|
@ -149,7 +149,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
||||||
assert all([generation.generated_text is None for generation in generations])
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
assert all([generation.token_id.item() == 259 for generation in generations])
|
assert all([generation.token_id.item() == 259 for generation in generations])
|
||||||
assert all([generation.token_text == "" for generation in generations])
|
assert all([generation.token_text == " " for generation in generations])
|
||||||
assert generations[0].request_id == 0
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ class BLOOM(CausalLM):
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super(BLOOM, self).__init__(
|
super(BLOOM, self).__init__(
|
||||||
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
|
model_id=model_id, revision=revision, quantize=quantize
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -104,14 +104,13 @@ class BLOOMSharded(BLOOM):
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval()
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
decode_buffer=1,
|
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
|
@ -35,8 +35,8 @@ class CausalLMBatch(Batch):
|
||||||
|
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
offsets: List[Optional[int]]
|
prefix_offsets: List[int]
|
||||||
token_offsets: List[Optional[int]]
|
read_offsets: List[int]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
|
@ -70,8 +70,8 @@ class CausalLMBatch(Batch):
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
|
@ -81,8 +81,6 @@ class CausalLMBatch(Batch):
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
offsets.append(None)
|
|
||||||
token_offsets.append(None)
|
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
|
@ -102,6 +100,10 @@ class CausalLMBatch(Batch):
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_truncation,
|
max_length=max_truncation,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
for _ in pb.requests:
|
||||||
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
|
prefix_offsets.append(0)
|
||||||
|
read_offsets.append(input_len)
|
||||||
|
|
||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
max_input_length = input_lengths.max()
|
max_input_length = input_lengths.max()
|
||||||
|
@ -130,8 +132,8 @@ class CausalLMBatch(Batch):
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=list(all_input_ids),
|
all_input_ids=list(all_input_ids),
|
||||||
input_lengths=input_lengths.tolist(),
|
input_lengths=input_lengths.tolist(),
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
|
@ -151,8 +153,8 @@ class CausalLMBatch(Batch):
|
||||||
# New values after filtering
|
# New values after filtering
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
max_input_length = 0
|
max_input_length = 0
|
||||||
|
|
||||||
|
@ -167,8 +169,8 @@ class CausalLMBatch(Batch):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
keep_indices.append(idx)
|
keep_indices.append(idx)
|
||||||
|
|
||||||
offsets.append(self.offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
token_offsets.append(self.token_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
|
|
||||||
request_input_length = self.input_lengths[idx]
|
request_input_length = self.input_lengths[idx]
|
||||||
|
@ -225,8 +227,8 @@ class CausalLMBatch(Batch):
|
||||||
self.position_ids = position_ids
|
self.position_ids = position_ids
|
||||||
self.all_input_ids = all_input_ids
|
self.all_input_ids = all_input_ids
|
||||||
self.input_lengths = input_lengths
|
self.input_lengths = input_lengths
|
||||||
self.offsets = offsets
|
self.prefix_offsets = prefix_offsets
|
||||||
self.token_offsets = token_offsets
|
self.read_offsets = read_offsets
|
||||||
self.next_token_choosers = next_token_choosers
|
self.next_token_choosers = next_token_choosers
|
||||||
self.stopping_criterias = stopping_criterias
|
self.stopping_criterias = stopping_criterias
|
||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
|
@ -251,8 +253,8 @@ class CausalLMBatch(Batch):
|
||||||
requests = []
|
requests = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
@ -270,8 +272,8 @@ class CausalLMBatch(Batch):
|
||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
offsets.extend(batch.offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
token_offsets.extend(batch.token_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
@ -428,8 +430,8 @@ class CausalLMBatch(Batch):
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
|
@ -448,7 +450,6 @@ class CausalLM(Model):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
decode_buffer: int = 3,
|
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
@ -463,25 +464,25 @@ class CausalLM(Model):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
).eval()
|
)
|
||||||
tokenizer.pad_token_id = (
|
tokenizer.pad_token_id = (
|
||||||
self.model.config.pad_token_id
|
model.config.pad_token_id
|
||||||
if self.model.config.pad_token_id is not None
|
if model.config.pad_token_id is not None
|
||||||
else self.model.config.eos_token_id
|
else model.config.eos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
decode_buffer=decode_buffer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -528,8 +529,8 @@ class CausalLM(Model):
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.offsets,
|
batch.prefix_offsets,
|
||||||
batch.token_offsets,
|
batch.read_offsets,
|
||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
|
@ -540,8 +541,8 @@ class CausalLM(Model):
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
offset,
|
prefix_offset,
|
||||||
token_offset,
|
read_offset,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
|
@ -559,8 +560,8 @@ class CausalLM(Model):
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text, offset, token_offset = self.decode_token(
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
all_input_ids[:, 0], offset, token_offset
|
all_input_ids[:, 0], prefix_offset, read_offset
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
|
@ -628,8 +629,8 @@ class CausalLM(Model):
|
||||||
batch.input_ids[i, 0] = next_token_id
|
batch.input_ids[i, 0] = next_token_id
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
batch.offsets[i] = offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.token_offsets[i] = token_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
|
|
|
@ -52,8 +52,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
offsets: List[Optional[int]]
|
prefix_offsets: List[Optional[int]]
|
||||||
token_offsets: List[Optional[int]]
|
read_offsets: List[Optional[int]]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
|
@ -82,8 +82,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
|
@ -108,8 +108,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
offsets.append(None)
|
prefix_offsets.append(0)
|
||||||
token_offsets.append(None)
|
read_offsets.append(input_length)
|
||||||
|
|
||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
|
@ -151,8 +151,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=[],
|
all_input_ids_tensor=[],
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
|
@ -190,8 +190,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids_tensor = []
|
all_input_ids_tensor = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
@ -222,8 +222,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
|
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
|
||||||
|
|
||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
offsets.append(self.offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
token_offsets.append(self.token_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
|
||||||
next_token_choosers.append(self.next_token_choosers[idx])
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
|
|
||||||
|
@ -269,8 +269,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
|
@ -302,8 +302,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids_tensor = []
|
all_input_ids_tensor = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
@ -347,8 +347,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
||||||
|
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
offsets.extend(batch.offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
token_offsets.extend(batch.token_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
|
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
@ -374,8 +374,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
|
@ -394,7 +394,6 @@ class FlashCausalLM(Model):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
decode_buffer: int = 3,
|
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
@ -405,23 +404,19 @@ class FlashCausalLM(Model):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
self.model = (
|
model = model_cls.from_pretrained(
|
||||||
model_cls.from_pretrained(
|
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
)
|
).to(device)
|
||||||
.eval()
|
|
||||||
.to(device)
|
|
||||||
)
|
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
decode_buffer=decode_buffer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -645,8 +640,8 @@ class FlashCausalLM(Model):
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.offsets,
|
batch.prefix_offsets,
|
||||||
batch.token_offsets,
|
batch.read_offsets,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
|
@ -659,8 +654,8 @@ class FlashCausalLM(Model):
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
offset,
|
prefix_offset,
|
||||||
token_offset,
|
read_offset,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
|
@ -675,10 +670,10 @@ class FlashCausalLM(Model):
|
||||||
all_input_ids.append(next_token_id)
|
all_input_ids.append(next_token_id)
|
||||||
|
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_text, offset, token_offset = self.decode_token(
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
offset,
|
prefix_offset,
|
||||||
token_offset,
|
read_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
|
@ -744,8 +739,8 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
batch.offsets[i] = offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.token_offsets[i] = token_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.max_seqlen = batch.max_seqlen + 1
|
batch.max_seqlen = batch.max_seqlen + 1
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
|
@ -64,9 +64,9 @@ class FlashLlama(FlashCausalLM):
|
||||||
model = FlashLlamaForCausalLM(config)
|
model = FlashLlamaForCausalLM(config)
|
||||||
|
|
||||||
self.load_weights(model, filenames, quantize, device, dtype)
|
self.load_weights(model, filenames, quantize, device, dtype)
|
||||||
self.model = model.eval().to(device)
|
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -189,9 +189,9 @@ class FlashLlamaSharded(FlashLlama):
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval().to(device)
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
|
@ -73,9 +73,9 @@ class FlashNeoXSharded(FlashNeoX):
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval().to(device)
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
|
@ -67,14 +67,13 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
dtype,
|
dtype,
|
||||||
config.architectures[0].startswith("GPT2"),
|
config.architectures[0].startswith("GPT2"),
|
||||||
)
|
)
|
||||||
self.model = model.eval().to(device)
|
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
decode_buffer=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -213,16 +212,15 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
transpose=config.architectures[0].startswith("GPT2"),
|
transpose=config.architectures[0].startswith("GPT2"),
|
||||||
)
|
)
|
||||||
self.model = model.eval().to(device)
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
decode_buffer=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -94,8 +94,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
|
@ -106,8 +106,6 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
offsets.append(None)
|
|
||||||
token_offsets.append(None)
|
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
|
@ -127,6 +125,10 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_truncation,
|
max_length=max_truncation,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
for _ in pb.requests:
|
||||||
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
|
prefix_offsets.append(0)
|
||||||
|
read_offsets.append(input_len)
|
||||||
|
|
||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
max_input_length = input_lengths.max()
|
max_input_length = input_lengths.max()
|
||||||
|
@ -155,8 +157,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=list(all_input_ids),
|
all_input_ids=list(all_input_ids),
|
||||||
input_lengths=input_lengths.tolist(),
|
input_lengths=input_lengths.tolist(),
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
|
@ -231,9 +233,9 @@ class GalacticaSharded(Galactica):
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval()
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
|
@ -70,9 +70,9 @@ class GPTNeoxSharded(CausalLM):
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval()
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
|
@ -13,23 +13,20 @@ B = TypeVar("B", bound=Batch)
|
||||||
class Model(ABC):
|
class Model(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
decode_buffer: int = 3,
|
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
):
|
):
|
||||||
if decode_buffer < 1:
|
self.model = model.eval()
|
||||||
raise ValueError("decode_buffer must be >= 1")
|
|
||||||
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
self.requires_padding = requires_padding
|
self.requires_padding = requires_padding
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.decode_buffer = decode_buffer
|
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.check_initialized()
|
self.check_initialized()
|
||||||
|
@ -54,52 +51,29 @@ class Model(ABC):
|
||||||
def decode_token(
|
def decode_token(
|
||||||
self,
|
self,
|
||||||
all_input_ids: List[int],
|
all_input_ids: List[int],
|
||||||
offset: Optional[int] = None,
|
prefix_offset: int = 0,
|
||||||
token_offset: Optional[int] = None,
|
read_offset: int = 0,
|
||||||
) -> Tuple[str, Optional[int], Optional[int]]:
|
) -> Tuple[str, int, int]:
|
||||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||||
if all_input_ids[-1] in self.all_special_ids:
|
|
||||||
return (
|
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
||||||
self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False),
|
# which decide to add a space or not depending on the surrounding ids.
|
||||||
None,
|
prefix_text = self.tokenizer.decode(
|
||||||
None,
|
all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
|
||||||
|
)
|
||||||
|
new_text = self.tokenizer.decode(
|
||||||
|
all_input_ids[prefix_offset:], skip_special_tokens=False
|
||||||
)
|
)
|
||||||
|
|
||||||
if token_offset is None:
|
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||||
token_offset = len(all_input_ids) - self.decode_buffer
|
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||||
# left token buffer
|
# from byte fallback tokenization.
|
||||||
if self.decode_buffer > 1:
|
# If it's in the middle, it's probably a real invalid id generated
|
||||||
# Decode token_offset token minus last one and token_offset tokens
|
# by the model
|
||||||
raw_texts = self.tokenizer.batch_decode(
|
new_text = new_text[len(prefix_text) :]
|
||||||
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
|
return new_text, read_offset, len(all_input_ids)
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# default offset is only the last token
|
|
||||||
offset = len(raw_texts[0])
|
|
||||||
sequence_text = raw_texts[1]
|
|
||||||
else:
|
else:
|
||||||
# Only decode the last token without using a token buffer
|
return "", prefix_offset, read_offset
|
||||||
sequence_text = self.tokenizer.decode(
|
|
||||||
all_input_ids[-1], skip_special_tokens=False
|
|
||||||
)
|
|
||||||
# no offset in this case
|
|
||||||
offset = 0
|
|
||||||
else:
|
|
||||||
assert offset is not None
|
|
||||||
sequence_text = self.tokenizer.decode(
|
|
||||||
all_input_ids[token_offset:],
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get text
|
|
||||||
token_text = sequence_text[offset:]
|
|
||||||
|
|
||||||
# if text is utf-8
|
|
||||||
if token_text and token_text[-1] != "<EFBFBD>":
|
|
||||||
return token_text, None, None
|
|
||||||
else:
|
|
||||||
return "", offset, token_offset
|
|
||||||
|
|
||||||
def check_initialized(self):
|
def check_initialized(self):
|
||||||
uninitialized_parameters = []
|
uninitialized_parameters = []
|
||||||
|
|
|
@ -86,9 +86,9 @@ class OPTSharded(OPT):
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval()
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
|
@ -46,24 +46,20 @@ class SantaCoder(CausalLM):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = (
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=True, # required
|
trust_remote_code=True, # required
|
||||||
)
|
).to(device)
|
||||||
.to(device)
|
|
||||||
.eval()
|
|
||||||
)
|
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
decode_buffer=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
|
|
|
@ -42,8 +42,8 @@ class Seq2SeqLMBatch(Batch):
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
decoder_input_lengths: List[int]
|
decoder_input_lengths: List[int]
|
||||||
offsets: List[Optional[int]]
|
prefix_offsets: List[int]
|
||||||
token_offsets: List[Optional[int]]
|
read_offsets: List[int]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
|
@ -79,8 +79,8 @@ class Seq2SeqLMBatch(Batch):
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
|
@ -91,8 +91,6 @@ class Seq2SeqLMBatch(Batch):
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
offsets.append(None)
|
|
||||||
token_offsets.append(None)
|
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
|
@ -123,6 +121,9 @@ class Seq2SeqLMBatch(Batch):
|
||||||
.repeat(len(pb.requests))
|
.repeat(len(pb.requests))
|
||||||
.view(-1, 1)
|
.view(-1, 1)
|
||||||
)
|
)
|
||||||
|
for _ in pb.requests:
|
||||||
|
prefix_offsets.append(0)
|
||||||
|
read_offsets.append(1)
|
||||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||||
|
|
||||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||||
|
@ -140,8 +141,8 @@ class Seq2SeqLMBatch(Batch):
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
input_lengths=input_lengths.tolist(),
|
input_lengths=input_lengths.tolist(),
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
|
@ -165,8 +166,8 @@ class Seq2SeqLMBatch(Batch):
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
all_decoder_input_ids = []
|
all_decoder_input_ids = []
|
||||||
|
|
||||||
|
@ -184,8 +185,8 @@ class Seq2SeqLMBatch(Batch):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
keep_indices.append(idx)
|
keep_indices.append(idx)
|
||||||
|
|
||||||
offsets.append(self.offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
token_offsets.append(self.token_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
|
||||||
all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
|
all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
|
||||||
|
|
||||||
|
@ -248,8 +249,8 @@ class Seq2SeqLMBatch(Batch):
|
||||||
self.all_decoder_input_ids = all_decoder_input_ids
|
self.all_decoder_input_ids = all_decoder_input_ids
|
||||||
self.input_lengths = input_lengths
|
self.input_lengths = input_lengths
|
||||||
self.decoder_input_lengths = decoder_input_lengths
|
self.decoder_input_lengths = decoder_input_lengths
|
||||||
self.offsets = offsets
|
self.prefix_offsets = prefix_offsets
|
||||||
self.token_offsets = token_offsets
|
self.read_offsets = read_offsets
|
||||||
self.next_token_choosers = next_token_choosers
|
self.next_token_choosers = next_token_choosers
|
||||||
self.stopping_criterias = stopping_criterias
|
self.stopping_criterias = stopping_criterias
|
||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
|
@ -283,8 +284,8 @@ class Seq2SeqLMBatch(Batch):
|
||||||
all_decoder_input_ids = []
|
all_decoder_input_ids = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
offsets = []
|
prefix_offsets = []
|
||||||
token_offsets = []
|
read_offsets = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
|
@ -306,8 +307,8 @@ class Seq2SeqLMBatch(Batch):
|
||||||
all_decoder_input_ids.extend(batch.all_decoder_input_ids)
|
all_decoder_input_ids.extend(batch.all_decoder_input_ids)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
||||||
offsets.extend(batch.offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
token_offsets.extend(batch.token_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
|
@ -482,8 +483,8 @@ class Seq2SeqLMBatch(Batch):
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
offsets=offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
token_offsets=token_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
|
@ -502,7 +503,6 @@ class Seq2SeqLM(Model):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
decode_buffer: int = 3,
|
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
@ -514,24 +514,24 @@ class Seq2SeqLM(Model):
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
).eval()
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
tokenizer.bos_token_id = model.config.decoder_start_token_id
|
||||||
|
|
||||||
super(Seq2SeqLM, self).__init__(
|
super(Seq2SeqLM, self).__init__(
|
||||||
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
decode_buffer=decode_buffer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -608,8 +608,8 @@ class Seq2SeqLM(Model):
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.offsets,
|
batch.prefix_offsets,
|
||||||
batch.token_offsets,
|
batch.read_offsets,
|
||||||
batch.decoder_input_lengths,
|
batch.decoder_input_lengths,
|
||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
|
@ -621,8 +621,8 @@ class Seq2SeqLM(Model):
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
offset,
|
prefix_offset,
|
||||||
token_offset,
|
read_offset,
|
||||||
decoder_input_length,
|
decoder_input_length,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
|
@ -643,8 +643,8 @@ class Seq2SeqLM(Model):
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text, offset, token_offset = self.decode_token(
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
all_decoder_input_ids, offset, token_offset
|
all_decoder_input_ids, prefix_offset, read_offset
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
|
@ -702,8 +702,8 @@ class Seq2SeqLM(Model):
|
||||||
batch.all_decoder_input_ids[i] = all_decoder_input_ids
|
batch.all_decoder_input_ids[i] = all_decoder_input_ids
|
||||||
batch.input_lengths[i] = input_length
|
batch.input_lengths[i] = input_length
|
||||||
batch.decoder_input_lengths[i] = new_decoder_input_length
|
batch.decoder_input_lengths[i] = new_decoder_input_length
|
||||||
batch.offsets[i] = offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.token_offsets[i] = token_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.max_input_length = max(batch.max_input_length, input_length)
|
batch.max_input_length = max(batch.max_input_length, input_length)
|
||||||
batch.max_decoder_input_length = max(
|
batch.max_decoder_input_length = max(
|
||||||
batch.max_decoder_input_length, new_decoder_input_length
|
batch.max_decoder_input_length, new_decoder_input_length
|
||||||
|
|
|
@ -16,9 +16,6 @@ from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.layers import (
|
|
||||||
FastLinear,
|
|
||||||
)
|
|
||||||
from transformers.models.t5.parallel_layers import (
|
from transformers.models.t5.parallel_layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -73,9 +70,9 @@ class T5Sharded(Seq2SeqLM):
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval()
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(Seq2SeqLM, self).__init__(
|
super(Seq2SeqLM, self).__init__(
|
||||||
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
Loading…
Reference in New Issue