feat(server): support vectorized warpers in flash causal lm (#317)
Co-authored-by: Joel Lamy-Poirier <joel.lamy-poirier@servicenow.com>
This commit is contained in:
parent
951930fbff
commit
62f91f78ac
|
@ -34,65 +34,65 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 408,
|
||||
"logprob": -1.9267578,
|
||||
"logprob": -0.07891846,
|
||||
"special": false,
|
||||
"text": " que"
|
||||
},
|
||||
{
|
||||
"id": 20288,
|
||||
"logprob": -2.9257812,
|
||||
"special": false,
|
||||
"text": " l'on"
|
||||
},
|
||||
{
|
||||
"id": 22255,
|
||||
"logprob": -2.8964844,
|
||||
"special": false,
|
||||
"text": " trouve"
|
||||
},
|
||||
{
|
||||
"id": 1622,
|
||||
"logprob": -1.1083984,
|
||||
"special": false,
|
||||
"text": " une"
|
||||
},
|
||||
{
|
||||
"id": 187079,
|
||||
"logprob": -7.796875,
|
||||
"special": false,
|
||||
"text": " posture"
|
||||
},
|
||||
{
|
||||
"id": 501,
|
||||
"logprob": -5.390625,
|
||||
"special": false,
|
||||
"text": " par"
|
||||
},
|
||||
{
|
||||
"id": 8741,
|
||||
"logprob": -0.34936523,
|
||||
"special": false,
|
||||
"text": " rapport"
|
||||
},
|
||||
{
|
||||
"id": 693,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " à"
|
||||
},
|
||||
{
|
||||
"id": 366,
|
||||
"logprob": -2.3378906,
|
||||
"logprob": -1.2939453,
|
||||
"special": false,
|
||||
"text": " la"
|
||||
},
|
||||
{
|
||||
"id": 36503,
|
||||
"logprob": -3.6640625,
|
||||
"id": 8769,
|
||||
"logprob": -0.3708496,
|
||||
"special": false,
|
||||
"text": " pratique"
|
||||
"text": " personne"
|
||||
},
|
||||
{
|
||||
"id": 1479,
|
||||
"logprob": -2.2871094,
|
||||
"special": false,
|
||||
"text": " qui"
|
||||
},
|
||||
{
|
||||
"id": 2997,
|
||||
"logprob": -0.8671875,
|
||||
"special": false,
|
||||
"text": " vous"
|
||||
},
|
||||
{
|
||||
"id": 35977,
|
||||
"logprob": -1.5097656,
|
||||
"special": false,
|
||||
"text": " suit"
|
||||
},
|
||||
{
|
||||
"id": 21558,
|
||||
"logprob": -0.07891846,
|
||||
"special": false,
|
||||
"text": " ait"
|
||||
},
|
||||
{
|
||||
"id": 447,
|
||||
"logprob": -0.12695312,
|
||||
"special": false,
|
||||
"text": " un"
|
||||
},
|
||||
{
|
||||
"id": 78606,
|
||||
"logprob": -2.21875,
|
||||
"special": false,
|
||||
"text": " profil"
|
||||
},
|
||||
{
|
||||
"id": 3899,
|
||||
"logprob": -1.3535156,
|
||||
"special": false,
|
||||
"text": " bien"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "Pour déguster un ortolan, il faut tout d'abord que l'on trouve une posture par rapport à la pratique"
|
||||
"generated_text": "Pour déguster un ortolan, il faut tout d'abord que la personne qui vous suit ait un profil bien"
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"finish_reason": "stop_sequence",
|
||||
"generated_tokens": 5,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
|
@ -24,65 +24,35 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 5229,
|
||||
"logprob": -3.3085938,
|
||||
"logprob": -2.5683594,
|
||||
"special": false,
|
||||
"text": " failed"
|
||||
},
|
||||
{
|
||||
"id": 363,
|
||||
"logprob": -3.984375,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 5641,
|
||||
"logprob": -6.53125,
|
||||
"special": false,
|
||||
"text": " IP"
|
||||
},
|
||||
{
|
||||
"id": 16428,
|
||||
"logprob": -3.1835938,
|
||||
"special": false,
|
||||
"text": " Address"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -1.2324219,
|
||||
"logprob": -0.45336914,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 525,
|
||||
"logprob": -2.6855469,
|
||||
"id": 4829,
|
||||
"logprob": -1.8408203,
|
||||
"special": false,
|
||||
"text": " '"
|
||||
"text": " Error"
|
||||
},
|
||||
{
|
||||
"id": 8516,
|
||||
"logprob": -7.1601562,
|
||||
"id": 297,
|
||||
"logprob": -1.0556641,
|
||||
"special": false,
|
||||
"text": "None"
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 4286,
|
||||
"logprob": -2.4433594,
|
||||
"id": 1243,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "'."
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.06530762,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 294,
|
||||
"logprob": -7.953125,
|
||||
"special": false,
|
||||
"text": "as"
|
||||
"text": " test"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "Test requestfailed for IP Address: 'None'.\nas"
|
||||
"generated_text": "Test requestfailed: Error in test"
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 12,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 60,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
|
@ -29,77 +29,365 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 2262,
|
||||
"logprob": -0.7451172,
|
||||
"logprob": -0.042999268,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.21325684,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 5741,
|
||||
"logprob": -5.734375,
|
||||
"special": false,
|
||||
"text": " logging"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"id": 1459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "."
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 1338,
|
||||
"logprob": -0.3232422,
|
||||
"id": 440,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "info"
|
||||
},
|
||||
{
|
||||
"id": 463,
|
||||
"logprob": -1.0380859,
|
||||
"special": false,
|
||||
"text": "('"
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8279,
|
||||
"logprob": -0.8378906,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.9501953,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 10896,
|
||||
"logprob": -1.3476562,
|
||||
"logprob": -0.3659668,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 683,
|
||||
"logprob": -1.796875,
|
||||
"id": 657,
|
||||
"logprob": -0.49804688,
|
||||
"special": false,
|
||||
"text": "')"
|
||||
"text": "\")"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": -0.9873047,
|
||||
"logprob": -0.11279297,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 0,
|
||||
"logprob": -0.7495117,
|
||||
"special": true,
|
||||
"text": "<|endoftext|>"
|
||||
"id": 203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 589,
|
||||
"logprob": -0.20141602,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 7656,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "hello"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 426,
|
||||
"logprob": -0.051635742,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 426,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 711,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "):"
|
||||
},
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 440,
|
||||
"logprob": -0.16027832,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8279,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 313,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 474,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 636,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " name"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 589,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 7656,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "hello"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 426,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 381,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "age"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 426,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 11442,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " age"
|
||||
},
|
||||
{
|
||||
"id": 711,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "):"
|
||||
},
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 440,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8279,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 313,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 474,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 636,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " name"
|
||||
},
|
||||
{
|
||||
"id": 474,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 313,
|
||||
"logprob": -0.6328125,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 313,
|
||||
"logprob": -1.7011719,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 474,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 596,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " str"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 381,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "age"
|
||||
},
|
||||
{
|
||||
"id": 490,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "))"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 589,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "():\n logging.info('Hello, World')\n<|endoftext|>"
|
||||
"generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print"
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 9,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 0,
|
||||
|
@ -14,65 +14,59 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 16017,
|
||||
"logprob": -1.3505859,
|
||||
"logprob": -0.30908203,
|
||||
"special": false,
|
||||
"text": " blue"
|
||||
},
|
||||
{
|
||||
"id": 20495,
|
||||
"logprob": -0.50439453,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " sky"
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"logprob": -1.2011719,
|
||||
"logprob": -0.28271484,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 15484,
|
||||
"logprob": -2.8378906,
|
||||
"logprob": -1.7929688,
|
||||
"special": false,
|
||||
"text": "appear"
|
||||
},
|
||||
{
|
||||
"id": 345,
|
||||
"logprob": -0.87597656,
|
||||
"logprob": -0.8935547,
|
||||
"special": false,
|
||||
"text": "ed"
|
||||
},
|
||||
{
|
||||
"id": 288,
|
||||
"logprob": -1.8447266,
|
||||
"id": 281,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 35622,
|
||||
"logprob": -7.1445312,
|
||||
"id": 287,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " cloud"
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": -1.2929688,
|
||||
"id": 20495,
|
||||
"logprob": -0.32299805,
|
||||
"special": false,
|
||||
"text": "s"
|
||||
"text": " sky"
|
||||
},
|
||||
{
|
||||
"id": 14701,
|
||||
"logprob": -3.0761719,
|
||||
"special": false,
|
||||
"text": " above"
|
||||
},
|
||||
{
|
||||
"id": 751,
|
||||
"logprob": -4.4375,
|
||||
"special": false,
|
||||
"text": " all"
|
||||
"id": 1,
|
||||
"logprob": 0.0,
|
||||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "Why is the sky blue?blue sky appeared to clouds above all"
|
||||
"generated_text": "Why is the sky blue?blue sky appeared in the sky"
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
|
|||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response.details.generated_tokens == 5
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot
|
|||
"def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 12
|
||||
assert response.details.generated_tokens == 60
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
|||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response.details.generated_tokens == 9
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ def default_pb_batch(default_pb_request):
|
|||
@pytest.fixture
|
||||
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
|
||||
return BloomCausalLMBatch.from_pb(
|
||||
default_pb_batch, bloom_560m_tokenizer, torch.device("cpu")
|
||||
default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
|
@ -52,7 +52,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
|
|||
|
||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||
return BloomCausalLMBatch.from_pb(
|
||||
batch_pb, bloom_560m_tokenizer, torch.device("cpu")
|
||||
batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
|
@ -286,7 +286,9 @@ def test_batch_concatenate(
|
|||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||
next_batch = next_batch.filter(
|
||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
|
|
|
@ -38,7 +38,9 @@ def default_pb_batch(default_pb_request):
|
|||
|
||||
@pytest.fixture
|
||||
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
||||
return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu"))
|
||||
return CausalLMBatch.from_pb(
|
||||
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -50,7 +52,9 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
|||
req_1.stopping_parameters.max_new_tokens = 5
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
|
||||
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
||||
return CausalLMBatch.from_pb(
|
||||
batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
||||
|
@ -285,7 +289,9 @@ def test_batch_concatenate(
|
|||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||
next_batch = next_batch.filter(
|
||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
|
|
|
@ -45,7 +45,10 @@ def default_fim_pb_batch(default_fim_pb_request):
|
|||
@pytest.mark.skip
|
||||
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
|
||||
batch = CausalLMBatch.from_pb(
|
||||
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
|
||||
default_pb_batch,
|
||||
default_santacoder.tokenizer,
|
||||
default_santacoder.dtype,
|
||||
default_santacoder.device,
|
||||
)
|
||||
next_batch = batch
|
||||
|
||||
|
@ -70,7 +73,10 @@ def test_fim_santacoder_generate_token_completion(
|
|||
default_santacoder, default_fim_pb_batch
|
||||
):
|
||||
batch = CausalLMBatch.from_pb(
|
||||
default_fim_pb_batch, default_santacoder.tokenizer, default_santacoder.device
|
||||
default_fim_pb_batch,
|
||||
default_santacoder.tokenizer,
|
||||
default_santacoder.dtype,
|
||||
default_santacoder.device,
|
||||
)
|
||||
next_batch = batch
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ def default_pb_batch(default_pb_request):
|
|||
@pytest.fixture
|
||||
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
|
||||
return Seq2SeqLMBatch.from_pb(
|
||||
default_pb_batch, mt0_small_tokenizer, torch.device("cpu")
|
||||
default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
|
@ -55,7 +55,9 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
|
|||
req_1.stopping_parameters.max_new_tokens = 5
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
|
||||
return Seq2SeqLMBatch.from_pb(
|
||||
batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
||||
|
@ -323,7 +325,9 @@ def test_batch_concatenate(
|
|||
)
|
||||
assert generations[2].generated_text.generated_tokens == 5
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||
next_batch = next_batch.filter(
|
||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
)
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
|
|
@ -39,10 +39,11 @@ class BloomCausalLMBatch(CausalLMBatch):
|
|||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
batch = super(BloomCausalLMBatch, cls).from_pb(
|
||||
pb=pb, tokenizer=tokenizer, device=device
|
||||
pb=pb, tokenizer=tokenizer, dtype=dtype, device=device
|
||||
)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
|
|
@ -66,6 +66,7 @@ class CausalLMBatch(Batch):
|
|||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
inputs = []
|
||||
|
|
|
@ -18,11 +18,7 @@ from text_generation_server.models.types import (
|
|||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import (
|
||||
NextTokenChooser,
|
||||
StoppingCriteria,
|
||||
Sampling,
|
||||
)
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
@ -48,7 +44,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
# All tokens
|
||||
all_input_ids: List[List[int]]
|
||||
all_input_ids_tensor: List[torch.Tensor]
|
||||
all_input_ids_tensor: torch.Tensor
|
||||
|
||||
# Lengths of all generations present in the batch
|
||||
input_lengths: List[int]
|
||||
|
@ -56,7 +52,7 @@ class FlashCausalLMBatch(Batch):
|
|||
read_offsets: List[Optional[int]]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
next_token_chooser: HeterogeneousNextTokenChooser
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
|
||||
# Maximum number of tokens this batch will grow to
|
||||
|
@ -75,6 +71,7 @@ class FlashCausalLMBatch(Batch):
|
|||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
position_ids = []
|
||||
|
@ -87,13 +84,14 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
next_token_choosers = []
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
|
||||
max_tokens = 0
|
||||
max_length = 0
|
||||
|
||||
# Parse batch
|
||||
for i, r in enumerate(pb.requests):
|
||||
|
@ -119,7 +117,7 @@ class FlashCausalLMBatch(Batch):
|
|||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlens.append(cumulative_length + input_length)
|
||||
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
next_token_chooser_parameters.append(r.parameters)
|
||||
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
|
@ -130,11 +128,26 @@ class FlashCausalLMBatch(Batch):
|
|||
# Update
|
||||
cumulative_length += input_length
|
||||
max_tokens += input_length + max_new_tokens
|
||||
max_length = max(max_length, input_length + max_new_tokens)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device
|
||||
)
|
||||
|
||||
# Padded all_input_ids_tensor
|
||||
all_input_ids_tensor = np.zeros(
|
||||
(len(all_input_ids), max_length), dtype=np.int64
|
||||
)
|
||||
for i, input_ids in enumerate(all_input_ids):
|
||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||
|
||||
# Create tensors on device
|
||||
input_ids = torch.tensor(
|
||||
np.concatenate(all_input_ids), dtype=torch.int64, device=device
|
||||
)
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
position_ids = torch.tensor(
|
||||
np.concatenate(position_ids), dtype=torch.int32, device=device
|
||||
)
|
||||
|
@ -154,8 +167,8 @@ class FlashCausalLMBatch(Batch):
|
|||
prefix_offsets=prefix_offsets,
|
||||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=[],
|
||||
next_token_choosers=next_token_choosers,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
@ -176,31 +189,29 @@ class FlashCausalLMBatch(Batch):
|
|||
# New values after filtering
|
||||
requests_idx_mapping = {}
|
||||
|
||||
input_ids = self.input_ids.new_empty(len(request_ids))
|
||||
position_ids = self.position_ids.new_empty(len(request_ids))
|
||||
# Used to index into tensors
|
||||
indices = []
|
||||
|
||||
# Create on CPU to only move to GPU once instead of at every copy
|
||||
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
|
||||
cu_seqlens_q = torch.arange(
|
||||
0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
|
||||
)
|
||||
cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1]
|
||||
max_seqlen = 0
|
||||
past_key_values = []
|
||||
|
||||
requests = []
|
||||
all_input_ids = []
|
||||
all_input_ids_tensor = []
|
||||
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
max_tokens = 0
|
||||
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
indices.append(idx)
|
||||
requests_idx_mapping[request_id] = i
|
||||
|
||||
requests.append(self.requests[idx])
|
||||
|
@ -208,10 +219,6 @@ class FlashCausalLMBatch(Batch):
|
|||
# Get length
|
||||
request_input_length = self.input_lengths[idx]
|
||||
|
||||
# Copy tensors (GPU)
|
||||
input_ids[i] = self.input_ids[idx]
|
||||
position_ids[i] = self.position_ids[idx]
|
||||
|
||||
# Copy to tensor (CPU)
|
||||
cu_seqlens[i + 1] = cumulative_length + request_input_length
|
||||
max_seqlen = max(max_seqlen, request_input_length)
|
||||
|
@ -222,14 +229,11 @@ class FlashCausalLMBatch(Batch):
|
|||
)
|
||||
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
|
||||
|
||||
input_lengths.append(request_input_length)
|
||||
prefix_offsets.append(self.prefix_offsets[idx])
|
||||
read_offsets.append(self.read_offsets[idx])
|
||||
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
|
@ -258,6 +262,12 @@ class FlashCausalLMBatch(Batch):
|
|||
# Cat all past
|
||||
past_key_values = torch.cat(past_key_values, dim=1)
|
||||
|
||||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
|
||||
# Move to GPU now that we have the whole tensor
|
||||
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
|
||||
|
||||
|
@ -276,7 +286,7 @@ class FlashCausalLMBatch(Batch):
|
|||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_choosers=next_token_choosers,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
@ -290,6 +300,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
total_batch_size = sum([len(b) for b in batches])
|
||||
|
||||
dtype = batches[0].past_key_values.dtype
|
||||
device = batches[0].input_ids.device
|
||||
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
|
@ -302,19 +313,19 @@ class FlashCausalLMBatch(Batch):
|
|||
past_key_values = []
|
||||
|
||||
all_input_ids = []
|
||||
all_input_ids_tensor = []
|
||||
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
|
||||
next_token_choosers = []
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_batch_size = 0
|
||||
cumulative_length = 0
|
||||
max_tokens = 0
|
||||
max_length = 0
|
||||
|
||||
for i, batch in enumerate(batches):
|
||||
requests.extend(batch.requests)
|
||||
|
@ -347,25 +358,54 @@ class FlashCausalLMBatch(Batch):
|
|||
)
|
||||
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
||||
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
prefix_offsets.extend(batch.prefix_offsets)
|
||||
read_offsets.extend(batch.read_offsets)
|
||||
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
# Update
|
||||
cumulative_length += batch.cu_seqlens[-1]
|
||||
cumulative_batch_size += len(batch)
|
||||
max_tokens += batch.max_tokens
|
||||
max_length = max(
|
||||
max_length,
|
||||
max(
|
||||
input_length
|
||||
+ stopping_criteria.max_new_tokens
|
||||
- stopping_criteria.current_tokens
|
||||
for input_length, stopping_criteria in zip(
|
||||
batch.input_lengths, batch.stopping_criterias
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
all_input_ids_tensor = torch.zeros(
|
||||
(total_batch_size, max_length), dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
cumulative_batch_size = 0
|
||||
for i, batch in enumerate(batches):
|
||||
start_index = cumulative_batch_size
|
||||
end_index = cumulative_batch_size + len(batch)
|
||||
|
||||
all_input_ids_tensor[
|
||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||
] = batch.all_input_ids_tensor[:, :max_length]
|
||||
|
||||
cumulative_batch_size += len(batch)
|
||||
|
||||
# Cat past
|
||||
past_key_values = torch.cat(past_key_values, dim=1)
|
||||
# Create final tensor on GPU
|
||||
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
return FlashCausalLMBatch(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
|
@ -381,7 +421,7 @@ class FlashCausalLMBatch(Batch):
|
|||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_choosers=next_token_choosers,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
@ -463,6 +503,7 @@ class FlashCausalLM(Model):
|
|||
self, batch: FlashCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||
prefill = batch.past_key_values is None
|
||||
single_request = len(batch) == 1
|
||||
|
||||
if prefill and len(batch) == 1:
|
||||
# Ask to pre-allocate kv to its max size
|
||||
|
@ -483,6 +524,17 @@ class FlashCausalLM(Model):
|
|||
pre_allocate_past_size,
|
||||
)
|
||||
|
||||
if prefill:
|
||||
next_token_logits = (
|
||||
out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1]
|
||||
)
|
||||
else:
|
||||
next_token_logits = out
|
||||
|
||||
next_input_ids, next_token_logprobs = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
|
||||
)
|
||||
|
||||
if prefill:
|
||||
if len(batch) > 1:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
|
@ -493,15 +545,11 @@ class FlashCausalLM(Model):
|
|||
batch.cu_seqlens_q = torch.arange(
|
||||
0, len(batch) + 1, device=self.device, dtype=torch.int32
|
||||
)
|
||||
next_input_ids = batch.input_ids.new_empty(len(batch))
|
||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||
else:
|
||||
prefill_logprobs = None
|
||||
next_input_ids = batch.input_ids
|
||||
next_position_ids = batch.position_ids
|
||||
|
||||
next_token_logprobs = out.new_empty(len(batch))
|
||||
|
||||
# Prepare past for next decode
|
||||
if len(batch) > 1:
|
||||
# Used to slice next batch past
|
||||
|
@ -552,7 +600,6 @@ class FlashCausalLM(Model):
|
|||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.input_lengths,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
)
|
||||
|
@ -564,7 +611,6 @@ class FlashCausalLM(Model):
|
|||
# For each member of the batch
|
||||
for i, (
|
||||
input_length,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
) in enumerate(iterator):
|
||||
|
@ -573,21 +619,6 @@ class FlashCausalLM(Model):
|
|||
end_index = cumulative_length + input_length
|
||||
|
||||
if prefill:
|
||||
# Prefill mode
|
||||
# out is of shape [cumulative_sequence_lengths, vocab_size]
|
||||
# only take last token logit
|
||||
logits = out[end_index - 1 : end_index]
|
||||
|
||||
# Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
|
||||
all_input_ids_tensor = batch.input_ids.new_empty(
|
||||
input_length + stopping_criteria.max_new_tokens
|
||||
)
|
||||
# Copy from batch.input_ids to all_input_ids_tensor
|
||||
all_input_ids_tensor[:input_length] = batch.input_ids[
|
||||
start_index:end_index
|
||||
]
|
||||
batch.all_input_ids_tensor.append(all_input_ids_tensor)
|
||||
|
||||
# Initialize position_ids
|
||||
# In decode, we do not need this as we can just increment position ids
|
||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||
|
@ -603,25 +634,8 @@ class FlashCausalLM(Model):
|
|||
prefill_tokens_indices = batch.input_ids[
|
||||
start_index + 1 : end_index
|
||||
]
|
||||
else:
|
||||
# Decode mode
|
||||
# out is of shape [batch_size, vocab_size]
|
||||
logits = out[i].view(1, -1)
|
||||
|
||||
all_input_ids_tensor = batch.all_input_ids_tensor[i]
|
||||
|
||||
# Select next token
|
||||
next_token_id, logprob = next_token_chooser(
|
||||
all_input_ids_tensor[None, :input_length], logits
|
||||
)
|
||||
|
||||
# Add to all_input_ids_tensor
|
||||
next_token_id_squeezed = next_token_id.view(1)
|
||||
all_input_ids_tensor[input_length] = next_token_id_squeezed
|
||||
|
||||
# Set values
|
||||
next_input_ids[i] = next_token_id_squeezed
|
||||
next_token_logprobs[i] = logprob[-1, next_token_id].view(1)
|
||||
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
|
||||
|
||||
cumulative_length += input_length
|
||||
|
||||
|
@ -651,10 +665,11 @@ class FlashCausalLM(Model):
|
|||
batch.input_lengths,
|
||||
batch.prefix_offsets,
|
||||
batch.read_offsets,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
batch.all_input_ids_tensor,
|
||||
batch.next_token_chooser.do_sample,
|
||||
batch.next_token_chooser.seeds,
|
||||
next_token_ids,
|
||||
next_token_logprobs,
|
||||
)
|
||||
|
@ -665,10 +680,11 @@ class FlashCausalLM(Model):
|
|||
input_length,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
all_input_ids_tensor,
|
||||
do_sample,
|
||||
seed,
|
||||
next_token_id,
|
||||
next_token_logprob,
|
||||
) in enumerate(iterator):
|
||||
|
@ -702,14 +718,11 @@ class FlashCausalLM(Model):
|
|||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :]
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
output_text,
|
||||
stopping_criteria.current_tokens,
|
||||
reason,
|
||||
seed if do_sample else None,
|
||||
)
|
||||
else:
|
||||
generated_text = None
|
||||
|
@ -751,8 +764,9 @@ class FlashCausalLM(Model):
|
|||
batch.prefix_offsets[i] = prefix_offset
|
||||
batch.read_offsets[i] = read_offset
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.max_seqlen = batch.max_seqlen + 1
|
||||
cumulative_length += input_length
|
||||
|
||||
batch.max_seqlen = batch.max_seqlen + 1
|
||||
|
||||
# No need to return a batch if we know that all requests stopped
|
||||
return generations, batch if not stopped else None
|
||||
|
|
|
@ -89,6 +89,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "GalacticaCausalLMBatch":
|
||||
inputs = []
|
||||
|
|
|
@ -71,6 +71,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "Seq2SeqLMBatch":
|
||||
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||
|
|
|
@ -21,6 +21,7 @@ class Batch(ABC):
|
|||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "Batch":
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -55,7 +55,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
|
||||
async def Prefill(self, request, context):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.device
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
|
|
|
@ -9,12 +9,13 @@ from text_generation_server.utils.hub import (
|
|||
RevisionNotFoundError,
|
||||
)
|
||||
from text_generation_server.utils.tokens import (
|
||||
Greedy,
|
||||
NextTokenChooser,
|
||||
Sampling,
|
||||
HeterogeneousNextTokenChooser,
|
||||
StoppingCriteria,
|
||||
StopSequenceCriteria,
|
||||
FinishReason,
|
||||
Sampling,
|
||||
Greedy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
@ -25,6 +26,7 @@ __all__ = [
|
|||
"weight_hub_files",
|
||||
"download_weights",
|
||||
"EntryNotFoundError",
|
||||
"HeterogeneousNextTokenChooser",
|
||||
"LocalEntryNotFoundError",
|
||||
"RevisionNotFoundError",
|
||||
"Greedy",
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
import concurrent
|
||||
import time
|
||||
import datetime
|
||||
import torch
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import defaultdict
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file, save_file
|
||||
from safetensors.torch import save_file
|
||||
from safetensors import safe_open
|
||||
from typing import Dict, List
|
||||
|
||||
|
|
|
@ -0,0 +1,405 @@
|
|||
import math
|
||||
import torch
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Optional, List, Dict, Union
|
||||
|
||||
from transformers import (
|
||||
LogitsWarper,
|
||||
LogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
)
|
||||
|
||||
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
|
||||
|
||||
class StaticWarper:
|
||||
def __init__(
|
||||
self,
|
||||
temperature=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
typical_p=None,
|
||||
):
|
||||
self.warpers = []
|
||||
|
||||
if temperature is not None and temperature != 1.0:
|
||||
temperature = float(temperature)
|
||||
self.warpers.append(TemperatureLogitsWarper(temperature))
|
||||
if top_k is not None and top_k != 0:
|
||||
self.warpers.append(TopKLogitsWarper(top_k=top_k))
|
||||
if top_p is not None and top_p < 1.0:
|
||||
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||
if typical_p is not None and typical_p < 1.0:
|
||||
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||
|
||||
self.cuda_graph = None
|
||||
self.static_scores = None
|
||||
self.static_warped_scores = None
|
||||
self.static_next_logprob = None
|
||||
|
||||
def __call__(self, scores):
|
||||
if self.cuda_graph is None:
|
||||
self.static_scores = scores
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph, pool=mempool):
|
||||
local_scores = self.static_scores
|
||||
for warper in self.warpers:
|
||||
local_scores = warper(None, local_scores)
|
||||
|
||||
self.static_warped_scores = local_scores
|
||||
# Compute logprobs
|
||||
self.static_next_logprob = torch.log_softmax(
|
||||
self.static_warped_scores, -1
|
||||
)
|
||||
|
||||
self.static_scores.copy_(scores)
|
||||
self.cuda_graph.replay()
|
||||
|
||||
return self.static_warped_scores, self.static_next_logprob
|
||||
|
||||
|
||||
@lru_cache(10)
|
||||
def static_warper(
|
||||
temperature: Optional[float],
|
||||
top_k: Optional[int],
|
||||
top_p: Optional[float],
|
||||
typical_p: Optional[float],
|
||||
) -> StaticWarper:
|
||||
return StaticWarper(
|
||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||
)
|
||||
|
||||
|
||||
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
repetition_penalty (`List[float]`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
||||
self.penalty = penalty
|
||||
self.penalty_tensor = torch.tensor(
|
||||
penalty, dtype=dtype, device=device
|
||||
).unsqueeze(1)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(
|
||||
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
||||
)
|
||||
|
||||
scores.scatter_(1, input_ids, score)
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
self.penalty = [self.penalty[i] for i in indices]
|
||||
if any([x != 1.0 for x in self.penalty]):
|
||||
self.penalty_tensor = self.penalty_tensor[indices]
|
||||
return self
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTemperatureLogitsWarper:
|
||||
r"""
|
||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, temperature: List[float], dtype: torch.dtype, device: torch.device
|
||||
):
|
||||
self.temperature = temperature
|
||||
self.temperature_tensor = torch.tensor(
|
||||
temperature, dtype=dtype, device=device
|
||||
).unsqueeze(1)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
scores.div_(self.temperature_tensor)
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
self.temperature = [self.temperature[i] for i in indices]
|
||||
if any([x != 1.0 for x in self.temperature]):
|
||||
self.temperature_tensor = self.temperature_tensor[indices]
|
||||
return self
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
"""
|
||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_p: List[float],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
filter_value: float = -math.inf,
|
||||
min_tokens_to_keep: int = 1,
|
||||
):
|
||||
self.top_p = top_p
|
||||
self.top_p_opposite = 1 - torch.tensor(
|
||||
top_p, dtype=dtype, device=device
|
||||
).unsqueeze(1)
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
||||
probs = sorted_logits.softmax(dim=-1)
|
||||
# This is way faster for some reason
|
||||
for i in range(probs.shape[0]):
|
||||
probs[i] = probs[i].cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = probs <= self.top_p_opposite
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||
|
||||
return warped_scores
|
||||
|
||||
def filter(self, indices):
|
||||
self.top_p = [self.top_p[i] for i in indices]
|
||||
if any([x < 1.0 for x in self.top_p]):
|
||||
self.top_p_opposite = self.top_p_opposite[indices]
|
||||
return self
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
top_k (`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: List[int],
|
||||
device: torch.device,
|
||||
filter_value: float = -math.inf,
|
||||
min_tokens_to_keep: int = 1,
|
||||
):
|
||||
self.top_k = top_k
|
||||
self.max_top_k = max(top_k)
|
||||
# value - 1 as we will use top_k to index and python uses 0 based numbering
|
||||
self.top_k_tensor = torch.tensor(
|
||||
[max(x - 1, min_tokens_to_keep - 1) for x in top_k],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
).unsqueeze(1)
|
||||
|
||||
# 0 is a special value that disables top_k warping for this member of the batch
|
||||
disabled = [x == 0 for x in top_k]
|
||||
|
||||
if any(disabled):
|
||||
self.top_k_disabled_mask = torch.tensor(
|
||||
disabled, dtype=torch.bool, device=device
|
||||
).view(-1, 1)
|
||||
else:
|
||||
self.top_k_disabled_mask = None
|
||||
|
||||
self.filter_value = filter_value
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
# If max_top_k is superior to the vocab, we need to clamp or the warper will fail
|
||||
if scores.size(-1) < self.max_top_k:
|
||||
max_top_k = scores.size(-1)
|
||||
top_k = torch.clamp_max(self.top_k_tensor, max_top_k)
|
||||
else:
|
||||
max_top_k = self.max_top_k
|
||||
top_k = self.top_k_tensor
|
||||
|
||||
# Get the kth score for each member of the batch
|
||||
kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
|
||||
|
||||
# Mask member of kth_scores that do not want to use top_k warping
|
||||
if self.top_k_disabled_mask is not None:
|
||||
kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value)
|
||||
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = scores < kth_scores
|
||||
scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
self.top_k = [self.top_k[i] for i in indices]
|
||||
disabled = [x == 0 for x in self.top_k]
|
||||
|
||||
if not all(disabled):
|
||||
self.top_k_tensor = self.top_k_tensor[indices]
|
||||
self.max_top_k = max(self.top_k)
|
||||
|
||||
if self.top_k_disabled_mask is not None:
|
||||
self.top_k_disabled_mask = (
|
||||
self.top_k_disabled_mask[indices] if any(disabled) else None
|
||||
)
|
||||
|
||||
return self
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
||||
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
mass (`float`):
|
||||
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mass: List[float],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
filter_value: float = -math.inf,
|
||||
min_tokens_to_keep: int = 1,
|
||||
):
|
||||
self.mass = mass
|
||||
self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
|
||||
|
||||
# 1 is a special value that disables typical_p warping for this member of the batch
|
||||
disabled = [x == 1.0 for x in mass]
|
||||
|
||||
if any(disabled):
|
||||
self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device)
|
||||
else:
|
||||
self.disabled_mask = None
|
||||
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
# calculate entropy
|
||||
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||
p = torch.exp(normalized)
|
||||
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||
|
||||
# shift and sort
|
||||
shifted_scores = torch.abs((-normalized) - ent)
|
||||
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||||
sorted_logits = scores.gather(-1, sorted_indices)
|
||||
probs = sorted_logits.softmax(dim=-1)
|
||||
# This is way faster for some reason
|
||||
for i in range(probs.shape[0]):
|
||||
probs[i] = probs[i].cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with cumulative mass above the threshold
|
||||
last_ind = (probs < self.mass_tensor).sum(dim=1)
|
||||
last_ind[last_ind < 0] = 0
|
||||
|
||||
if self.disabled_mask is not None:
|
||||
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
|
||||
|
||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
||||
1, last_ind.view(-1, 1)
|
||||
)
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
|
||||
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||
|
||||
return warped_scores
|
||||
|
||||
def filter(self, indices):
|
||||
self.mass = [self.mass[i] for i in indices]
|
||||
disabled = [x == 1.0 for x in self.mass]
|
||||
|
||||
if not all(disabled):
|
||||
self.mass_tensor = self.mass_tensor[indices]
|
||||
|
||||
if self.disabled_mask is not None:
|
||||
self.disabled_mask = (
|
||||
self.disabled_mask[indices] if any(disabled) else None
|
||||
)
|
||||
|
||||
return self
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousProcessorWrapper(LogitsProcessor):
|
||||
r"""
|
||||
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
||||
Args:
|
||||
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
|
||||
A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
|
||||
):
|
||||
self.processors = processors
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
for i, processor in self.processors.items():
|
||||
scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1])
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
new_processors = {}
|
||||
for i, idx in enumerate(indices):
|
||||
if idx in self.processors:
|
||||
new_processors[i] = self.processors[idx]
|
||||
|
||||
if new_processors:
|
||||
self.processors = new_processors
|
||||
return self
|
||||
return None
|
|
@ -1,12 +1,7 @@
|
|||
import re
|
||||
import torch
|
||||
|
||||
from functools import lru_cache
|
||||
from transformers import (
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
@ -15,82 +10,15 @@ from typing import List, Tuple, Optional
|
|||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
|
||||
|
||||
class Sampling:
|
||||
def __init__(self, seed: int, device: str = "cpu"):
|
||||
self.generator = torch.Generator(device)
|
||||
self.generator.manual_seed(seed)
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, logits):
|
||||
probs = torch.nn.functional.softmax(logits, -1)
|
||||
# Avoid GPU<->CPU sync done by torch multinomial
|
||||
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
||||
return probs.div_(q).argmax()
|
||||
|
||||
|
||||
class Greedy:
|
||||
def __call__(self, logits):
|
||||
return logits.argmax()
|
||||
|
||||
|
||||
class StaticWarper:
|
||||
def __init__(
|
||||
self,
|
||||
temperature=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
typical_p=None,
|
||||
):
|
||||
self.warpers = []
|
||||
|
||||
if temperature is not None and temperature != 1.0:
|
||||
temperature = float(temperature)
|
||||
self.warpers.append(TemperatureLogitsWarper(temperature))
|
||||
if top_k is not None and top_k != 0:
|
||||
self.warpers.append(TopKLogitsWarper(top_k=top_k))
|
||||
if top_p is not None and top_p < 1.0:
|
||||
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||
if typical_p is not None and typical_p < 1.0:
|
||||
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||
|
||||
self.cuda_graph = None
|
||||
self.static_scores = None
|
||||
self.static_warped_scores = None
|
||||
self.static_next_logprob = None
|
||||
|
||||
def __call__(self, scores):
|
||||
if self.cuda_graph is None:
|
||||
self.static_scores = scores
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph):
|
||||
for warper in self.warpers:
|
||||
self.static_warped_scores = warper(None, self.static_scores)
|
||||
|
||||
# Compute logprobs
|
||||
self.static_next_logprob = torch.log_softmax(
|
||||
self.static_warped_scores, -1
|
||||
)
|
||||
|
||||
self.static_scores.copy_(scores)
|
||||
self.cuda_graph.replay()
|
||||
|
||||
return self.static_warped_scores, self.static_next_logprob
|
||||
|
||||
|
||||
@lru_cache(10)
|
||||
def static_warper(
|
||||
temperature: Optional[float],
|
||||
top_k: Optional[int],
|
||||
top_p: Optional[float],
|
||||
typical_p: Optional[float],
|
||||
) -> StaticWarper:
|
||||
return StaticWarper(
|
||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||
)
|
||||
from text_generation_server.utils.logits_process import (
|
||||
static_warper,
|
||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||
HeterogeneousTemperatureLogitsWarper,
|
||||
HeterogeneousTopKLogitsWarper,
|
||||
HeterogeneousTopPLogitsWarper,
|
||||
HeterogeneousTypicalLogitsWarper,
|
||||
HeterogeneousProcessorWrapper,
|
||||
)
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
|
@ -132,9 +60,9 @@ class NextTokenChooser:
|
|||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
if self.watermark_processor:
|
||||
if self.watermark_processor is not None:
|
||||
scores = self.watermark_processor(input_ids, scores)
|
||||
if self.repetition_processor:
|
||||
if self.repetition_processor is not None:
|
||||
scores = self.repetition_processor(input_ids, scores)
|
||||
|
||||
if self.static_warper is None:
|
||||
|
@ -221,3 +149,191 @@ class StoppingCriteria:
|
|||
pb.max_new_tokens,
|
||||
pb.ignore_eos_token,
|
||||
)
|
||||
|
||||
|
||||
class HeterogeneousNextTokenChooser:
|
||||
def __init__(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
watermark: List[bool],
|
||||
temperature: List[float],
|
||||
repetition_penalty: List[float],
|
||||
top_k: List[int],
|
||||
top_p: List[float],
|
||||
typical_p: List[float],
|
||||
do_sample: List[bool],
|
||||
seeds: List[int],
|
||||
):
|
||||
warpers = []
|
||||
|
||||
self.watermark_processor = (
|
||||
HeterogeneousProcessorWrapper(
|
||||
{
|
||||
i: WatermarkLogitsProcessor(device=device)
|
||||
for i, do_watermark in enumerate(watermark)
|
||||
if do_watermark
|
||||
}
|
||||
)
|
||||
if any(watermark)
|
||||
else None
|
||||
)
|
||||
|
||||
self.repetition_processor = (
|
||||
HeterogeneousRepetitionPenaltyLogitsProcessor(
|
||||
repetition_penalty, dtype, device
|
||||
)
|
||||
if any([x != 1.0 for x in repetition_penalty])
|
||||
else None
|
||||
)
|
||||
|
||||
if any([x != 1.0 for x in temperature]):
|
||||
do_sample = [
|
||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||
]
|
||||
warpers.append(
|
||||
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
||||
)
|
||||
|
||||
if any([x != 0 for x in top_k]):
|
||||
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
||||
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
||||
|
||||
if any([x < 1.0 for x in top_p]):
|
||||
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
||||
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
|
||||
|
||||
if any([x < 1.0 for x in typical_p]):
|
||||
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
|
||||
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
|
||||
|
||||
self.warpers = warpers
|
||||
|
||||
if any(do_sample):
|
||||
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
||||
else:
|
||||
self.choice = Greedy()
|
||||
|
||||
self.seeds = seeds
|
||||
self.do_sample = do_sample
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
||||
if self.watermark_processor is not None:
|
||||
scores = self.watermark_processor(input_ids, scores)
|
||||
if self.repetition_processor is not None:
|
||||
scores = self.repetition_processor(input_ids, scores)
|
||||
|
||||
for warper in self.warpers:
|
||||
scores = warper(input_ids, scores)
|
||||
|
||||
next_ids = self.choice(scores)
|
||||
next_logprobs = torch.gather(
|
||||
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
|
||||
).view(-1)
|
||||
|
||||
return next_ids, next_logprobs
|
||||
|
||||
def filter(self, indices):
|
||||
if self.watermark_processor is not None:
|
||||
self.watermark_processor = self.watermark_processor.filter(indices)
|
||||
|
||||
if self.repetition_processor is not None:
|
||||
self.repetition_processor = self.repetition_processor.filter(indices)
|
||||
|
||||
filtered_warpers = []
|
||||
for warper in self.warpers:
|
||||
filtered_warper = warper.filter(indices)
|
||||
if filtered_warper is not None:
|
||||
filtered_warpers.append(filtered_warper)
|
||||
self.warpers = filtered_warpers
|
||||
|
||||
self.seeds = [self.seeds[i] for i in indices]
|
||||
self.do_sample = [self.do_sample[i] for i in indices]
|
||||
|
||||
if any(self.do_sample):
|
||||
self.choice.filter(indices)
|
||||
else:
|
||||
self.choice = Greedy()
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "HeterogeneousNextTokenChooser":
|
||||
return HeterogeneousNextTokenChooser(
|
||||
watermark=[pb_.watermark for pb_ in pb],
|
||||
temperature=[pb_.temperature for pb_ in pb],
|
||||
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
||||
top_k=[pb_.top_k for pb_ in pb],
|
||||
top_p=[pb_.top_p for pb_ in pb],
|
||||
typical_p=[pb_.typical_p for pb_ in pb],
|
||||
do_sample=[pb_.do_sample for pb_ in pb],
|
||||
seeds=[pb_.seed for pb_ in pb],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
class Sampling:
|
||||
def __init__(self, seed: int, device: str = "cpu"):
|
||||
self.generator = torch.Generator(device)
|
||||
self.generator.manual_seed(seed)
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, logits):
|
||||
probs = torch.nn.functional.softmax(logits, -1)
|
||||
# Avoid GPU<->CPU sync done by torch multinomial
|
||||
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
||||
return probs.div_(q).argmax()
|
||||
|
||||
|
||||
class Greedy:
|
||||
def __call__(self, logits):
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
|
||||
class HeterogeneousSampling:
|
||||
r"""
|
||||
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||
"""
|
||||
|
||||
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
|
||||
self.seeds = seeds
|
||||
|
||||
self.greedy_indices = []
|
||||
self.sampling_mapping = {}
|
||||
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
|
||||
if sample:
|
||||
self.sampling_mapping[i] = Sampling(seed, device)
|
||||
else:
|
||||
self.greedy_indices.append(i)
|
||||
|
||||
self.greedy = Greedy()
|
||||
|
||||
def __call__(self, logits):
|
||||
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||
if self.greedy_indices:
|
||||
# Computing for all indices is faster than slicing
|
||||
torch.argmax(logits, -1, out=out)
|
||||
|
||||
for i, sampling in self.sampling_mapping.items():
|
||||
out[i] = sampling(logits[i])
|
||||
return out
|
||||
|
||||
def filter(self, indices):
|
||||
new_greedy_indices = []
|
||||
new_sampling_mapping = {}
|
||||
for i, idx in enumerate(indices):
|
||||
if idx in self.sampling_mapping:
|
||||
new_sampling_mapping[i] = self.sampling_mapping[idx]
|
||||
else:
|
||||
new_greedy_indices.append(i)
|
||||
|
||||
self.greedy_indices = new_greedy_indices
|
||||
self.sampling_mapping = new_sampling_mapping
|
||||
return self
|
||||
|
|
Loading…
Reference in New Issue