feat(server): support vectorized warpers in flash causal lm (#317)

Co-authored-by: Joel Lamy-Poirier <joel.lamy-poirier@servicenow.com>
This commit is contained in:
OlivierDehaene 2023-05-26 12:30:27 +02:00 committed by GitHub
parent 951930fbff
commit 62f91f78ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1151 additions and 343 deletions

View File

@ -34,65 +34,65 @@
"tokens": [
{
"id": 408,
"logprob": -1.9267578,
"logprob": -0.07891846,
"special": false,
"text": " que"
},
{
"id": 20288,
"logprob": -2.9257812,
"special": false,
"text": " l'on"
},
{
"id": 22255,
"logprob": -2.8964844,
"special": false,
"text": " trouve"
},
{
"id": 1622,
"logprob": -1.1083984,
"special": false,
"text": " une"
},
{
"id": 187079,
"logprob": -7.796875,
"special": false,
"text": " posture"
},
{
"id": 501,
"logprob": -5.390625,
"special": false,
"text": " par"
},
{
"id": 8741,
"logprob": -0.34936523,
"special": false,
"text": " rapport"
},
{
"id": 693,
"logprob": 0.0,
"special": false,
"text": " à"
},
{
"id": 366,
"logprob": -2.3378906,
"logprob": -1.2939453,
"special": false,
"text": " la"
},
{
"id": 36503,
"logprob": -3.6640625,
"id": 8769,
"logprob": -0.3708496,
"special": false,
"text": " pratique"
"text": " personne"
},
{
"id": 1479,
"logprob": -2.2871094,
"special": false,
"text": " qui"
},
{
"id": 2997,
"logprob": -0.8671875,
"special": false,
"text": " vous"
},
{
"id": 35977,
"logprob": -1.5097656,
"special": false,
"text": " suit"
},
{
"id": 21558,
"logprob": -0.07891846,
"special": false,
"text": " ait"
},
{
"id": 447,
"logprob": -0.12695312,
"special": false,
"text": " un"
},
{
"id": 78606,
"logprob": -2.21875,
"special": false,
"text": " profil"
},
{
"id": 3899,
"logprob": -1.3535156,
"special": false,
"text": " bien"
}
]
},
"generated_text": "Pour déguster un ortolan, il faut tout d'abord que l'on trouve une posture par rapport à la pratique"
"generated_text": "Pour déguster un ortolan, il faut tout d'abord que la personne qui vous suit ait un profil bien"
}

View File

@ -1,8 +1,8 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"finish_reason": "stop_sequence",
"generated_tokens": 5,
"prefill": [
{
"id": 1,
@ -24,65 +24,35 @@
"tokens": [
{
"id": 5229,
"logprob": -3.3085938,
"logprob": -2.5683594,
"special": false,
"text": " failed"
},
{
"id": 363,
"logprob": -3.984375,
"special": false,
"text": " for"
},
{
"id": 5641,
"logprob": -6.53125,
"special": false,
"text": " IP"
},
{
"id": 16428,
"logprob": -3.1835938,
"special": false,
"text": " Address"
},
{
"id": 29901,
"logprob": -1.2324219,
"logprob": -0.45336914,
"special": false,
"text": ":"
},
{
"id": 525,
"logprob": -2.6855469,
"id": 4829,
"logprob": -1.8408203,
"special": false,
"text": " '"
"text": " Error"
},
{
"id": 8516,
"logprob": -7.1601562,
"id": 297,
"logprob": -1.0556641,
"special": false,
"text": "None"
"text": " in"
},
{
"id": 4286,
"logprob": -2.4433594,
"id": 1243,
"logprob": 0.0,
"special": false,
"text": "'."
},
{
"id": 13,
"logprob": -0.06530762,
"special": false,
"text": "\n"
},
{
"id": 294,
"logprob": -7.953125,
"special": false,
"text": "as"
"text": " test"
}
]
},
"generated_text": "Test requestfailed for IP Address: 'None'.\nas"
"generated_text": "Test requestfailed: Error in test"
}

View File

