fix(server): fix flash-neox scores warping (#137)

This commit is contained in:
OlivierDehaene 2023-03-24 18:21:41 +01:00 committed by GitHub
parent 05e9a796cc
commit d6a93fe992
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 79 additions and 60 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation"
version = "0.4.0" version = "0.4.1"
description = "Hugging Face Text Generation Python Client" description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0" license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -14,6 +14,8 @@ def test_parameters_validation():
Parameters(best_of=2, do_sample=True) Parameters(best_of=2, do_sample=True)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
Parameters(best_of=2) Parameters(best_of=2)
with pytest.raises(ValidationError):
Parameters(best_of=2, seed=1)
# Test repetition_penalty # Test repetition_penalty
Parameters(repetition_penalty=1) Parameters(repetition_penalty=1)

View File

@ -150,7 +150,6 @@ class Client:
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: int = 20, max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -172,8 +171,6 @@ class Client:
Activate logits sampling Activate logits sampling
max_new_tokens (`int`): max_new_tokens (`int`):
Maximum number of generated tokens Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`): repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
@ -203,7 +200,7 @@ class Client:
""" """
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=best_of, best_of=None,
details=True, details=True,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
@ -388,7 +385,6 @@ class AsyncClient:
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: int = 20, max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -410,8 +406,6 @@ class AsyncClient:
Activate logits sampling Activate logits sampling
max_new_tokens (`int`): max_new_tokens (`int`):
Maximum number of generated tokens Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`): repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
@ -441,7 +435,7 @@ class AsyncClient:
""" """
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=best_of, best_of=None,
details=True, details=True,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,

View File

