From d6a93fe992bc932027df6f4a8f2b87c68d233f55 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 24 Mar 2023 18:21:41 +0100 Subject: [PATCH] fix(server): fix flash-neox scores warping (#137) --- clients/python/pyproject.toml | 2 +- clients/python/tests/test_types.py | 2 + clients/python/text_generation/client.py | 10 +-- clients/python/text_generation/types.py | 2 + .../text_generation_server/models/__init__.py | 1 + .../models/flash_neox.py | 58 ++++++++++------- .../models/flash_neox_modeling.py | 62 ++++++++++--------- .../text_generation_server/utils/watermark.py | 2 +- 8 files changed, 79 insertions(+), 60 deletions(-) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 505717de..d07f1dbc 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation" -version = "0.4.0" +version = "0.4.1" description = "Hugging Face Text Generation Python Client" license = "Apache-2.0" authors = ["Olivier Dehaene "] diff --git a/clients/python/tests/test_types.py b/clients/python/tests/test_types.py index 4c9d4c89..77689ade 100644 --- a/clients/python/tests/test_types.py +++ b/clients/python/tests/test_types.py @@ -14,6 +14,8 @@ def test_parameters_validation(): Parameters(best_of=2, do_sample=True) with pytest.raises(ValidationError): Parameters(best_of=2) + with pytest.raises(ValidationError): + Parameters(best_of=2, seed=1) # Test repetition_penalty Parameters(repetition_penalty=1) diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 03bc3888..8b8742fc 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -150,7 +150,6 @@ class Client: prompt: str, do_sample: bool = False, max_new_tokens: int = 20, - best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -172,8 +171,6 @@ class Client: Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens - best_of (`int`): - Generate best_of sequences and return the one if the highest token logprobs repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. @@ -203,7 +200,7 @@ class Client: """ # Validate parameters parameters = Parameters( - best_of=best_of, + best_of=None, details=True, do_sample=do_sample, max_new_tokens=max_new_tokens, @@ -388,7 +385,6 @@ class AsyncClient: prompt: str, do_sample: bool = False, max_new_tokens: int = 20, - best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -410,8 +406,6 @@ class AsyncClient: Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens - best_of (`int`): - Generate best_of sequences and return the one if the highest token logprobs repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. @@ -441,7 +435,7 @@ class AsyncClient: """ # Validate parameters parameters = Parameters( - best_of=best_of, + best_of=None, details=True, do_sample=do_sample, max_new_tokens=max_new_tokens, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index ea2070b8..21a9849b 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -43,6 +43,8 @@ class Parameters(BaseModel): if field_value is not None: if field_value <= 0: raise ValidationError("`best_of` must be strictly positive") + if field_value > 1 and values["seed"] is not None: + raise ValidationError("`seed` must not be set when `best_of` is > 1") sampling = ( values["do_sample"] | (values["temperature"] is not None) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 2f637ae1..f3a92ad2 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -16,6 +16,7 @@ from text_generation_server.models.t5 import T5Sharded try: from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded + FLASH_NEOX = torch.cuda.is_available() and int(os.environ.get("FLASH_NEOX", 0)) == 1 except ImportError: if int(os.environ.get("FLASH_NEOX", 0)) == 1: diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 7be4708b..cbaa78ca 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -1,6 +1,8 @@ import torch import torch.distributed +from torch.nn import functional as F + from accelerate import init_empty_weights from dataclasses import dataclass from opentelemetry import trace @@ -48,6 +50,7 @@ class FlashNeoXBatch(Batch): # All tokens all_input_ids: List[List[int]] + all_input_ids_tensor: List[torch.Tensor] # Lengths of all generations present in the batch input_lengths: List[int] @@ -75,6 +78,7 @@ class FlashNeoXBatch(Batch): input_lengths = [] all_input_ids = [] + all_input_ids_tensor = [] next_token_choosers = [] stopping_criterias = [] @@ -84,15 +88,14 @@ class FlashNeoXBatch(Batch): # Parse batch for r in pb.requests: - tokenized_input = tokenizer(r.inputs, return_tensors="pt")[ - "input_ids" - ].squeeze(0) - input_ids.append(tokenized_input) - all_input_ids.append(tokenized_input.tolist()) - + tokenized_input = tokenizer(r.inputs)["input_ids"] input_length = len(tokenized_input) max_seqlen = max(max_seqlen, input_length) input_lengths.append(input_length) + all_input_ids.append(tokenized_input) + + tokenized_input = torch.tensor(tokenized_input, device=device) + input_ids.append(tokenized_input) # Position ids position_ids.append(torch.arange(0, input_length, dtype=torch.int32)) @@ -101,14 +104,18 @@ class FlashNeoXBatch(Batch): cu_seqlens.append(cumulative_length + input_length) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) - stopping_criterias.append( - StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + all_input_ids_tensor.append( + F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens)) ) # Update cumulative_length += input_length - input_ids = torch.concat(input_ids).unsqueeze(1) + input_ids = torch.concat(input_ids) position_ids = torch.concat(position_ids) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) @@ -122,6 +129,7 @@ class FlashNeoXBatch(Batch): past_key_values=None, input_lengths=input_lengths, all_input_ids=all_input_ids, + all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, ) @@ -133,6 +141,7 @@ class FlashNeoXBatch(Batch): requests = [] input_lengths = [] all_input_ids = [] + all_input_ids_tensor = [] next_token_choosers = [] stopping_criterias = [] @@ -150,6 +159,7 @@ class FlashNeoXBatch(Batch): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) all_input_ids.extend(batch.all_input_ids) + all_input_ids_tensor.extend(batch.all_input_ids_tensor) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -181,6 +191,7 @@ class FlashNeoXBatch(Batch): past_key_values=past_key_values, input_lengths=input_lengths, all_input_ids=all_input_ids, + all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, ) @@ -255,11 +266,10 @@ class FlashNeoX(Model): ) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]: # Better to send to device here to avoid device issues in concatenate position_ids = batch.position_ids.to(self.device, non_blocking=True) - cu_seqlens = batch.cu_seqlens.to(self.device, non_blocking=True) - input_ids = batch.input_ids.squeeze(1).to(self.device) + cu_seqlens = batch.cu_seqlens.to(self.device) out, present = self.forward( - input_ids, + batch.input_ids, position_ids, cu_seqlens, batch.max_seqlen, @@ -277,6 +287,7 @@ class FlashNeoX(Model): next_batch_past_key_values = [] next_batch_input_lengths = [] next_batch_all_input_ids = [] + next_batch_all_input_ids_tensor = [] # Cumulative length cumulative_length = 0 @@ -291,6 +302,7 @@ class FlashNeoX(Model): batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, + batch.all_input_ids_tensor, ) # For each member of the batch @@ -300,6 +312,7 @@ class FlashNeoX(Model): next_token_chooser, stopping_criteria, all_input_ids, + all_input_ids_tensor, ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length @@ -315,20 +328,19 @@ class FlashNeoX(Model): logits = out[i].unsqueeze(0) # Select next token - next_token_id, logprobs = next_token_chooser(all_input_ids, logits) - # Copy to cpu to avoid other copies when indexing and calling .item() - next_token_id = next_token_id.to("cpu", non_blocking=True) - logprobs = logprobs.to("cpu") - + next_token_id, logprobs = next_token_chooser( + all_input_ids_tensor[None, :input_length], logits + ) next_token_id_squeezed = next_token_id.squeeze() next_token_id_item = next_token_id_squeezed.item() # Append next token to all tokens all_input_ids.append(next_token_id_item) + all_input_ids_tensor[input_length] = next_token_id_item new_input_length = input_length + 1 # Generated token - next_token_logprob = logprobs[-1, next_token_id] + next_token_logprob = logprobs[-1, next_token_id_item] next_token_text = self.decode_token( next_token_id_item, ) @@ -372,13 +384,14 @@ class FlashNeoX(Model): ) next_batch_input_lengths.append(new_input_length) next_batch_all_input_ids.append(all_input_ids) + next_batch_all_input_ids_tensor.append(all_input_ids_tensor) next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) # Prefill if stopping_criteria.current_tokens == 1: # Remove generated token to only have prefill and add nan for first prompt token prefill_logprobs = [float("nan")] + logprobs.gather( - 1, torch.tensor(all_input_ids[1:]).unsqueeze(1) + 1, all_input_ids_tensor[1:input_length].unsqueeze(1) ).squeeze(1)[:-1].tolist() prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( @@ -431,12 +444,14 @@ class FlashNeoX(Model): ) next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32) if len(next_batch_keep_indices) > 1: - next_batch_input_ids = torch.concat(next_batch_input_ids) + next_batch_input_ids = torch.concat(next_batch_input_ids).squeeze(1) next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1) else: - next_batch_input_ids = next_batch_input_ids[0] + next_batch_input_ids = next_batch_input_ids[0].view(1) next_batch_past_key_values = next_batch_past_key_values[0] + print(next_batch_input_ids.shape) + next_batch = FlashNeoXBatch( batch_id=batch.batch_id, requests=next_batch_requests, @@ -447,6 +462,7 @@ class FlashNeoX(Model): past_key_values=next_batch_past_key_values, input_lengths=next_batch_input_lengths, all_input_ids=next_batch_all_input_ids, + all_input_ids_tensor=next_batch_all_input_ids_tensor, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, ) diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/flash_neox_modeling.py index d67ca3c0..dcfb613d 100644 --- a/server/text_generation_server/models/flash_neox_modeling.py +++ b/server/text_generation_server/models/flash_neox_modeling.py @@ -1,8 +1,6 @@ import torch import torch.distributed -import torch.nn.functional as F - from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel @@ -16,7 +14,29 @@ import dropout_layer_norm from flash_attn.layers.rotary import RotaryEmbedding -class TensorParallelColumnLinear(nn.Linear): +class FastLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) + self.swap_dims = True + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.swap_dims: + self.weight = nn.Parameter(self.weight.T) + self.swap_dims = False + + if self.bias is not None: + return torch.addmm(self.bias, input, self.weight) + return torch.matmul(input, self.weight) + + +class TensorParallelColumnLinear(FastLinear): def __init__( self, in_features, @@ -39,15 +59,11 @@ class TensorParallelColumnLinear(nn.Linear): dtype=dtype, ) - @staticmethod - def linear(input, weight, bias): - return F.linear(input, weight, bias) - def forward(self, input): - return self.linear(input, self.weight, self.bias) + return super(TensorParallelColumnLinear, self).forward(input) -class TensorParallelRowLinear(nn.Linear): +class TensorParallelRowLinear(FastLinear): def __init__( self, in_features, @@ -70,12 +86,8 @@ class TensorParallelRowLinear(nn.Linear): dtype=dtype, ) - @staticmethod - def linear(input, weight, bias): - return F.linear(input, weight, bias) - def forward(self, input: torch.Tensor) -> torch.Tensor: - out = self.linear(input, self.weight, self.bias) + out = super(TensorParallelRowLinear, self).forward(input) torch.distributed.all_reduce(out, group=self.process_group) return out @@ -122,14 +134,6 @@ class TensorParallelEmbedding(nn.Embedding): ) def forward(self, input: torch.Tensor) -> torch.Tensor: - # Sanity check - if torch.any( - torch.logical_or(0 > input, input >= self.original_num_embeddings) - ): - raise IndexError( - f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}" - ) - # `0` if input is in the correct interval, else `1` input_mask = torch.logical_or(self.min_id > input, input >= self.max_id) # translate for [0, self.max_id - self.min_id[ @@ -196,8 +200,8 @@ class FlashNeoxAttention(torch.nn.Module): self.softmax_scale = self.head_size ** (-0.5) if process_group is None: - self.query_key_value = nn.Linear(hidden_size, 3 * hidden_size) - self.dense = nn.Linear(hidden_size, hidden_size) + self.query_key_value = FastLinear(hidden_size, 3 * hidden_size) + self.dense = FastLinear(hidden_size, hidden_size) else: self.num_heads = self.num_heads // process_group.size() self.query_key_value = TensorParallelColumnLinear( @@ -312,8 +316,8 @@ class FlashMLP(nn.Module): ) if process_group is None: - self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size) - self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size) + self.dense_h_to_4h = FastLinear(hidden_size, intermediate_size) + self.dense_4h_to_h = FastLinear(intermediate_size, hidden_size) else: self.dense_h_to_4h = TensorParallelColumnLinear( hidden_size, @@ -556,7 +560,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): # Create indices from cumulative sequence lengths layer_past_present_indices = cu_seqlens[1:] - 1 cu_seqlens_q = torch.arange( - len(cu_seqlens), dtype=torch.int32, device=hidden_states.device + cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device ) # Get rotary cos and sin for this forward @@ -613,13 +617,13 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): self.gpt_neox = FlashGPTNeoXModel(config, process_group) if self.gpt_neox.tp_embeddings: - self.embed_out = nn.Linear( + self.embed_out = FastLinear( config.hidden_size, config.vocab_size // process_group.size(), bias=False, ) else: - self.embed_out = nn.Linear( + self.embed_out = FastLinear( config.hidden_size, config.vocab_size, bias=False ) diff --git a/server/text_generation_server/utils/watermark.py b/server/text_generation_server/utils/watermark.py index 1850561d..df7b90e3 100644 --- a/server/text_generation_server/utils/watermark.py +++ b/server/text_generation_server/utils/watermark.py @@ -44,8 +44,8 @@ class WatermarkLogitsProcessor(LogitsProcessor): ), "requires at least a 1 token prefix sequence to seed rng" prev_token = input_ids[-1] else: - input_ids = input_ids[0] assert len(input_ids) == 1 + input_ids = input_ids[0] assert ( input_ids.shape[-1] >= 1 ), "requires at least a 1 token prefix sequence to seed rng"