@ -1,8 +1,8 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 12,
"finish_reason": "length",
"generated_tokens": 60,
"prefill": [
{
"id": 589,
@ -29,77 +29,365 @@
"tokens": [
{
"id": 2262,
"logprob": -0.7451172,
"logprob": -0.042999268,
"special": false,
"text": "():"
},
{
"id": 284,
"logprob": -0.21325684,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 5741,
"logprob": -5.734375,
"special": false,
"text": " logging"
},
{
"id": 32,
"id": 1459,
"logprob": 0.0,
"special": false,
"text": "."
"text": " print"
},
{
"id": 1338,
"logprob": -0.3232422,
"id": 440,
"logprob": 0.0,
"special": false,
"text": "info"
},
{
"id": 463,
"logprob": -1.0380859,
"special": false,
"text": "('"
"text": "(\""
},
{
"id": 8279,
"logprob": -0.8378906,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 30,
"logprob": -1.9501953,
"special": false,
"text": ","
},
{
"id": 10896,
"logprob": -1.3476562,
"logprob": -0.3659668,
"special": false,
"text": " World"
},
{
"id": 683,
"logprob": -1.796875,
"id": 657,
"logprob": -0.49804688,
"special": false,
"text": "')"
"text": "\")"
},
{
"id": 203,
"logprob": -0.9873047,
"logprob": -0.11279297,
"special": false,
"text": "\n"
},
{
"id": 0,
"logprob": -0.7495117,
"special": true,
"text": "<|endoftext|>"
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 589,
"logprob": -0.20141602,
"special": false,
"text": "def"
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 7656,
"logprob": 0.0,
"special": false,
"text": "hello"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 426,
"logprob": -0.051635742,
"special": false,
"text": "name"
},
{
"id": 26,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 426,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 711,
"logprob": 0.0,
"special": false,
"text": "):"
},
{
"id": 284,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 440,
"logprob": -0.16027832,
"special": false,
"text": "(\""
},
{
"id": 8279,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 313,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 636,
"logprob": 0.0,
"special": false,
"text": " name"
},
{
"id": 27,
"logprob": 0.0,
"special": false,
"text": ")"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 589,
"logprob": 0.0,
"special": false,
"text": "def"
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 7656,
"logprob": 0.0,
"special": false,
"text": "hello"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 426,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 381,
"logprob": 0.0,
"special": false,
"text": "age"
},
{
"id": 26,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 426,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 30,
"logprob": 0.0,
"special": false,
"text": ","
},
{
"id": 11442,
"logprob": 0.0,
"special": false,
"text": " age"
},
{
"id": 711,
"logprob": 0.0,
"special": false,
"text": "):"
},
{
"id": 284,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 440,
"logprob": 0.0,
"special": false,
"text": "(\""
},
{
"id": 8279,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 313,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 636,
"logprob": 0.0,
"special": false,
"text": " name"
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 313,
"logprob": -0.6328125,
"special": false,
"text": " \""
},
{
"id": 313,
"logprob": -1.7011719,
"special": false,
"text": " \""
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 596,
"logprob": 0.0,
"special": false,
"text": " str"
},
{
"id": 26,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 381,
"logprob": 0.0,
"special": false,
"text": "age"
},
{
"id": 490,
"logprob": 0.0,
"special": false,
"text": "))"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 589,
"logprob": 0.0,
"special": false,
"text": "def"
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
}
]
},
"generated_text": "():\n logging.info('Hello, World')\n<|endoftext|>"
"generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print"
}

View File