@ -43,6 +43,8 @@ class Parameters(BaseModel):
if field_value is not None: if field_value is not None:
if field_value <= 0: if field_value <= 0:
raise ValidationError("`best_of` must be strictly positive") raise ValidationError("`best_of` must be strictly positive")
if field_value > 1 and values["seed"] is not None:
raise ValidationError("`seed` must not be set when `best_of` is > 1")
sampling = ( sampling = (
values["do_sample"] values["do_sample"]
| (values["temperature"] is not None) | (values["temperature"] is not None)

View File

@ -16,6 +16,7 @@ from text_generation_server.models.t5 import T5Sharded
try: try:
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
FLASH_NEOX = torch.cuda.is_available() and int(os.environ.get("FLASH_NEOX", 0)) == 1 FLASH_NEOX = torch.cuda.is_available() and int(os.environ.get("FLASH_NEOX", 0)) == 1
except ImportError: except ImportError:
if int(os.environ.get("FLASH_NEOX", 0)) == 1: if int(os.environ.get("FLASH_NEOX", 0)) == 1:

View File

@ -1,6 +1,8 @@
import torch import torch
import torch.distributed import torch.distributed
from torch.nn import functional as F
from accelerate import init_empty_weights from accelerate import init_empty_weights
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
@ -48,6 +50,7 @@ class FlashNeoXBatch(Batch):
# All tokens # All tokens
all_input_ids: List[List[int]] all_input_ids: List[List[int]]
all_input_ids_tensor: List[torch.Tensor]
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
@ -75,6 +78,7 @@ class FlashNeoXBatch(Batch):
input_lengths = [] input_lengths = []
all_input_ids = [] all_input_ids = []
all_input_ids_tensor = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -84,15 +88,14 @@ class FlashNeoXBatch(Batch):
# Parse batch # Parse batch
for r in pb.requests: for r in pb.requests:
tokenized_input = tokenizer(r.inputs, return_tensors="pt")[ tokenized_input = tokenizer(r.inputs)["input_ids"]
"input_ids"
].squeeze(0)
input_ids.append(tokenized_input)
all_input_ids.append(tokenized_input.tolist())
input_length = len(tokenized_input) input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length) input_lengths.append(input_length)
all_input_ids.append(tokenized_input)
tokenized_input = torch.tensor(tokenized_input, device=device)
input_ids.append(tokenized_input)
# Position ids # Position ids
position_ids.append(torch.arange(0, input_length, dtype=torch.int32)) position_ids.append(torch.arange(0, input_length, dtype=torch.int32))
@ -101,14 +104,18 @@ class FlashNeoXBatch(Batch):
cu_seqlens.append(cumulative_length + input_length) cu_seqlens.append(cumulative_length + input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append( stopping_criteria = StoppingCriteria.from_pb(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
all_input_ids_tensor.append(
F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens))
) )
# Update # Update
cumulative_length += input_length cumulative_length += input_length
input_ids = torch.concat(input_ids).unsqueeze(1) input_ids = torch.concat(input_ids)
position_ids = torch.concat(position_ids) position_ids = torch.concat(position_ids)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)
@ -122,6 +129,7 @@ class FlashNeoXBatch(Batch):
past_key_values=None, past_key_values=None,
input_lengths=input_lengths, input_lengths=input_lengths,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
) )
@ -133,6 +141,7 @@ class FlashNeoXBatch(Batch):
requests = [] requests = []
input_lengths = [] input_lengths = []
all_input_ids = [] all_input_ids = []
all_input_ids_tensor = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -150,6 +159,7 @@ class FlashNeoXBatch(Batch):
requests.extend(batch.requests) requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
@ -181,6 +191,7 @@ class FlashNeoXBatch(Batch):
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
) )
@ -255,11 +266,10 @@ class FlashNeoX(Model):
) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]: ) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]:
# Better to send to device here to avoid device issues in concatenate # Better to send to device here to avoid device issues in concatenate
position_ids = batch.position_ids.to(self.device, non_blocking=True) position_ids = batch.position_ids.to(self.device, non_blocking=True)
cu_seqlens = batch.cu_seqlens.to(self.device, non_blocking=True) cu_seqlens = batch.cu_seqlens.to(self.device)
input_ids = batch.input_ids.squeeze(1).to(self.device)
out, present = self.forward( out, present = self.forward(
input_ids, batch.input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
batch.max_seqlen, batch.max_seqlen,
@ -277,6 +287,7 @@ class FlashNeoX(Model):
next_batch_past_key_values = [] next_batch_past_key_values = []
next_batch_input_lengths = [] next_batch_input_lengths = []
next_batch_all_input_ids = [] next_batch_all_input_ids = []
next_batch_all_input_ids_tensor = []
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
@ -291,6 +302,7 @@ class FlashNeoX(Model):
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.all_input_ids_tensor,
) )
# For each member of the batch # For each member of the batch
@ -300,6 +312,7 @@ class FlashNeoX(Model):
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
all_input_ids_tensor,
) in enumerate(iterator): ) in enumerate(iterator):
# Indexing metadata # Indexing metadata
start_index = cumulative_length start_index = cumulative_length
@ -315,20 +328,19 @@ class FlashNeoX(Model):
logits = out[i].unsqueeze(0) logits = out[i].unsqueeze(0)
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser(all_input_ids, logits) next_token_id, logprobs = next_token_chooser(
# Copy to cpu to avoid other copies when indexing and calling .item() all_input_ids_tensor[None, :input_length], logits
next_token_id = next_token_id.to("cpu", non_blocking=True) )
logprobs = logprobs.to("cpu")
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_id_item = next_token_id_squeezed.item() next_token_id_item = next_token_id_squeezed.item()
# Append next token to all tokens # Append next token to all tokens
all_input_ids.append(next_token_id_item) all_input_ids.append(next_token_id_item)
all_input_ids_tensor[input_length] = next_token_id_item
new_input_length = input_length + 1 new_input_length = input_length + 1
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id_item]
next_token_text = self.decode_token( next_token_text = self.decode_token(
next_token_id_item, next_token_id_item,
) )
@ -372,13 +384,14 @@ class FlashNeoX(Model):
) )
next_batch_input_lengths.append(new_input_length) next_batch_input_lengths.append(new_input_length)
next_batch_all_input_ids.append(all_input_ids) next_batch_all_input_ids.append(all_input_ids)
next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather( prefill_logprobs = [float("nan")] + logprobs.gather(
1, torch.tensor(all_input_ids[1:]).unsqueeze(1) 1, all_input_ids_tensor[1:input_length].unsqueeze(1)
).squeeze(1)[:-1].tolist() ).squeeze(1)[:-1].tolist()
prefill_token_ids = all_input_ids[:-1] prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
@ -431,12 +444,14 @@ class FlashNeoX(Model):
) )
next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32) next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32)
if len(next_batch_keep_indices) > 1: if len(next_batch_keep_indices) > 1:
next_batch_input_ids = torch.concat(next_batch_input_ids) next_batch_input_ids = torch.concat(next_batch_input_ids).squeeze(1)
next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1) next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1)
else: else:
next_batch_input_ids = next_batch_input_ids[0] next_batch_input_ids = next_batch_input_ids[0].view(1)
next_batch_past_key_values = next_batch_past_key_values[0] next_batch_past_key_values = next_batch_past_key_values[0]
print(next_batch_input_ids.shape)
next_batch = FlashNeoXBatch( next_batch = FlashNeoXBatch(
batch_id=batch.batch_id, batch_id=batch.batch_id,
requests=next_batch_requests, requests=next_batch_requests,
@ -447,6 +462,7 @@ class FlashNeoX(Model):
past_key_values=next_batch_past_key_values, past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths, input_lengths=next_batch_input_lengths,
all_input_ids=next_batch_all_input_ids, all_input_ids=next_batch_all_input_ids,
all_input_ids_tensor=next_batch_all_input_ids_tensor,
next_token_choosers=next_batch_next_token_choosers, next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias, stopping_criterias=next_batch_stopping_criterias,
) )

