diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 90c70cb5..bcaf6ec1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -8,6 +8,7 @@ from typing import Optional from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM +from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.galactica import Galactica, GalacticaSharded @@ -17,18 +18,22 @@ from text_generation_server.models.t5 import T5Sharded try: from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded + from text_generation_server.models.flash_santacoder import FlashSantacoder - FLASH_NEOX = torch.cuda.is_available() and int(os.environ.get("FLASH_NEOX", 0)) == 1 + FLASH_ATTENTION = ( + torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1 + ) except ImportError: - if int(os.environ.get("FLASH_NEOX", 0)) == 1: - logger.exception("Could not import FlashNeoX") - FLASH_NEOX = False + if int(os.environ.get("FLASH_ATTENTION", 0)) == 1: + logger.exception("Could not import Flash Attention models") + FLASH_ATTENTION = False __all__ = [ "Model", "BLOOM", "BLOOMSharded", "CausalLM", + "FlashCausalLM", "Galactica", "GalacticaSharded", "GPTNeoxSharded", @@ -38,9 +43,10 @@ __all__ = [ "get_model", ] -if FLASH_NEOX: +if FLASH_ATTENTION: __all__.append(FlashNeoX) __all__.append(FlashNeoXSharded) + __all__.append(FlashSantacoder) # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -63,7 +69,11 @@ def get_model( return Galactica(model_id, revision, quantize=quantize) if "santacoder" in model_id: - return SantaCoder(model_id, revision, quantize) + if sharded: + raise NotImplementedError("sharded is not supported for Santacoder") + else: + santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder + return santacoder_cls(model_id, revision, quantize) config = AutoConfig.from_pretrained(model_id, revision=revision) model_type = config.model_type @@ -76,10 +86,10 @@ def get_model( if model_type == "gpt_neox": if sharded: - neox_cls = FlashNeoXSharded if FLASH_NEOX else GPTNeoxSharded + neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded return neox_cls(model_id, revision, quantize=quantize) else: - neox_cls = FlashNeoX if FLASH_NEOX else CausalLM + neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM return neox_cls(model_id, revision, quantize=quantize) if model_type == "t5": diff --git a/server/text_generation_server/models/custom_modeling/__init__.py b/server/text_generation_server/models/custom_modeling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py similarity index 100% rename from server/text_generation_server/models/flash_neox_modeling.py rename to server/text_generation_server/models/custom_modeling/flash_neox_modeling.py diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py new file mode 100644 index 00000000..ef073636 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -0,0 +1,357 @@ +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN + +# Flash attention imports +import flash_attn_cuda +import dropout_layer_norm + + +class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 6144: + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super(FastLayerNorm, self).forward(hidden_states), residual + else: + ( + normed_hidden_states, + residual, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + + +class FastLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) + + def transpose_weight(self): + self.weight = nn.Parameter(self.weight.T) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.bias is not None: + return torch.addmm(self.bias, input, self.weight) + return torch.matmul(input, self.weight) + + +class FlashMQAttention(torch.nn.Module): + def __init__( + self, + num_heads, + hidden_size, + process_group=None, + ): + super().__init__() + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + + self.softmax_scale = self.head_size ** (-0.5) + + if process_group is None: + self.attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size) + self.c_proj = FastLinear(hidden_size, hidden_size) + else: + raise NotImplementedError + + def forward( + self, + hidden_states, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + qkv = self.attn(hidden_states) + + # Split query from key_value + query, key_value = qkv.split([self.hidden_size, 2 * self.head_size], dim=1) + + # Prepare query and key_value for indexing + query = query.view(-1, self.num_heads, self.head_size) + key_value = key_value.view(-1, 2, 1, self.head_size) + + # Prefill + if layer_past_present_indices is None: + # Copy to layer past + layer_past[...] = key_value + # Expand from 1 to num_heads + key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + key_value[:, 0], + key_value[:, 1], + attn_output, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + self.softmax_scale, + False, + True, + False, + 0, + None, + ) + # Decode + else: + # Add present to the layer_past tensor at the correct indices + layer_past[layer_past_present_indices] = key_value + # Expand from 1 to num_heads + key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size) + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + key_value[:, 0], + key_value[:, 1], + attn_output, + cu_seqlens_q, + cu_seqlens, + 1, + max_s, + 0.0, + self.softmax_scale, + False, + False, + False, + 0, + None, + ) + + return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class MLP(nn.Module): + def __init__( + self, act, hidden_size, intermediate_size, process_group=None + ): + super().__init__() + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu(x, approximate="tanh" if act in ["gelu_fast", + "gelu_pytorch_tanh"] else None) + ) + + if process_group is None: + self.c_fc = FastLinear(hidden_size, intermediate_size) + self.c_proj = FastLinear(intermediate_size, hidden_size) + else: + raise NotImplementedError + + def forward(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + return hidden_states + + +class Block(nn.Module): + def __init__( + self, + num_heads, + act, + hidden_size, + intermediate_size, + layer_norm_eps, + process_group=None, + ): + super().__init__() + self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps) + self.ln_2 = FastLayerNorm(hidden_size, eps=layer_norm_eps) + self.attn = FlashMQAttention( + num_heads, + hidden_size, + process_group, + ) + self.mlp = MLP( + act, + hidden_size, + intermediate_size, + process_group, + ) + + def forward( + self, + hidden_states, + residual, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + hidden_states, residual = self.ln_1(hidden_states, residual) + + hidden_states = self.attn( + hidden_states, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ) + + hidden_states, residual = self.ln_2( + hidden_states, residual + ) + + mlp_output = self.mlp(hidden_states) + + return mlp_output, residual + + +class FlashSantacoderModel(nn.Module): + def __init__(self, config, process_group=None): + super().__init__() + self.config = config + + if process_group is not None: + raise NotImplementedError + + self.wte = nn.Embedding(config.vocab_size, config.hidden_size) + self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + self.h = nn.ModuleList( + [ + Block( + config.num_attention_heads, + config.activation_function, + config.hidden_size, + config.n_inner if config.n_inner is not None else 4 * config.hidden_size, + config.layer_norm_epsilon, + process_group, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.ln_f = FastLayerNorm( + config.hidden_size, eps=config.layer_norm_epsilon + ) + + self.head_size = self.h[0].attn.head_size + self.num_heads = self.h[0].attn.num_heads + + def post_load_weights(self): + for layer in self.h: + layer: Block + layer.attn.attn.transpose_weight() + layer.attn.c_proj.transpose_weight() + layer.mlp.c_fc.transpose_weight() + layer.mlp.c_proj.transpose_weight() + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, + ): + hidden_states = self.wte(input_ids) + self.wpe(position_ids) + + # Prefill + if past_key_values is None: + # Create past tensor + past_key_values = hidden_states.new_empty( + ( + len(self.h), + len(hidden_states), + 2, + 1, + self.head_size, + ) + ) + layer_past_present_indices = None + cu_seqlens_q = None + # Decode + else: + # Create indices from cumulative sequence lengths + layer_past_present_indices = cu_seqlens[1:] - 1 + cu_seqlens_q = torch.arange( + cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device + ) + + residual = None + for i, layer in enumerate(self.h): + hidden_states, residual = layer( + hidden_states, + residual, + cu_seqlens, + max_s, + past_key_values[i], + layer_past_present_indices, + cu_seqlens_q, + ) + + hidden_states, _ = self.ln_f(hidden_states, residual) + + return hidden_states, past_key_values + + +class FlashSantacoderForCausalLM(nn.Module): + def __init__(self, config, process_group=None): + super().__init__() + + self.transformer = FlashSantacoderModel(config, process_group) + + self.lm_head = FastLinear( + config.hidden_size, config.vocab_size, bias=False + ) + + def post_load_weights(self): + self.transformer.post_load_weights() + self.lm_head.transpose_weight() + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, + ): + hidden_states, present = self.transformer( + input_ids, position_ids, cu_seqlens, max_s, past_key_values + ) + return self.lm_head(hidden_states), present diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py new file mode 100644 index 00000000..e1a10cbf --- /dev/null +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -0,0 +1,458 @@ +import torch +import torch.distributed + +from torch.nn import functional as F + +from dataclasses import dataclass +from opentelemetry import trace +from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel +from typing import Optional, Tuple, List, Type, Union + +from text_generation_server.models import Model +from text_generation_server.models.types import ( + Batch, + PrefillTokens, + Generation, + GeneratedText, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import ( + NextTokenChooser, + StoppingCriteria, + Sampling, +) + +tracer = trace.get_tracer(__name__) + + +@dataclass +class FlashCausalLMBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + + # Decoder values + input_ids: torch.Tensor + position_ids: torch.Tensor + # cumulative sequence lengths + cu_seqlens: torch.Tensor + max_seqlen: int + past_key_values: Optional[torch.Tensor] + + # All tokens + all_input_ids: List[List[int]] + all_input_ids_tensor: List[torch.Tensor] + + # Lengths of all generations present in the batch + input_lengths: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + + def to_pb(self) -> generate_pb2.Batch: + return generate_pb2.Batch( + id=self.batch_id, requests=self.requests, size=len(self) + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, + ) -> "CausalLMBatch": + input_ids = [] + position_ids = [] + cu_seqlens = [0] + max_seqlen = 0 + + input_lengths = [] + all_input_ids = [] + all_input_ids_tensor = [] + + next_token_choosers = [] + stopping_criterias = [] + + # Cumulative length + cumulative_length = 0 + + # Parse batch + for r in pb.requests: + tokenized_input = tokenizer(r.inputs)["input_ids"] + input_length = len(tokenized_input) + max_seqlen = max(max_seqlen, input_length) + input_lengths.append(input_length) + all_input_ids.append(tokenized_input) + + tokenized_input = torch.tensor(tokenized_input, device=device) + input_ids.append(tokenized_input) + + # Position ids + position_ids.append(torch.arange(0, input_length, dtype=torch.int32)) + + # Add cumulative lengths of all previous inputs + cu_seqlens.append(cumulative_length + input_length) + + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + all_input_ids_tensor.append( + F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens)) + ) + + # Update + cumulative_length += input_length + + input_ids = torch.concat(input_ids) + position_ids = torch.concat(position_ids) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) + + return cls( + batch_id=pb.id, + requests=pb.requests, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=None, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + all_input_ids_tensor=all_input_ids_tensor, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + ) + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": + # Batch attributes + requests = [] + input_lengths = [] + all_input_ids = [] + all_input_ids_tensor = [] + next_token_choosers = [] + stopping_criterias = [] + + # Batch tensors + input_ids = [] + position_ids = [] + cu_seqlens = [torch.tensor([0], dtype=torch.int32)] + max_seqlen = 0 + past_key_values = [] + + # Cumulative length + cumulative_length = torch.tensor(0) + + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + all_input_ids.extend(batch.all_input_ids) + all_input_ids_tensor.extend(batch.all_input_ids_tensor) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + # Add cumulative lengths of all previous inputs + cu_seqlens.append(batch.cu_seqlens[1:] + cumulative_length) + + input_ids.append(batch.input_ids) + position_ids.append(batch.position_ids) + past_key_values.append(batch.past_key_values) + + max_seqlen = max(max_seqlen, batch.max_seqlen) + + # Update + cumulative_length += batch.cu_seqlens[-1] + + input_ids = torch.concat(input_ids) + position_ids = torch.concat(position_ids) + # Concat on dim=1 as first dim represents the model layers + past_key_values = torch.concat(past_key_values, dim=1) + cu_seqlens = torch.concat(cu_seqlens) + + return FlashCausalLMBatch( + batch_id=batches[0].batch_id, + requests=requests, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + all_input_ids_tensor=all_input_ids_tensor, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + ) + + def __len__(self): + return len(self.requests) + + +class FlashCausalLM(Model): + def __init__( + self, + model_cls: Type[PreTrainedModel], + model_id: str, + revision: Optional[str] = None, + quantize=False, + ): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + raise NotImplementedError("FlashCausalLM is only available on GPU") + + if quantize: + raise NotImplementedError("FlashCausalLM does not support quantization") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left" + ) + self.model = ( + model_cls.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + ) + .eval() + .cuda() + ) + + super(FlashCausalLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @property + def batch_type(self) -> Type[FlashCausalLMBatch]: + return FlashCausalLMBatch + + def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + max_s: int, + past_key_values: Optional = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Model Forward + return self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_s=max_s, + past_key_values=past_key_values, + ) + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batch: FlashCausalLMBatch + ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: + # Better to send to device here to avoid device issues in concatenate + position_ids = batch.position_ids.to(self.device, non_blocking=True) + cu_seqlens = batch.cu_seqlens.to(self.device) + + out, present = self.forward( + batch.input_ids, + position_ids, + cu_seqlens, + batch.max_seqlen, + batch.past_key_values, + ) + + # List of indices to cache + next_batch_keep_indices = [] + + # New values for next forward + next_batch_input_ids = [] + next_batch_position_ids = [] + next_batch_cu_seqlens = [0] + next_batch_max_seqlen = 0 + next_batch_past_key_values = [] + next_batch_input_lengths = [] + next_batch_all_input_ids = [] + next_batch_all_input_ids_tensor = [] + + # Cumulative length + cumulative_length = 0 + + # Results + generations: List[Generation] = [] + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + batch.all_input_ids_tensor, + ) + + # For each member of the batch + for i, ( + request, + input_length, + next_token_chooser, + stopping_criteria, + all_input_ids, + all_input_ids_tensor, + ) in enumerate(iterator): + # Indexing metadata + start_index = cumulative_length + end_index = cumulative_length + input_length + + if batch.past_key_values is None: + # Prefill mode + # out is of shape [cumulative_sequence_lengths, vocab_size] + logits = out[start_index:end_index] + else: + # Decode mode + # out is of shape [batch_size, vocab_size] + logits = out[i].unsqueeze(0) + + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids_tensor[None, :input_length], logits + ) + next_token_id_squeezed = next_token_id.squeeze() + next_token_id_item = next_token_id_squeezed.item() + + # Append next token to all tokens + all_input_ids.append(next_token_id_item) + all_input_ids_tensor[input_length] = next_token_id_item + new_input_length = input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id_item] + next_token_text = self.decode_token( + next_token_id_item, + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_item, + next_token_text, + ) + + if stop: + # Decode generated tokens + output_text = self.decode( + all_input_ids[-stopping_criteria.current_tokens :] + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + # Keep request in the batch + next_batch_keep_indices.append(i) + generated_text = None + + # Get sequence present + seq_present = present[:, start_index:end_index] + # Pad it for next iter attention + past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1)) + next_batch_past_key_values.append(past) + + next_batch_input_ids.append(next_token_id) + next_batch_position_ids.append(input_length) + # Cumulative sum + next_batch_cu_seqlens.append( + next_batch_cu_seqlens[-1] + new_input_length + ) + next_batch_input_lengths.append(new_input_length) + next_batch_all_input_ids.append(all_input_ids) + next_batch_all_input_ids_tensor.append(all_input_ids_tensor) + next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) + + # Prefill + if stopping_criteria.current_tokens == 1: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + logprobs.gather( + 1, all_input_ids_tensor[1:input_length].unsqueeze(1) + ).squeeze(1)[:-1].tolist() + prefill_token_ids = all_input_ids[:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, prefill_logprobs, prefill_texts + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id_item, + next_token_logprob, + next_token_text, + next_token_id_item in self.all_special_ids, + generated_text, + ) + + generations.append(generation) + cumulative_length += input_length + + # We finished all generations in the batch; there is no next batch + if not next_batch_keep_indices: + return generations, None + + # If we finished at least one generation, we need to evict the indices of the generations that finished + # from the values of the next batch + if len(next_batch_keep_indices) != len(batch): + # Apply indices to requests, token_choosers and stopping_criterias that need to be cached + next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] + next_batch_next_token_choosers = [ + batch.next_token_choosers[i] for i in next_batch_keep_indices + ] + next_batch_stopping_criterias = [ + batch.stopping_criterias[i] for i in next_batch_keep_indices + ] + else: + next_batch_requests = batch.requests + next_batch_next_token_choosers = batch.next_token_choosers + next_batch_stopping_criterias = batch.stopping_criterias + + # Create final next batch tensors + next_batch_position_ids = torch.tensor( + next_batch_position_ids, dtype=torch.int32 + ) + next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32) + if len(next_batch_keep_indices) > 1: + next_batch_input_ids = torch.concat(next_batch_input_ids).squeeze(1) + next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1) + else: + next_batch_input_ids = next_batch_input_ids[0].view(1) + next_batch_past_key_values = next_batch_past_key_values[0] + + next_batch = FlashCausalLMBatch( + batch_id=batch.batch_id, + requests=next_batch_requests, + input_ids=next_batch_input_ids, + position_ids=next_batch_position_ids, + cu_seqlens=next_batch_cu_seqlens, + max_seqlen=next_batch_max_seqlen, + past_key_values=next_batch_past_key_values, + input_lengths=next_batch_input_lengths, + all_input_ids=next_batch_all_input_ids, + all_input_ids_tensor=next_batch_all_input_ids_tensor, + next_token_choosers=next_batch_next_token_choosers, + stopping_criterias=next_batch_stopping_criterias, + ) + return generations, next_batch diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index b97f342a..e415a725 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -1,33 +1,20 @@ import torch import torch.distributed -from torch.nn import functional as F - from accelerate import init_empty_weights -from dataclasses import dataclass from opentelemetry import trace from safetensors import safe_open -from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig -from typing import Optional, Tuple, List, Type, Union +from transformers import AutoTokenizer, AutoConfig +from typing import Optional, Tuple, List -from text_generation_server.models import Model -from text_generation_server.models.flash_neox_modeling import ( +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_neox_modeling import ( FlashGPTNeoXForCausalLM, TensorParallelEmbedding, TensorParallelRowLinear, TensorParallelColumnLinear, ) -from text_generation_server.models.types import ( - Batch, - PrefillTokens, - Generation, - GeneratedText, -) -from text_generation_server.pb import generate_pb2 from text_generation_server.utils import ( - NextTokenChooser, - StoppingCriteria, - Sampling, initialize_torch_distributed, weight_files, ) @@ -35,437 +22,12 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) -@dataclass -class FlashNeoXBatch(Batch): - batch_id: int - requests: List[generate_pb2.Request] - - # Decoder values - input_ids: torch.Tensor - position_ids: torch.Tensor - # cumulative sequence lengths - cu_seqlens: torch.Tensor - max_seqlen: int - past_key_values: Optional[torch.Tensor] - - # All tokens - all_input_ids: List[List[int]] - all_input_ids_tensor: List[torch.Tensor] - - # Lengths of all generations present in the batch - input_lengths: List[int] - - # Generation helpers - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - - def to_pb(self) -> generate_pb2.Batch: - return generate_pb2.Batch( - id=self.batch_id, requests=self.requests, size=len(self) - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - device: torch.device, - ) -> "CausalLMBatch": - input_ids = [] - position_ids = [] - cu_seqlens = [0] - max_seqlen = 0 - - input_lengths = [] - all_input_ids = [] - all_input_ids_tensor = [] - - next_token_choosers = [] - stopping_criterias = [] - - # Cumulative length - cumulative_length = 0 - - # Parse batch - for r in pb.requests: - tokenized_input = tokenizer(r.inputs)["input_ids"] - input_length = len(tokenized_input) - max_seqlen = max(max_seqlen, input_length) - input_lengths.append(input_length) - all_input_ids.append(tokenized_input) - - tokenized_input = torch.tensor(tokenized_input, device=device) - input_ids.append(tokenized_input) - - # Position ids - position_ids.append(torch.arange(0, input_length, dtype=torch.int32)) - - # Add cumulative lengths of all previous inputs - cu_seqlens.append(cumulative_length + input_length) - - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - all_input_ids_tensor.append( - F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens)) - ) - - # Update - cumulative_length += input_length - - input_ids = torch.concat(input_ids) - position_ids = torch.concat(position_ids) - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) - - return cls( - batch_id=pb.id, - requests=pb.requests, - input_ids=input_ids, - position_ids=position_ids, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=None, - input_lengths=input_lengths, - all_input_ids=all_input_ids, - all_input_ids_tensor=all_input_ids_tensor, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - ) - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": - # Batch attributes - requests = [] - input_lengths = [] - all_input_ids = [] - all_input_ids_tensor = [] - next_token_choosers = [] - stopping_criterias = [] - - # Batch tensors - input_ids = [] - position_ids = [] - cu_seqlens = [torch.tensor([0], dtype=torch.int32)] - max_seqlen = 0 - past_key_values = [] - - # Cumulative length - cumulative_length = torch.tensor(0) - - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - all_input_ids.extend(batch.all_input_ids) - all_input_ids_tensor.extend(batch.all_input_ids_tensor) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - - # Add cumulative lengths of all previous inputs - cu_seqlens.append(batch.cu_seqlens[1:] + cumulative_length) - - input_ids.append(batch.input_ids) - position_ids.append(batch.position_ids) - past_key_values.append(batch.past_key_values) - - max_seqlen = max(max_seqlen, batch.max_seqlen) - - # Update - cumulative_length += batch.cu_seqlens[-1] - - input_ids = torch.concat(input_ids) - position_ids = torch.concat(position_ids) - # Concat on dim=1 as first dim represents the model layers - past_key_values = torch.concat(past_key_values, dim=1) - cu_seqlens = torch.concat(cu_seqlens) - - return FlashNeoXBatch( - batch_id=batches[0].batch_id, - requests=requests, - input_ids=input_ids, - position_ids=position_ids, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - input_lengths=input_lengths, - all_input_ids=all_input_ids, - all_input_ids_tensor=all_input_ids_tensor, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - ) - - def __len__(self): - return len(self.requests) - - -class FlashNeoX(Model): +class FlashNeoX(FlashCausalLM): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - else: - raise NotImplementedError("FlashNeoX is only available on GPU") - - if quantize: - raise NotImplementedError("FlashNeoX does not support quantization") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left" - ) - self.model = ( - FlashGPTNeoXForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - ) - .eval() - .cuda() - ) - tokenizer.pad_token_id = ( - self.model.config.pad_token_id - if self.model.config.pad_token_id is not None - else self.model.config.eos_token_id - ) - super(FlashNeoX, self).__init__( - tokenizer=tokenizer, - device=device, + FlashGPTNeoXForCausalLM, model_id, revision, quantize ) - @property - def batch_type(self) -> Type[FlashNeoXBatch]: - return FlashNeoXBatch - - def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False - ) - - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlens: torch.Tensor, - max_s: int, - past_key_values: Optional = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Model Forward - return self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlens=cu_seqlens, - max_s=max_s, - past_key_values=past_key_values, - ) - - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batch: FlashNeoXBatch - ) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]: - # Better to send to device here to avoid device issues in concatenate - position_ids = batch.position_ids.to(self.device, non_blocking=True) - cu_seqlens = batch.cu_seqlens.to(self.device) - - out, present = self.forward( - batch.input_ids, - position_ids, - cu_seqlens, - batch.max_seqlen, - batch.past_key_values, - ) - - # List of indices to cache - next_batch_keep_indices = [] - - # New values for next forward - next_batch_input_ids = [] - next_batch_position_ids = [] - next_batch_cu_seqlens = [0] - next_batch_max_seqlen = 0 - next_batch_past_key_values = [] - next_batch_input_lengths = [] - next_batch_all_input_ids = [] - next_batch_all_input_ids_tensor = [] - - # Cumulative length - cumulative_length = 0 - - # Results - generations: List[Generation] = [] - - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - batch.all_input_ids_tensor, - ) - - # For each member of the batch - for i, ( - request, - input_length, - next_token_chooser, - stopping_criteria, - all_input_ids, - all_input_ids_tensor, - ) in enumerate(iterator): - # Indexing metadata - start_index = cumulative_length - end_index = cumulative_length + input_length - - if batch.past_key_values is None: - # Prefill mode - # out is of shape [cumulative_sequence_lengths, vocab_size] - logits = out[start_index:end_index] - else: - # Decode mode - # out is of shape [batch_size, vocab_size] - logits = out[i].unsqueeze(0) - - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids_tensor[None, :input_length], logits - ) - next_token_id_squeezed = next_token_id.squeeze() - next_token_id_item = next_token_id_squeezed.item() - - # Append next token to all tokens - all_input_ids.append(next_token_id_item) - all_input_ids_tensor[input_length] = next_token_id_item - new_input_length = input_length + 1 - - # Generated token - next_token_logprob = logprobs[-1, next_token_id_item] - next_token_text = self.decode_token( - next_token_id_item, - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id_item, - next_token_text, - ) - - if stop: - # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :] - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - # Keep request in the batch - next_batch_keep_indices.append(i) - generated_text = None - - # Get sequence present - seq_present = present[:, start_index:end_index] - # Pad it for next iter attention - past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1)) - next_batch_past_key_values.append(past) - - next_batch_input_ids.append(next_token_id) - next_batch_position_ids.append(input_length) - # Cumulative sum - next_batch_cu_seqlens.append( - next_batch_cu_seqlens[-1] + new_input_length - ) - next_batch_input_lengths.append(new_input_length) - next_batch_all_input_ids.append(all_input_ids) - next_batch_all_input_ids_tensor.append(all_input_ids_tensor) - next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) - - # Prefill - if stopping_criteria.current_tokens == 1: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + logprobs.gather( - 1, all_input_ids_tensor[1:input_length].unsqueeze(1) - ).squeeze(1)[:-1].tolist() - prefill_token_ids = all_input_ids[:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = PrefillTokens( - prefill_token_ids, prefill_logprobs, prefill_texts - ) - else: - prefill_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - next_token_id_item, - next_token_logprob, - next_token_text, - next_token_id_item in self.all_special_ids, - generated_text, - ) - - generations.append(generation) - cumulative_length += input_length - - # We finished all generations in the batch; there is no next batch - if not next_batch_keep_indices: - return generations, None - - # If we finished at least one generation, we need to evict the indices of the generations that finished - # from the values of the next batch - if len(next_batch_keep_indices) != len(batch): - # Apply indices to requests, token_choosers and stopping_criterias that need to be cached - next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] - next_batch_next_token_choosers = [ - batch.next_token_choosers[i] for i in next_batch_keep_indices - ] - next_batch_stopping_criterias = [ - batch.stopping_criterias[i] for i in next_batch_keep_indices - ] - else: - next_batch_requests = batch.requests - next_batch_next_token_choosers = batch.next_token_choosers - next_batch_stopping_criterias = batch.stopping_criterias - - # Create final next batch tensors - next_batch_position_ids = torch.tensor( - next_batch_position_ids, dtype=torch.int32 - ) - next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32) - if len(next_batch_keep_indices) > 1: - next_batch_input_ids = torch.concat(next_batch_input_ids).squeeze(1) - next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1) - else: - next_batch_input_ids = next_batch_input_ids[0].view(1) - next_batch_past_key_values = next_batch_past_key_values[0] - - next_batch = FlashNeoXBatch( - batch_id=batch.batch_id, - requests=next_batch_requests, - input_ids=next_batch_input_ids, - position_ids=next_batch_position_ids, - cu_seqlens=next_batch_cu_seqlens, - max_seqlen=next_batch_max_seqlen, - past_key_values=next_batch_past_key_values, - input_lengths=next_batch_input_lengths, - all_input_ids=next_batch_all_input_ids, - all_input_ids_tensor=next_batch_all_input_ids_tensor, - next_token_choosers=next_batch_next_token_choosers, - stopping_criterias=next_batch_stopping_criterias, - ) - return generations, next_batch - class FlashNeoXSharded(FlashNeoX): def __init__( @@ -508,7 +70,7 @@ class FlashNeoXSharded(FlashNeoX): model.post_load_weights() self.model = model.eval().to(dtype) torch.distributed.barrier(group=self.process_group) - super(FlashNeoX, self).__init__( + super(FlashCausalLM, self).__init__( tokenizer=tokenizer, device=device, ) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py new file mode 100644 index 00000000..b33d0477 --- /dev/null +++ b/server/text_generation_server/models/flash_santacoder.py @@ -0,0 +1,138 @@ +import torch +import torch.distributed + +from accelerate import init_empty_weights +from opentelemetry import trace +from pathlib import Path +from transformers import AutoTokenizer, AutoConfig +from typing import Optional, List + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( + FlashSantacoderForCausalLM +) +from text_generation_server.utils import ( + weight_files, + download_weights, + weight_hub_files, + LocalEntryNotFoundError, +) + +tracer = trace.get_tracer(__name__) + + +class FlashSantacoder(FlashCausalLM): + def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + raise NotImplementedError("FlashSantacoder is only available on GPU") + + if quantize: + raise NotImplementedError("FlashSantacoder does not support quantization") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left" + ) + + config = AutoConfig.from_pretrained( + model_id, revision=revision, + trust_remote_code=True # Needed as the config is not part of Transformers + ) + + # We do not use from_pretrained as we modified the model internal module layout + try: + filenames = weight_files(model_id, revision, ".bin") + # Local files not found + except LocalEntryNotFoundError: + hub_files = weight_hub_files(model_id, revision, ".bin") + filenames = download_weights(hub_files, model_id, revision) + + with init_empty_weights(): + model = FlashSantacoderForCausalLM(config) + + self.load_weights( + model, + filenames, + ) + self.model = model.eval().to(device).to(dtype) + + super(FlashCausalLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @staticmethod + def load_weights( + model: FlashSantacoderForCausalLM, + filenames: List[Path], + ): + for filename in filenames: + state_dict = torch.load(filename, map_location="cpu") + for key, value in state_dict.items(): + layer_name = ".".join(key.split(".")[:4]) + + # Fused qkv + if "q_attn.weight" in key or "kv_attn.weight" in key: + final_key = layer_name + ".attn.weight" + elif "q_attn.bias" in key or "kv_attn.bias" in key: + final_key = layer_name + ".attn.bias" + + else: + final_key = key + + module_name, param_name = final_key.rsplit(".", 1) + module = model.get_submodule(module_name) + + try: + current_parameter_tensor = module._parameters[param_name] + except KeyError: + current_parameter_tensor = None + + if current_parameter_tensor is not None: + if "c_fc.weight" in key or "c_proj.weight" in key or "q_attn.weight" in key or "kv_attn.weight" in key: + # Tranpose as we use nn.Linear instead of Conv1D + value = value.T + + if current_parameter_tensor.device == torch.device("meta"): + # Init qkv + if "attn.weight" in final_key: + module._parameters[param_name] = value.new_empty( + (model.transformer.head_size * (model.transformer.num_heads + 2), value.shape[1]) + ) + elif "attn.bias" in final_key: + module._parameters[param_name] = value.new_empty( + (model.transformer.head_size * (model.transformer.num_heads + 2)) + ) + + # Copy to correct slice + if "q_attn.weight" in key: + module._parameters[param_name][: value.shape[0]] = value + elif "q_attn.bias" in key: + module._parameters[param_name][: value.shape[0]] = value + elif "kv_attn.weight" in key: + module._parameters[param_name][ + model.transformer.head_size * model.transformer.num_heads: + ] = value + elif "kv_attn.bias" in key: + module._parameters[param_name][ + model.transformer.head_size * model.transformer.num_heads: + ] = value + else: + if current_parameter_tensor.shape != value.shape: + raise ValueError( + f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" + ) + module._parameters[param_name] = value + else: + module._buffers[param_name] = value + + torch.cuda.empty_cache() + model.post_load_weights() + + def decode(self, generated_ids: List[int]) -> str: + # Do not skip special tokens as they are used for custom parsing rules of the generated text + return self.tokenizer.decode( + generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False + ) diff --git a/supported_models.json b/supported_models.json index 86d3bdfe..8e9dd9e0 100644 --- a/supported_models.json +++ b/supported_models.json @@ -1,9 +1,9 @@ [ + "bigcode/santacoder", "bigscience/bloom", "bigscience/bloomz", "EleutherAI/gpt-neox-20b", "google/flan-ul2", "google/flan-t5-xxl", - "OpenAssistant/oasst-sft-1-pythia-12b", - "olivierdehaene/optimized-santacoder" + "OpenAssistant/oasst-sft-1-pythia-12b" ]