@ -1,8 +1,8 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"finish_reason": "eos_token",
"generated_tokens": 9,
"prefill": [
{
"id": 0,
@ -14,65 +14,59 @@
"tokens": [
{
"id": 16017,
"logprob": -1.3505859,
"logprob": -0.30908203,
"special": false,
"text": " blue"
},
{
"id": 20495,
"logprob": -0.50439453,
"logprob": 0.0,
"special": false,
"text": " sky"
},
{
"id": 259,
"logprob": -1.2011719,
"logprob": -0.28271484,
"special": false,
"text": " "
},
{
"id": 15484,
"logprob": -2.8378906,
"logprob": -1.7929688,
"special": false,
"text": "appear"
},
{
"id": 345,
"logprob": -0.87597656,
"logprob": -0.8935547,
"special": false,
"text": "ed"
},
{
"id": 288,
"logprob": -1.8447266,
"id": 281,
"logprob": 0.0,
"special": false,
"text": " to"
"text": " in"
},
{
"id": 35622,
"logprob": -7.1445312,
"id": 287,
"logprob": 0.0,
"special": false,
"text": " cloud"
"text": " the"
},
{
"id": 263,
"logprob": -1.2929688,
"id": 20495,
"logprob": -0.32299805,
"special": false,
"text": "s"
"text": " sky"
},
{
"id": 14701,
"logprob": -3.0761719,
"special": false,
"text": " above"
},
{
"id": 751,
"logprob": -4.4375,
"special": false,
"text": " all"
"id": 1,
"logprob": 0.0,
"special": true,
"text": "</s>"
}
]
},
"generated_text": "Why is the sky blue?blue sky appeared to clouds above all"
"generated_text": "Why is the sky blue?blue sky appeared in the sky"
}

View File

@ -40,7 +40,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
seed=0,
)
assert response.details.generated_tokens == 10
assert response.details.generated_tokens == 5
assert response == response_snapshot

View File

@ -29,7 +29,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot
"def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0
)
assert response.details.generated_tokens == 12
assert response.details.generated_tokens == 60
assert response == response_snapshot

View File

@ -43,7 +43,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
seed=0,
)
assert response.details.generated_tokens == 10
assert response.details.generated_tokens == 9
assert response == response_snapshot

View File

@ -38,7 +38,7 @@ def default_pb_batch(default_pb_request):
@pytest.fixture
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
return BloomCausalLMBatch.from_pb(
default_pb_batch, bloom_560m_tokenizer, torch.device("cpu")
default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
)
@ -52,7 +52,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return BloomCausalLMBatch.from_pb(
batch_pb, bloom_560m_tokenizer, torch.device("cpu")
batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
)
@ -286,7 +286,9 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens

View File

