feat(server): only compute prefill logprobs when asked (#406)

Close #288
This commit is contained in:
OlivierDehaene 2023-06-02 17:12:30 +02:00 committed by GitHub
parent 83b84486ad
commit 895c5f1562
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 252 additions and 73 deletions

View File

@ -3,6 +3,7 @@ install-server:
install-integration-tests:
cd integration-tests && pip install -r requirements.txt
cd clients/python && pip install .
install-router:
cd router && cargo install --path .

View File

@ -136,6 +136,7 @@ async fn prefill(
let requests = (0..batch_size)
.map(|id| Request {
id: id.into(),
prefill_logprobs: false,
inputs: sequence.clone(),
truncate: sequence_length,
parameters: Some(parameters.clone()),

View File

@ -107,8 +107,42 @@ print(text)
### Types
```python
# Prompt tokens
class PrefillToken:
# Request Parameters
class Parameters:
# Activate logits sampling
do_sample: bool
# Maximum number of generated tokens
max_new_tokens: int
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float]
# Whether to prepend the prompt to the generated text
return_full_text: bool
# Stop generating tokens if a member of `stop_sequences` is generated
stop: List[str]
# Random sampling seed
seed: Optional[int]
# The value used to module the logits distribution.
temperature: Optional[float]
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_k: Optional[int]
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
# higher are kept for generation.
top_p: Optional[float]
# truncate inputs tokens to the given size
truncate: Optional[int]
# Typical Decoding mass
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
typical_p: Optional[float]
# Generate best_of sequences and return the one if the highest token logprobs
best_of: Optional[int]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark: bool
# Get decoder input token logprobs and ids
decoder_input_details: bool
# Decoder input tokens
class InputToken:
# Token ID from the model tokenizer
id: int
# Token text
@ -151,8 +185,8 @@ class BestOfSequence:
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
@ -165,8 +199,8 @@ class Details:
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation"
version = "0.5.2"
version = "0.6.0"
description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -2,28 +2,30 @@ import pytest
from text_generation import Client, AsyncClient
from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, PrefillToken, Token
from text_generation.types import FinishReason, InputToken
def test_generate(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1)
response = client.generate("test", max_new_tokens=1, decoder_input_details=True)
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == ""
assert response.details.tokens[0].text == " "
assert not response.details.tokens[0].special
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True)
response = client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
)
assert response.details.seed is not None
assert response.details.best_of_sequences is not None
@ -73,17 +75,19 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
@pytest.mark.asyncio
async def test_generate_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
response = await client.generate("test", max_new_tokens=1)
response = await client.generate(
"test", max_new_tokens=1, decoder_input_details=True
)
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == ""
assert response.details.tokens[0].text == " "
assert not response.details.tokens[0].special
@ -91,7 +95,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers):
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
response = await client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
)
assert response.details.seed is not None

View File

@ -74,6 +74,7 @@ class Client:
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
) -> Response:
"""
Given a prompt, generate the following text
@ -110,6 +111,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
Returns:
Response: generated response
@ -130,6 +133,7 @@ class Client:
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
decoder_input_details=decoder_input_details,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -202,6 +206,7 @@ class Client:
parameters = Parameters(
best_of=None,
details=True,
decoder_input_details=False,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
@ -311,6 +316,7 @@ class AsyncClient:
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
@ -347,6 +353,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
Returns:
Response: generated response
@ -355,6 +363,7 @@ class AsyncClient:
parameters = Parameters(
best_of=best_of,
details=True,
decoder_input_details=decoder_input_details,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
@ -437,6 +446,7 @@ class AsyncClient:
parameters = Parameters(
best_of=None,
details=True,
decoder_input_details=False,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,

View File

@ -37,6 +37,8 @@ class Parameters(BaseModel):
watermark: bool = False
# Get generation details
details: bool = False
# Get decoder input token logprobs and ids
decoder_input_details: bool = False
@validator("best_of")
def valid_best_of(cls, field_value, values):
@ -129,8 +131,8 @@ class Request(BaseModel):
return field_value
# Prompt tokens
class PrefillToken(BaseModel):
# Decoder input tokens
class InputToken(BaseModel):
# Token ID from the model tokenizer
id: int
# Token text
@ -173,8 +175,8 @@ class BestOfSequence(BaseModel):
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
@ -187,8 +189,8 @@ class Details(BaseModel):
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter

View File

@ -16,7 +16,7 @@ from syrupy.extensions.json import JSONSnapshotExtension
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from text_generation import AsyncClient
from text_generation.types import Response, Details, PrefillToken, Token, BestOfSequence
from text_generation.types import Response, Details, InputToken, Token, BestOfSequence
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
@ -62,7 +62,7 @@ class ResponseComparator(JSONSnapshotExtension):
and token.special == other.special
)
def eq_prefill_token(prefill_token: PrefillToken, other: PrefillToken) -> bool:
def eq_prefill_token(prefill_token: InputToken, other: InputToken) -> bool:
try:
return (
prefill_token.id == other.id
@ -332,7 +332,10 @@ def generate_load():
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
) -> List[Response]:
futures = [
client.generate(prompt, max_new_tokens=max_new_tokens) for _ in range(n)
client.generate(
prompt, max_new_tokens=max_new_tokens, decoder_input_details=True
)
for _ in range(n)
]
return await asyncio.gather(*futures)

View File

@ -19,6 +19,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
"Pour déguster un ortolan, il faut tout d'abord",
max_new_tokens=10,
top_p=0.9,
decoder_input_details=True,
seed=0,
)
@ -40,6 +41,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)

View File

@ -19,6 +19,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
"Pour déguster un ortolan, il faut tout d'abord",
max_new_tokens=10,
top_p=0.9,
decoder_input_details=True,
seed=0,
)

View File

@ -19,6 +19,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
response = await flash_falcon.generate(
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10
@ -40,6 +41,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)

View File

@ -16,7 +16,9 @@ async def flash_llama(flash_llama_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama(flash_llama, response_snapshot):
response = await flash_llama.generate("Test request", max_new_tokens=10)
response = await flash_llama.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@ -37,6 +39,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)

View File

@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
response = await flash_neox.generate(
"<|USER|>What's your mood today?<|ASSISTANT|>",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10

View File

@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
response = await flash_neox_sharded.generate(
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10

View File

@ -15,7 +15,9 @@ async def flash_santacoder(flash_santacoder_handle):
@pytest.mark.asyncio
async def test_flash_santacoder(flash_santacoder, response_snapshot):
response = await flash_santacoder.generate("def print_hello", max_new_tokens=10)
response = await flash_santacoder.generate(
"def print_hello", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot

View File

@ -16,7 +16,9 @@ async def flash_starcoder(flash_starcoder_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder(flash_starcoder, response_snapshot):
response = await flash_starcoder.generate("def print_hello", max_new_tokens=10)
response = await flash_starcoder.generate(
"def print_hello", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@ -26,7 +28,12 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot):
@pytest.mark.private
async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot):
response = await flash_starcoder.generate(
"def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0
"def print_hello",
max_new_tokens=60,
temperature=0.2,
top_p=0.95,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 60

View File

@ -19,6 +19,7 @@ async def test_mt0_base(mt0_base, response_snapshot):
"Why is the sky blue?",
max_new_tokens=10,
top_p=0.9,
decoder_input_details=True,
seed=0,
)
@ -40,6 +41,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)

View File

@ -18,6 +18,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot):
response = await t5_sharded.generate(
"Please answer the following question. What is the boiling point of Nitrogen?",
max_new_tokens=10,
decoder_input_details=True,
)
assert response == response_snapshot

View File

@ -1,5 +1,5 @@
syrupy
text-generation==0.5.2
text-generation
pytest
pytest-asyncio==0.17.2
docker

View File

@ -87,6 +87,8 @@ message Request {
NextTokenChooserParameters parameters = 4;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs
bool prefill_logprobs = 6;
}
message Batch {

View File

@ -34,6 +34,7 @@ impl Health {
id: LIVENESS_ID,
inputs: "liveness".to_string(),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,

View File

@ -125,6 +125,9 @@ pub(crate) struct GenerateParameters {
#[schema(default = "true")]
pub details: bool,
#[serde(default)]
#[schema(default = "true")]
pub decoder_input_details: bool,
#[serde(default)]
#[schema(
exclusive_minimum = 0,
nullable = true,
@ -153,6 +156,7 @@ fn default_parameters() -> GenerateParameters {
truncate: None,
watermark: false,
details: false,
decoder_input_details: false,
seed: None,
}
}

View File

@ -201,6 +201,7 @@ impl State {
batch_requests.push(Request {
id,
prefill_logprobs: entry.request.decoder_input_details,
inputs: entry.request.inputs.clone(),
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
@ -281,6 +282,7 @@ mod tests {
inputs: "".to_string(),
input_length: 0,
truncate: 0,
decoder_input_details: false,
parameters: NextTokenChooserParameters {
temperature: 0.0,
top_k: 0,

View File

@ -160,7 +160,7 @@ async fn generate(
add_prompt = Some(req.0.inputs.clone());
}
let details = req.0.parameters.details;
let details = req.0.parameters.details || req.0.parameters.decoder_input_details;
// Inference
let (response, best_of_responses) = match req.0.parameters.best_of {
@ -364,7 +364,17 @@ async fn generate_stream(
let details = req.0.parameters.details;
let best_of = req.0.parameters.best_of.unwrap_or(1);
if best_of == 1 {
if best_of != 1 {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else if req.0.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
@ -474,11 +484,6 @@ async fn generate_stream(
tracing::error!("{err}");
yield Ok(Event::from(err));
}
} else {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
}
};

View File

@ -145,6 +145,7 @@ impl Validation {
truncate,
seed,
watermark,
decoder_input_details,
..
} = request.parameters;
@ -261,6 +262,7 @@ impl Validation {
Ok(ValidGenerateRequest {
inputs,
decoder_input_details,
input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters,
@ -335,6 +337,7 @@ pub(crate) struct ValidGenerateRequest {
pub inputs: String,
pub input_length: u32,
pub truncate: u32,
pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
}
@ -351,6 +354,8 @@ pub enum ValidationError {
BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream,
#[error("`decoder_input_details` == true is not supported when streaming tokens")]
PrefillDetailsStream,
#[error("`temperature` must be strictly positive")]
Temperature,
#[error("`repetition_penalty` must be strictly positive")]

View File

@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,

View File

@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,

View File

@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="def",
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
@ -31,6 +32,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,

View File

@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,

View File

@ -104,7 +104,7 @@ class CausalLMBatch(Batch):
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(0)
prefix_offsets.append(input_len - 5)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
@ -617,7 +617,7 @@ class CausalLM(Model):
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1:
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1

View File

@ -443,6 +443,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.model(
input_ids,
@ -453,6 +454,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if self.model.tp_embeddings:

View File

@ -481,6 +481,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.gpt_neox(
input_ids,
@ -491,6 +492,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states)
if self.gpt_neox.tp_embeddings:

View File

@ -752,6 +752,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.transformer(
input_ids,
@ -762,6 +763,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:

View File

@ -358,6 +358,7 @@ class FlashSantacoderForCausalLM(nn.Module):
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.transformer(
input_ids,
@ -368,6 +369,8 @@ class FlashSantacoderForCausalLM(nn.Module):
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:

View File

@ -42,6 +42,11 @@ class FlashCausalLMBatch(Batch):
past_key_values: Optional[torch.Tensor]
max_seqlen: int
# Prefill metadata tensors to efficiently compute logprobs
prefill_head_indices: Optional[torch.Tensor]
prefill_next_token_indices: Optional[torch.tensor]
prefill_cu_outlens: Optional[List[int]]
# All tokens
all_input_ids: List[List[int]]
all_input_ids_tensor: torch.Tensor
@ -84,11 +89,18 @@ class FlashCausalLMBatch(Batch):
all_input_ids = []
requests_idx_mapping = {}
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
next_token_chooser_parameters = []
stopping_criterias = []
# Cumulative length
cumulative_length = 0
prefill_out_cumulative_length = 0
max_tokens = 0
max_length = 0
@ -106,13 +118,14 @@ class FlashCausalLMBatch(Batch):
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
prefix_offsets.append(0)
prefix_offsets.append(input_length - 5)
read_offsets.append(input_length)
all_input_ids.append(tokenized_input)
# Position ids
position_ids.append(np.arange(0, input_length))
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
@ -125,6 +138,26 @@ class FlashCausalLMBatch(Batch):
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
if r.prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1], dtype=torch.int32
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
# Update
cumulative_length += input_length
max_tokens += input_length + max_new_tokens
@ -141,18 +174,35 @@ class FlashCausalLMBatch(Batch):
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
# Create tensors on device
input_ids = torch.tensor(
np.concatenate(all_input_ids), dtype=torch.int64, device=device
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
position_ids = torch.tensor(
np.concatenate(position_ids), dtype=torch.int32, device=device
)
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlens[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlens[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
return cls(
batch_id=pb.id,
requests=pb.requests,
@ -162,6 +212,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens=cu_seqlens,
cu_seqlens_q=None,
max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens,
past_key_values=None,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
@ -280,6 +333,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
past_key_values=past_key_values,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
@ -415,6 +471,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
past_key_values=past_key_values,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
@ -486,6 +545,7 @@ class FlashCausalLM(Model):
max_s: int,
past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
return self.model.forward(
@ -496,6 +556,7 @@ class FlashCausalLM(Model):
max_s=max_s,
past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size,
lm_head_indices=lm_head_indices,
)
@tracer.start_as_current_span("generate_token")
@ -503,9 +564,10 @@ class FlashCausalLM(Model):
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None
prefill_logprobs = batch.prefill_next_token_indices is not None
single_request = len(batch) == 1
if prefill and len(batch) == 1:
if prefill and single_request:
# Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens
pre_allocate_past_size = (
@ -522,11 +584,12 @@ class FlashCausalLM(Model):
batch.max_seqlen,
batch.past_key_values,
pre_allocate_past_size,
batch.prefill_head_indices,
)
if prefill:
next_token_logits = (
out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1]
out[batch.prefill_next_token_indices] if prefill_logprobs else out
)
else:
next_token_logits = out
@ -536,10 +599,10 @@ class FlashCausalLM(Model):
)
if prefill:
if len(batch) > 1:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids))
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
# Create batch.cu_seqlens_q for decode
batch.cu_seqlens_q = torch.arange(
@ -600,7 +663,6 @@ class FlashCausalLM(Model):
# Zipped iterator
iterator = zip(
batch.input_lengths,
batch.stopping_criterias,
batch.all_input_ids,
)
@ -611,29 +673,33 @@ class FlashCausalLM(Model):
# For each member of the batch
for i, (
input_length,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
if prefill:
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
out_length = out_end_index - out_start_index
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if len(batch) > 1:
prefill_tokens_indices[
start_index : end_index - 1
] = batch.input_ids[start_index + 1 : end_index]
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[
start_index + 1 : end_index
]
if prefill_logprobs:
if len(batch) > 1:
prefill_tokens_indices[
out_start_index : out_end_index - 1
] = batch.input_ids[start_index + 1 : start_index + out_length]
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[
start_index + 1 : start_index + out_length
]
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
@ -644,7 +710,7 @@ class FlashCausalLM(Model):
batch.position_ids = next_position_ids + 1
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q
if prefill:
if prefill and prefill_logprobs:
# Get prefill logprobs
prefill_logprobs_tensor = torch.log_softmax(out, -1)
prefill_logprobs = torch.gather(
@ -657,8 +723,6 @@ class FlashCausalLM(Model):
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = batch.input_ids.tolist()
cumulative_length = 0
# Zipped iterator
iterator = zip(
batch.requests,
@ -688,9 +752,6 @@ class FlashCausalLM(Model):
next_token_id,
next_token_logprob,
) in enumerate(iterator):
start_index = cumulative_length
end_index = cumulative_length + input_length
# Append next token to all tokens
all_input_ids.append(next_token_id)
@ -728,10 +789,13 @@ class FlashCausalLM(Model):
generated_text = None
# Prefill
if prefill:
if prefill and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
start_index : end_index - 1
out_start_index : out_end_index - 1
]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
@ -764,8 +828,10 @@ class FlashCausalLM(Model):
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
cumulative_length += input_length
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1
# No need to return a batch if we know that all requests stopped

View File

@ -688,7 +688,7 @@ class Seq2SeqLM(Model):
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1:
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill_tokens = PrefillTokens(
[self.tokenizer.bos_token_id],
[float("nan")],