diff --git a/README.md b/README.md index 02a94d2..9d3de4f 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ You may set the `TGICHAT_(USER|ASS|SYS)_(PRE|POST)` environment variables, to wr ```bash model=TheBloke/Llama-2-13B-Chat-fp16 # around 14GB Vram. volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run -image=docker.io/michaelf34/tgi:03-10-2023 # docker image by @michaelfeil +image=docker.io/michaelf34/tgi:05-11-2023 # docker image by @michaelfeil docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data $image --model-id $model --quantize ct2 ``` diff --git a/server/pyproject.toml b/server/pyproject.toml index cb0930a..73c853e 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -15,9 +15,11 @@ grpcio-status = "^1.51.1" grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" typer = "^0.6.1" -accelerate = { version = "^0.19.0", optional = true } -ctranslate2 = { version = "^3.20.0", optional = true } -bitsandbytes = { version = "^0.40.0", optional = true } +accelerate = { version = "^0.20.3", optional = true } +ctranslate2 = { version = "^3.23.0", optional = true } +bitsandbytes = { version = "^0.41.1", optional = true } +torch = { version = "^2.0.1" } +scipy = "^1.11.3" safetensors = "0.3.1" loguru = "^0.6.0" opentelemetry-api = "^1.15.0" @@ -26,8 +28,8 @@ opentelemetry-instrumentation-grpc = "^0.36b0" hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" tokenizers = "0.13.3" -huggingface-hub = "^0.14.1" -transformers = "4.29.2" +huggingface-hub = "^0.15.1" +transformers = "4.32.1" einops = "^0.6.1" texttable = { version = "^1.6.7", optional = true } datasets = { version = "^2.14.0", optional = true } diff --git a/server/requirements.txt b/server/requirements.txt index 98838b3..d1a4933 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,10 +1,10 @@ -accelerate==0.19.0 ; python_version >= "3.9" and python_version < "4.0" +accelerate==0.20.3 ; python_version >= "3.9" and python_version < "4.0" aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "4.0" aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0" async-timeout==4.0.2 ; python_version >= "3.9" and python_version < "4.0" attrs==23.1.0 ; python_version >= "3.9" and python_version < "4.0" backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0" -bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0" +bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "4.0" certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0" charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0" click==8.1.3 ; python_version >= "3.9" and python_version < "4.0" @@ -23,7 +23,7 @@ grpcio-reflection==1.56.0 ; python_version >= "3.9" and python_version < "4.0" grpcio-status==1.56.0 ; python_version >= "3.9" and python_version < "4.0" grpcio==1.56.0 ; python_version >= "3.9" and python_version < "4.0" hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0" -huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0" +huggingface-hub==0.15.1 ; python_version >= "3.9" and python_version < "4.0" idna==3.4 ; python_version >= "3.9" and python_version < "4.0" jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0" loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0" @@ -56,12 +56,13 @@ safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0" setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0" six==1.16.0 ; python_version >= "3.9" and python_version < "4.0" +scipy==1.11.3 ; python_version >= "3.9" and python_version < "4.0" sympy==1.12 ; python_version >= "3.9" and python_version < "4.0" texttable==1.6.7 ; python_version >= "3.9" and python_version < "4.0" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0" torch==2.0.1 ; python_version >= "3.9" and python_version < "4.0" tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0" -transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0" +transformers==4.32.1 ; python_version >= "3.9" and python_version < "4.0" typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0" typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "4.0" tzdata==2023.3 ; python_version >= "3.9" and python_version < "4.0" diff --git a/server/text_generation_server/models/ct2_causal_lm.py b/server/text_generation_server/models/ct2_causal_lm.py index 7f128a7..6ede9dd 100644 --- a/server/text_generation_server/models/ct2_causal_lm.py +++ b/server/text_generation_server/models/ct2_causal_lm.py @@ -23,12 +23,20 @@ import numpy as np import os import multiprocessing from pathlib import Path +from dataclasses import dataclass from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from opentelemetry import trace from transformers import ( AutoTokenizer, AutoConfig, + PreTrainedTokenizerBase +) +from text_generation_server.models.types import ( + Batch, + PrefillTokens, + Generation, + GeneratedText, ) from typing import Optional, Tuple, List, Type, Dict @@ -38,9 +46,10 @@ from text_generation_server.models.types import ( Generation, GeneratedText, ) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import Sampling -from text_generation_server.models.causal_lm import CausalLMBatch try: import ctranslate2 @@ -51,6 +60,434 @@ except ImportError: tracer = trace.get_tracer(__name__) +@dataclass +class CT2CausalLMBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] + + # Decoder values + input_ids: torch.Tensor + attention_mask: torch.Tensor + position_ids: torch.Tensor + past_key_values: Optional[List[Tuple]] + + # All tokens + all_input_ids: List[torch.Tensor] + + # Lengths of all generations present in the batch + input_lengths: List[int] + prefix_offsets: List[int] + read_offsets: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + + # Metadata used for padding + max_input_length: int + padding_right_offset: int + + # Maximum number of tokens this batch will grow to + max_tokens: int + + # Past metadata + keys_head_dim_last: bool = True + + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.id for r in self.requests], + size=len(self), + max_tokens=self.max_tokens, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "CT2CausalLMBatch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + prefix_offsets = [] + read_offsets = [] + requests_idx_mapping = {} + + # Parse batch + max_truncation = 0 + padding_right_offset = 0 + max_decode_tokens = 0 + for i, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = i + inputs.append(r.inputs) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) + + tokenized_inputs = tokenizer( + inputs, + return_tensors="pt", + padding=True, + return_token_type_ids=False, + truncation=True, + max_length=max_truncation, + ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append(input_len - 5) + read_offsets.append(input_len) + + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + + input_ids = tokenized_inputs["input_ids"] + # Allocate maximum attention_mask + attention_mask = input_ids.new_zeros( + (pb.size, max_input_length + padding_right_offset) + ) + # Copy tokenizer attention_mask into fully allocated attention_mask + attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] + + position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 + position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) + all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) + + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + all_input_ids=list(all_input_ids), + input_lengths=input_lengths.tolist(), + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + max_input_length=max_input_length.item(), + padding_right_offset=padding_right_offset, + max_tokens=max_tokens, + ) + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]) -> Optional["CT2CausalLMBatch"]: + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + requests = [] + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + max_input_length = 0 + + next_token_choosers = [] + stopping_criterias = [] + + total_remaining_decode_tokens = 0 + new_padding_right_offset = 0 + + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + keep_indices.append(idx) + + requests.append(self.requests[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) + all_input_ids.append(self.all_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max(max_input_length, request_input_length) + + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + remaining_decode_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) + total_remaining_decode_tokens += remaining_decode_tokens + new_padding_right_offset = max( + new_padding_right_offset, remaining_decode_tokens + ) + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + input_ids = self.input_ids[keep_indices] + position_ids = self.position_ids[keep_indices] + self.attention_mask = self.attention_mask[ + keep_indices, + -(self.padding_right_offset + max_input_length) : ( + self.attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, + ] + + # Ensure that past_key_values tensors can be updated in-place + if type(self.past_key_values[0]) == tuple: + self.past_key_values = [list(layer) for layer in self.past_key_values] + + # Update tensors in-place to allow incremental garbage collection + past_kv_length = max_input_length - 1 + for layer in self.past_key_values: + past_keys, past_values = layer + if len(past_keys.shape) == 3: + # Force past to be of dim [self_size, num_heads, ...] for easy indexing + past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) + past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) + if self.keys_head_dim_last: + layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] + else: + layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] + del past_keys + layer[1] = past_values[keep_indices, :, -past_kv_length:, :] + del past_values + + max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = input_ids + self.position_ids = position_ids + self.all_input_ids = all_input_ids + self.input_lengths = input_lengths + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.max_input_length = max_input_length + self.padding_right_offset = new_padding_right_offset + self.max_tokens = max_tokens + + return self + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches: List["CT2CausalLMBatch"]) -> "CT2CausalLMBatch": + # Used for padding + total_batch_size = 0 + max_input_length = 0 + padding_right_offset = 0 + for batch in batches: + total_batch_size += len(batch) + max_input_length = max(max_input_length, batch.max_input_length) + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) + + # Batch attributes + requests = [] + requests_idx_mapping = {} + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + max_tokens = 0 + + # Batch tensors + input_ids = None + attention_mask = None + position_ids = None + past_key_values = [] + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + # Slicing end index for this batch + end_index = start_index + len(batch) + + # We only concatenate batches that did at least one step + # if batch.past_key_values is None: + # raise ValueError("only concatenate prefilled batches") + + # Create empty tensor + # input_ids is always of shape [batch_size, 1] + # We do not need to pad it + if input_ids is None: + input_ids = batch.input_ids.new_empty((total_batch_size, 1)) + # Copy to correct indices + input_ids[start_index:end_index] = batch.input_ids + + # Create padded tensor + if attention_mask is None: + attention_mask = batch.attention_mask.new_zeros( + (total_batch_size, max_input_length + padding_right_offset), + ) + + # We need to slice the attention mask to remove padding from previous steps + # and to remove unused allocated space + left_offset = max_input_length - batch.max_input_length + batch_left_offset = ( + batch.attention_mask.shape[1] + - batch.max_input_length + - batch.padding_right_offset + ) + attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + ] = batch.attention_mask[ + :, + batch_left_offset : -batch.padding_right_offset, + ] + + # Create empty tensor + # position_ids is always of shape [batch_size, 1] + if position_ids is None: + position_ids = batch.position_ids.new_empty((total_batch_size, 1)) + position_ids[start_index:end_index] = batch.position_ids + + # Shenanigans to get dimensions because BLOOM outputs a past with a different shape + # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] + # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] + # And ensure that we can update tensors in-place + # if type(batch.past_key_values[0]) == tuple: + # batch.past_key_values = [ + # [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] + # for layer in batch.past_key_values + # ] + # elif len(batch.past_key_values[0][0].shape) == 3: + # for layer in batch.past_key_values: + # for k, t in enumerate(layer): + # layer[k] = t.view(len(batch), -1, *t.shape[-2:]) + + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length - batch.max_input_length + ) * len(batch) + + start_index = end_index + + # first_past_kvs = batches[0].past_key_values + # _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape + + # padded_past_values_shape = ( + # total_batch_size, + # num_heads, + # max_input_length - 1, + # head_dim, + # ) + + # if batches[0].keys_head_dim_last: + # padded_past_keys_shape = padded_past_values_shape + # else: + # # seq_length is last for BLOOM + # padded_past_keys_shape = ( + # total_batch_size, + # num_heads, + # head_dim, + # max_input_length - 1, + # ) + + # Iterate over attention layers + # Concatenate past key values layer by layer to allow incremental garbage collection + # for j in range(len(first_past_kvs)): + # padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) + # start_index = 0 + # for batch in batches: + # past_keys = batch.past_key_values[j][0] + # # Clear reference to the original tensor + # batch.past_key_values[j][0] = None + + # # Slicing end index for this batch + # end_index = start_index + len(batch) + # # We slice the keys to remove the padding from previous batches + # past_seq_len = batch.max_input_length - 1 + # if batch.keys_head_dim_last: + # padded_past_keys[ + # start_index:end_index, :, -past_seq_len:, : + # ] = past_keys[:, :, -past_seq_len:, :] + # else: + # # BLOOM case + # padded_past_keys[ + # start_index:end_index, :, :, -past_seq_len: + # ] = past_keys[:, :, :, -past_seq_len:] + # del past_keys + + # start_index = end_index + + # padded_past_values = first_past_kvs[j][1].new_zeros( + # padded_past_values_shape + # ) + # start_index = 0 + # for batch in batches: + # past_values = batch.past_key_values[j][1] + # # Clear reference to the original tensor + # batch.past_key_values[j][1] = None + + # # Slicing end index for this batch + # end_index = start_index + len(batch) + # # We slice the past values to remove the padding from previous batches + # past_seq_len = batch.max_input_length - 1 + # padded_past_values[ + # start_index:end_index, :, -past_seq_len:, : + # ] = past_values[:, :, -past_seq_len:, :] + # del past_values + + # # Update values + # start_index = end_index + + # past_key_values.append([padded_past_keys, padded_past_values]) + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + all_input_ids=all_input_ids, + input_lengths=input_lengths, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + max_input_length=max_input_length, + padding_right_offset=padding_right_offset, + keys_head_dim_last=batches[0].keys_head_dim_last, + max_tokens=max_tokens, + ) + + def __len__(self): + return len(self.requests) + class CT2CausalLM(Model): def __init__( self, @@ -176,8 +613,8 @@ class CT2CausalLM(Model): ) @property - def batch_type(self) -> Type[CausalLMBatch]: - return CausalLMBatch + def batch_type(self) -> Type[CT2CausalLMBatch]: + return CT2CausalLMBatch def decode(self, generated_ids: List[int]) -> str: return self.tokenizer.decode( @@ -221,8 +658,8 @@ class CT2CausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( - self, batch: CausalLMBatch - ) -> Tuple[List[Generation], Optional[CausalLMBatch]]: + self, batch: CT2CausalLMBatch + ) -> Tuple[List[Generation], Optional[CT2CausalLMBatch]]: logits, past = self.forward_ct2(batch.all_input_ids, batch.input_lengths) # Results diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index b628585..c5b0e28 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -185,8 +185,8 @@ class FlashLlamaAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.load( - prefix=f"{prefix}.rotary_emb", weights=weights + self.rotary_emb = PositionRotaryEmbedding.static( + dim=self.head_size, device=weights.device, base=10000.0, ) self.softmax_scale = self.head_size**-0.5