@ -38,7 +38,9 @@ def default_pb_batch(default_pb_request):
@pytest.fixture
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu"))
return CausalLMBatch.from_pb(
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu")
)
@pytest.fixture
@ -50,7 +52,9 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
return CausalLMBatch.from_pb(
batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
@ -285,7 +289,9 @@ def test_batch_concatenate(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens

View File

@ -45,7 +45,10 @@ def default_fim_pb_batch(default_fim_pb_request):
@pytest.mark.skip
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
batch = CausalLMBatch.from_pb(
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
default_pb_batch,
default_santacoder.tokenizer,
default_santacoder.dtype,
default_santacoder.device,
)
next_batch = batch
@ -70,7 +73,10 @@ def test_fim_santacoder_generate_token_completion(
default_santacoder, default_fim_pb_batch
):
batch = CausalLMBatch.from_pb(
default_fim_pb_batch, default_santacoder.tokenizer, default_santacoder.device
default_fim_pb_batch,
default_santacoder.tokenizer,
default_santacoder.dtype,
default_santacoder.device,
)
next_batch = batch

View File

@ -42,7 +42,7 @@ def default_pb_batch(default_pb_request):
@pytest.fixture
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
return Seq2SeqLMBatch.from_pb(
default_pb_batch, mt0_small_tokenizer, torch.device("cpu")
default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu")
)
@ -55,7 +55,9 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
return Seq2SeqLMBatch.from_pb(
batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
@ -323,7 +325,9 @@ def test_batch_concatenate(
)
assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None

View File

@ -39,10 +39,11 @@ class BloomCausalLMBatch(CausalLMBatch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super(BloomCausalLMBatch, cls).from_pb(
pb=pb, tokenizer=tokenizer, device=device
pb=pb, tokenizer=tokenizer, dtype=dtype, device=device
)
batch.keys_head_dim_last = False
return batch

View File

@ -66,6 +66,7 @@ class CausalLMBatch(Batch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
inputs = []

View File

@ -18,11 +18,7 @@ from text_generation_server.models.types import (
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
Sampling,
)
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
tracer = trace.get_tracer(__name__)
@ -48,7 +44,7 @@ class FlashCausalLMBatch(Batch):
# All tokens
all_input_ids: List[List[int]]
all_input_ids_tensor: List[torch.Tensor]
all_input_ids_tensor: torch.Tensor
# Lengths of all generations present in the batch
input_lengths: List[int]
@ -56,7 +52,7 @@ class FlashCausalLMBatch(Batch):
read_offsets: List[Optional[int]]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria]
# Maximum number of tokens this batch will grow to
@ -75,6 +71,7 @@ class FlashCausalLMBatch(Batch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
position_ids = []
@ -87,13 +84,14 @@ class FlashCausalLMBatch(Batch):
all_input_ids = []
requests_idx_mapping = {}
next_token_choosers = []
next_token_chooser_parameters = []
stopping_criterias = []
# Cumulative length
cumulative_length = 0
max_tokens = 0
max_length = 0
# Parse batch
for i, r in enumerate(pb.requests):
@ -119,7 +117,7 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
@ -130,11 +128,26 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_length += input_length
max_tokens += input_length + max_new_tokens
max_length = max(max_length, input_length + max_new_tokens)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros(
(len(all_input_ids), max_length), dtype=np.int64
)
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device
input_ids = torch.tensor(
np.concatenate(all_input_ids), dtype=torch.int64, device=device
)
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
position_ids = torch.tensor(
np.concatenate(position_ids), dtype=torch.int32, device=device
)
@ -154,8 +167,8 @@ class FlashCausalLMBatch(Batch):
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=[],
next_token_choosers=next_token_choosers,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
)
@ -176,31 +189,29 @@ class FlashCausalLMBatch(Batch):
# New values after filtering
requests_idx_mapping = {}
input_ids = self.input_ids.new_empty(len(request_ids))
position_ids = self.position_ids.new_empty(len(request_ids))
# Used to index into tensors
indices = []
# Create on CPU to only move to GPU once instead of at every copy
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
cu_seqlens_q = torch.arange(
0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
)
cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1]
max_seqlen = 0
past_key_values = []
requests = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
prefix_offsets = []
read_offsets = []
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
@ -208,10 +219,6 @@ class FlashCausalLMBatch(Batch):
# Get length
request_input_length = self.input_lengths[idx]
# Copy tensors (GPU)
input_ids[i] = self.input_ids[idx]
position_ids[i] = self.position_ids[idx]
# Copy to tensor (CPU)
cu_seqlens[i + 1] = cumulative_length + request_input_length
max_seqlen = max(max_seqlen, request_input_length)
@ -222,14 +229,11 @@ class FlashCausalLMBatch(Batch):
)
all_input_ids.append(self.all_input_ids[idx])
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
input_lengths.append(request_input_length)
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
@ -258,6 +262,12 @@ class FlashCausalLMBatch(Batch):
# Cat all past
past_key_values = torch.cat(past_key_values, dim=1)
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices)
# Move to GPU now that we have the whole tensor
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
@ -276,7 +286,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
)
@ -290,6 +300,7 @@ class FlashCausalLMBatch(Batch):
total_batch_size = sum([len(b) for b in batches])
dtype = batches[0].past_key_values.dtype
device = batches[0].input_ids.device
input_ids = batches[0].input_ids.new_empty(total_batch_size)
@ -302,19 +313,19 @@ class FlashCausalLMBatch(Batch):
past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
prefix_offsets = []
read_offsets = []
next_token_choosers = []
next_token_chooser_parameters = []
stopping_criterias = []
# Cumulative length
cumulative_batch_size = 0
cumulative_length = 0
max_tokens = 0
max_length = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
@ -347,25 +358,54 @@ class FlashCausalLMBatch(Batch):
)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers)
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias)
# Update
cumulative_length += batch.cu_seqlens[-1]
cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens
max_length = max(
max_length,
max(
input_length
+ stopping_criteria.max_new_tokens
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
batch.input_lengths, batch.stopping_criterias
)
),
)
all_input_ids_tensor = torch.zeros(
(total_batch_size, max_length), dtype=torch.int64, device=device
)
cumulative_batch_size = 0
for i, batch in enumerate(batches):
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
cumulative_batch_size += len(batch)
# Cat past
past_key_values = torch.cat(past_key_values, dim=1)
# Create final tensor on GPU
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype=dtype, device=device
)
return FlashCausalLMBatch(
batch_id=batches[0].batch_id,
requests=requests,
@ -381,7 +421,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
)
@ -463,6 +503,7 @@ class FlashCausalLM(Model):
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None
single_request = len(batch) == 1
if prefill and len(batch) == 1:
# Ask to pre-allocate kv to its max size
@ -483,6 +524,17 @@ class FlashCausalLM(Model):
pre_allocate_past_size,
)
if prefill:
next_token_logits = (
out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1]
)
else:
next_token_logits = out
next_input_ids, next_token_logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
)
if prefill:
if len(batch) > 1:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -493,15 +545,11 @@ class FlashCausalLM(Model):
batch.cu_seqlens_q = torch.arange(
0, len(batch) + 1, device=self.device, dtype=torch.int32
)
next_input_ids = batch.input_ids.new_empty(len(batch))
next_position_ids = batch.position_ids.new_empty(len(batch))
else:
prefill_logprobs = None
next_input_ids = batch.input_ids
next_position_ids = batch.position_ids
next_token_logprobs = out.new_empty(len(batch))
# Prepare past for next decode
if len(batch) > 1:
# Used to slice next batch past
@ -552,7 +600,6 @@ class FlashCausalLM(Model):
# Zipped iterator
iterator = zip(
batch.input_lengths,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
@ -564,7 +611,6 @@ class FlashCausalLM(Model):
# For each member of the batch
for i, (
input_length,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
@ -573,21 +619,6 @@ class FlashCausalLM(Model):
end_index = cumulative_length + input_length
if prefill:
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
# only take last token logit
logits = out[end_index - 1 : end_index]
# Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
all_input_ids_tensor = batch.input_ids.new_empty(
input_length + stopping_criteria.max_new_tokens
)
# Copy from batch.input_ids to all_input_ids_tensor
all_input_ids_tensor[:input_length] = batch.input_ids[
start_index:end_index
]
batch.all_input_ids_tensor.append(all_input_ids_tensor)
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
@ -603,25 +634,8 @@ class FlashCausalLM(Model):
prefill_tokens_indices = batch.input_ids[
start_index + 1 : end_index
]
else:
# Decode mode
# out is of shape [batch_size, vocab_size]
logits = out[i].view(1, -1)
all_input_ids_tensor = batch.all_input_ids_tensor[i]
# Select next token
next_token_id, logprob = next_token_chooser(
all_input_ids_tensor[None, :input_length], logits
)
# Add to all_input_ids_tensor
next_token_id_squeezed = next_token_id.view(1)
all_input_ids_tensor[input_length] = next_token_id_squeezed
# Set values
next_input_ids[i] = next_token_id_squeezed
next_token_logprobs[i] = logprob[-1, next_token_id].view(1)
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
cumulative_length += input_length
@ -651,10 +665,11 @@ class FlashCausalLM(Model):
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.all_input_ids_tensor,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
next_token_ids,
next_token_logprobs,
)
@ -665,10 +680,11 @@ class FlashCausalLM(Model):
input_length,
prefix_offset,
read_offset,
next_token_chooser,
stopping_criteria,
all_input_ids,
all_input_ids_tensor,
do_sample,
seed,
next_token_id,
next_token_logprob,
) in enumerate(iterator):
@ -702,14 +718,11 @@ class FlashCausalLM(Model):
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
output_text,
stopping_criteria.current_tokens,
reason,
seed if do_sample else None,
)
else:
generated_text = None
@ -751,8 +764,9 @@ class FlashCausalLM(Model):
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
batch.max_seqlen = batch.max_seqlen + 1
cumulative_length += input_length
batch.max_seqlen = batch.max_seqlen + 1
# No need to return a batch if we know that all requests stopped
return generations, batch if not stopped else None

