diff --git a/Dockerfile b/Dockerfile index ebe79609..9b6ef835 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,6 +28,7 @@ ENV LANG=C.UTF-8 \ MODEL_NAME=bigscience/bloom \ QUANTIZE=false \ NUM_GPUS=8 \ + SAFETENSORS_FAST_GPU=1 \ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ NCCL_ASYNC_ERROR_HANDLING=1 \ CUDA_HOME=/usr/local/cuda \ @@ -55,12 +56,6 @@ RUN cd server && make install-torch # Install specific version of transformers RUN cd server && make install-transformers -# Install specific version of safetensors -# FIXME: This is a temporary fix while we wait for a new release -RUN curl https://sh.rustup.rs -sSf | bash -s -- -y -ENV PATH="/root/.cargo/bin:${PATH}" -RUN cd server && make install-safetensors - # Install server COPY proto proto COPY server server diff --git a/README.md b/README.md index eadf5e50..cd9e2176 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,8 @@ -A Rust and gRPC server for text generation inference. +A Rust and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co) +to power Bloom, BloomZ and MT0-XXL api-inference widgets. ## Features @@ -15,11 +16,11 @@ A Rust and gRPC server for text generation inference. - [Safetensors](https://github.com/huggingface/safetensors) weight loading - 45ms per token generation for BLOOM with 8xA100 80GB -## Officialy supported models +## Officially supported models -- BLOOM -- BLOOMZ -- BLOOM-560m +- [BLOOM](https://huggingface.co/bigscience/bloom) +- [BLOOMZ](https://huggingface.co/bigscience/bloomz) +- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) Other models are supported on a best effort basis using: @@ -90,5 +91,4 @@ make router-dev ## TODO: - [ ] Add tests for the `server/model` logic -- [ ] Backport custom CUDA kernels to Transformers -- [ ] Install safetensors with pip \ No newline at end of file +- [ ] Backport custom CUDA kernels to Transformers \ No newline at end of file diff --git a/launcher/src/main.rs b/launcher/src/main.rs index b42ed0c5..016a28eb 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -295,6 +295,10 @@ fn shard_manager( "MASTER_PORT".parse().unwrap(), master_port.to_string().parse().unwrap(), ), + ( + "SAFETENSORS_FAST_GPU".parse().unwrap(), + "1".to_string().parse().unwrap(), + ), ]; // If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard diff --git a/server/Makefile b/server/Makefile index 99764028..39a98b65 100644 --- a/server/Makefile +++ b/server/Makefile @@ -16,24 +16,13 @@ install-transformers: mv transformers-7302a24535e8dc5637ea5b4e4572fc971d404098 transformers cd transformers && python setup.py install -install-safetensors: - # Install specific version of safetensors - pip install setuptools_rust - rm safetensors || true - rm safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5 || true - curl -L -O https://github.com/huggingface/safetensors/archive/634deccbcbad5eaf417935281f8b3be7ebca69c5.zip - unzip 634deccbcbad5eaf417935281f8b3be7ebca69c5.zip - rm 634deccbcbad5eaf417935281f8b3be7ebca69c5.zip - mv safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5 safetensors - cd safetensors/bindings/python && python setup.py develop - install-torch: # Install specific version of torch pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir -install: gen-server install-torch install-transformers install-safetensors +install: gen-server install-torch install-transformers pip install pip --upgrade pip install -e . --no-cache-dir run-dev: - python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded \ No newline at end of file + SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded \ No newline at end of file diff --git a/server/poetry.lock b/server/poetry.lock index 3c92903e..ebd64ea9 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -145,6 +145,18 @@ category = "main" optional = false python-versions = ">=3.6" +[[package]] +name = "safetensors" +version = "0.2.4" +description = "Fast and Safe Tensor serialization" +category = "main" +optional = false +python-versions = "*" + +[package.extras] +dev = ["black (==22.3)", "flake8 (>=3.8.3)", "huggingface-hub", "isort (>=5.5.4)", "numpy", "pytest", "setuptools-rust"] +testing = ["black (==22.3)", "flake8 (>=3.8.3)", "huggingface-hub", "isort (>=5.5.4)", "numpy", "pytest", "setuptools-rust"] + [[package]] name = "setuptools" version = "65.5.0" @@ -208,7 +220,7 @@ bnb = ["bitsandbytes"] [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "224b1e379d6105fe911bff4563946a90dfa6ff5918cf2e7be59f8d4f7c5cd7cf" +content-hash = "3266187ef14fe8f9e29b3b6530d07781ea952aa670c0fe0de34be43efa231a67" [metadata.files] accelerate = [ @@ -459,6 +471,39 @@ PyYAML = [ {file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"}, {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, ] +safetensors = [ + {file = "safetensors-0.2.4-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:79c4a7610d7699c64d8531c43f758ded4990ebaa7b0887c2078640e6de44e726"}, + {file = "safetensors-0.2.4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ef425a4ddd29612fe733a6eeca6ad8f3ee3939f530a032114974aac4c4667b89"}, + {file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77758f8ba4de6e20bf394dd964854a926dee2efee82eaa95e6c0893e2a7d960c"}, + {file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fb956e9090cce515649f00b491b5ddc0f9c3d989139016a8d69f9dcf57e8d3d9"}, + {file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3e31b02d27249bd519f05ec9d189097c59fc6851c59daa1a86ef347659e33ac3"}, + {file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c2fead03a1497042efea4358574f3d7acf501b0c82e54d605f393f2b4e2aafe"}, + {file = "safetensors-0.2.4-cp310-cp310-win32.whl", hash = "sha256:dce6ed3c7d13aafa574737eb3309c928adcb6781e879b41f0861be83b439cf3e"}, + {file = "safetensors-0.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:1dfe727325a1342767c6725dc2cc1f00463eb40a1f5df37c338d8e03957e27ce"}, + {file = "safetensors-0.2.4-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:c066bc7b90a582a01ec468fef61a7581b5c726bf12c50491cb6ea5db215ea5e0"}, + {file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca6ed53dad5d7d0e67eb676528ff2ad345cac3a34010e4dc1e3736972de294a5"}, + {file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ada03b44acbb036cfabe7066a8df4ad9b1ac05bb585a6b6c0f285f08e016381d"}, + {file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:58a0902708daa7ec2b2293b46e85df61f4fa359ddfe648e7ac025a79e6f59627"}, + {file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e0a4e38f7cbb4bfc513588e52f349b906c941e74fbbe192f2b19fc34221d448"}, + {file = "safetensors-0.2.4-cp37-cp37m-win32.whl", hash = "sha256:4f8695b77dd847203258f035f8468f8b701c90621cb6b457e109f8d89c27f16c"}, + {file = "safetensors-0.2.4-cp37-cp37m-win_amd64.whl", hash = "sha256:16b08f33c753c7da64b3999beea7c30d58204a0820961e33881d05a331e3f5c0"}, + {file = "safetensors-0.2.4-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:a381606804f23db9eede51135f5fbd1f75dda02100415ee150fd39eb1cd6be4c"}, + {file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7aceae84d0c7233d83923029aaf8d184848561e0211ec98c5317327b3db025d6"}, + {file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:da48fc929485cbd9ee22621e388764a7cef27b0205e73aee2ad75aadd7d67662"}, + {file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2619b88f934c4de6b59de90c9dc00eae2d0e30f254a1daebd6eb232ac1f9a7a7"}, + {file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1f78b987ae1f6b71da8ea110164e4cab2ee31b53835d2a66279df89c5d73f0e"}, + {file = "safetensors-0.2.4-cp38-cp38-win32.whl", hash = "sha256:34b3e60b5130fb0fe07114705e51d30aa2c7eae4c1d1e77d6f260fa4ade70ede"}, + {file = "safetensors-0.2.4-cp38-cp38-win_amd64.whl", hash = "sha256:debaa4fa98a7af44ba6dcb6945efee77b8480284c2cb05918ab97cf511c40826"}, + {file = "safetensors-0.2.4-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:90baaafc0c872a736124b341db54b0bdd61765cbf3a61418371066a37905b18d"}, + {file = "safetensors-0.2.4-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b4bf7e23191d6a3ff00de141512869fc776e8ff159c872cb44af018cb04d45eb"}, + {file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf11a3aba8796e548ceb0a65f34dcd334dcf0c4c891dccabe18a8b53918ae8ab"}, + {file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:95c31935ea71d63a38c546654136d7f0dbf1e7aeb6564dbc2201bc1fe9b34e4c"}, + {file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ef31776e2e081d6f075408eed34a0fbd524cbd19e50268bef02c238b209213b7"}, + {file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06bb1d68148f6d6934352124d8cbfcf0db092f969db7187e348bd5cbf183db5"}, + {file = "safetensors-0.2.4-cp39-cp39-win32.whl", hash = "sha256:5d546152b9a5bd58eae97c2ddefba394404d37ddedec305f7639c9b6054513e5"}, + {file = "safetensors-0.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:553ecfd895d379c1e03a7c9241f7343b3af66573436969ed7eb95df81dfbe9af"}, + {file = "safetensors-0.2.4.tar.gz", hash = "sha256:35c0719a898f1f1292464f4cd9370bb6c2698032f1db4d677489f078b66b5a75"}, +] setuptools = [ {file = "setuptools-65.5.0-py3-none-any.whl", hash = "sha256:f62ea9da9ed6289bfe868cd6845968a2c854d1427f8548d52cae02a42b4f0356"}, {file = "setuptools-65.5.0.tar.gz", hash = "sha256:512e5536220e38146176efb833d4a62aa726b7bbff82cfbc8ba9eaa3996e0b17"}, diff --git a/server/pyproject.toml b/server/pyproject.toml index e2ba98a7..cdf89869 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -15,6 +15,7 @@ typer = "^0.6.1" grpcio-reflection = "^1.49.1" accelerate = "^0.12.0" bitsandbytes = "^0.35.1" +safetensors = "^0.2.4" [tool.poetry.extras] bnb = ["bitsandbytes"] diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index ade22d4e..bf44115c 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -9,17 +9,13 @@ __all__ = ["Model", "BLOOMSharded", "CausalLM", "Seq2SeqLM"] def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: if model_name.startswith("bigscience/bloom"): if sharded: - return BLOOMSharded(model_name, quantize) + return BLOOMSharded(model_name, quantize=quantize) else: - if quantize: - raise ValueError("quantization is not supported for non-sharded BLOOM") - return CausalLM(model_name) + return CausalLM(model_name, quantize=quantize) else: if sharded: raise ValueError("sharded is not supported for AutoModel") - if quantize: - raise ValueError("quantize is not supported for AutoModel") try: - return CausalLM(model_name) + return CausalLM(model_name, quantize=quantize) except Exception as e: - return Seq2SeqLM(model_name) + return Seq2SeqLM(model_name, quantize=quantize) diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index b07537bd..2ba36b1a 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -2,7 +2,7 @@ import torch from dataclasses import dataclass from transformers import AutoTokenizer, AutoModelForCausalLM -from typing import Optional, Tuple, List, Dict, Type +from typing import Optional, Tuple, List, Type from text_generation.models import Model from text_generation.models.types import GeneratedText @@ -14,11 +14,23 @@ from text_generation.utils import NextTokenChooser, StoppingCriteria class CausalLMBatch: batch_id: int requests: List[generate_pb2.Request] - all_input_lengths: List[int] - input_ids: Dict[str, torch.Tensor] + + # Decoder values + input_ids: torch.Tensor + attention_mask: torch.Tensor + past_key_values: Optional[List[Tuple]] + + # All tokens all_input_ids: List[torch.Tensor] + + # Lengths of all generations present in the batch + input_lengths: List[int] + + # Generation helpers next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] + + # Metadata used for padding size: int max_sequence_length: int @@ -36,12 +48,12 @@ class CausalLMBatch: inputs = [] next_token_choosers = [] stopping_criterias = [] - all_input_lengths = [] + input_lengths = [] # Parse batch for r in pb.requests: inputs.append(r.inputs) - all_input_lengths.append(r.input_length) + input_lengths.append(r.input_length) next_token_choosers.append( NextTokenChooser( temperature=r.parameters.temperature, @@ -56,21 +68,23 @@ class CausalLMBatch: ) ) - input_ids = tokenizer( + tokenized_inputs = tokenizer( inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 ).to(device) - all_input_ids = input_ids["input_ids"].unsqueeze(-1) + all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) return cls( batch_id=pb.id, requests=pb.requests, - all_input_lengths=all_input_lengths, - input_ids=input_ids, + input_ids=tokenized_inputs["input_ids"], + attention_mask=tokenized_inputs["attention_mask"], + past_key_values=None, all_input_ids=all_input_ids, + input_lengths=input_lengths, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=pb.size, - max_sequence_length=max(all_input_lengths), + max_sequence_length=max(input_lengths), ) @classmethod @@ -80,19 +94,23 @@ class CausalLMBatch: max_sequence_length = max(batch.max_sequence_length for batch in batches) # Batch attributes - input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []} requests = [] - all_input_lengths = [] + input_lengths = [] all_input_ids = [] next_token_choosers = [] stopping_criterias = [] + # Batch tensors + input_ids = None + attention_mask = None + past_key_values = [] + # Used for slicing correctly inside the tensors # Equivalent to a cumsum on batch sizes start_index = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) - all_input_lengths.extend(batch.all_input_lengths) + input_lengths.extend(batch.input_lengths) all_input_ids.extend(batch.all_input_ids) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -101,32 +119,35 @@ class CausalLMBatch: end_index = start_index + batch.size # We only concatenate batches that did at least one step - if batch.input_ids["input_ids"].shape[1] > 1: + if batch.input_ids.shape[1] > 1: raise ValueError("Batch input_ids should be of shape (batch_size, 1)") - # Initialize tensors - if i == 0: - input_ids["input_ids"] = torch.empty( - (total_batch_size, 1), - dtype=batch.input_ids["input_ids"].dtype, - device=batch.input_ids["input_ids"].device, - ) - input_ids["attention_mask"] = torch.zeros( - (total_batch_size, max_sequence_length), - dtype=batch.input_ids["attention_mask"].dtype, - device=batch.input_ids["attention_mask"].device, - ) - - # input_ids["input_ids"] is always of shape [batch_size, 1] + # Create empty tensor + # input_ids is always of shape [batch_size, 1] # We do not need to pad it - input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"] + if input_ids is None: + input_ids = torch.empty( + (total_batch_size, 1), + dtype=batch.input_ids.dtype, + device=batch.input_ids.device, + ) + # Copy to correct indices + input_ids[start_index:end_index] = batch.input_ids + + # Create padded tensor + if attention_mask is None: + attention_mask = torch.zeros( + (total_batch_size, max_sequence_length), + dtype=batch.attention_mask.dtype, + device=batch.attention_mask.device, + ) # We need to slice the attention mask to remove padding from previous steps - input_ids["attention_mask"][ + attention_mask[ start_index:end_index, -batch.max_sequence_length : - ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :] + ] = batch.attention_mask[:, -batch.max_sequence_length :] - for j, past in enumerate(batch.input_ids["past_key_values"]): + for j, past in enumerate(batch.past_key_values): # Shenanigans to get dimensions because BLOOM outputs a past with a different shape # BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...] head_dim, padded_sequence_length = past[0].shape[-2:] @@ -137,8 +158,8 @@ class CausalLMBatch: ) # This will run only once per layer - if j == len(input_ids["past_key_values"]): - input_ids["past_key_values"].append([]) + if j == len(past_key_values): + past_key_values.append([]) # Decoder past for k, t in enumerate(past): @@ -172,21 +193,21 @@ class CausalLMBatch: # Initialize tensors # This will run only once per layer and per past tensor - if k == len(input_ids["past_key_values"][j]): - input_ids["past_key_values"][j].append( + if k == len(past_key_values[j]): + past_key_values[j].append( torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device) ) # We slice the past keys and values to remove the padding from previous batches if not head_dim_last: - input_ids["past_key_values"][j][k][ + past_key_values[j][k][ start_index:end_index, :, :, -(batch.max_sequence_length - 1) :, ] = t[:, :, :, -(batch.max_sequence_length - 1) :] else: - input_ids["past_key_values"][j][k][ + past_key_values[j][k][ start_index:end_index, :, -(batch.max_sequence_length - 1) :, @@ -198,9 +219,11 @@ class CausalLMBatch: return cls( batch_id=batches[0].batch_id, requests=requests, - all_input_lengths=all_input_lengths, input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, all_input_ids=all_input_ids, + input_lengths=input_lengths, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=total_batch_size, @@ -209,7 +232,7 @@ class CausalLMBatch: class CausalLM(Model): - def __init__(self, model_name: str): + def __init__(self, model_name: str, quantize=False): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 @@ -223,6 +246,7 @@ class CausalLM(Model): model_name, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, + load_in_8bit=quantize, ).eval() super(CausalLM, self).__init__( @@ -255,16 +279,19 @@ class CausalLM(Model): torch.no_grad if self.device.type == "cpu" else torch.inference_mode ) with context_manager(): - logits, past = self.forward(**batch.input_ids) + logits, past = self.forward( + batch.input_ids, batch.attention_mask, batch.past_key_values + ) # List of indices to cache next_batch_keep_indices = [] - # New input_ids for next forward + # New values for next forward + next_batch_input_lengths = [] next_batch_input_ids = [] next_batch_all_input_ids = [] - next_all_input_lengths = [] + # Metadata next_batch_size = 0 next_batch_max_sequence_length = 0 @@ -274,7 +301,7 @@ class CausalLM(Model): # Zipped iterator iterator = zip( batch.requests, - batch.all_input_lengths, + batch.input_lengths, logits, batch.next_token_choosers, batch.stopping_criterias, @@ -313,7 +340,7 @@ class CausalLM(Model): next_batch_all_input_ids.append(all_tokens) next_batch_size += 1 new_input_length = input_length + 1 - next_all_input_lengths.append(new_input_length) + next_batch_input_lengths.append(new_input_length) next_batch_max_sequence_length = max( next_batch_max_sequence_length, new_input_length ) @@ -322,15 +349,14 @@ class CausalLM(Model): if not next_batch_keep_indices: return generated_texts, None - # If we finished at least one generation - next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)} + next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0) + # If we finished at least one generation, we need to evict the indices of the generations that finished + # from the values of the next batch if generated_texts: # Apply indices to attention mask, past key values and other items that need to be cached - next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][ - next_batch_keep_indices - ] + next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] # Force past to be of dim [batch_size, num_heads, ...] for easy indexing - next_batch_input_ids["past_key_values"] = [ + next_batch_past_key_values = [ [ t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices] for t in layer @@ -345,16 +371,16 @@ class CausalLM(Model): batch.stopping_criterias[i] for i in next_batch_keep_indices ] else: - next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"] - next_batch_input_ids["past_key_values"] = past + next_batch_attention_mask = batch.attention_mask + next_batch_past_key_values = past next_batch_requests = batch.requests next_batch_next_token_choosers = batch.next_token_choosers next_batch_stopping_criterias = batch.stopping_criterias # Update attention_mask with padding as we added a new token to input_ids - next_batch_input_ids["attention_mask"] = torch.cat( + next_batch_attention_mask = torch.cat( [ - next_batch_input_ids["attention_mask"], + next_batch_attention_mask, torch.ones((next_batch_size, 1)).to(self.device), ], dim=1, @@ -363,9 +389,11 @@ class CausalLM(Model): next_batch = CausalLMBatch( batch_id=batch.batch_id, requests=next_batch_requests, - all_input_lengths=next_all_input_lengths, input_ids=next_batch_input_ids, + attention_mask=next_batch_attention_mask, + past_key_values=next_batch_past_key_values, all_input_ids=next_batch_all_input_ids, + input_lengths=next_batch_input_lengths, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, size=next_batch_size, diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 0607b3d5..cb1291ab 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -15,26 +15,33 @@ class Seq2SeqLMBatch: batch_id: int requests: List[generate_pb2.Request] + # Encoder values input_ids: torch.Tensor attention_mask: torch.Tensor + # Decoder values decoder_input_ids: torch.Tensor decoder_attention_mask: Optional[torch.Tensor] encoder_last_hidden_state: Optional[torch.Tensor] + # Seq2SeqLM keeps track of both encoder and decoder attention keys and values past_key_values: Optional[List[Tuple]] + # Lengths of all generations present in the batch input_lengths: List[int] decoder_input_lengths: List[int] + # Generation helpers next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] + # Metadata used for padding size: int max_input_length: int max_decoder_input_length: int def to_pb(self): + """Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf""" return generate_pb2.Batch( id=self.batch_id, requests=self.requests, @@ -45,6 +52,7 @@ class Seq2SeqLMBatch: def from_pb( cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device ) -> "Seq2SeqLMBatch": + """Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" inputs = [] next_token_choosers = [] stopping_criterias = [] @@ -57,6 +65,7 @@ class Seq2SeqLMBatch: for r in pb.requests: inputs.append(r.inputs) input_lengths.append(r.input_length) + # Decoder sequence only contains the bos_token decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_lengths.append(1) next_token_choosers.append( @@ -73,9 +82,11 @@ class Seq2SeqLMBatch: ) ) + # Tokenize batch tokenized_inputs = tokenizer( inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 ).to(device) + # Convert decoder_input_ids to torch tensor of size [batch_size, 1] decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1) return cls( @@ -98,6 +109,8 @@ class Seq2SeqLMBatch: @classmethod def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": + """Concatenate multiple batches together by padding internal torch tensors""" + # Used for padding total_batch_size = sum(batch.size for batch in batches) max_input_length = max(batch.max_input_length for batch in batches) @@ -112,6 +125,7 @@ class Seq2SeqLMBatch: next_token_choosers = [] stopping_criterias = [] + # Batch tensors input_ids = None attention_mask = None decoder_input_ids = None @@ -122,7 +136,9 @@ class Seq2SeqLMBatch: # Used for slicing correctly inside the tensors # Equivalent to a cumsum on batch sizes start_index = 0 + for i, batch in enumerate(batches): + # Extend all list attributes requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths) @@ -136,51 +152,62 @@ class Seq2SeqLMBatch: if batch.encoder_last_hidden_state is None: raise ValueError("Batch encoder_last_hidden_state cannot be None") + # Create padded tensor if input_ids is None: input_ids = torch.zeros( (total_batch_size, max_input_length), dtype=batch.input_ids.dtype, device=batch.input_ids.device, ) + # Copy to correct indices input_ids[ start_index:end_index, -batch.max_input_length : ] = batch.input_ids[:, -batch.max_input_length :] + # Create padded tensor if attention_mask is None: attention_mask = torch.zeros( (total_batch_size, max_input_length), dtype=batch.attention_mask.dtype, device=batch.attention_mask.device, ) + # Copy to correct indices attention_mask[ start_index:end_index, -batch.max_input_length : ] = batch.attention_mask[:, -batch.max_input_length :] + # Create padded tensor if decoder_input_ids is None: decoder_input_ids = torch.zeros( (total_batch_size, max_decoder_input_length), dtype=batch.decoder_input_ids.dtype, device=batch.decoder_input_ids.device, ) + # Copy to correct indices decoder_input_ids[ start_index:end_index, -batch.max_decoder_input_length : ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :] + # Create padded tensor if decoder_attention_mask is None: decoder_attention_mask = torch.zeros( (total_batch_size, max_decoder_input_length), - dtype=batch.attention_mask.dtype, - device=batch.attention_mask.device, + dtype=batch.attention_mask.dtype, # As decoder_attention_mask might not exist, + device=batch.attention_mask.device, # we use `batch.attention_maks` for device here ) + # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated + # this batch. All generations are of length `batch.max_decoder_input_length`. if batch.decoder_attention_mask is None: decoder_attention_mask[ start_index:end_index, -batch.max_decoder_input_length : ] = 1 + # If it exists, we need to index else: decoder_attention_mask[ start_index:end_index, -batch.max_decoder_input_length : ] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :] + # Create padded tensor if encoder_last_hidden_state is None: encoder_last_hidden_state = torch.zeros( ( @@ -192,10 +219,12 @@ class Seq2SeqLMBatch: device=batch.encoder_last_hidden_state.device, ) + # Copy to correct indices encoder_last_hidden_state[ start_index:end_index, -batch.max_decoder_input_length :, : ] = batch.encoder_last_hidden_state[:, -batch.max_decoder_input_length :, :] + # Iterate over attention layers for j, past in enumerate(batch.past_key_values): _, num_heads, _, head_dim = past[0].shape @@ -271,7 +300,7 @@ class Seq2SeqLMBatch: class Seq2SeqLM(Model): - def __init__(self, model_name: str): + def __init__(self, model_name: str, quantize=False): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 @@ -283,6 +312,7 @@ class Seq2SeqLM(Model): model_name, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, + load_in_8bit=quantize, ).eval() tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") tokenizer.bos_token_id = self.model.config.decoder_start_token_id @@ -314,14 +344,17 @@ class Seq2SeqLM(Model): if past_key_values is not None: decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1) + # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]` + # internally... + if encoder_last_hidden_state is not None: + encoder_last_hidden_state = [encoder_last_hidden_state] + outputs = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, - encoder_outputs=[encoder_last_hidden_state] - if encoder_last_hidden_state is not None - else None, + encoder_outputs=encoder_last_hidden_state, past_key_values=past_key_values, use_cache=True, ) @@ -351,11 +384,12 @@ class Seq2SeqLM(Model): # List of indices to cache next_batch_keep_indices = [] - # New input_ids for next forward + # New values for next forward next_batch_input_lengths = [] next_batch_decoder_input_ids = [] next_batch_decoder_input_lengths = [] + # Metadata next_batch_size = 0 next_batch_max_input_length = 0 next_batch_max_decoder_input_length = 0 @@ -395,7 +429,7 @@ class Seq2SeqLM(Model): # Evaluate stopping criteria if stopping_criteria(decoder_tokens): - # Decode all tokens + # Decode tokens output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True) # Add to the list of finished generations with the original request generated_texts.append( @@ -420,9 +454,11 @@ class Seq2SeqLM(Model): if not next_batch_keep_indices: return generated_texts, None - # If we finished at least one generation next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids) + # If we finished at least one generation, we need to evict the indices of the generations that finished + # from the values of the next batch if generated_texts: + # Apply indices to attention mask, past key values and other items that need to be cached next_batch_input_ids = batch.input_ids[next_batch_keep_indices] next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] @@ -458,7 +494,7 @@ class Seq2SeqLM(Model): next_batch_next_token_choosers = batch.next_token_choosers next_batch_stopping_criterias = batch.stopping_criterias - # Update attention_mask with padding as we added a new token to input_ids + # Update decoder_attention_mask with padding as we added a new token to input_ids if next_batch_decoder_attention_mask is not None: next_batch_decoder_attention_mask = torch.cat( [