View File

@ -1,8 +1,6 @@
import torch import torch
import torch.distributed import torch.distributed
import torch.nn.functional as F
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
@ -16,7 +14,29 @@ import dropout_layer_norm
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
class TensorParallelColumnLinear(nn.Linear): class FastLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.swap_dims = True
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.swap_dims:
self.weight = nn.Parameter(self.weight.T)
self.swap_dims = False
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
return torch.matmul(input, self.weight)
class TensorParallelColumnLinear(FastLinear):
def __init__( def __init__(
self, self,
in_features, in_features,
@ -39,15 +59,11 @@ class TensorParallelColumnLinear(nn.Linear):
dtype=dtype, dtype=dtype,
) )
@staticmethod
def linear(input, weight, bias):
return F.linear(input, weight, bias)
def forward(self, input): def forward(self, input):
return self.linear(input, self.weight, self.bias) return super(TensorParallelColumnLinear, self).forward(input)
class TensorParallelRowLinear(nn.Linear): class TensorParallelRowLinear(FastLinear):
def __init__( def __init__(
self, self,
in_features, in_features,
@ -70,12 +86,8 @@ class TensorParallelRowLinear(nn.Linear):
dtype=dtype, dtype=dtype,
) )
@staticmethod
def linear(input, weight, bias):
return F.linear(input, weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
out = self.linear(input, self.weight, self.bias) out = super(TensorParallelRowLinear, self).forward(input)
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
return out return out
@ -122,14 +134,6 @@ class TensorParallelEmbedding(nn.Embedding):
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
# Sanity check
if torch.any(
torch.logical_or(0 > input, input >= self.original_num_embeddings)
):
raise IndexError(
f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}"
)
# `0` if input is in the correct interval, else `1` # `0` if input is in the correct interval, else `1`
input_mask = torch.logical_or(self.min_id > input, input >= self.max_id) input_mask = torch.logical_or(self.min_id > input, input >= self.max_id)
# translate for [0, self.max_id - self.min_id[ # translate for [0, self.max_id - self.min_id[
@ -196,8 +200,8 @@ class FlashNeoxAttention(torch.nn.Module):
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
if process_group is None: if process_group is None:
self.query_key_value = nn.Linear(hidden_size, 3 * hidden_size) self.query_key_value = FastLinear(hidden_size, 3 * hidden_size)
self.dense = nn.Linear(hidden_size, hidden_size) self.dense = FastLinear(hidden_size, hidden_size)
else: else:
self.num_heads = self.num_heads // process_group.size() self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear( self.query_key_value = TensorParallelColumnLinear(
@ -312,8 +316,8 @@ class FlashMLP(nn.Module):
) )
if process_group is None: if process_group is None:
self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size) self.dense_h_to_4h = FastLinear(hidden_size, intermediate_size)
self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size) self.dense_4h_to_h = FastLinear(intermediate_size, hidden_size)
else: else:
self.dense_h_to_4h = TensorParallelColumnLinear( self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size, hidden_size,
@ -556,7 +560,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
# Create indices from cumulative sequence lengths # Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1 layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange( cu_seqlens_q = torch.arange(
len(cu_seqlens), dtype=torch.int32, device=hidden_states.device cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
) )
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
@ -613,13 +617,13 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
self.gpt_neox = FlashGPTNeoXModel(config, process_group) self.gpt_neox = FlashGPTNeoXModel(config, process_group)
if self.gpt_neox.tp_embeddings: if self.gpt_neox.tp_embeddings:
self.embed_out = nn.Linear( self.embed_out = FastLinear(
config.hidden_size, config.hidden_size,
config.vocab_size // process_group.size(), config.vocab_size // process_group.size(),
bias=False, bias=False,
) )
else: else:
self.embed_out = nn.Linear( self.embed_out = FastLinear(
config.hidden_size, config.vocab_size, bias=False config.hidden_size, config.vocab_size, bias=False
) )

View File

@ -44,8 +44,8 @@ class WatermarkLogitsProcessor(LogitsProcessor):
), "requires at least a 1 token prefix sequence to seed rng" ), "requires at least a 1 token prefix sequence to seed rng"
prev_token = input_ids[-1] prev_token = input_ids[-1]
else: else:
input_ids = input_ids[0]
assert len(input_ids) == 1 assert len(input_ids) == 1
input_ids = input_ids[0]
assert ( assert (
input_ids.shape[-1] >= 1 input_ids.shape[-1] >= 1
), "requires at least a 1 token prefix sequence to seed rng" ), "requires at least a 1 token prefix sequence to seed rng"