diff --git a/Dockerfile b/Dockerfile index 587ab9b..1bb7f7c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -188,7 +188,7 @@ COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ pip install -r requirements.txt && \ - pip install ".[bnb, accelerate, quantize]" --no-cache-dir + pip install ".[bnb, accelerate, quantize, ct2]" --no-cache-dir # Install benchmarker COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 88acb2f..0a77f24 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -25,6 +25,7 @@ enum Quantization { BitsandbytesNF4, BitsandbytesFP4, Gptq, + Ct2, } impl std::fmt::Display for Quantization { @@ -43,6 +44,9 @@ impl std::fmt::Display for Quantization { Quantization::Gptq => { write!(f, "gptq") } + Quantization::Ct2 => { + write!(f, "ct2") + } } } } @@ -104,7 +108,7 @@ struct Args { num_shard: Option, /// Whether you want the model to be quantized. This will use `bitsandbytes` for - /// quantization on the fly, or `gptq`. 4bit quantization is available through + /// quantization on the fly, `bnb` or `gptq`, or `ctranslate2`. 4bit quantization is available through /// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options. #[clap(long, env, value_enum)] quantize: Option, diff --git a/server/Makefile b/server/Makefile index a4ce6d8..8338884 100644 --- a/server/Makefile +++ b/server/Makefile @@ -21,7 +21,7 @@ install-torch: install: gen-server install-torch pip install pip --upgrade pip install -r requirements.txt - pip install -e ".[bnb, accelerate]" + pip install -e ".[bnb, accelerate, quantize, ct2]" run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded diff --git a/server/pyproject.toml b/server/pyproject.toml index 9bca55b..cb0930a 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -16,6 +16,7 @@ grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" typer = "^0.6.1" accelerate = { version = "^0.19.0", optional = true } +ctranslate2 = { version = "^3.20.0", optional = true } bitsandbytes = { version = "^0.40.0", optional = true } safetensors = "0.3.1" loguru = "^0.6.0" @@ -35,6 +36,7 @@ datasets = { version = "^2.14.0", optional = true } accelerate = ["accelerate"] bnb = ["bitsandbytes"] quantize = ["texttable", "datasets", "accelerate"] +ct2 = ["ctranslate2"] [tool.poetry.group.dev.dependencies] grpcio-tools = "^1.51.1" diff --git a/server/tests/models/test_ct2.py b/server/tests/models/test_ct2.py new file mode 100644 index 0000000..ab4b058 --- /dev/null +++ b/server/tests/models/test_ct2.py @@ -0,0 +1,99 @@ +import pytest +import torch +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.models.ct2_causal_lm import CT2CausalLM + + +@pytest.fixture(scope="session") +def default_santacoder(): + return CT2CausalLM("bigcode/gpt_bigcode-santacoder", dtype=torch.float16) + + +@pytest.fixture +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="def", + prefill_logprobs=True, + truncate=100, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="defworld", + prefill_logprobs=True, + truncate=100, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_fim_pb_batch(default_fim_pb_request): + return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1) + + +def test_ct2santa_generate_token_completion(default_santacoder, default_pb_batch): + batch = CausalLMBatch.from_pb( + default_pb_batch, + default_santacoder.tokenizer, + default_santacoder.dtype, + default_santacoder.device, + ) + next_batch = batch + + for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): + generations, next_batch = default_santacoder.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch = default_santacoder.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert generations[0].generated_text.text in (" test_get_all_users_with_", ' test_get_all_users(client):') + assert generations[0].request_id == batch.requests[0].id + assert ( + generations[0].generated_text.generated_tokens + == batch.stopping_criterias[0].max_new_tokens + ) + + +def test_fim_ct2santacoder_generate_token_completion( + default_santacoder, default_fim_pb_batch +): + batch = CausalLMBatch.from_pb( + default_fim_pb_batch, + default_santacoder.tokenizer, + default_santacoder.dtype, + default_santacoder.device, + ) + next_batch = batch + + for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): + generations, next_batch = default_santacoder.generate_token(next_batch) + assert len(generations) == len(next_batch) + + generations, next_batch = default_santacoder.generate_token(next_batch) + assert next_batch is None + + assert len(generations) == 1 + assert ( + generations[0].generated_text.text + == """ineProperty(exports, "__esModule", { value""" + ) + assert generations[0].request_id == batch.requests[0].id + assert ( + generations[0].generated_text.generated_tokens + == batch.stopping_criterias[0].max_new_tokens + ) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 9b378ce..0a96698 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -16,6 +16,7 @@ class Quantization(str, Enum): bitsandbytes_nf4 = "bitsandbytes-nf4" bitsandbytes_fp4 = "bitsandbytes-fp4" gptq = "gptq" + ct2 = "ct2" class Dtype(str, Enum): @@ -73,7 +74,7 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = None if dtype is None else dtype.value - if dtype is not None and quantize is not None: + if dtype is not None and quantize is not None and quantize != Quantization.ct2: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) @@ -90,6 +91,7 @@ def download_weights( auto_convert: bool = True, logger_level: str = "INFO", json_output: bool = False, + trust_remote_code: bool = False ): # Remove default handler logger.remove() @@ -169,6 +171,7 @@ def download_weights( config = AutoConfig.from_pretrained( model_id, revision=revision, + trust_remote_code=trust_remote_code ) architecture = config.architectures[0] diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 2e843bc..44a053a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -18,6 +18,7 @@ from text_generation_server.models.galactica import GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded +from text_generation_server.models.ct2_causal_lm import CT2CausalLM # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -75,6 +76,7 @@ def get_model( dtype: Optional[str], trust_remote_code: bool, ) -> Model: + dtype_ct2 = dtype if dtype is None: dtype = torch.float16 elif dtype == "float16": @@ -84,6 +86,15 @@ def get_model( else: raise RuntimeError(f"Unknown dtype {dtype}") + if quantize is not None and "ct2" in quantize: + return CT2CausalLM( + model_id, + revision, + quantize=quantize, + dtype=dtype_ct2, + trust_remote_code=trust_remote_code, + ) + if "facebook/galactica" in model_id: return GalacticaSharded( model_id, diff --git a/server/text_generation_server/models/ct2_causal_lm.py b/server/text_generation_server/models/ct2_causal_lm.py new file mode 100644 index 0000000..7f128a7 --- /dev/null +++ b/server/text_generation_server/models/ct2_causal_lm.py @@ -0,0 +1,359 @@ +# coding=utf-8 +# Copyright 2023 Michael Feil. +# +# This code is loosely based on Huggingface text-generation-inference v0.9.3's causal_lm.py implementation. +# While it remains licensed under Apache License, Version 2.0, +# text-generation-inference itself on 7/28/2023 has changed its license. +# This code remains unaffected by this change. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +import os +import multiprocessing +from pathlib import Path + +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from opentelemetry import trace +from transformers import ( + AutoTokenizer, + AutoConfig, +) +from typing import Optional, Tuple, List, Type, Dict + +from text_generation_server.models import Model +from text_generation_server.models.types import ( + PrefillTokens, + Generation, + GeneratedText, +) + +from text_generation_server.utils import Sampling +from text_generation_server.models.causal_lm import CausalLMBatch + +try: + import ctranslate2 +except ImportError: + ctranslate2 = None + + +tracer = trace.get_tracer(__name__) + + +class CT2CausalLM(Model): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + if ctranslate2 is None: + raise ValueError( + "for quantization with ct2, the installation requires the pip package ctranslate2. " + "install via `text-generation-server[ct2]` or `pip install ctranslate2` is required.", + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + # Start CT2 + ct2_generator_kwargs = { + "inter_threads": int(os.environ.get("TGI_CT2_INTER_THREADS", 1)) + } + if torch.cuda.is_available(): + self.ct2_device = "cuda" + ct2_generator_kwargs["intra_threads"] = int( + os.environ.get("TGI_CT2_INTRA_THREADS", 1) + ) + else: + self.ct2_device = "cpu" + ct2_generator_kwargs["intra_threads"] = int( + os.environ.get( + "TGI_CT2_INTRA_THREADS", multiprocessing.cpu_count() // 2 + ) + ) + + if dtype == torch.float16 and self.ct2_device == "cuda": + ct2_compute_type = "float16" + elif dtype == torch.bfloat16 and self.ct2_device == "cuda": + ct2_compute_type = "bfloat16" + elif self.ct2_device == "cpu" and dtype in [torch.float16, torch.bfloat16]: + # float16 is not available on CPU + # and int16 has no stable implementation + ct2_compute_type = "float32" + else: + # default, int8 quantization. + + if "cuda" in self.ct2_device: + # int8 for int8 layers, float16 for non-quantized layers + ct2_compute_type = "int8_float16" + else: + # int8 for int8 layers, float32 for non-quantized layers + ct2_compute_type = "int8" + + # Start CT2 - conversion + out_dir = ( + Path(HUGGINGFACE_HUB_CACHE) + / "ct2models" / f"{model_id.replace('/','--')}--{ct2_compute_type}" + ) + + if not os.path.exists(out_dir / "model.bin"): + try: + converter = ctranslate2.converters.TransformersConverter( + model_id, + activation_scales=None, + load_as_float16=ct2_compute_type != "bfloat16", + revision=revision, + low_cpu_mem_usage=True, + trust_remote_code=trust_remote_code, + ) + converter.convert( + output_dir=out_dir, + vmap=None, + quantization=ct2_compute_type, + force=True, + ) + except Exception as ex: + raise ValueError( + f"conversion with ctranslate2 for {model_id} failed : Error {ex}" + ) + if not os.path.exists(out_dir / "model.bin"): + raise ValueError( + f"no ctranslate2 model for {model_id} found after conversion in {out_dir}" + ) + + # Start CT2 + self.ct2_model = ctranslate2.Generator( + str(out_dir), + device=self.ct2_device, + compute_type=ct2_compute_type, + **ct2_generator_kwargs, + ) + + class DummyModel(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.config = AutoConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + + model = DummyModel() + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + super().__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=torch.int8 if "int8" in ct2_compute_type else torch.float16, + device=torch.device(self.ct2_device), + ) + + @property + def batch_type(self) -> Type[CausalLMBatch]: + return CausalLMBatch + + def decode(self, generated_ids: List[int]) -> str: + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + def forward_ct2( + self, + all_input_ids, + input_lengths, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # CT2 forward requires a list of list of input tokens ids and lengths + ids_input = ( + torch.nested.to_padded_tensor( + torch.nested.nested_tensor(all_input_ids), 1234567 + ) + .flatten(1) + .to(torch.int32) + ) + # lengths of the padded ids_input, i.e. how often not pad=1234567 is used. + lengths = np.array(input_lengths, dtype=np.int32) + + if self.ct2_device == "cuda": + lengths = torch.from_numpy(lengths).to(self.ct2_device) + elif self.ct2_device == "cpu": + ids_input = ids_input.numpy() + + ids_input = ctranslate2.StorageView.from_array(ids_input) + lengths = ctranslate2.StorageView.from_array(lengths) + # now, forward through the network + logits = self.ct2_model.forward_batch(ids_input, lengths) + + # continue with logits as torch tensor + if self.ct2_device == "cpu": + # logits is a float32 torch cpu tensor + logits = torch.from_numpy(np.asarray(logits)) + else: + # logits is a float16 torch cuda tensor + logits = torch.as_tensor(logits, device=self.ct2_device) + return logits, None + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batch: CausalLMBatch + ) -> Tuple[List[Generation], Optional[CausalLMBatch]]: + logits, past = self.forward_ct2(batch.all_input_ids, batch.input_lengths) + + # Results + generations: List[Generation] = [] + stopped = True + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + + # Append next token to all tokens + all_input_ids = torch.cat([all_input_ids, next_token_id]) + new_input_length = input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_squeezed, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text = self.decode( + all_input_ids[-stopping_criteria.current_tokens :, 0] + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None + + # Prefill + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, prefill_logprobs, prefill_texts + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id_squeezed, + next_token_logprob, + next_token_text, + next_token_id_squeezed.item() in self.all_special_ids, + generated_text, + ) + + generations.append(generation) + + # Update values + batch.input_ids[i, 0] = next_token_id + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) + + # We finished all generations in the batch; there is no next batch + if stopped: + return generations, None + + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + + # Update attention_mask as we added a new token to input_ids + batch.attention_mask[:, -batch.padding_right_offset] = 1 + # Decrease right offset + batch.padding_right_offset -= 1 + + # Update position_ids + batch.position_ids = batch.position_ids[:, -1:] + 1 + + # Update past key values + batch.past_key_values = past + + return generations, batch diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index f424eae..e29321a 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -42,7 +42,7 @@ class StaticWarper: self.static_next_logprob = None def __call__(self, scores): - if torch.cuda.is_available(): + if scores.device.type == "cuda": if self.cuda_graph is None: self.static_scores = scores self.cuda_graph = torch.cuda.CUDAGraph()