feat(server): support vectorized warpers in flash causal lm (#317)
Co-authored-by: Joel Lamy-Poirier <joel.lamy-poirier@servicenow.com>
This commit is contained in:
parent
951930fbff
commit
62f91f78ac
|
@ -34,65 +34,65 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 408,
|
"id": 408,
|
||||||
"logprob": -1.9267578,
|
"logprob": -0.07891846,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " que"
|
"text": " que"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 20288,
|
|
||||||
"logprob": -2.9257812,
|
|
||||||
"special": false,
|
|
||||||
"text": " l'on"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 22255,
|
|
||||||
"logprob": -2.8964844,
|
|
||||||
"special": false,
|
|
||||||
"text": " trouve"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1622,
|
|
||||||
"logprob": -1.1083984,
|
|
||||||
"special": false,
|
|
||||||
"text": " une"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 187079,
|
|
||||||
"logprob": -7.796875,
|
|
||||||
"special": false,
|
|
||||||
"text": " posture"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 501,
|
|
||||||
"logprob": -5.390625,
|
|
||||||
"special": false,
|
|
||||||
"text": " par"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 8741,
|
|
||||||
"logprob": -0.34936523,
|
|
||||||
"special": false,
|
|
||||||
"text": " rapport"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 693,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " à"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 366,
|
"id": 366,
|
||||||
"logprob": -2.3378906,
|
"logprob": -1.2939453,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " la"
|
"text": " la"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 36503,
|
"id": 8769,
|
||||||
"logprob": -3.6640625,
|
"logprob": -0.3708496,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " pratique"
|
"text": " personne"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1479,
|
||||||
|
"logprob": -2.2871094,
|
||||||
|
"special": false,
|
||||||
|
"text": " qui"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2997,
|
||||||
|
"logprob": -0.8671875,
|
||||||
|
"special": false,
|
||||||
|
"text": " vous"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 35977,
|
||||||
|
"logprob": -1.5097656,
|
||||||
|
"special": false,
|
||||||
|
"text": " suit"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21558,
|
||||||
|
"logprob": -0.07891846,
|
||||||
|
"special": false,
|
||||||
|
"text": " ait"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 447,
|
||||||
|
"logprob": -0.12695312,
|
||||||
|
"special": false,
|
||||||
|
"text": " un"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 78606,
|
||||||
|
"logprob": -2.21875,
|
||||||
|
"special": false,
|
||||||
|
"text": " profil"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3899,
|
||||||
|
"logprob": -1.3535156,
|
||||||
|
"special": false,
|
||||||
|
"text": " bien"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "Pour déguster un ortolan, il faut tout d'abord que l'on trouve une posture par rapport à la pratique"
|
"generated_text": "Pour déguster un ortolan, il faut tout d'abord que la personne qui vous suit ait un profil bien"
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
"finish_reason": "length",
|
"finish_reason": "stop_sequence",
|
||||||
"generated_tokens": 10,
|
"generated_tokens": 5,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 1,
|
"id": 1,
|
||||||
|
@ -24,65 +24,35 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 5229,
|
"id": 5229,
|
||||||
"logprob": -3.3085938,
|
"logprob": -2.5683594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " failed"
|
"text": " failed"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 363,
|
|
||||||
"logprob": -3.984375,
|
|
||||||
"special": false,
|
|
||||||
"text": " for"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5641,
|
|
||||||
"logprob": -6.53125,
|
|
||||||
"special": false,
|
|
||||||
"text": " IP"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 16428,
|
|
||||||
"logprob": -3.1835938,
|
|
||||||
"special": false,
|
|
||||||
"text": " Address"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 29901,
|
"id": 29901,
|
||||||
"logprob": -1.2324219,
|
"logprob": -0.45336914,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 525,
|
"id": 4829,
|
||||||
"logprob": -2.6855469,
|
"logprob": -1.8408203,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " '"
|
"text": " Error"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 8516,
|
"id": 297,
|
||||||
"logprob": -7.1601562,
|
"logprob": -1.0556641,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "None"
|
"text": " in"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4286,
|
"id": 1243,
|
||||||
"logprob": -2.4433594,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "'."
|
"text": " test"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13,
|
|
||||||
"logprob": -0.06530762,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 294,
|
|
||||||
"logprob": -7.953125,
|
|
||||||
"special": false,
|
|
||||||
"text": "as"
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "Test requestfailed for IP Address: 'None'.\nas"
|
"generated_text": "Test requestfailed: Error in test"
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
"finish_reason": "eos_token",
|
"finish_reason": "length",
|
||||||
"generated_tokens": 12,
|
"generated_tokens": 60,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 589,
|
"id": 589,
|
||||||
|
@ -29,77 +29,365 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 2262,
|
"id": 2262,
|
||||||
"logprob": -0.7451172,
|
"logprob": -0.042999268,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "():"
|
"text": "():"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": -0.21325684,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5741,
|
"id": 1459,
|
||||||
"logprob": -5.734375,
|
|
||||||
"special": false,
|
|
||||||
"text": " logging"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 32,
|
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "."
|
"text": " print"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1338,
|
"id": 440,
|
||||||
"logprob": -0.3232422,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "info"
|
"text": "(\""
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 463,
|
|
||||||
"logprob": -1.0380859,
|
|
||||||
"special": false,
|
|
||||||
"text": "('"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 8279,
|
"id": 8279,
|
||||||
"logprob": -0.8378906,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Hello"
|
"text": "Hello"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 30,
|
|
||||||
"logprob": -1.9501953,
|
|
||||||
"special": false,
|
|
||||||
"text": ","
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 10896,
|
"id": 10896,
|
||||||
"logprob": -1.3476562,
|
"logprob": -0.3659668,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " World"
|
"text": " World"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 683,
|
"id": 657,
|
||||||
"logprob": -1.796875,
|
"logprob": -0.49804688,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "')"
|
"text": "\")"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 203,
|
"id": 203,
|
||||||
"logprob": -0.9873047,
|
"logprob": -0.11279297,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": 203,
|
||||||
"logprob": -0.7495117,
|
"logprob": 0.0,
|
||||||
"special": true,
|
"special": false,
|
||||||
"text": "<|endoftext|>"
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 589,
|
||||||
|
"logprob": -0.20141602,
|
||||||
|
"special": false,
|
||||||
|
"text": "def"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1459,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 81,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7656,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 81,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 426,
|
||||||
|
"logprob": -0.051635742,
|
||||||
|
"special": false,
|
||||||
|
"text": "name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "("
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 426,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 711,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "):"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 284,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1459,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 440,
|
||||||
|
"logprob": -0.16027832,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8279,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 313,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 474,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 636,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": ")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 203,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 203,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 589,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "def"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1459,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 81,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7656,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 81,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 426,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 81,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 381,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "age"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "("
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 426,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11442,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " age"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 711,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "):"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 284,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1459,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 440,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8279,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 313,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 474,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 636,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 474,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 313,
|
||||||
|
"logprob": -0.6328125,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 313,
|
||||||
|
"logprob": -1.7011719,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 474,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 596,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " str"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "("
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 381,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "age"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 490,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "))"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 203,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 203,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 589,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "def"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1459,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "():\n logging.info('Hello, World')\n<|endoftext|>"
|
"generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print"
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
"finish_reason": "length",
|
"finish_reason": "eos_token",
|
||||||
"generated_tokens": 10,
|
"generated_tokens": 9,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": 0,
|
||||||
|
@ -14,65 +14,59 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 16017,
|
"id": 16017,
|
||||||
"logprob": -1.3505859,
|
"logprob": -0.30908203,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " blue"
|
"text": " blue"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 20495,
|
"id": 20495,
|
||||||
"logprob": -0.50439453,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " sky"
|
"text": " sky"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
"logprob": -1.2011719,
|
"logprob": -0.28271484,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " "
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 15484,
|
"id": 15484,
|
||||||
"logprob": -2.8378906,
|
"logprob": -1.7929688,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "appear"
|
"text": "appear"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 345,
|
"id": 345,
|
||||||
"logprob": -0.87597656,
|
"logprob": -0.8935547,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ed"
|
"text": "ed"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 288,
|
"id": 281,
|
||||||
"logprob": -1.8447266,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " to"
|
"text": " in"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 35622,
|
"id": 287,
|
||||||
"logprob": -7.1445312,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " cloud"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 20495,
|
||||||
"logprob": -1.2929688,
|
"logprob": -0.32299805,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "s"
|
"text": " sky"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 14701,
|
"id": 1,
|
||||||
"logprob": -3.0761719,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": true,
|
||||||
"text": " above"
|
"text": "</s>"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 751,
|
|
||||||
"logprob": -4.4375,
|
|
||||||
"special": false,
|
|
||||||
"text": " all"
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "Why is the sky blue?blue sky appeared to clouds above all"
|
"generated_text": "Why is the sky blue?blue sky appeared in the sky"
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 5
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot
|
||||||
"def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0
|
"def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 12
|
assert response.details.generated_tokens == 60
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 9
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ def default_pb_batch(default_pb_request):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
|
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
|
||||||
return BloomCausalLMBatch.from_pb(
|
return BloomCausalLMBatch.from_pb(
|
||||||
default_pb_batch, bloom_560m_tokenizer, torch.device("cpu")
|
default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
|
||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||||
return BloomCausalLMBatch.from_pb(
|
return BloomCausalLMBatch.from_pb(
|
||||||
batch_pb, bloom_560m_tokenizer, torch.device("cpu")
|
batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -286,7 +286,9 @@ def test_batch_concatenate(
|
||||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
next_batch = next_batch.filter(
|
||||||
|
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
|
|
@ -38,7 +38,9 @@ def default_pb_batch(default_pb_request):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
||||||
return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu"))
|
return CausalLMBatch.from_pb(
|
||||||
|
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -50,7 +52,9 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
||||||
req_1.stopping_parameters.max_new_tokens = 5
|
req_1.stopping_parameters.max_new_tokens = 5
|
||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
|
||||||
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
return CausalLMBatch.from_pb(
|
||||||
|
batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
||||||
|
@ -285,7 +289,9 @@ def test_batch_concatenate(
|
||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
next_batch = next_batch.filter(
|
||||||
|
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
|
|
@ -45,7 +45,10 @@ def default_fim_pb_batch(default_fim_pb_request):
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
|
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
|
||||||
batch = CausalLMBatch.from_pb(
|
batch = CausalLMBatch.from_pb(
|
||||||
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
|
default_pb_batch,
|
||||||
|
default_santacoder.tokenizer,
|
||||||
|
default_santacoder.dtype,
|
||||||
|
default_santacoder.device,
|
||||||
)
|
)
|
||||||
next_batch = batch
|
next_batch = batch
|
||||||
|
|
||||||
|
@ -70,7 +73,10 @@ def test_fim_santacoder_generate_token_completion(
|
||||||
default_santacoder, default_fim_pb_batch
|
default_santacoder, default_fim_pb_batch
|
||||||
):
|
):
|
||||||
batch = CausalLMBatch.from_pb(
|
batch = CausalLMBatch.from_pb(
|
||||||
default_fim_pb_batch, default_santacoder.tokenizer, default_santacoder.device
|
default_fim_pb_batch,
|
||||||
|
default_santacoder.tokenizer,
|
||||||
|
default_santacoder.dtype,
|
||||||
|
default_santacoder.device,
|
||||||
)
|
)
|
||||||
next_batch = batch
|
next_batch = batch
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ def default_pb_batch(default_pb_request):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
|
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
|
||||||
return Seq2SeqLMBatch.from_pb(
|
return Seq2SeqLMBatch.from_pb(
|
||||||
default_pb_batch, mt0_small_tokenizer, torch.device("cpu")
|
default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,7 +55,9 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
|
||||||
req_1.stopping_parameters.max_new_tokens = 5
|
req_1.stopping_parameters.max_new_tokens = 5
|
||||||
|
|
||||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||||
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
|
return Seq2SeqLMBatch.from_pb(
|
||||||
|
batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
||||||
|
@ -323,7 +325,9 @@ def test_batch_concatenate(
|
||||||
)
|
)
|
||||||
assert generations[2].generated_text.generated_tokens == 5
|
assert generations[2].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
next_batch = next_batch.filter(
|
||||||
|
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||||
|
)
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
|
@ -39,10 +39,11 @@ class BloomCausalLMBatch(CausalLMBatch):
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
batch = super(BloomCausalLMBatch, cls).from_pb(
|
batch = super(BloomCausalLMBatch, cls).from_pb(
|
||||||
pb=pb, tokenizer=tokenizer, device=device
|
pb=pb, tokenizer=tokenizer, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
batch.keys_head_dim_last = False
|
batch.keys_head_dim_last = False
|
||||||
return batch
|
return batch
|
||||||
|
|
|
@ -66,6 +66,7 @@ class CausalLMBatch(Batch):
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
inputs = []
|
inputs = []
|
||||||
|
|
|
@ -18,11 +18,7 @@ from text_generation_server.models.types import (
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
NextTokenChooser,
|
|
||||||
StoppingCriteria,
|
|
||||||
Sampling,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -48,7 +44,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
all_input_ids: List[List[int]]
|
all_input_ids: List[List[int]]
|
||||||
all_input_ids_tensor: List[torch.Tensor]
|
all_input_ids_tensor: torch.Tensor
|
||||||
|
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
|
@ -56,7 +52,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
read_offsets: List[Optional[int]]
|
read_offsets: List[Optional[int]]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_chooser: HeterogeneousNextTokenChooser
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
|
||||||
# Maximum number of tokens this batch will grow to
|
# Maximum number of tokens this batch will grow to
|
||||||
|
@ -75,6 +71,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
position_ids = []
|
position_ids = []
|
||||||
|
@ -87,13 +84,14 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
next_token_choosers = []
|
next_token_chooser_parameters = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
|
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
|
max_length = 0
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
|
@ -119,7 +117,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
cu_seqlens.append(cumulative_length + input_length)
|
cu_seqlens.append(cumulative_length + input_length)
|
||||||
|
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_chooser_parameters.append(r.parameters)
|
||||||
|
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
|
@ -130,11 +128,26 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
max_tokens += input_length + max_new_tokens
|
max_tokens += input_length + max_new_tokens
|
||||||
|
max_length = max(max_length, input_length + max_new_tokens)
|
||||||
|
|
||||||
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
|
next_token_chooser_parameters, dtype, device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Padded all_input_ids_tensor
|
||||||
|
all_input_ids_tensor = np.zeros(
|
||||||
|
(len(all_input_ids), max_length), dtype=np.int64
|
||||||
|
)
|
||||||
|
for i, input_ids in enumerate(all_input_ids):
|
||||||
|
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||||
|
|
||||||
# Create tensors on device
|
# Create tensors on device
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(
|
||||||
np.concatenate(all_input_ids), dtype=torch.int64, device=device
|
np.concatenate(all_input_ids), dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
|
all_input_ids_tensor = torch.tensor(
|
||||||
|
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
position_ids = torch.tensor(
|
position_ids = torch.tensor(
|
||||||
np.concatenate(position_ids), dtype=torch.int32, device=device
|
np.concatenate(position_ids), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
@ -154,8 +167,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=[],
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
@ -176,31 +189,29 @@ class FlashCausalLMBatch(Batch):
|
||||||
# New values after filtering
|
# New values after filtering
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
input_ids = self.input_ids.new_empty(len(request_ids))
|
# Used to index into tensors
|
||||||
position_ids = self.position_ids.new_empty(len(request_ids))
|
indices = []
|
||||||
|
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
# Create on CPU to only move to GPU once instead of at every copy
|
||||||
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
|
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
|
||||||
cu_seqlens_q = torch.arange(
|
cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1]
|
||||||
0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
all_input_ids_tensor = []
|
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
next_token_choosers = []
|
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
|
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
indices.append(idx)
|
||||||
requests_idx_mapping[request_id] = i
|
requests_idx_mapping[request_id] = i
|
||||||
|
|
||||||
requests.append(self.requests[idx])
|
requests.append(self.requests[idx])
|
||||||
|
@ -208,10 +219,6 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Get length
|
# Get length
|
||||||
request_input_length = self.input_lengths[idx]
|
request_input_length = self.input_lengths[idx]
|
||||||
|
|
||||||
# Copy tensors (GPU)
|
|
||||||
input_ids[i] = self.input_ids[idx]
|
|
||||||
position_ids[i] = self.position_ids[idx]
|
|
||||||
|
|
||||||
# Copy to tensor (CPU)
|
# Copy to tensor (CPU)
|
||||||
cu_seqlens[i + 1] = cumulative_length + request_input_length
|
cu_seqlens[i + 1] = cumulative_length + request_input_length
|
||||||
max_seqlen = max(max_seqlen, request_input_length)
|
max_seqlen = max(max_seqlen, request_input_length)
|
||||||
|
@ -222,14 +229,11 @@ class FlashCausalLMBatch(Batch):
|
||||||
)
|
)
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
|
|
||||||
|
|
||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
prefix_offsets.append(self.prefix_offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
read_offsets.append(self.read_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
|
||||||
next_token_choosers.append(self.next_token_choosers[idx])
|
|
||||||
|
|
||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
|
||||||
|
@ -258,6 +262,12 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Cat all past
|
# Cat all past
|
||||||
past_key_values = torch.cat(past_key_values, dim=1)
|
past_key_values = torch.cat(past_key_values, dim=1)
|
||||||
|
|
||||||
|
# Index into tensors
|
||||||
|
input_ids = self.input_ids[indices]
|
||||||
|
position_ids = self.position_ids[indices]
|
||||||
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
|
|
||||||
# Move to GPU now that we have the whole tensor
|
# Move to GPU now that we have the whole tensor
|
||||||
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
|
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
|
||||||
|
|
||||||
|
@ -276,7 +286,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
@ -290,6 +300,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
total_batch_size = sum([len(b) for b in batches])
|
total_batch_size = sum([len(b) for b in batches])
|
||||||
|
|
||||||
|
dtype = batches[0].past_key_values.dtype
|
||||||
device = batches[0].input_ids.device
|
device = batches[0].input_ids.device
|
||||||
|
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
|
@ -302,19 +313,19 @@ class FlashCausalLMBatch(Batch):
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
|
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
all_input_ids_tensor = []
|
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
next_token_choosers = []
|
next_token_chooser_parameters = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_batch_size = 0
|
cumulative_batch_size = 0
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
|
max_length = 0
|
||||||
|
|
||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
|
@ -347,25 +358,54 @@ class FlashCausalLMBatch(Batch):
|
||||||
)
|
)
|
||||||
|
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
|
||||||
|
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
prefix_offsets.extend(batch.prefix_offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
read_offsets.extend(batch.read_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
|
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += batch.cu_seqlens[-1]
|
cumulative_length += batch.cu_seqlens[-1]
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
max_tokens += batch.max_tokens
|
max_tokens += batch.max_tokens
|
||||||
|
max_length = max(
|
||||||
|
max_length,
|
||||||
|
max(
|
||||||
|
input_length
|
||||||
|
+ stopping_criteria.max_new_tokens
|
||||||
|
- stopping_criteria.current_tokens
|
||||||
|
for input_length, stopping_criteria in zip(
|
||||||
|
batch.input_lengths, batch.stopping_criterias
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
all_input_ids_tensor = torch.zeros(
|
||||||
|
(total_batch_size, max_length), dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
cumulative_batch_size = 0
|
||||||
|
for i, batch in enumerate(batches):
|
||||||
|
start_index = cumulative_batch_size
|
||||||
|
end_index = cumulative_batch_size + len(batch)
|
||||||
|
|
||||||
|
all_input_ids_tensor[
|
||||||
|
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||||
|
] = batch.all_input_ids_tensor[:, :max_length]
|
||||||
|
|
||||||
|
cumulative_batch_size += len(batch)
|
||||||
|
|
||||||
# Cat past
|
# Cat past
|
||||||
past_key_values = torch.cat(past_key_values, dim=1)
|
past_key_values = torch.cat(past_key_values, dim=1)
|
||||||
# Create final tensor on GPU
|
# Create final tensor on GPU
|
||||||
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
|
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
|
next_token_chooser_parameters, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return FlashCausalLMBatch(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
|
@ -381,7 +421,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
@ -463,6 +503,7 @@ class FlashCausalLM(Model):
|
||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||||
prefill = batch.past_key_values is None
|
prefill = batch.past_key_values is None
|
||||||
|
single_request = len(batch) == 1
|
||||||
|
|
||||||
if prefill and len(batch) == 1:
|
if prefill and len(batch) == 1:
|
||||||
# Ask to pre-allocate kv to its max size
|
# Ask to pre-allocate kv to its max size
|
||||||
|
@ -483,6 +524,17 @@ class FlashCausalLM(Model):
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prefill:
|
||||||
|
next_token_logits = (
|
||||||
|
out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
next_token_logits = out
|
||||||
|
|
||||||
|
next_input_ids, next_token_logprobs = batch.next_token_chooser(
|
||||||
|
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
|
||||||
|
)
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
if len(batch) > 1:
|
if len(batch) > 1:
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||||
|
@ -493,15 +545,11 @@ class FlashCausalLM(Model):
|
||||||
batch.cu_seqlens_q = torch.arange(
|
batch.cu_seqlens_q = torch.arange(
|
||||||
0, len(batch) + 1, device=self.device, dtype=torch.int32
|
0, len(batch) + 1, device=self.device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
next_input_ids = batch.input_ids.new_empty(len(batch))
|
|
||||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||||
else:
|
else:
|
||||||
prefill_logprobs = None
|
prefill_logprobs = None
|
||||||
next_input_ids = batch.input_ids
|
|
||||||
next_position_ids = batch.position_ids
|
next_position_ids = batch.position_ids
|
||||||
|
|
||||||
next_token_logprobs = out.new_empty(len(batch))
|
|
||||||
|
|
||||||
# Prepare past for next decode
|
# Prepare past for next decode
|
||||||
if len(batch) > 1:
|
if len(batch) > 1:
|
||||||
# Used to slice next batch past
|
# Used to slice next batch past
|
||||||
|
@ -552,7 +600,6 @@ class FlashCausalLM(Model):
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.next_token_choosers,
|
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
)
|
)
|
||||||
|
@ -564,7 +611,6 @@ class FlashCausalLM(Model):
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
input_length,
|
input_length,
|
||||||
next_token_chooser,
|
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
|
@ -573,21 +619,6 @@ class FlashCausalLM(Model):
|
||||||
end_index = cumulative_length + input_length
|
end_index = cumulative_length + input_length
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
# Prefill mode
|
|
||||||
# out is of shape [cumulative_sequence_lengths, vocab_size]
|
|
||||||
# only take last token logit
|
|
||||||
logits = out[end_index - 1 : end_index]
|
|
||||||
|
|
||||||
# Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
|
|
||||||
all_input_ids_tensor = batch.input_ids.new_empty(
|
|
||||||
input_length + stopping_criteria.max_new_tokens
|
|
||||||
)
|
|
||||||
# Copy from batch.input_ids to all_input_ids_tensor
|
|
||||||
all_input_ids_tensor[:input_length] = batch.input_ids[
|
|
||||||
start_index:end_index
|
|
||||||
]
|
|
||||||
batch.all_input_ids_tensor.append(all_input_ids_tensor)
|
|
||||||
|
|
||||||
# Initialize position_ids
|
# Initialize position_ids
|
||||||
# In decode, we do not need this as we can just increment position ids
|
# In decode, we do not need this as we can just increment position ids
|
||||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||||
|
@ -603,25 +634,8 @@ class FlashCausalLM(Model):
|
||||||
prefill_tokens_indices = batch.input_ids[
|
prefill_tokens_indices = batch.input_ids[
|
||||||
start_index + 1 : end_index
|
start_index + 1 : end_index
|
||||||
]
|
]
|
||||||
else:
|
|
||||||
# Decode mode
|
|
||||||
# out is of shape [batch_size, vocab_size]
|
|
||||||
logits = out[i].view(1, -1)
|
|
||||||
|
|
||||||
all_input_ids_tensor = batch.all_input_ids_tensor[i]
|
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
|
||||||
|
|
||||||
# Select next token
|
|
||||||
next_token_id, logprob = next_token_chooser(
|
|
||||||
all_input_ids_tensor[None, :input_length], logits
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add to all_input_ids_tensor
|
|
||||||
next_token_id_squeezed = next_token_id.view(1)
|
|
||||||
all_input_ids_tensor[input_length] = next_token_id_squeezed
|
|
||||||
|
|
||||||
# Set values
|
|
||||||
next_input_ids[i] = next_token_id_squeezed
|
|
||||||
next_token_logprobs[i] = logprob[-1, next_token_id].view(1)
|
|
||||||
|
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
|
@ -651,10 +665,11 @@ class FlashCausalLM(Model):
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.prefix_offsets,
|
batch.prefix_offsets,
|
||||||
batch.read_offsets,
|
batch.read_offsets,
|
||||||
batch.next_token_choosers,
|
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
batch.all_input_ids_tensor,
|
batch.all_input_ids_tensor,
|
||||||
|
batch.next_token_chooser.do_sample,
|
||||||
|
batch.next_token_chooser.seeds,
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
next_token_logprobs,
|
next_token_logprobs,
|
||||||
)
|
)
|
||||||
|
@ -665,10 +680,11 @@ class FlashCausalLM(Model):
|
||||||
input_length,
|
input_length,
|
||||||
prefix_offset,
|
prefix_offset,
|
||||||
read_offset,
|
read_offset,
|
||||||
next_token_chooser,
|
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
all_input_ids_tensor,
|
all_input_ids_tensor,
|
||||||
|
do_sample,
|
||||||
|
seed,
|
||||||
next_token_id,
|
next_token_id,
|
||||||
next_token_logprob,
|
next_token_logprob,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
|
@ -702,14 +718,11 @@ class FlashCausalLM(Model):
|
||||||
output_text = self.decode(
|
output_text = self.decode(
|
||||||
all_input_ids[-stopping_criteria.current_tokens :]
|
all_input_ids[-stopping_criteria.current_tokens :]
|
||||||
)
|
)
|
||||||
# Get seed
|
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
|
||||||
seed = next_token_chooser.choice.seed
|
|
||||||
else:
|
|
||||||
seed = None
|
|
||||||
|
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
output_text, stopping_criteria.current_tokens, reason, seed
|
output_text,
|
||||||
|
stopping_criteria.current_tokens,
|
||||||
|
reason,
|
||||||
|
seed if do_sample else None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
generated_text = None
|
generated_text = None
|
||||||
|
@ -751,8 +764,9 @@ class FlashCausalLM(Model):
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.read_offsets[i] = read_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.max_seqlen = batch.max_seqlen + 1
|
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
|
batch.max_seqlen = batch.max_seqlen + 1
|
||||||
|
|
||||||
# No need to return a batch if we know that all requests stopped
|
# No need to return a batch if we know that all requests stopped
|
||||||
return generations, batch if not stopped else None
|
return generations, batch if not stopped else None
|
||||||
|
|
|
@ -89,6 +89,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "GalacticaCausalLMBatch":
|
) -> "GalacticaCausalLMBatch":
|
||||||
inputs = []
|
inputs = []
|
||||||
|
|
|
@ -71,6 +71,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "Seq2SeqLMBatch":
|
) -> "Seq2SeqLMBatch":
|
||||||
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||||
|
|
|
@ -21,6 +21,7 @@ class Batch(ABC):
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "Batch":
|
) -> "Batch":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -55,7 +55,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
|
|
|
@ -9,12 +9,13 @@ from text_generation_server.utils.hub import (
|
||||||
RevisionNotFoundError,
|
RevisionNotFoundError,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.tokens import (
|
from text_generation_server.utils.tokens import (
|
||||||
Greedy,
|
|
||||||
NextTokenChooser,
|
NextTokenChooser,
|
||||||
Sampling,
|
HeterogeneousNextTokenChooser,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
StopSequenceCriteria,
|
StopSequenceCriteria,
|
||||||
FinishReason,
|
FinishReason,
|
||||||
|
Sampling,
|
||||||
|
Greedy,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -25,6 +26,7 @@ __all__ = [
|
||||||
"weight_hub_files",
|
"weight_hub_files",
|
||||||
"download_weights",
|
"download_weights",
|
||||||
"EntryNotFoundError",
|
"EntryNotFoundError",
|
||||||
|
"HeterogeneousNextTokenChooser",
|
||||||
"LocalEntryNotFoundError",
|
"LocalEntryNotFoundError",
|
||||||
"RevisionNotFoundError",
|
"RevisionNotFoundError",
|
||||||
"Greedy",
|
"Greedy",
|
||||||
|
|
|
@ -1,14 +1,10 @@
|
||||||
import concurrent
|
|
||||||
import time
|
|
||||||
import datetime
|
import datetime
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import timedelta
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import save_file
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,405 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Optional, List, Dict, Union
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
LogitsWarper,
|
||||||
|
LogitsProcessor,
|
||||||
|
TemperatureLogitsWarper,
|
||||||
|
TopKLogitsWarper,
|
||||||
|
TopPLogitsWarper,
|
||||||
|
TypicalLogitsWarper,
|
||||||
|
)
|
||||||
|
|
||||||
|
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
|
|
||||||
|
|
||||||
|
class StaticWarper:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=None,
|
||||||
|
top_p=None,
|
||||||
|
typical_p=None,
|
||||||
|
):
|
||||||
|
self.warpers = []
|
||||||
|
|
||||||
|
if temperature is not None and temperature != 1.0:
|
||||||
|
temperature = float(temperature)
|
||||||
|
self.warpers.append(TemperatureLogitsWarper(temperature))
|
||||||
|
if top_k is not None and top_k != 0:
|
||||||
|
self.warpers.append(TopKLogitsWarper(top_k=top_k))
|
||||||
|
if top_p is not None and top_p < 1.0:
|
||||||
|
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||||
|
if typical_p is not None and typical_p < 1.0:
|
||||||
|
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||||
|
|
||||||
|
self.cuda_graph = None
|
||||||
|
self.static_scores = None
|
||||||
|
self.static_warped_scores = None
|
||||||
|
self.static_next_logprob = None
|
||||||
|
|
||||||
|
def __call__(self, scores):
|
||||||
|
if self.cuda_graph is None:
|
||||||
|
self.static_scores = scores
|
||||||
|
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
|
with torch.cuda.graph(self.cuda_graph, pool=mempool):
|
||||||
|
local_scores = self.static_scores
|
||||||
|
for warper in self.warpers:
|
||||||
|
local_scores = warper(None, local_scores)
|
||||||
|
|
||||||
|
self.static_warped_scores = local_scores
|
||||||
|
# Compute logprobs
|
||||||
|
self.static_next_logprob = torch.log_softmax(
|
||||||
|
self.static_warped_scores, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.static_scores.copy_(scores)
|
||||||
|
self.cuda_graph.replay()
|
||||||
|
|
||||||
|
return self.static_warped_scores, self.static_next_logprob
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(10)
|
||||||
|
def static_warper(
|
||||||
|
temperature: Optional[float],
|
||||||
|
top_k: Optional[int],
|
||||||
|
top_p: Optional[float],
|
||||||
|
typical_p: Optional[float],
|
||||||
|
) -> StaticWarper:
|
||||||
|
return StaticWarper(
|
||||||
|
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
|
||||||
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
|
It doesn't validate inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repetition_penalty (`List[float]`):
|
||||||
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
||||||
|
self.penalty = penalty
|
||||||
|
self.penalty_tensor = torch.tensor(
|
||||||
|
penalty, dtype=dtype, device=device
|
||||||
|
).unsqueeze(1)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
score = torch.gather(scores, 1, input_ids)
|
||||||
|
|
||||||
|
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||||
|
score = torch.where(
|
||||||
|
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
||||||
|
)
|
||||||
|
|
||||||
|
scores.scatter_(1, input_ids, score)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def filter(self, indices):
|
||||||
|
self.penalty = [self.penalty[i] for i in indices]
|
||||||
|
if any([x != 1.0 for x in self.penalty]):
|
||||||
|
self.penalty_tensor = self.penalty_tensor[indices]
|
||||||
|
return self
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousTemperatureLogitsWarper:
|
||||||
|
r"""
|
||||||
|
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||||
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
|
It doesn't validate inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, temperature: List[float], dtype: torch.dtype, device: torch.device
|
||||||
|
):
|
||||||
|
self.temperature = temperature
|
||||||
|
self.temperature_tensor = torch.tensor(
|
||||||
|
temperature, dtype=dtype, device=device
|
||||||
|
).unsqueeze(1)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
scores.div_(self.temperature_tensor)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def filter(self, indices):
|
||||||
|
self.temperature = [self.temperature[i] for i in indices]
|
||||||
|
if any([x != 1.0 for x in self.temperature]):
|
||||||
|
self.temperature_tensor = self.temperature_tensor[indices]
|
||||||
|
return self
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||||
|
"""
|
||||||
|
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||||
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
|
It doesn't validate inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation.
|
||||||
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||||
|
All filtered values will be set to this float value.
|
||||||
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||||
|
Minimum number of tokens that cannot be filtered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
top_p: List[float],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
filter_value: float = -math.inf,
|
||||||
|
min_tokens_to_keep: int = 1,
|
||||||
|
):
|
||||||
|
self.top_p = top_p
|
||||||
|
self.top_p_opposite = 1 - torch.tensor(
|
||||||
|
top_p, dtype=dtype, device=device
|
||||||
|
).unsqueeze(1)
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
||||||
|
probs = sorted_logits.softmax(dim=-1)
|
||||||
|
# This is way faster for some reason
|
||||||
|
for i in range(probs.shape[0]):
|
||||||
|
probs[i] = probs[i].cumsum(dim=-1)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||||
|
sorted_indices_to_remove = probs <= self.top_p_opposite
|
||||||
|
if self.min_tokens_to_keep > 1:
|
||||||
|
# Keep at least min_tokens_to_keep
|
||||||
|
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||||
|
|
||||||
|
# scatter sorted tensors to original indexing
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
|
|
||||||
|
return warped_scores
|
||||||
|
|
||||||
|
def filter(self, indices):
|
||||||
|
self.top_p = [self.top_p[i] for i in indices]
|
||||||
|
if any([x < 1.0 for x in self.top_p]):
|
||||||
|
self.top_p_opposite = self.top_p_opposite[indices]
|
||||||
|
return self
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||||
|
r"""
|
||||||
|
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||||
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
|
It doesn't validate inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_k (`int`):
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||||
|
All filtered values will be set to this float value.
|
||||||
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||||
|
Minimum number of tokens that cannot be filtered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
top_k: List[int],
|
||||||
|
device: torch.device,
|
||||||
|
filter_value: float = -math.inf,
|
||||||
|
min_tokens_to_keep: int = 1,
|
||||||
|
):
|
||||||
|
self.top_k = top_k
|
||||||
|
self.max_top_k = max(top_k)
|
||||||
|
# value - 1 as we will use top_k to index and python uses 0 based numbering
|
||||||
|
self.top_k_tensor = torch.tensor(
|
||||||
|
[max(x - 1, min_tokens_to_keep - 1) for x in top_k],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
).unsqueeze(1)
|
||||||
|
|
||||||
|
# 0 is a special value that disables top_k warping for this member of the batch
|
||||||
|
disabled = [x == 0 for x in top_k]
|
||||||
|
|
||||||
|
if any(disabled):
|
||||||
|
self.top_k_disabled_mask = torch.tensor(
|
||||||
|
disabled, dtype=torch.bool, device=device
|
||||||
|
).view(-1, 1)
|
||||||
|
else:
|
||||||
|
self.top_k_disabled_mask = None
|
||||||
|
|
||||||
|
self.filter_value = filter_value
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
# If max_top_k is superior to the vocab, we need to clamp or the warper will fail
|
||||||
|
if scores.size(-1) < self.max_top_k:
|
||||||
|
max_top_k = scores.size(-1)
|
||||||
|
top_k = torch.clamp_max(self.top_k_tensor, max_top_k)
|
||||||
|
else:
|
||||||
|
max_top_k = self.max_top_k
|
||||||
|
top_k = self.top_k_tensor
|
||||||
|
|
||||||
|
# Get the kth score for each member of the batch
|
||||||
|
kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
|
||||||
|
|
||||||
|
# Mask member of kth_scores that do not want to use top_k warping
|
||||||
|
if self.top_k_disabled_mask is not None:
|
||||||
|
kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value)
|
||||||
|
|
||||||
|
# Remove all tokens with a probability less than the last token of the top-k
|
||||||
|
indices_to_remove = scores < kth_scores
|
||||||
|
scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def filter(self, indices):
|
||||||
|
self.top_k = [self.top_k[i] for i in indices]
|
||||||
|
disabled = [x == 0 for x in self.top_k]
|
||||||
|
|
||||||
|
if not all(disabled):
|
||||||
|
self.top_k_tensor = self.top_k_tensor[indices]
|
||||||
|
self.max_top_k = max(self.top_k)
|
||||||
|
|
||||||
|
if self.top_k_disabled_mask is not None:
|
||||||
|
self.top_k_disabled_mask = (
|
||||||
|
self.top_k_disabled_mask[indices] if any(disabled) else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||||
|
r"""
|
||||||
|
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
||||||
|
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||||
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
|
It doesn't validate inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mass (`float`):
|
||||||
|
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
|
||||||
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||||
|
All filtered values will be set to this float value.
|
||||||
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||||
|
Minimum number of tokens that cannot be filtered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mass: List[float],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
filter_value: float = -math.inf,
|
||||||
|
min_tokens_to_keep: int = 1,
|
||||||
|
):
|
||||||
|
self.mass = mass
|
||||||
|
self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
|
||||||
|
|
||||||
|
# 1 is a special value that disables typical_p warping for this member of the batch
|
||||||
|
disabled = [x == 1.0 for x in mass]
|
||||||
|
|
||||||
|
if any(disabled):
|
||||||
|
self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device)
|
||||||
|
else:
|
||||||
|
self.disabled_mask = None
|
||||||
|
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
# calculate entropy
|
||||||
|
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||||
|
p = torch.exp(normalized)
|
||||||
|
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||||
|
|
||||||
|
# shift and sort
|
||||||
|
shifted_scores = torch.abs((-normalized) - ent)
|
||||||
|
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||||||
|
sorted_logits = scores.gather(-1, sorted_indices)
|
||||||
|
probs = sorted_logits.softmax(dim=-1)
|
||||||
|
# This is way faster for some reason
|
||||||
|
for i in range(probs.shape[0]):
|
||||||
|
probs[i] = probs[i].cumsum(dim=-1)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative mass above the threshold
|
||||||
|
last_ind = (probs < self.mass_tensor).sum(dim=1)
|
||||||
|
last_ind[last_ind < 0] = 0
|
||||||
|
|
||||||
|
if self.disabled_mask is not None:
|
||||||
|
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
|
||||||
|
|
||||||
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
||||||
|
1, last_ind.view(-1, 1)
|
||||||
|
)
|
||||||
|
if self.min_tokens_to_keep > 1:
|
||||||
|
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||||
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
|
||||||
|
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
|
|
||||||
|
return warped_scores
|
||||||
|
|
||||||
|
def filter(self, indices):
|
||||||
|
self.mass = [self.mass[i] for i in indices]
|
||||||
|
disabled = [x == 1.0 for x in self.mass]
|
||||||
|
|
||||||
|
if not all(disabled):
|
||||||
|
self.mass_tensor = self.mass_tensor[indices]
|
||||||
|
|
||||||
|
if self.disabled_mask is not None:
|
||||||
|
self.disabled_mask = (
|
||||||
|
self.disabled_mask[indices] if any(disabled) else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousProcessorWrapper(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
||||||
|
Args:
|
||||||
|
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
|
||||||
|
A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
|
||||||
|
):
|
||||||
|
self.processors = processors
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
for i, processor in self.processors.items():
|
||||||
|
scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1])
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def filter(self, indices):
|
||||||
|
new_processors = {}
|
||||||
|
for i, idx in enumerate(indices):
|
||||||
|
if idx in self.processors:
|
||||||
|
new_processors[i] = self.processors[idx]
|
||||||
|
|
||||||
|
if new_processors:
|
||||||
|
self.processors = new_processors
|
||||||
|
return self
|
||||||
|
return None
|
|
@ -1,12 +1,7 @@
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from functools import lru_cache
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TemperatureLogitsWarper,
|
|
||||||
TopKLogitsWarper,
|
|
||||||
TopPLogitsWarper,
|
|
||||||
TypicalLogitsWarper,
|
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
|
@ -15,82 +10,15 @@ from typing import List, Tuple, Optional
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
|
from text_generation_server.utils.logits_process import (
|
||||||
|
static_warper,
|
||||||
class Sampling:
|
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||||
def __init__(self, seed: int, device: str = "cpu"):
|
HeterogeneousTemperatureLogitsWarper,
|
||||||
self.generator = torch.Generator(device)
|
HeterogeneousTopKLogitsWarper,
|
||||||
self.generator.manual_seed(seed)
|
HeterogeneousTopPLogitsWarper,
|
||||||
self.seed = seed
|
HeterogeneousTypicalLogitsWarper,
|
||||||
|
HeterogeneousProcessorWrapper,
|
||||||
def __call__(self, logits):
|
)
|
||||||
probs = torch.nn.functional.softmax(logits, -1)
|
|
||||||
# Avoid GPU<->CPU sync done by torch multinomial
|
|
||||||
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
|
||||||
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
|
||||||
return probs.div_(q).argmax()
|
|
||||||
|
|
||||||
|
|
||||||
class Greedy:
|
|
||||||
def __call__(self, logits):
|
|
||||||
return logits.argmax()
|
|
||||||
|
|
||||||
|
|
||||||
class StaticWarper:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
temperature=1.0,
|
|
||||||
top_k=None,
|
|
||||||
top_p=None,
|
|
||||||
typical_p=None,
|
|
||||||
):
|
|
||||||
self.warpers = []
|
|
||||||
|
|
||||||
if temperature is not None and temperature != 1.0:
|
|
||||||
temperature = float(temperature)
|
|
||||||
self.warpers.append(TemperatureLogitsWarper(temperature))
|
|
||||||
if top_k is not None and top_k != 0:
|
|
||||||
self.warpers.append(TopKLogitsWarper(top_k=top_k))
|
|
||||||
if top_p is not None and top_p < 1.0:
|
|
||||||
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
|
||||||
if typical_p is not None and typical_p < 1.0:
|
|
||||||
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
|
||||||
|
|
||||||
self.cuda_graph = None
|
|
||||||
self.static_scores = None
|
|
||||||
self.static_warped_scores = None
|
|
||||||
self.static_next_logprob = None
|
|
||||||
|
|
||||||
def __call__(self, scores):
|
|
||||||
if self.cuda_graph is None:
|
|
||||||
self.static_scores = scores
|
|
||||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
|
||||||
|
|
||||||
with torch.cuda.graph(self.cuda_graph):
|
|
||||||
for warper in self.warpers:
|
|
||||||
self.static_warped_scores = warper(None, self.static_scores)
|
|
||||||
|
|
||||||
# Compute logprobs
|
|
||||||
self.static_next_logprob = torch.log_softmax(
|
|
||||||
self.static_warped_scores, -1
|
|
||||||
)
|
|
||||||
|
|
||||||
self.static_scores.copy_(scores)
|
|
||||||
self.cuda_graph.replay()
|
|
||||||
|
|
||||||
return self.static_warped_scores, self.static_next_logprob
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(10)
|
|
||||||
def static_warper(
|
|
||||||
temperature: Optional[float],
|
|
||||||
top_k: Optional[int],
|
|
||||||
top_p: Optional[float],
|
|
||||||
typical_p: Optional[float],
|
|
||||||
) -> StaticWarper:
|
|
||||||
return StaticWarper(
|
|
||||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
|
@ -132,9 +60,9 @@ class NextTokenChooser:
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
if self.watermark_processor:
|
if self.watermark_processor is not None:
|
||||||
scores = self.watermark_processor(input_ids, scores)
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
if self.repetition_processor:
|
if self.repetition_processor is not None:
|
||||||
scores = self.repetition_processor(input_ids, scores)
|
scores = self.repetition_processor(input_ids, scores)
|
||||||
|
|
||||||
if self.static_warper is None:
|
if self.static_warper is None:
|
||||||
|
@ -221,3 +149,191 @@ class StoppingCriteria:
|
||||||
pb.max_new_tokens,
|
pb.max_new_tokens,
|
||||||
pb.ignore_eos_token,
|
pb.ignore_eos_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousNextTokenChooser:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
watermark: List[bool],
|
||||||
|
temperature: List[float],
|
||||||
|
repetition_penalty: List[float],
|
||||||
|
top_k: List[int],
|
||||||
|
top_p: List[float],
|
||||||
|
typical_p: List[float],
|
||||||
|
do_sample: List[bool],
|
||||||
|
seeds: List[int],
|
||||||
|
):
|
||||||
|
warpers = []
|
||||||
|
|
||||||
|
self.watermark_processor = (
|
||||||
|
HeterogeneousProcessorWrapper(
|
||||||
|
{
|
||||||
|
i: WatermarkLogitsProcessor(device=device)
|
||||||
|
for i, do_watermark in enumerate(watermark)
|
||||||
|
if do_watermark
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if any(watermark)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.repetition_processor = (
|
||||||
|
HeterogeneousRepetitionPenaltyLogitsProcessor(
|
||||||
|
repetition_penalty, dtype, device
|
||||||
|
)
|
||||||
|
if any([x != 1.0 for x in repetition_penalty])
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if any([x != 1.0 for x in temperature]):
|
||||||
|
do_sample = [
|
||||||
|
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||||
|
]
|
||||||
|
warpers.append(
|
||||||
|
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if any([x != 0 for x in top_k]):
|
||||||
|
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
||||||
|
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
||||||
|
|
||||||
|
if any([x < 1.0 for x in top_p]):
|
||||||
|
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
||||||
|
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
|
||||||
|
|
||||||
|
if any([x < 1.0 for x in typical_p]):
|
||||||
|
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
|
||||||
|
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
|
||||||
|
|
||||||
|
self.warpers = warpers
|
||||||
|
|
||||||
|
if any(do_sample):
|
||||||
|
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
||||||
|
else:
|
||||||
|
self.choice = Greedy()
|
||||||
|
|
||||||
|
self.seeds = seeds
|
||||||
|
self.do_sample = do_sample
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
||||||
|
if self.watermark_processor is not None:
|
||||||
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
|
if self.repetition_processor is not None:
|
||||||
|
scores = self.repetition_processor(input_ids, scores)
|
||||||
|
|
||||||
|
for warper in self.warpers:
|
||||||
|
scores = warper(input_ids, scores)
|
||||||
|
|
||||||
|
next_ids = self.choice(scores)
|
||||||
|
next_logprobs = torch.gather(
|
||||||
|
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
|
||||||
|
).view(-1)
|
||||||
|
|
||||||
|
return next_ids, next_logprobs
|
||||||
|
|
||||||
|
def filter(self, indices):
|
||||||
|
if self.watermark_processor is not None:
|
||||||
|
self.watermark_processor = self.watermark_processor.filter(indices)
|
||||||
|
|
||||||
|
if self.repetition_processor is not None:
|
||||||
|
self.repetition_processor = self.repetition_processor.filter(indices)
|
||||||
|
|
||||||
|
filtered_warpers = []
|
||||||
|
for warper in self.warpers:
|
||||||
|
filtered_warper = warper.filter(indices)
|
||||||
|
if filtered_warper is not None:
|
||||||
|
filtered_warpers.append(filtered_warper)
|
||||||
|
self.warpers = filtered_warpers
|
||||||
|
|
||||||
|
self.seeds = [self.seeds[i] for i in indices]
|
||||||
|
self.do_sample = [self.do_sample[i] for i in indices]
|
||||||
|
|
||||||
|
if any(self.do_sample):
|
||||||
|
self.choice.filter(indices)
|
||||||
|
else:
|
||||||
|
self.choice = Greedy()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "HeterogeneousNextTokenChooser":
|
||||||
|
return HeterogeneousNextTokenChooser(
|
||||||
|
watermark=[pb_.watermark for pb_ in pb],
|
||||||
|
temperature=[pb_.temperature for pb_ in pb],
|
||||||
|
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
||||||
|
top_k=[pb_.top_k for pb_ in pb],
|
||||||
|
top_p=[pb_.top_p for pb_ in pb],
|
||||||
|
typical_p=[pb_.typical_p for pb_ in pb],
|
||||||
|
do_sample=[pb_.do_sample for pb_ in pb],
|
||||||
|
seeds=[pb_.seed for pb_ in pb],
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Sampling:
|
||||||
|
def __init__(self, seed: int, device: str = "cpu"):
|
||||||
|
self.generator = torch.Generator(device)
|
||||||
|
self.generator.manual_seed(seed)
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, logits):
|
||||||
|
probs = torch.nn.functional.softmax(logits, -1)
|
||||||
|
# Avoid GPU<->CPU sync done by torch multinomial
|
||||||
|
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||||
|
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
||||||
|
return probs.div_(q).argmax()
|
||||||
|
|
||||||
|
|
||||||
|
class Greedy:
|
||||||
|
def __call__(self, logits):
|
||||||
|
return logits.argmax(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousSampling:
|
||||||
|
r"""
|
||||||
|
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
|
||||||
|
self.seeds = seeds
|
||||||
|
|
||||||
|
self.greedy_indices = []
|
||||||
|
self.sampling_mapping = {}
|
||||||
|
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
|
||||||
|
if sample:
|
||||||
|
self.sampling_mapping[i] = Sampling(seed, device)
|
||||||
|
else:
|
||||||
|
self.greedy_indices.append(i)
|
||||||
|
|
||||||
|
self.greedy = Greedy()
|
||||||
|
|
||||||
|
def __call__(self, logits):
|
||||||
|
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||||
|
if self.greedy_indices:
|
||||||
|
# Computing for all indices is faster than slicing
|
||||||
|
torch.argmax(logits, -1, out=out)
|
||||||
|
|
||||||
|
for i, sampling in self.sampling_mapping.items():
|
||||||
|
out[i] = sampling(logits[i])
|
||||||
|
return out
|
||||||
|
|
||||||
|
def filter(self, indices):
|
||||||
|
new_greedy_indices = []
|
||||||
|
new_sampling_mapping = {}
|
||||||
|
for i, idx in enumerate(indices):
|
||||||
|
if idx in self.sampling_mapping:
|
||||||
|
new_sampling_mapping[i] = self.sampling_mapping[idx]
|
||||||
|
else:
|
||||||
|
new_greedy_indices.append(i)
|
||||||
|
|
||||||
|
self.greedy_indices = new_greedy_indices
|
||||||
|
self.sampling_mapping = new_sampling_mapping
|
||||||
|
return self
|
||||||
|
|
Loading…
Reference in New Issue