feat(server): only compute prefill logprobs when asked (#406)
Close #288
This commit is contained in:
parent
83b84486ad
commit
895c5f1562
1
Makefile
1
Makefile
|
@ -3,6 +3,7 @@ install-server:
|
||||||
|
|
||||||
install-integration-tests:
|
install-integration-tests:
|
||||||
cd integration-tests && pip install -r requirements.txt
|
cd integration-tests && pip install -r requirements.txt
|
||||||
|
cd clients/python && pip install .
|
||||||
|
|
||||||
install-router:
|
install-router:
|
||||||
cd router && cargo install --path .
|
cd router && cargo install --path .
|
||||||
|
|
|
@ -136,6 +136,7 @@ async fn prefill(
|
||||||
let requests = (0..batch_size)
|
let requests = (0..batch_size)
|
||||||
.map(|id| Request {
|
.map(|id| Request {
|
||||||
id: id.into(),
|
id: id.into(),
|
||||||
|
prefill_logprobs: false,
|
||||||
inputs: sequence.clone(),
|
inputs: sequence.clone(),
|
||||||
truncate: sequence_length,
|
truncate: sequence_length,
|
||||||
parameters: Some(parameters.clone()),
|
parameters: Some(parameters.clone()),
|
||||||
|
|
|
@ -107,8 +107,42 @@ print(text)
|
||||||
### Types
|
### Types
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Prompt tokens
|
# Request Parameters
|
||||||
class PrefillToken:
|
class Parameters:
|
||||||
|
# Activate logits sampling
|
||||||
|
do_sample: bool
|
||||||
|
# Maximum number of generated tokens
|
||||||
|
max_new_tokens: int
|
||||||
|
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||||
|
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
repetition_penalty: Optional[float]
|
||||||
|
# Whether to prepend the prompt to the generated text
|
||||||
|
return_full_text: bool
|
||||||
|
# Stop generating tokens if a member of `stop_sequences` is generated
|
||||||
|
stop: List[str]
|
||||||
|
# Random sampling seed
|
||||||
|
seed: Optional[int]
|
||||||
|
# The value used to module the logits distribution.
|
||||||
|
temperature: Optional[float]
|
||||||
|
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
top_k: Optional[int]
|
||||||
|
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
# higher are kept for generation.
|
||||||
|
top_p: Optional[float]
|
||||||
|
# truncate inputs tokens to the given size
|
||||||
|
truncate: Optional[int]
|
||||||
|
# Typical Decoding mass
|
||||||
|
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
|
typical_p: Optional[float]
|
||||||
|
# Generate best_of sequences and return the one if the highest token logprobs
|
||||||
|
best_of: Optional[int]
|
||||||
|
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
watermark: bool
|
||||||
|
# Get decoder input token logprobs and ids
|
||||||
|
decoder_input_details: bool
|
||||||
|
|
||||||
|
# Decoder input tokens
|
||||||
|
class InputToken:
|
||||||
# Token ID from the model tokenizer
|
# Token ID from the model tokenizer
|
||||||
id: int
|
id: int
|
||||||
# Token text
|
# Token text
|
||||||
|
@ -151,8 +185,8 @@ class BestOfSequence:
|
||||||
generated_tokens: int
|
generated_tokens: int
|
||||||
# Sampling seed if sampling was activated
|
# Sampling seed if sampling was activated
|
||||||
seed: Optional[int]
|
seed: Optional[int]
|
||||||
# Prompt tokens
|
# Decoder input tokens, empty if decoder_input_details is False
|
||||||
prefill: List[PrefillToken]
|
prefill: List[InputToken]
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
|
|
||||||
|
@ -165,8 +199,8 @@ class Details:
|
||||||
generated_tokens: int
|
generated_tokens: int
|
||||||
# Sampling seed if sampling was activated
|
# Sampling seed if sampling was activated
|
||||||
seed: Optional[int]
|
seed: Optional[int]
|
||||||
# Prompt tokens
|
# Decoder input tokens, empty if decoder_input_details is False
|
||||||
prefill: List[PrefillToken]
|
prefill: List[InputToken]
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
# Additional sequences when using the `best_of` parameter
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation"
|
name = "text-generation"
|
||||||
version = "0.5.2"
|
version = "0.6.0"
|
||||||
description = "Hugging Face Text Generation Python Client"
|
description = "Hugging Face Text Generation Python Client"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
|
@ -2,28 +2,30 @@ import pytest
|
||||||
|
|
||||||
from text_generation import Client, AsyncClient
|
from text_generation import Client, AsyncClient
|
||||||
from text_generation.errors import NotFoundError, ValidationError
|
from text_generation.errors import NotFoundError, ValidationError
|
||||||
from text_generation.types import FinishReason, PrefillToken, Token
|
from text_generation.types import FinishReason, InputToken
|
||||||
|
|
||||||
|
|
||||||
def test_generate(flan_t5_xxl_url, hf_headers):
|
def test_generate(flan_t5_xxl_url, hf_headers):
|
||||||
client = Client(flan_t5_xxl_url, hf_headers)
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
response = client.generate("test", max_new_tokens=1)
|
response = client.generate("test", max_new_tokens=1, decoder_input_details=True)
|
||||||
|
|
||||||
assert response.generated_text == ""
|
assert response.generated_text == ""
|
||||||
assert response.details.finish_reason == FinishReason.Length
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
assert response.details.generated_tokens == 1
|
assert response.details.generated_tokens == 1
|
||||||
assert response.details.seed is None
|
assert response.details.seed is None
|
||||||
assert len(response.details.prefill) == 1
|
assert len(response.details.prefill) == 1
|
||||||
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
|
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
|
||||||
assert len(response.details.tokens) == 1
|
assert len(response.details.tokens) == 1
|
||||||
assert response.details.tokens[0].id == 3
|
assert response.details.tokens[0].id == 3
|
||||||
assert response.details.tokens[0].text == ""
|
assert response.details.tokens[0].text == " "
|
||||||
assert not response.details.tokens[0].special
|
assert not response.details.tokens[0].special
|
||||||
|
|
||||||
|
|
||||||
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
|
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
|
||||||
client = Client(flan_t5_xxl_url, hf_headers)
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True)
|
response = client.generate(
|
||||||
|
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
assert response.details.seed is not None
|
assert response.details.seed is not None
|
||||||
assert response.details.best_of_sequences is not None
|
assert response.details.best_of_sequences is not None
|
||||||
|
@ -73,17 +75,19 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_async(flan_t5_xxl_url, hf_headers):
|
async def test_generate_async(flan_t5_xxl_url, hf_headers):
|
||||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
response = await client.generate("test", max_new_tokens=1)
|
response = await client.generate(
|
||||||
|
"test", max_new_tokens=1, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
assert response.generated_text == ""
|
assert response.generated_text == ""
|
||||||
assert response.details.finish_reason == FinishReason.Length
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
assert response.details.generated_tokens == 1
|
assert response.details.generated_tokens == 1
|
||||||
assert response.details.seed is None
|
assert response.details.seed is None
|
||||||
assert len(response.details.prefill) == 1
|
assert len(response.details.prefill) == 1
|
||||||
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
|
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
|
||||||
assert len(response.details.tokens) == 1
|
assert len(response.details.tokens) == 1
|
||||||
assert response.details.tokens[0].id == 3
|
assert response.details.tokens[0].id == 3
|
||||||
assert response.details.tokens[0].text == ""
|
assert response.details.tokens[0].text == " "
|
||||||
assert not response.details.tokens[0].special
|
assert not response.details.tokens[0].special
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,7 +95,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers):
|
||||||
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers):
|
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers):
|
||||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
response = await client.generate(
|
response = await client.generate(
|
||||||
"test", max_new_tokens=1, best_of=2, do_sample=True
|
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.seed is not None
|
assert response.details.seed is not None
|
||||||
|
|
|
@ -74,6 +74,7 @@ class Client:
|
||||||
truncate: Optional[int] = None,
|
truncate: Optional[int] = None,
|
||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
|
decoder_input_details: bool = False,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following text
|
Given a prompt, generate the following text
|
||||||
|
@ -110,6 +111,8 @@ class Client:
|
||||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
watermark (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
decoder_input_details (`bool`):
|
||||||
|
Return the decoder input token logprobs and ids
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response: generated response
|
Response: generated response
|
||||||
|
@ -130,6 +133,7 @@ class Client:
|
||||||
truncate=truncate,
|
truncate=truncate,
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
|
decoder_input_details=decoder_input_details,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
|
@ -202,6 +206,7 @@ class Client:
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
best_of=None,
|
best_of=None,
|
||||||
details=True,
|
details=True,
|
||||||
|
decoder_input_details=False,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
|
@ -311,6 +316,7 @@ class AsyncClient:
|
||||||
truncate: Optional[int] = None,
|
truncate: Optional[int] = None,
|
||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
|
decoder_input_details: bool = False,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following text asynchronously
|
Given a prompt, generate the following text asynchronously
|
||||||
|
@ -347,6 +353,8 @@ class AsyncClient:
|
||||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
watermark (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
decoder_input_details (`bool`):
|
||||||
|
Return the decoder input token logprobs and ids
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response: generated response
|
Response: generated response
|
||||||
|
@ -355,6 +363,7 @@ class AsyncClient:
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
details=True,
|
details=True,
|
||||||
|
decoder_input_details=decoder_input_details,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
|
@ -437,6 +446,7 @@ class AsyncClient:
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
best_of=None,
|
best_of=None,
|
||||||
details=True,
|
details=True,
|
||||||
|
decoder_input_details=False,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
|
|
|
@ -37,6 +37,8 @@ class Parameters(BaseModel):
|
||||||
watermark: bool = False
|
watermark: bool = False
|
||||||
# Get generation details
|
# Get generation details
|
||||||
details: bool = False
|
details: bool = False
|
||||||
|
# Get decoder input token logprobs and ids
|
||||||
|
decoder_input_details: bool = False
|
||||||
|
|
||||||
@validator("best_of")
|
@validator("best_of")
|
||||||
def valid_best_of(cls, field_value, values):
|
def valid_best_of(cls, field_value, values):
|
||||||
|
@ -129,8 +131,8 @@ class Request(BaseModel):
|
||||||
return field_value
|
return field_value
|
||||||
|
|
||||||
|
|
||||||
# Prompt tokens
|
# Decoder input tokens
|
||||||
class PrefillToken(BaseModel):
|
class InputToken(BaseModel):
|
||||||
# Token ID from the model tokenizer
|
# Token ID from the model tokenizer
|
||||||
id: int
|
id: int
|
||||||
# Token text
|
# Token text
|
||||||
|
@ -173,8 +175,8 @@ class BestOfSequence(BaseModel):
|
||||||
generated_tokens: int
|
generated_tokens: int
|
||||||
# Sampling seed if sampling was activated
|
# Sampling seed if sampling was activated
|
||||||
seed: Optional[int]
|
seed: Optional[int]
|
||||||
# Prompt tokens
|
# Decoder input tokens, empty if decoder_input_details is False
|
||||||
prefill: List[PrefillToken]
|
prefill: List[InputToken]
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
|
|
||||||
|
@ -187,8 +189,8 @@ class Details(BaseModel):
|
||||||
generated_tokens: int
|
generated_tokens: int
|
||||||
# Sampling seed if sampling was activated
|
# Sampling seed if sampling was activated
|
||||||
seed: Optional[int]
|
seed: Optional[int]
|
||||||
# Prompt tokens
|
# Decoder input tokens, empty if decoder_input_details is False
|
||||||
prefill: List[PrefillToken]
|
prefill: List[InputToken]
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
# Additional sequences when using the `best_of` parameter
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
|
|
@ -16,7 +16,7 @@ from syrupy.extensions.json import JSONSnapshotExtension
|
||||||
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
||||||
|
|
||||||
from text_generation import AsyncClient
|
from text_generation import AsyncClient
|
||||||
from text_generation.types import Response, Details, PrefillToken, Token, BestOfSequence
|
from text_generation.types import Response, Details, InputToken, Token, BestOfSequence
|
||||||
|
|
||||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||||
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
||||||
|
@ -62,7 +62,7 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||||
and token.special == other.special
|
and token.special == other.special
|
||||||
)
|
)
|
||||||
|
|
||||||
def eq_prefill_token(prefill_token: PrefillToken, other: PrefillToken) -> bool:
|
def eq_prefill_token(prefill_token: InputToken, other: InputToken) -> bool:
|
||||||
try:
|
try:
|
||||||
return (
|
return (
|
||||||
prefill_token.id == other.id
|
prefill_token.id == other.id
|
||||||
|
@ -332,7 +332,10 @@ def generate_load():
|
||||||
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
|
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
|
||||||
) -> List[Response]:
|
) -> List[Response]:
|
||||||
futures = [
|
futures = [
|
||||||
client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)
|
client.generate(
|
||||||
|
prompt, max_new_tokens=max_new_tokens, decoder_input_details=True
|
||||||
|
)
|
||||||
|
for _ in range(n)
|
||||||
]
|
]
|
||||||
|
|
||||||
return await asyncio.gather(*futures)
|
return await asyncio.gather(*futures)
|
||||||
|
|
|
@ -19,6 +19,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
|
||||||
"Pour déguster un ortolan, il faut tout d'abord",
|
"Pour déguster un ortolan, il faut tout d'abord",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
|
decoder_input_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -40,6 +41,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
|
||||||
truncate=5,
|
truncate=5,
|
||||||
typical_p=0.9,
|
typical_p=0.9,
|
||||||
watermark=True,
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
|
||||||
"Pour déguster un ortolan, il faut tout d'abord",
|
"Pour déguster un ortolan, il faut tout d'abord",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
|
decoder_input_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
|
||||||
response = await flash_falcon.generate(
|
response = await flash_falcon.generate(
|
||||||
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
|
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
decoder_input_details=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
|
@ -40,6 +41,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
|
||||||
truncate=5,
|
truncate=5,
|
||||||
typical_p=0.9,
|
typical_p=0.9,
|
||||||
watermark=True,
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,9 @@ async def flash_llama(flash_llama_handle):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama(flash_llama, response_snapshot):
|
async def test_flash_llama(flash_llama, response_snapshot):
|
||||||
response = await flash_llama.generate("Test request", max_new_tokens=10)
|
response = await flash_llama.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
@ -37,6 +39,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
|
||||||
truncate=5,
|
truncate=5,
|
||||||
typical_p=0.9,
|
typical_p=0.9,
|
||||||
watermark=True,
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
|
||||||
response = await flash_neox.generate(
|
response = await flash_neox.generate(
|
||||||
"<|USER|>What's your mood today?<|ASSISTANT|>",
|
"<|USER|>What's your mood today?<|ASSISTANT|>",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
decoder_input_details=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
|
|
|
@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
|
||||||
response = await flash_neox_sharded.generate(
|
response = await flash_neox_sharded.generate(
|
||||||
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
|
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
decoder_input_details=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
|
|
|
@ -15,7 +15,9 @@ async def flash_santacoder(flash_santacoder_handle):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
||||||
response = await flash_santacoder.generate("def print_hello", max_new_tokens=10)
|
response = await flash_santacoder.generate(
|
||||||
|
"def print_hello", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
|
@ -16,7 +16,9 @@ async def flash_starcoder(flash_starcoder_handle):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
||||||
response = await flash_starcoder.generate("def print_hello", max_new_tokens=10)
|
response = await flash_starcoder.generate(
|
||||||
|
"def print_hello", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
@ -26,7 +28,12 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot):
|
async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot):
|
||||||
response = await flash_starcoder.generate(
|
response = await flash_starcoder.generate(
|
||||||
"def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0
|
"def print_hello",
|
||||||
|
max_new_tokens=60,
|
||||||
|
temperature=0.2,
|
||||||
|
top_p=0.95,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 60
|
assert response.details.generated_tokens == 60
|
||||||
|
|
|
@ -19,6 +19,7 @@ async def test_mt0_base(mt0_base, response_snapshot):
|
||||||
"Why is the sky blue?",
|
"Why is the sky blue?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
|
decoder_input_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -40,6 +41,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
||||||
truncate=5,
|
truncate=5,
|
||||||
typical_p=0.9,
|
typical_p=0.9,
|
||||||
watermark=True,
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot):
|
||||||
response = await t5_sharded.generate(
|
response = await t5_sharded.generate(
|
||||||
"Please answer the following question. What is the boiling point of Nitrogen?",
|
"Please answer the following question. What is the boiling point of Nitrogen?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
decoder_input_details=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
syrupy
|
syrupy
|
||||||
text-generation==0.5.2
|
text-generation
|
||||||
pytest
|
pytest
|
||||||
pytest-asyncio==0.17.2
|
pytest-asyncio==0.17.2
|
||||||
docker
|
docker
|
|
@ -87,6 +87,8 @@ message Request {
|
||||||
NextTokenChooserParameters parameters = 4;
|
NextTokenChooserParameters parameters = 4;
|
||||||
/// Stopping Criteria Parameters
|
/// Stopping Criteria Parameters
|
||||||
StoppingCriteriaParameters stopping_parameters = 5;
|
StoppingCriteriaParameters stopping_parameters = 5;
|
||||||
|
/// Return prefill logprobs
|
||||||
|
bool prefill_logprobs = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
|
|
@ -34,6 +34,7 @@ impl Health {
|
||||||
id: LIVENESS_ID,
|
id: LIVENESS_ID,
|
||||||
inputs: "liveness".to_string(),
|
inputs: "liveness".to_string(),
|
||||||
truncate: 10,
|
truncate: 10,
|
||||||
|
prefill_logprobs: false,
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
top_k: 0,
|
top_k: 0,
|
||||||
|
|
|
@ -125,6 +125,9 @@ pub(crate) struct GenerateParameters {
|
||||||
#[schema(default = "true")]
|
#[schema(default = "true")]
|
||||||
pub details: bool,
|
pub details: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(default = "true")]
|
||||||
|
pub decoder_input_details: bool,
|
||||||
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = 0,
|
exclusive_minimum = 0,
|
||||||
nullable = true,
|
nullable = true,
|
||||||
|
@ -153,6 +156,7 @@ fn default_parameters() -> GenerateParameters {
|
||||||
truncate: None,
|
truncate: None,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
details: false,
|
details: false,
|
||||||
|
decoder_input_details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -201,6 +201,7 @@ impl State {
|
||||||
|
|
||||||
batch_requests.push(Request {
|
batch_requests.push(Request {
|
||||||
id,
|
id,
|
||||||
|
prefill_logprobs: entry.request.decoder_input_details,
|
||||||
inputs: entry.request.inputs.clone(),
|
inputs: entry.request.inputs.clone(),
|
||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
||||||
parameters: Some(entry.request.parameters.clone()),
|
parameters: Some(entry.request.parameters.clone()),
|
||||||
|
@ -281,6 +282,7 @@ mod tests {
|
||||||
inputs: "".to_string(),
|
inputs: "".to_string(),
|
||||||
input_length: 0,
|
input_length: 0,
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
|
decoder_input_details: false,
|
||||||
parameters: NextTokenChooserParameters {
|
parameters: NextTokenChooserParameters {
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
top_k: 0,
|
top_k: 0,
|
||||||
|
|
|
@ -160,7 +160,7 @@ async fn generate(
|
||||||
add_prompt = Some(req.0.inputs.clone());
|
add_prompt = Some(req.0.inputs.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
let details = req.0.parameters.details;
|
let details = req.0.parameters.details || req.0.parameters.decoder_input_details;
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
let (response, best_of_responses) = match req.0.parameters.best_of {
|
let (response, best_of_responses) = match req.0.parameters.best_of {
|
||||||
|
@ -364,7 +364,17 @@ async fn generate_stream(
|
||||||
let details = req.0.parameters.details;
|
let details = req.0.parameters.details;
|
||||||
|
|
||||||
let best_of = req.0.parameters.best_of.unwrap_or(1);
|
let best_of = req.0.parameters.best_of.unwrap_or(1);
|
||||||
if best_of == 1 {
|
if best_of != 1 {
|
||||||
|
let err = InferError::from(ValidationError::BestOfStream);
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
tracing::error!("{err}");
|
||||||
|
yield Ok(Event::from(err));
|
||||||
|
} else if req.0.parameters.decoder_input_details {
|
||||||
|
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
tracing::error!("{err}");
|
||||||
|
yield Ok(Event::from(err));
|
||||||
|
} else {
|
||||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
// Keep permit as long as generate_stream lives
|
// Keep permit as long as generate_stream lives
|
||||||
Ok((_permit, mut response_stream)) => {
|
Ok((_permit, mut response_stream)) => {
|
||||||
|
@ -474,11 +484,6 @@ async fn generate_stream(
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
let err = InferError::from(ValidationError::BestOfStream);
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
|
||||||
tracing::error!("{err}");
|
|
||||||
yield Ok(Event::from(err));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -145,6 +145,7 @@ impl Validation {
|
||||||
truncate,
|
truncate,
|
||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
|
decoder_input_details,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
|
@ -261,6 +262,7 @@ impl Validation {
|
||||||
|
|
||||||
Ok(ValidGenerateRequest {
|
Ok(ValidGenerateRequest {
|
||||||
inputs,
|
inputs,
|
||||||
|
decoder_input_details,
|
||||||
input_length: input_length as u32,
|
input_length: input_length as u32,
|
||||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||||
parameters,
|
parameters,
|
||||||
|
@ -335,6 +337,7 @@ pub(crate) struct ValidGenerateRequest {
|
||||||
pub inputs: String,
|
pub inputs: String,
|
||||||
pub input_length: u32,
|
pub input_length: u32,
|
||||||
pub truncate: u32,
|
pub truncate: u32,
|
||||||
|
pub decoder_input_details: bool,
|
||||||
pub parameters: NextTokenChooserParameters,
|
pub parameters: NextTokenChooserParameters,
|
||||||
pub stopping_parameters: StoppingCriteriaParameters,
|
pub stopping_parameters: StoppingCriteriaParameters,
|
||||||
}
|
}
|
||||||
|
@ -351,6 +354,8 @@ pub enum ValidationError {
|
||||||
BestOfSeed,
|
BestOfSeed,
|
||||||
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
||||||
BestOfStream,
|
BestOfStream,
|
||||||
|
#[error("`decoder_input_details` == true is not supported when streaming tokens")]
|
||||||
|
PrefillDetailsStream,
|
||||||
#[error("`temperature` must be strictly positive")]
|
#[error("`temperature` must be strictly positive")]
|
||||||
Temperature,
|
Temperature,
|
||||||
#[error("`repetition_penalty` must be strictly positive")]
|
#[error("`repetition_penalty` must be strictly positive")]
|
||||||
|
|
|
@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
|
|
|
@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
|
|
|
@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="def",
|
inputs="def",
|
||||||
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
|
@ -31,6 +32,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
||||||
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
|
|
|
@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
|
|
|
@ -104,7 +104,7 @@ class CausalLMBatch(Batch):
|
||||||
).to(device)
|
).to(device)
|
||||||
for _ in pb.requests:
|
for _ in pb.requests:
|
||||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
prefix_offsets.append(0)
|
prefix_offsets.append(input_len - 5)
|
||||||
read_offsets.append(input_len)
|
read_offsets.append(input_len)
|
||||||
|
|
||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
|
@ -617,7 +617,7 @@ class CausalLM(Model):
|
||||||
generated_text = None
|
generated_text = None
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 1:
|
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||||
logits, -1
|
logits, -1
|
||||||
|
|
|
@ -443,6 +443,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
hidden_states, present = self.model(
|
hidden_states, present = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -453,6 +454,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
past_key_values,
|
past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
if self.model.tp_embeddings:
|
if self.model.tp_embeddings:
|
||||||
|
|
|
@ -481,6 +481,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
hidden_states, present = self.gpt_neox(
|
hidden_states, present = self.gpt_neox(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -491,6 +492,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
past_key_values,
|
past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.embed_out(hidden_states)
|
logits = self.embed_out(hidden_states)
|
||||||
|
|
||||||
if self.gpt_neox.tp_embeddings:
|
if self.gpt_neox.tp_embeddings:
|
||||||
|
|
|
@ -752,6 +752,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
hidden_states, present = self.transformer(
|
hidden_states, present = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -762,6 +763,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
past_key_values,
|
past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
if self.transformer.tp_embeddings:
|
if self.transformer.tp_embeddings:
|
||||||
|
|
|
@ -358,6 +358,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
hidden_states, present = self.transformer(
|
hidden_states, present = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -368,6 +369,8 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
past_key_values,
|
past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
if self.transformer.tp_embeddings:
|
if self.transformer.tp_embeddings:
|
||||||
|
|
|
@ -42,6 +42,11 @@ class FlashCausalLMBatch(Batch):
|
||||||
past_key_values: Optional[torch.Tensor]
|
past_key_values: Optional[torch.Tensor]
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
|
# Prefill metadata tensors to efficiently compute logprobs
|
||||||
|
prefill_head_indices: Optional[torch.Tensor]
|
||||||
|
prefill_next_token_indices: Optional[torch.tensor]
|
||||||
|
prefill_cu_outlens: Optional[List[int]]
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
all_input_ids: List[List[int]]
|
all_input_ids: List[List[int]]
|
||||||
all_input_ids_tensor: torch.Tensor
|
all_input_ids_tensor: torch.Tensor
|
||||||
|
@ -84,11 +89,18 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
|
all_prefill_logprobs = True
|
||||||
|
no_prefill_logprobs = True
|
||||||
|
prefill_head_indices = []
|
||||||
|
prefill_next_token_indices = []
|
||||||
|
prefill_cu_outlens = [0]
|
||||||
|
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
|
prefill_out_cumulative_length = 0
|
||||||
|
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
max_length = 0
|
max_length = 0
|
||||||
|
@ -106,13 +118,14 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
prefix_offsets.append(0)
|
prefix_offsets.append(input_length - 5)
|
||||||
read_offsets.append(input_length)
|
read_offsets.append(input_length)
|
||||||
|
|
||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
position_ids.append(np.arange(0, input_length))
|
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
||||||
|
position_ids.append(request_position_ids)
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
cu_seqlens.append(cumulative_length + input_length)
|
cu_seqlens.append(cumulative_length + input_length)
|
||||||
|
@ -125,6 +138,26 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_new_tokens = stopping_criteria.max_new_tokens
|
max_new_tokens = stopping_criteria.max_new_tokens
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
|
||||||
|
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||||
|
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||||
|
|
||||||
|
if r.prefill_logprobs:
|
||||||
|
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||||
|
prefill_next_token_indices.append(
|
||||||
|
prefill_out_cumulative_length + input_length - 1
|
||||||
|
)
|
||||||
|
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
||||||
|
prefill_out_cumulative_length += input_length
|
||||||
|
else:
|
||||||
|
prefill_head_indices.append(
|
||||||
|
torch.tensor(
|
||||||
|
[cumulative_length + input_length - 1], dtype=torch.int32
|
||||||
|
)
|
||||||
|
)
|
||||||
|
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||||
|
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||||
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
max_tokens += input_length + max_new_tokens
|
max_tokens += input_length + max_new_tokens
|
||||||
|
@ -141,18 +174,35 @@ class FlashCausalLMBatch(Batch):
|
||||||
for i, input_ids in enumerate(all_input_ids):
|
for i, input_ids in enumerate(all_input_ids):
|
||||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||||
|
|
||||||
|
if len(pb.requests) > 1:
|
||||||
|
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||||
|
position_ids = torch.cat(position_ids)
|
||||||
|
else:
|
||||||
|
input_ids = all_input_ids[0]
|
||||||
|
position_ids = position_ids[0]
|
||||||
|
|
||||||
# Create tensors on device
|
# Create tensors on device
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
np.concatenate(all_input_ids), dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
all_input_ids_tensor = torch.tensor(
|
all_input_ids_tensor = torch.tensor(
|
||||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
position_ids = torch.tensor(
|
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
||||||
np.concatenate(position_ids), dtype=torch.int32, device=device
|
|
||||||
)
|
|
||||||
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
|
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
if all_prefill_logprobs:
|
||||||
|
prefill_head_indices = None
|
||||||
|
prefill_next_token_indices = cu_seqlens[1:] - 1
|
||||||
|
elif no_prefill_logprobs:
|
||||||
|
prefill_head_indices = cu_seqlens[1:] - 1
|
||||||
|
prefill_next_token_indices = None
|
||||||
|
else:
|
||||||
|
prefill_head_indices = torch.tensor(
|
||||||
|
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
prefill_next_token_indices = torch.tensor(
|
||||||
|
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
|
@ -162,6 +212,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
cu_seqlens_q=None,
|
cu_seqlens_q=None,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
|
prefill_head_indices=prefill_head_indices,
|
||||||
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
|
prefill_cu_outlens=prefill_cu_outlens,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
|
@ -280,6 +333,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
|
prefill_head_indices=None,
|
||||||
|
prefill_next_token_indices=None,
|
||||||
|
prefill_cu_outlens=None,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
|
@ -415,6 +471,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
|
prefill_head_indices=None,
|
||||||
|
prefill_next_token_indices=None,
|
||||||
|
prefill_cu_outlens=None,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
|
@ -486,6 +545,7 @@ class FlashCausalLM(Model):
|
||||||
max_s: int,
|
max_s: int,
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
|
@ -496,6 +556,7 @@ class FlashCausalLM(Model):
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
pre_allocate_past_size=pre_allocate_past_size,
|
pre_allocate_past_size=pre_allocate_past_size,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
|
@ -503,9 +564,10 @@ class FlashCausalLM(Model):
|
||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||||
prefill = batch.past_key_values is None
|
prefill = batch.past_key_values is None
|
||||||
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
single_request = len(batch) == 1
|
single_request = len(batch) == 1
|
||||||
|
|
||||||
if prefill and len(batch) == 1:
|
if prefill and single_request:
|
||||||
# Ask to pre-allocate kv to its max size
|
# Ask to pre-allocate kv to its max size
|
||||||
# == number of tokens + max_new_tokens
|
# == number of tokens + max_new_tokens
|
||||||
pre_allocate_past_size = (
|
pre_allocate_past_size = (
|
||||||
|
@ -522,11 +584,12 @@ class FlashCausalLM(Model):
|
||||||
batch.max_seqlen,
|
batch.max_seqlen,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
|
batch.prefill_head_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
next_token_logits = (
|
next_token_logits = (
|
||||||
out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1]
|
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
|
@ -536,10 +599,10 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
if len(batch) > 1:
|
if len(batch) > 1 and prefill_logprobs:
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||||
# When batch == 1, we will just use the batch.input_ids values directly
|
# When batch == 1, we will just use the batch.input_ids values directly
|
||||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids))
|
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||||
|
|
||||||
# Create batch.cu_seqlens_q for decode
|
# Create batch.cu_seqlens_q for decode
|
||||||
batch.cu_seqlens_q = torch.arange(
|
batch.cu_seqlens_q = torch.arange(
|
||||||
|
@ -600,7 +663,6 @@ class FlashCausalLM(Model):
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.stopping_criterias,
|
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -611,28 +673,32 @@ class FlashCausalLM(Model):
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
input_length,
|
input_length,
|
||||||
stopping_criteria,
|
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Indexing metadata
|
|
||||||
start_index = cumulative_length
|
start_index = cumulative_length
|
||||||
end_index = cumulative_length + input_length
|
end_index = cumulative_length + input_length
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
|
# Indexing metadata
|
||||||
|
out_start_index = batch.prefill_cu_outlens[i]
|
||||||
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
out_length = out_end_index - out_start_index
|
||||||
|
|
||||||
# Initialize position_ids
|
# Initialize position_ids
|
||||||
# In decode, we do not need this as we can just increment position ids
|
# In decode, we do not need this as we can just increment position ids
|
||||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||||
|
|
||||||
# Used to gather prefill logprobs
|
# Used to gather prefill logprobs
|
||||||
# Copy batch.input_ids to prefill_token_indices
|
# Copy batch.input_ids to prefill_token_indices
|
||||||
|
if prefill_logprobs:
|
||||||
if len(batch) > 1:
|
if len(batch) > 1:
|
||||||
prefill_tokens_indices[
|
prefill_tokens_indices[
|
||||||
start_index : end_index - 1
|
out_start_index : out_end_index - 1
|
||||||
] = batch.input_ids[start_index + 1 : end_index]
|
] = batch.input_ids[start_index + 1 : start_index + out_length]
|
||||||
else:
|
else:
|
||||||
# Set prefill_tokens_indices to the correct slice
|
# Set prefill_tokens_indices to the correct slice
|
||||||
prefill_tokens_indices = batch.input_ids[
|
prefill_tokens_indices = batch.input_ids[
|
||||||
start_index + 1 : end_index
|
start_index + 1 : start_index + out_length
|
||||||
]
|
]
|
||||||
|
|
||||||
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
|
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
|
||||||
|
@ -644,7 +710,7 @@ class FlashCausalLM(Model):
|
||||||
batch.position_ids = next_position_ids + 1
|
batch.position_ids = next_position_ids + 1
|
||||||
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q
|
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q
|
||||||
|
|
||||||
if prefill:
|
if prefill and prefill_logprobs:
|
||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
||||||
prefill_logprobs = torch.gather(
|
prefill_logprobs = torch.gather(
|
||||||
|
@ -657,8 +723,6 @@ class FlashCausalLM(Model):
|
||||||
next_token_logprobs = next_token_logprobs.tolist()
|
next_token_logprobs = next_token_logprobs.tolist()
|
||||||
next_token_ids = batch.input_ids.tolist()
|
next_token_ids = batch.input_ids.tolist()
|
||||||
|
|
||||||
cumulative_length = 0
|
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
|
@ -688,9 +752,6 @@ class FlashCausalLM(Model):
|
||||||
next_token_id,
|
next_token_id,
|
||||||
next_token_logprob,
|
next_token_logprob,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
start_index = cumulative_length
|
|
||||||
end_index = cumulative_length + input_length
|
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids.append(next_token_id)
|
all_input_ids.append(next_token_id)
|
||||||
|
|
||||||
|
@ -728,10 +789,13 @@ class FlashCausalLM(Model):
|
||||||
generated_text = None
|
generated_text = None
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill and request.prefill_logprobs:
|
||||||
|
out_start_index = batch.prefill_cu_outlens[i]
|
||||||
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
||||||
start_index : end_index - 1
|
out_start_index : out_end_index - 1
|
||||||
]
|
]
|
||||||
prefill_token_ids = all_input_ids[:-1]
|
prefill_token_ids = all_input_ids[:-1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
|
@ -764,8 +828,10 @@ class FlashCausalLM(Model):
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.read_offsets[i] = read_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
cumulative_length += input_length
|
|
||||||
|
|
||||||
|
batch.prefill_cu_outlens = None
|
||||||
|
batch.prefill_head_indices = None
|
||||||
|
batch.prefill_next_token_indices = None
|
||||||
batch.max_seqlen = batch.max_seqlen + 1
|
batch.max_seqlen = batch.max_seqlen + 1
|
||||||
|
|
||||||
# No need to return a batch if we know that all requests stopped
|
# No need to return a batch if we know that all requests stopped
|
||||||
|
|
|
@ -688,7 +688,7 @@ class Seq2SeqLM(Model):
|
||||||
generated_text = None
|
generated_text = None
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 1:
|
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||||
prefill_tokens = PrefillTokens(
|
prefill_tokens = PrefillTokens(
|
||||||
[self.tokenizer.bos_token_id],
|
[self.tokenizer.bos_token_id],
|
||||||
[float("nan")],
|
[float("nan")],
|
||||||
|
|
Loading…
Reference in New Issue