diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml
index 79b3c777..9305b8e7 100644
--- a/.github/workflows/build.yaml
+++ b/.github/workflows/build.yaml
@@ -213,12 +213,13 @@ jobs:
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
- name: Install
run: |
+ pip install pytest-xdist
make install-integration-tests
- name: Run tests
run: |
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
- pytest -s -vv integration-tests
+ pytest -s -vv -n 2 --dist loadfile integration-tests
stop-runner:
name: Stop self-hosted EC2 runner
diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml
index d9858a3b..7e5ba52c 100644
--- a/.github/workflows/tests.yaml
+++ b/.github/workflows/tests.yaml
@@ -66,7 +66,8 @@ jobs:
- name: Run server tests
run: |
pip install pytest
- make python-server-tests
+ export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
+ pytest -s -vv server/tests
- name: Run Rust fmt
run: |
cargo fmt --check
diff --git a/Makefile b/Makefile
index 29c318fa..7309aaee 100644
--- a/Makefile
+++ b/Makefile
@@ -31,7 +31,7 @@ update-integration-tests: install-integration-tests
pytest -s -vv --snapshot-update integration-tests
python-server-tests:
- HF_HUB_ENABLE_HF_TRANSFER=1 pytest server/tests
+ HF_HUB_ENABLE_HF_TRANSFER=1 pytest -s -vv -m "not private" server/tests
python-client-tests:
pytest clients/python/tests
diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py
index ba1abca9..3086ecda 100644
--- a/integration-tests/conftest.py
+++ b/integration-tests/conftest.py
@@ -1,3 +1,4 @@
+import sys
import subprocess
import contextlib
import pytest
@@ -7,6 +8,7 @@ import docker
import json
import math
import time
+import random
from docker.errors import NotFound
from typing import Optional, List, Dict
@@ -205,10 +207,12 @@ def launcher(event_loop):
def local_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
):
- port = 9999
- master_port = 19999
+ port = random.randint(8000, 10_000)
+ master_port = random.randint(10_000, 20_000)
- shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server"
+ shard_uds_path = (
+ f"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server"
+ )
args = [
"text-generation-launcher",
@@ -236,7 +240,7 @@ def launcher(event_loop):
process.wait(60)
launcher_output = process.stdout.read().decode("utf-8")
- print(launcher_output)
+ print(launcher_output, file=sys.stderr)
process.stdout.close()
process.stderr.close()
@@ -245,7 +249,7 @@ def launcher(event_loop):
def docker_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
):
- port = 9999
+ port = random.randint(8000, 10_000)
args = ["--model-id", model_id, "--env"]
@@ -298,7 +302,7 @@ def launcher(event_loop):
pass
container_output = container.logs().decode("utf-8")
- print(container_output)
+ print(container_output, file=sys.stderr)
container.remove()
diff --git a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json
index 5a8ba217..9bbb5322 100644
--- a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json
+++ b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json
@@ -1,92 +1,4 @@
[
- {
- "details": {
- "best_of_sequences": null,
- "finish_reason": "length",
- "generated_tokens": 10,
- "prefill": [
- {
- "id": 1,
- "logprob": null,
- "text": ""
- },
- {
- "id": 4321,
- "logprob": -8.6875,
- "text": "Test"
- },
- {
- "id": 2009,
- "logprob": -11.5546875,
- "text": "request"
- }
- ],
- "seed": null,
- "tokens": [
- {
- "id": 363,
- "logprob": -1.5322266,
- "special": false,
- "text": " for"
- },
- {
- "id": 847,
- "logprob": -2.5585938,
- "special": false,
- "text": " /"
- },
- {
- "id": 2754,
- "logprob": -2.265625,
- "special": false,
- "text": "api"
- },
- {
- "id": 29914,
- "logprob": -0.034088135,
- "special": false,
- "text": "/"
- },
- {
- "id": 29894,
- "logprob": -0.96240234,
- "special": false,
- "text": "v"
- },
- {
- "id": 29896,
- "logprob": -0.36816406,
- "special": false,
- "text": "1"
- },
- {
- "id": 29914,
- "logprob": -0.013191223,
- "special": false,
- "text": "/"
- },
- {
- "id": 16418,
- "logprob": -3.15625,
- "special": false,
- "text": "projects"
- },
- {
- "id": 29914,
- "logprob": -0.43774414,
- "special": false,
- "text": "/"
- },
- {
- "id": 29896,
- "logprob": -1.9443359,
- "special": false,
- "text": "1"
- }
- ]
- },
- "generated_text": "for /api/v1/projects/1"
- },
{
"details": {
"best_of_sequences": null,
@@ -263,6 +175,94 @@
},
"generated_text": "for /api/v1/projects/1"
},
+ {
+ "details": {
+ "best_of_sequences": null,
+ "finish_reason": "length",
+ "generated_tokens": 10,
+ "prefill": [
+ {
+ "id": 1,
+ "logprob": null,
+ "text": ""
+ },
+ {
+ "id": 4321,
+ "logprob": -8.6875,
+ "text": "Test"
+ },
+ {
+ "id": 2009,
+ "logprob": -11.5546875,
+ "text": "request"
+ }
+ ],
+ "seed": null,
+ "tokens": [
+ {
+ "id": 363,
+ "logprob": -1.5322266,
+ "special": false,
+ "text": " for"
+ },
+ {
+ "id": 847,
+ "logprob": -2.5585938,
+ "special": false,
+ "text": " /"
+ },
+ {
+ "id": 2754,
+ "logprob": -2.265625,
+ "special": false,
+ "text": "api"
+ },
+ {
+ "id": 29914,
+ "logprob": -0.034088135,
+ "special": false,
+ "text": "/"
+ },
+ {
+ "id": 29894,
+ "logprob": -0.96240234,
+ "special": false,
+ "text": "v"
+ },
+ {
+ "id": 29896,
+ "logprob": -0.36816406,
+ "special": false,
+ "text": "1"
+ },
+ {
+ "id": 29914,
+ "logprob": -0.013191223,
+ "special": false,
+ "text": "/"
+ },
+ {
+ "id": 16418,
+ "logprob": -3.15625,
+ "special": false,
+ "text": "projects"
+ },
+ {
+ "id": 29914,
+ "logprob": -0.43774414,
+ "special": false,
+ "text": "/"
+ },
+ {
+ "id": 29896,
+ "logprob": -1.9443359,
+ "special": false,
+ "text": "1"
+ }
+ ]
+ },
+ "generated_text": "for /api/v1/projects/1"
+ },
{
"details": {
"best_of_sequences": null,
diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json
index 2a26e3db..c1cd24cd 100644
--- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json
+++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json
@@ -16,7 +16,7 @@
"id": 926,
"logprob": -4.3554688,
"special": false,
- "text": "To"
+ "text": " To"
},
{
"id": 18295,
diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json
index fd77252d..3e9f3d73 100644
--- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json
+++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json
@@ -16,7 +16,7 @@
"id": 16017,
"logprob": -1.3505859,
"special": false,
- "text": "blue"
+ "text": " blue"
},
{
"id": 20495,
diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json
index c9e552b6..c0834ae1 100644
--- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json
+++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json
@@ -1,58 +1,4 @@
[
- {
- "details": {
- "best_of_sequences": null,
- "finish_reason": "eos_token",
- "generated_tokens": 6,
- "prefill": [
- {
- "id": 0,
- "logprob": null,
- "text": ""
- }
- ],
- "seed": null,
- "tokens": [
- {
- "id": 259,
- "logprob": -1.3789062,
- "special": false,
- "text": ""
- },
- {
- "id": 39261,
- "logprob": -0.36279297,
- "special": false,
- "text": "Because"
- },
- {
- "id": 609,
- "logprob": -1.0966797,
- "special": false,
- "text": " it"
- },
- {
- "id": 339,
- "logprob": -0.8276367,
- "special": false,
- "text": " is"
- },
- {
- "id": 16017,
- "logprob": -1.6845703,
- "special": false,
- "text": " blue"
- },
- {
- "id": 1,
- "logprob": -0.72753906,
- "special": true,
- "text": ""
- }
- ]
- },
- "generated_text": "Because it is blue"
- },
{
"details": {
"best_of_sequences": null,
@@ -71,7 +17,7 @@
"id": 259,
"logprob": -1.3798828,
"special": false,
- "text": ""
+ "text": " "
},
{
"id": 39261,
@@ -125,7 +71,7 @@
"id": 259,
"logprob": -1.3789062,
"special": false,
- "text": ""
+ "text": " "
},
{
"id": 39261,
@@ -179,7 +125,61 @@
"id": 259,
"logprob": -1.3789062,
"special": false,
- "text": ""
+ "text": " "
+ },
+ {
+ "id": 39261,
+ "logprob": -0.36279297,
+ "special": false,
+ "text": "Because"
+ },
+ {
+ "id": 609,
+ "logprob": -1.0966797,
+ "special": false,
+ "text": " it"
+ },
+ {
+ "id": 339,
+ "logprob": -0.8276367,
+ "special": false,
+ "text": " is"
+ },
+ {
+ "id": 16017,
+ "logprob": -1.6845703,
+ "special": false,
+ "text": " blue"
+ },
+ {
+ "id": 1,
+ "logprob": -0.72753906,
+ "special": true,
+ "text": ""
+ }
+ ]
+ },
+ "generated_text": "Because it is blue"
+ },
+ {
+ "details": {
+ "best_of_sequences": null,
+ "finish_reason": "eos_token",
+ "generated_tokens": 6,
+ "prefill": [
+ {
+ "id": 0,
+ "logprob": null,
+ "text": ""
+ }
+ ],
+ "seed": null,
+ "tokens": [
+ {
+ "id": 259,
+ "logprob": -1.3789062,
+ "special": false,
+ "text": " "
},
{
"id": 39261,
diff --git a/router/src/main.rs b/router/src/main.rs
index 5ad49003..82bf6ba8 100644
--- a/router/src/main.rs
+++ b/router/src/main.rs
@@ -146,7 +146,7 @@ fn main() -> Result<(), std::io::Error> {
sha: None,
pipeline_tag: None,
},
- false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({
+ false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
}),
diff --git a/server/Makefile b/server/Makefile
index 150d7e4a..6eb56c75 100644
--- a/server/Makefile
+++ b/server/Makefile
@@ -2,7 +2,7 @@ include Makefile-transformers
include Makefile-flash-att
unit-tests:
- python -m pytest tests
+ pytest -s -vv -m "not private" tests
gen-server:
# Compile protos
diff --git a/server/tests/models/test_model.py b/server/tests/models/test_model.py
new file mode 100644
index 00000000..32bcd45f
--- /dev/null
+++ b/server/tests/models/test_model.py
@@ -0,0 +1,78 @@
+import pytest
+import torch
+
+from transformers import AutoTokenizer
+
+from text_generation_server.models import Model
+
+
+def get_test_model():
+ class TestModel(Model):
+ def batch_type(self):
+ raise NotImplementedError
+
+ def generate_token(self, batch):
+ raise NotImplementedError
+
+ tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
+
+ model = TestModel(
+ torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
+ )
+ return model
+
+
+@pytest.mark.private
+def test_decode_streaming_english_spaces():
+ model = get_test_model()
+ truth = "Hello here, this is a simple test"
+ all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
+ assert (
+ all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
+ )
+
+ decoded_text = ""
+ offset = 0
+ token_offset = 0
+ for i in range(len(all_input_ids)):
+ text, offset, token_offset = model.decode_token(
+ all_input_ids[: i + 1], offset, token_offset
+ )
+ decoded_text += text
+
+ assert decoded_text == truth
+
+
+@pytest.mark.private
+def test_decode_streaming_chinese_utf8():
+ model = get_test_model()
+ truth = "我很感谢你的热情"
+ all_input_ids = [
+ 30672,
+ 232,
+ 193,
+ 139,
+ 233,
+ 135,
+ 162,
+ 235,
+ 179,
+ 165,
+ 30919,
+ 30210,
+ 234,
+ 134,
+ 176,
+ 30993,
+ ]
+
+ decoded_text = ""
+ offset = 0
+ token_offset = 0
+ for i in range(len(all_input_ids)):
+ text, offset, token_offset = model.decode_token(
+ all_input_ids[: i + 1], offset, token_offset
+ )
+ decoded_text += text
+
+ assert decoded_text == truth
diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py
index 65dafa50..ba769e75 100644
--- a/server/tests/models/test_seq2seq_lm.py
+++ b/server/tests/models/test_seq2seq_lm.py
@@ -149,7 +149,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 259 for generation in generations])
- assert all([generation.token_text == "" for generation in generations])
+ assert all([generation.token_text == " " for generation in generations])
assert generations[0].request_id == 0
diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py
index 9029e954..1f324f77 100644
--- a/server/text_generation_server/models/bloom.py
+++ b/server/text_generation_server/models/bloom.py
@@ -56,7 +56,7 @@ class BLOOM(CausalLM):
quantize: Optional[str] = None,
):
super(BLOOM, self).__init__(
- model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
+ model_id=model_id, revision=revision, quantize=quantize
)
@property
@@ -104,14 +104,13 @@ class BLOOMSharded(BLOOM):
rank=rank,
world_size=world_size,
)
- self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
+ model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
- decode_buffer=1,
rank=rank,
world_size=world_size,
)
diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py
index 0d521ac4..9d8ae254 100644
--- a/server/text_generation_server/models/causal_lm.py
+++ b/server/text_generation_server/models/causal_lm.py
@@ -35,8 +35,8 @@ class CausalLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
- offsets: List[Optional[int]]
- token_offsets: List[Optional[int]]
+ prefix_offsets: List[int]
+ read_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
@@ -70,8 +70,8 @@ class CausalLMBatch(Batch):
inputs = []
next_token_choosers = []
stopping_criterias = []
- offsets = []
- token_offsets = []
+ prefix_offsets = []
+ read_offsets = []
requests_idx_mapping = {}
# Parse batch
@@ -81,8 +81,6 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
- offsets.append(None)
- token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
@@ -102,6 +100,10 @@ class CausalLMBatch(Batch):
truncation=True,
max_length=max_truncation,
).to(device)
+ for _ in pb.requests:
+ input_len = tokenized_inputs["input_ids"].shape[1]
+ prefix_offsets.append(0)
+ read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
@@ -130,8 +132,8 @@ class CausalLMBatch(Batch):
past_key_values=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
- offsets=offsets,
- token_offsets=token_offsets,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(),
@@ -151,8 +153,8 @@ class CausalLMBatch(Batch):
# New values after filtering
requests_idx_mapping = {}
input_lengths = []
- offsets = []
- token_offsets = []
+ prefix_offsets = []
+ read_offsets = []
all_input_ids = []
max_input_length = 0
@@ -167,8 +169,8 @@ class CausalLMBatch(Batch):
requests_idx_mapping[r.id] = i
keep_indices.append(idx)
- offsets.append(self.offsets[idx])
- token_offsets.append(self.token_offsets[idx])
+ prefix_offsets.append(self.prefix_offsets[idx])
+ read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
@@ -225,8 +227,8 @@ class CausalLMBatch(Batch):
self.position_ids = position_ids
self.all_input_ids = all_input_ids
self.input_lengths = input_lengths
- self.offsets = offsets
- self.token_offsets = token_offsets
+ self.prefix_offsets = prefix_offsets
+ self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length
@@ -251,8 +253,8 @@ class CausalLMBatch(Batch):
requests = []
requests_idx_mapping = {}
input_lengths = []
- offsets = []
- token_offsets = []
+ prefix_offsets = []
+ read_offsets = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
@@ -270,8 +272,8 @@ class CausalLMBatch(Batch):
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
- offsets.extend(batch.offsets)
- token_offsets.extend(batch.token_offsets)
+ prefix_offsets.extend(batch.prefix_offsets)
+ read_offsets.extend(batch.read_offsets)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
@@ -428,8 +430,8 @@ class CausalLMBatch(Batch):
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
- offsets=offsets,
- token_offsets=token_offsets,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
@@ -448,7 +450,6 @@ class CausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
- decode_buffer: int = 3,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@@ -463,25 +464,25 @@ class CausalLM(Model):
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)
- self.model = AutoModelForCausalLM.from_pretrained(
+ model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize == "bitsandbytes",
- ).eval()
+ )
tokenizer.pad_token_id = (
- self.model.config.pad_token_id
- if self.model.config.pad_token_id is not None
- else self.model.config.eos_token_id
+ model.config.pad_token_id
+ if model.config.pad_token_id is not None
+ else model.config.eos_token_id
)
super(CausalLM, self).__init__(
+ model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
- decode_buffer=decode_buffer,
)
@property
@@ -528,8 +529,8 @@ class CausalLM(Model):
iterator = zip(
batch.requests,
batch.input_lengths,
- batch.offsets,
- batch.token_offsets,
+ batch.prefix_offsets,
+ batch.read_offsets,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
@@ -540,8 +541,8 @@ class CausalLM(Model):
for i, (
request,
input_length,
- offset,
- token_offset,
+ prefix_offset,
+ read_offset,
logits,
next_token_chooser,
stopping_criteria,
@@ -559,8 +560,8 @@ class CausalLM(Model):
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
- next_token_text, offset, token_offset = self.decode_token(
- all_input_ids[:, 0], offset, token_offset
+ next_token_text, prefix_offset, read_offset = self.decode_token(
+ all_input_ids[:, 0], prefix_offset, read_offset
)
# Evaluate stopping criteria
@@ -628,8 +629,8 @@ class CausalLM(Model):
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
- batch.offsets[i] = offset
- batch.token_offsets[i] = token_offset
+ batch.prefix_offsets[i] = prefix_offset
+ batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py
index 0a9fccca..aee0480d 100644
--- a/server/text_generation_server/models/flash_causal_lm.py
+++ b/server/text_generation_server/models/flash_causal_lm.py
@@ -52,8 +52,8 @@ class FlashCausalLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
- offsets: List[Optional[int]]
- token_offsets: List[Optional[int]]
+ prefix_offsets: List[Optional[int]]
+ read_offsets: List[Optional[int]]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
@@ -82,8 +82,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen = 0
input_lengths = []
- offsets = []
- token_offsets = []
+ prefix_offsets = []
+ read_offsets = []
all_input_ids = []
requests_idx_mapping = {}
@@ -108,8 +108,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
- offsets.append(None)
- token_offsets.append(None)
+ prefix_offsets.append(0)
+ read_offsets.append(input_length)
all_input_ids.append(tokenized_input)
@@ -151,8 +151,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen=max_seqlen,
past_key_values=None,
input_lengths=input_lengths,
- offsets=offsets,
- token_offsets=token_offsets,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=[],
next_token_choosers=next_token_choosers,
@@ -190,8 +190,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = []
input_lengths = []
- offsets = []
- token_offsets = []
+ prefix_offsets = []
+ read_offsets = []
next_token_choosers = []
stopping_criterias = []
@@ -222,8 +222,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
input_lengths.append(request_input_length)
- offsets.append(self.offsets[idx])
- token_offsets.append(self.token_offsets[idx])
+ prefix_offsets.append(self.prefix_offsets[idx])
+ read_offsets.append(self.read_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx])
@@ -269,8 +269,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
- offsets=offsets,
- token_offsets=token_offsets,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
@@ -302,8 +302,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = []
input_lengths = []
- offsets = []
- token_offsets = []
+ prefix_offsets = []
+ read_offsets = []
next_token_choosers = []
stopping_criterias = []
@@ -347,8 +347,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
input_lengths.extend(batch.input_lengths)
- offsets.extend(batch.offsets)
- token_offsets.extend(batch.token_offsets)
+ prefix_offsets.extend(batch.prefix_offsets)
+ read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
@@ -374,8 +374,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
- offsets=offsets,
- token_offsets=token_offsets,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
@@ -394,7 +394,6 @@ class FlashCausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
- decode_buffer: int = 3,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@@ -405,23 +404,19 @@ class FlashCausalLM(Model):
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)
- self.model = (
- model_cls.from_pretrained(
- model_id,
- revision=revision,
- torch_dtype=dtype,
- load_in_8bit=quantize == "bitsandbytes",
- )
- .eval()
- .to(device)
- )
+ model = model_cls.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=dtype,
+ load_in_8bit=quantize == "bitsandbytes",
+ ).to(device)
super(FlashCausalLM, self).__init__(
+ model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
- decode_buffer=decode_buffer,
)
@property
@@ -645,8 +640,8 @@ class FlashCausalLM(Model):
iterator = zip(
batch.requests,
batch.input_lengths,
- batch.offsets,
- batch.token_offsets,
+ batch.prefix_offsets,
+ batch.read_offsets,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
@@ -659,8 +654,8 @@ class FlashCausalLM(Model):
for i, (
request,
input_length,
- offset,
- token_offset,
+ prefix_offset,
+ read_offset,
next_token_chooser,
stopping_criteria,
all_input_ids,
@@ -675,10 +670,10 @@ class FlashCausalLM(Model):
all_input_ids.append(next_token_id)
# Generated token
- next_token_text, offset, token_offset = self.decode_token(
+ next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
- offset,
- token_offset,
+ prefix_offset,
+ read_offset,
)
# Evaluate stopping criteria
@@ -744,8 +739,8 @@ class FlashCausalLM(Model):
# Update values
batch.input_lengths[i] = new_input_length
- batch.offsets[i] = offset
- batch.token_offsets[i] = token_offset
+ batch.prefix_offsets[i] = prefix_offset
+ batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
batch.max_seqlen = batch.max_seqlen + 1
cumulative_length += input_length
diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py
index b775bd79..ebdbe206 100644
--- a/server/text_generation_server/models/flash_llama.py
+++ b/server/text_generation_server/models/flash_llama.py
@@ -64,9 +64,9 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(config)
self.load_weights(model, filenames, quantize, device, dtype)
- self.model = model.eval().to(device)
super(FlashCausalLM, self).__init__(
+ model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
@@ -189,9 +189,9 @@ class FlashLlamaSharded(FlashLlama):
rank=rank,
world_size=world_size,
)
- self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
+ model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py
index 0924f107..cac40bab 100644
--- a/server/text_generation_server/models/flash_neox.py
+++ b/server/text_generation_server/models/flash_neox.py
@@ -73,9 +73,9 @@ class FlashNeoXSharded(FlashNeoX):
rank=rank,
world_size=world_size,
)
- self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
+ model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py
index 031a67eb..5dc31309 100644
--- a/server/text_generation_server/models/flash_santacoder.py
+++ b/server/text_generation_server/models/flash_santacoder.py
@@ -67,14 +67,13 @@ class FlashSantacoder(FlashCausalLM):
dtype,
config.architectures[0].startswith("GPT2"),
)
- self.model = model.eval().to(device)
super(FlashCausalLM, self).__init__(
+ model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
- decode_buffer=1,
)
@staticmethod
@@ -213,16 +212,15 @@ class FlashSantacoderSharded(FlashSantacoder):
world_size=world_size,
transpose=config.architectures[0].startswith("GPT2"),
)
- self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
+ model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
- decode_buffer=1,
)
@staticmethod
diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py
index d1e5e841..24c37c19 100644
--- a/server/text_generation_server/models/galactica.py
+++ b/server/text_generation_server/models/galactica.py
@@ -94,8 +94,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
inputs = []
next_token_choosers = []
stopping_criterias = []
- offsets = []
- token_offsets = []
+ prefix_offsets = []
+ read_offsets = []
requests_idx_mapping = {}
# Parse batch
@@ -106,8 +106,6 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
- offsets.append(None)
- token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
@@ -127,6 +125,10 @@ class GalacticaCausalLMBatch(CausalLMBatch):
truncation=True,
max_length=max_truncation,
).to(device)
+ for _ in pb.requests:
+ input_len = tokenized_inputs["input_ids"].shape[1]
+ prefix_offsets.append(0)
+ read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
@@ -155,8 +157,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
past_key_values=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
- offsets=offsets,
- token_offsets=token_offsets,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(),
@@ -231,9 +233,9 @@ class GalacticaSharded(Galactica):
rank=rank,
world_size=world_size,
)
- self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
+ model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py
index f95e5be2..a10dfcb8 100644
--- a/server/text_generation_server/models/gpt_neox.py
+++ b/server/text_generation_server/models/gpt_neox.py
@@ -70,9 +70,9 @@ class GPTNeoxSharded(CausalLM):
rank=rank,
world_size=world_size,
)
- self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
+ model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py
index 03f14013..29bad321 100644
--- a/server/text_generation_server/models/model.py
+++ b/server/text_generation_server/models/model.py
@@ -13,23 +13,20 @@ B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(
self,
+ model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
requires_padding: bool,
dtype: torch.dtype,
device: torch.device,
- decode_buffer: int = 3,
rank: int = 0,
world_size: int = 1,
):
- if decode_buffer < 1:
- raise ValueError("decode_buffer must be >= 1")
-
+ self.model = model.eval()
self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids)
self.requires_padding = requires_padding
self.dtype = dtype
self.device = device
- self.decode_buffer = decode_buffer
self.rank = rank
self.world_size = world_size
self.check_initialized()
@@ -54,52 +51,29 @@ class Model(ABC):
def decode_token(
self,
all_input_ids: List[int],
- offset: Optional[int] = None,
- token_offset: Optional[int] = None,
- ) -> Tuple[str, Optional[int], Optional[int]]:
+ prefix_offset: int = 0,
+ read_offset: int = 0,
+ ) -> Tuple[str, int, int]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
- if all_input_ids[-1] in self.all_special_ids:
- return (
- self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False),
- None,
- None,
- )
- if token_offset is None:
- token_offset = len(all_input_ids) - self.decode_buffer
- # left token buffer
- if self.decode_buffer > 1:
- # Decode token_offset token minus last one and token_offset tokens
- raw_texts = self.tokenizer.batch_decode(
- [all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
- skip_special_tokens=False,
- )
+ # The prefix text is necessary only to defeat cleanup algorithms in the decode
+ # which decide to add a space or not depending on the surrounding ids.
+ prefix_text = self.tokenizer.decode(
+ all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
+ )
+ new_text = self.tokenizer.decode(
+ all_input_ids[prefix_offset:], skip_special_tokens=False
+ )
- # default offset is only the last token
- offset = len(raw_texts[0])
- sequence_text = raw_texts[1]
- else:
- # Only decode the last token without using a token buffer
- sequence_text = self.tokenizer.decode(
- all_input_ids[-1], skip_special_tokens=False
- )
- # no offset in this case
- offset = 0
+ if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
+ # utf-8 char at the end means it's a potential unfinished byte sequence
+ # from byte fallback tokenization.
+ # If it's in the middle, it's probably a real invalid id generated
+ # by the model
+ new_text = new_text[len(prefix_text) :]
+ return new_text, read_offset, len(all_input_ids)
else:
- assert offset is not None
- sequence_text = self.tokenizer.decode(
- all_input_ids[token_offset:],
- skip_special_tokens=False,
- )
-
- # get text
- token_text = sequence_text[offset:]
-
- # if text is utf-8
- if token_text and token_text[-1] != "�":
- return token_text, None, None
- else:
- return "", offset, token_offset
+ return "", prefix_offset, read_offset
def check_initialized(self):
uninitialized_parameters = []
diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py
index 093cf70a..fdae795b 100644
--- a/server/text_generation_server/models/opt.py
+++ b/server/text_generation_server/models/opt.py
@@ -86,9 +86,9 @@ class OPTSharded(OPT):
rank=rank,
world_size=world_size,
)
- self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
+ model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py
index 4bd56de1..23f89f48 100644
--- a/server/text_generation_server/models/santacoder.py
+++ b/server/text_generation_server/models/santacoder.py
@@ -46,24 +46,20 @@ class SantaCoder(CausalLM):
}
)
- self.model = (
- AutoModelForCausalLM.from_pretrained(
- model_id,
- revision=revision,
- torch_dtype=dtype,
- load_in_8bit=quantize == "bitsandbytes",
- trust_remote_code=True, # required
- )
- .to(device)
- .eval()
- )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=dtype,
+ load_in_8bit=quantize == "bitsandbytes",
+ trust_remote_code=True, # required
+ ).to(device)
super(CausalLM, self).__init__(
+ model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
- decode_buffer=1,
)
def decode(self, generated_ids: List[int]) -> str:
diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py
index 84854f5d..4f55b22f 100644
--- a/server/text_generation_server/models/seq2seq_lm.py
+++ b/server/text_generation_server/models/seq2seq_lm.py
@@ -42,8 +42,8 @@ class Seq2SeqLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
decoder_input_lengths: List[int]
- offsets: List[Optional[int]]
- token_offsets: List[Optional[int]]
+ prefix_offsets: List[int]
+ read_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
@@ -79,8 +79,8 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias = []
decoder_input_lengths = []
- offsets = []
- token_offsets = []
+ prefix_offsets = []
+ read_offsets = []
requests_idx_mapping = {}
# Parse batch
@@ -91,8 +91,6 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs)
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
- offsets.append(None)
- token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
@@ -123,6 +121,9 @@ class Seq2SeqLMBatch(Batch):
.repeat(len(pb.requests))
.view(-1, 1)
)
+ for _ in pb.requests:
+ prefix_offsets.append(0)
+ read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
@@ -140,8 +141,8 @@ class Seq2SeqLMBatch(Batch):
past_key_values=None,
input_lengths=input_lengths.tolist(),
decoder_input_lengths=decoder_input_lengths,
- offsets=offsets,
- token_offsets=token_offsets,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(),
@@ -165,8 +166,8 @@ class Seq2SeqLMBatch(Batch):
requests_idx_mapping = {}
input_lengths = []
decoder_input_lengths = []
- offsets = []
- token_offsets = []
+ prefix_offsets = []
+ read_offsets = []
all_decoder_input_ids = []
@@ -184,8 +185,8 @@ class Seq2SeqLMBatch(Batch):
requests_idx_mapping[r.id] = i
keep_indices.append(idx)
- offsets.append(self.offsets[idx])
- token_offsets.append(self.token_offsets[idx])
+ prefix_offsets.append(self.prefix_offsets[idx])
+ read_offsets.append(self.read_offsets[idx])
all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
@@ -248,8 +249,8 @@ class Seq2SeqLMBatch(Batch):
self.all_decoder_input_ids = all_decoder_input_ids
self.input_lengths = input_lengths
self.decoder_input_lengths = decoder_input_lengths
- self.offsets = offsets
- self.token_offsets = token_offsets
+ self.prefix_offsets = prefix_offsets
+ self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length
@@ -283,8 +284,8 @@ class Seq2SeqLMBatch(Batch):
all_decoder_input_ids = []
input_lengths = []
decoder_input_lengths = []
- offsets = []
- token_offsets = []
+ prefix_offsets = []
+ read_offsets = []
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
@@ -306,8 +307,8 @@ class Seq2SeqLMBatch(Batch):
all_decoder_input_ids.extend(batch.all_decoder_input_ids)
input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths)
- offsets.extend(batch.offsets)
- token_offsets.extend(batch.token_offsets)
+ prefix_offsets.extend(batch.prefix_offsets)
+ read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
@@ -482,8 +483,8 @@ class Seq2SeqLMBatch(Batch):
past_key_values=past_key_values,
input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths,
- offsets=offsets,
- token_offsets=token_offsets,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
@@ -502,7 +503,6 @@ class Seq2SeqLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
- decode_buffer: int = 3,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@@ -514,24 +514,24 @@ class Seq2SeqLM(Model):
device = torch.device("cpu")
dtype = torch.float32
- self.model = AutoModelForSeq2SeqLM.from_pretrained(
+ model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize == "bitsandbytes",
- ).eval()
+ )
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)
- tokenizer.bos_token_id = self.model.config.decoder_start_token_id
+ tokenizer.bos_token_id = model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__(
+ model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
- decode_buffer=decode_buffer,
)
@property
@@ -608,8 +608,8 @@ class Seq2SeqLM(Model):
iterator = zip(
batch.requests,
batch.input_lengths,
- batch.offsets,
- batch.token_offsets,
+ batch.prefix_offsets,
+ batch.read_offsets,
batch.decoder_input_lengths,
logits,
batch.next_token_choosers,
@@ -621,8 +621,8 @@ class Seq2SeqLM(Model):
for i, (
request,
input_length,
- offset,
- token_offset,
+ prefix_offset,
+ read_offset,
decoder_input_length,
logits,
next_token_chooser,
@@ -643,8 +643,8 @@ class Seq2SeqLM(Model):
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
- next_token_text, offset, token_offset = self.decode_token(
- all_decoder_input_ids, offset, token_offset
+ next_token_text, prefix_offset, read_offset = self.decode_token(
+ all_decoder_input_ids, prefix_offset, read_offset
)
# Evaluate stopping criteria
@@ -702,8 +702,8 @@ class Seq2SeqLM(Model):
batch.all_decoder_input_ids[i] = all_decoder_input_ids
batch.input_lengths[i] = input_length
batch.decoder_input_lengths[i] = new_decoder_input_length
- batch.offsets[i] = offset
- batch.token_offsets[i] = token_offset
+ batch.prefix_offsets[i] = prefix_offset
+ batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, input_length)
batch.max_decoder_input_length = max(
batch.max_decoder_input_length, new_decoder_input_length
diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py
index 8e3826a4..b1ba2432 100644
--- a/server/text_generation_server/models/t5.py
+++ b/server/text_generation_server/models/t5.py
@@ -16,9 +16,6 @@ from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
)
-from text_generation_server.utils.layers import (
- FastLinear,
-)
from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@@ -73,9 +70,9 @@ class T5Sharded(Seq2SeqLM):
rank=rank,
world_size=world_size,
)
- self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__(
+ model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,