View File

@ -89,6 +89,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "GalacticaCausalLMBatch":
inputs = []

View File

@ -71,6 +71,7 @@ class Seq2SeqLMBatch(Batch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "Seq2SeqLMBatch":
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""

View File

@ -21,6 +21,7 @@ class Batch(ABC):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "Batch":
raise NotImplementedError

View File

@ -55,7 +55,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
generations, next_batch = self.model.generate_token(batch)

View File

@ -9,12 +9,13 @@ from text_generation_server.utils.hub import (
RevisionNotFoundError,
)
from text_generation_server.utils.tokens import (
Greedy,
NextTokenChooser,
Sampling,
HeterogeneousNextTokenChooser,
StoppingCriteria,
StopSequenceCriteria,
FinishReason,
Sampling,
Greedy,
)
__all__ = [
@ -25,6 +26,7 @@ __all__ = [
"weight_hub_files",
"download_weights",
"EntryNotFoundError",
"HeterogeneousNextTokenChooser",
"LocalEntryNotFoundError",
"RevisionNotFoundError",
"Greedy",

View File

@ -1,14 +1,10 @@
import concurrent
import time
import datetime
import torch
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from datetime import timedelta
from loguru import logger
from pathlib import Path
from safetensors.torch import load_file, save_file
from safetensors.torch import save_file
from safetensors import safe_open
from typing import Dict, List

View File

@ -0,0 +1,405 @@
import math
import torch
from functools import lru_cache
from typing import Optional, List, Dict, Union
from transformers import (
LogitsWarper,
LogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
class StaticWarper:
def __init__(
self,
temperature=1.0,
top_k=None,
top_p=None,
typical_p=None,
):
self.warpers = []
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
self.warpers.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
self.warpers.append(TopKLogitsWarper(top_k=top_k))
if top_p is not None and top_p < 1.0:
self.warpers.append(TopPLogitsWarper(top_p=top_p))
if typical_p is not None and typical_p < 1.0:
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
self.cuda_graph = None
self.static_scores = None
self.static_warped_scores = None
self.static_next_logprob = None
def __call__(self, scores):
if self.cuda_graph is None:
self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph, pool=mempool):
local_scores = self.static_scores
for warper in self.warpers:
local_scores = warper(None, local_scores)
self.static_warped_scores = local_scores
# Compute logprobs
self.static_next_logprob = torch.log_softmax(
self.static_warped_scores, -1
)
self.static_scores.copy_(scores)
self.cuda_graph.replay()
return self.static_warped_scores, self.static_next_logprob
@lru_cache(10)
def static_warper(
temperature: Optional[float],
top_k: Optional[int],
top_p: Optional[float],
typical_p: Optional[float],
) -> StaticWarper:
return StaticWarper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
repetition_penalty (`List[float]`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
"""
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = penalty
self.penalty_tensor = torch.tensor(
penalty, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
scores.scatter_(1, input_ids, score)
return scores
def filter(self, indices):
self.penalty = [self.penalty[i] for i in indices]
if any([x != 1.0 for x in self.penalty]):
self.penalty_tensor = self.penalty_tensor[indices]
return self
return None
class HeterogeneousTemperatureLogitsWarper:
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
temperature (`float`):
The value used to module the logits distribution.
"""
def __init__(
self, temperature: List[float], dtype: torch.dtype, device: torch.device
):
self.temperature = temperature
self.temperature_tensor = torch.tensor(
temperature, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores.div_(self.temperature_tensor)
return scores
def filter(self, indices):
self.temperature = [self.temperature[i] for i in indices]
if any([x != 1.0 for x in self.temperature]):
self.temperature_tensor = self.temperature_tensor[indices]
return self
return None
class HeterogeneousTopPLogitsWarper(LogitsWarper):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(
self,
top_p: List[float],
dtype: torch.dtype,
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.top_p = top_p
self.top_p_opposite = 1 - torch.tensor(
top_p, dtype=dtype, device=device
).unsqueeze(1)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
probs = sorted_logits.softmax(dim=-1)
# This is way faster for some reason
for i in range(probs.shape[0]):
probs[i] = probs[i].cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = probs <= self.top_p_opposite
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
return warped_scores
def filter(self, indices):
self.top_p = [self.top_p[i] for i in indices]
if any([x < 1.0 for x in self.top_p]):
self.top_p_opposite = self.top_p_opposite[indices]
return self
return None
class HeterogeneousTopKLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(
self,
top_k: List[int],
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.top_k = top_k
self.max_top_k = max(top_k)
# value - 1 as we will use top_k to index and python uses 0 based numbering
self.top_k_tensor = torch.tensor(
[max(x - 1, min_tokens_to_keep - 1) for x in top_k],
dtype=torch.int64,
device=device,
).unsqueeze(1)
# 0 is a special value that disables top_k warping for this member of the batch
disabled = [x == 0 for x in top_k]
if any(disabled):
self.top_k_disabled_mask = torch.tensor(
disabled, dtype=torch.bool, device=device
).view(-1, 1)
else:
self.top_k_disabled_mask = None
self.filter_value = filter_value
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
# If max_top_k is superior to the vocab, we need to clamp or the warper will fail
if scores.size(-1) < self.max_top_k:
max_top_k = scores.size(-1)
top_k = torch.clamp_max(self.top_k_tensor, max_top_k)
else:
max_top_k = self.max_top_k
top_k = self.top_k_tensor
# Get the kth score for each member of the batch
kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
# Mask member of kth_scores that do not want to use top_k warping
if self.top_k_disabled_mask is not None:
kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value)
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < kth_scores
scores.masked_fill_(indices_to_remove, self.filter_value)
return scores
def filter(self, indices):
self.top_k = [self.top_k[i] for i in indices]
disabled = [x == 0 for x in self.top_k]
if not all(disabled):
self.top_k_tensor = self.top_k_tensor[indices]
self.max_top_k = max(self.top_k)
if self.top_k_disabled_mask is not None:
self.top_k_disabled_mask = (
self.top_k_disabled_mask[indices] if any(disabled) else None
)
return self
return None
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
mass (`float`):
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(
self,
mass: List[float],
dtype: torch.dtype,
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.mass = mass
self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
# 1 is a special value that disables typical_p warping for this member of the batch
disabled = [x == 1.0 for x in mass]
if any(disabled):
self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device)
else:
self.disabled_mask = None
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
probs = sorted_logits.softmax(dim=-1)
# This is way faster for some reason
for i in range(probs.shape[0]):
probs[i] = probs[i].cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (probs < self.mass_tensor).sum(dim=1)
last_ind[last_ind < 0] = 0
if self.disabled_mask is not None:
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
1, last_ind.view(-1, 1)
)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
return warped_scores
def filter(self, indices):
self.mass = [self.mass[i] for i in indices]
disabled = [x == 1.0 for x in self.mass]
if not all(disabled):
self.mass_tensor = self.mass_tensor[indices]
if self.disabled_mask is not None:
self.disabled_mask = (
self.disabled_mask[indices] if any(disabled) else None
)
return self
return None
class HeterogeneousProcessorWrapper(LogitsProcessor):
r"""
A wrapper for logit warpers or processors without heterogeneous parameter support.
Args:
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
A mapping of sample indices to logit warpers or processors, to be run sequentially.
"""
def __init__(
self,
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
):
self.processors = processors
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
for i, processor in self.processors.items():
scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1])
return scores
def filter(self, indices):
new_processors = {}
for i, idx in enumerate(indices):
if idx in self.processors:
new_processors[i] = self.processors[idx]
if new_processors:
self.processors = new_processors
return self
return None

View File

@ -1,12 +1,7 @@
import re
import torch
from functools import lru_cache
from transformers import (
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
@ -15,82 +10,15 @@ from typing import List, Tuple, Optional
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, -1)
# Avoid GPU<->CPU sync done by torch multinomial
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
return probs.div_(q).argmax()
class Greedy:
def __call__(self, logits):
return logits.argmax()
class StaticWarper:
def __init__(
self,
temperature=1.0,
top_k=None,
top_p=None,
typical_p=None,
):
self.warpers = []
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
self.warpers.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
self.warpers.append(TopKLogitsWarper(top_k=top_k))
if top_p is not None and top_p < 1.0:
self.warpers.append(TopPLogitsWarper(top_p=top_p))
if typical_p is not None and typical_p < 1.0:
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
self.cuda_graph = None
self.static_scores = None
self.static_warped_scores = None
self.static_next_logprob = None
def __call__(self, scores):
if self.cuda_graph is None:
self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph):
for warper in self.warpers:
self.static_warped_scores = warper(None, self.static_scores)
# Compute logprobs
self.static_next_logprob = torch.log_softmax(
self.static_warped_scores, -1
)
self.static_scores.copy_(scores)
self.cuda_graph.replay()
return self.static_warped_scores, self.static_next_logprob
@lru_cache(10)
def static_warper(
temperature: Optional[float],
top_k: Optional[int],
top_p: Optional[float],
typical_p: Optional[float],
) -> StaticWarper:
return StaticWarper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
from text_generation_server.utils.logits_process import (
static_warper,
HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper,
HeterogeneousProcessorWrapper,
)
class NextTokenChooser:
@ -132,9 +60,9 @@ class NextTokenChooser:
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
if self.watermark_processor:
if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor:
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
if self.static_warper is None:
@ -221,3 +149,191 @@ class StoppingCriteria:
pb.max_new_tokens,
pb.ignore_eos_token,
)
class HeterogeneousNextTokenChooser:
def __init__(
self,
dtype: torch.dtype,
device: torch.device,
watermark: List[bool],
temperature: List[float],
repetition_penalty: List[float],
top_k: List[int],
top_p: List[float],
typical_p: List[float],
do_sample: List[bool],
seeds: List[int],
):
warpers = []
self.watermark_processor = (
HeterogeneousProcessorWrapper(
{
i: WatermarkLogitsProcessor(device=device)
for i, do_watermark in enumerate(watermark)
if do_watermark
}
)
if any(watermark)
else None
)
self.repetition_processor = (
HeterogeneousRepetitionPenaltyLogitsProcessor(
repetition_penalty, dtype, device
)
if any([x != 1.0 for x in repetition_penalty])
else None
)
if any([x != 1.0 for x in temperature]):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
]
warpers.append(
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
)
if any([x != 0 for x in top_k]):
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
if any([x < 1.0 for x in top_p]):
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
if any([x < 1.0 for x in typical_p]):
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
self.warpers = warpers
if any(do_sample):
self.choice = HeterogeneousSampling(do_sample, seeds, device)
else:
self.choice = Greedy()
self.seeds = seeds
self.do_sample = do_sample
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
for warper in self.warpers:
scores = warper(input_ids, scores)
next_ids = self.choice(scores)
next_logprobs = torch.gather(
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
).view(-1)
return next_ids, next_logprobs
def filter(self, indices):
if self.watermark_processor is not None:
self.watermark_processor = self.watermark_processor.filter(indices)
if self.repetition_processor is not None:
self.repetition_processor = self.repetition_processor.filter(indices)
filtered_warpers = []
for warper in self.warpers:
filtered_warper = warper.filter(indices)
if filtered_warper is not None:
filtered_warpers.append(filtered_warper)
self.warpers = filtered_warpers
self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices]
if any(self.do_sample):
self.choice.filter(indices)
else:
self.choice = Greedy()
return self
@classmethod
def from_pb(
cls,
pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype,
device: torch.device,
) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb],
temperature=[pb_.temperature for pb_ in pb],
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
top_k=[pb_.top_k for pb_ in pb],
top_p=[pb_.top_p for pb_ in pb],
typical_p=[pb_.typical_p for pb_ in pb],
do_sample=[pb_.do_sample for pb_ in pb],
seeds=[pb_.seed for pb_ in pb],
device=device,
dtype=dtype,
)
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, -1)
# Avoid GPU<->CPU sync done by torch multinomial
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
return probs.div_(q).argmax()
class Greedy:
def __call__(self, logits):
return logits.argmax(dim=-1)
class HeterogeneousSampling:
r"""
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
"""
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
self.seeds = seeds
self.greedy_indices = []
self.sampling_mapping = {}
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
if sample:
self.sampling_mapping[i] = Sampling(seed, device)
else:
self.greedy_indices.append(i)
self.greedy = Greedy()
def __call__(self, logits):
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
if self.greedy_indices:
# Computing for all indices is faster than slicing
torch.argmax(logits, -1, out=out)
for i, sampling in self.sampling_mapping.items():
out[i] = sampling(logits[i])
return out
def filter(self, indices):
new_greedy_indices = []
new_sampling_mapping = {}
for i, idx in enumerate(indices):
if idx in self.sampling_mapping:
new_sampling_mapping[i] = self.sampling_mapping[idx]
else:
new_greedy_indices.append(i)
self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping
return self