From 26cdea5c0c0c4316b8a2ea4b449da8c4aecbc796 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 28 Feb 2024 15:50:31 +0100 Subject: [PATCH] feat: Qwen2 (#1608) See #1584 --------- Co-authored-by: Cheng Kuan Yong Jason --- .../test_flash_qwen2/test_flash_qwen2.json | 84 ++++ .../test_flash_qwen2_all_params.json | 84 ++++ .../test_flash_qwen2_load.json | 338 +++++++++++++++ integration-tests/models/test_flash_qwen2.py | 59 +++ .../text_generation_server/models/__init__.py | 68 ++- .../models/causal_lm.py | 3 + .../models/custom_modeling/bloom_modeling.py | 2 +- .../custom_modeling/flash_qwen2_modeling.py | 400 ++++++++++++++++++ .../models/custom_modeling/neox_modeling.py | 17 +- .../models/custom_modeling/opt_modeling.py | 17 +- .../models/flash_mistral.py | 2 +- .../models/flash_qwen2.py | 88 ++++ .../models/flash_starcoder2.py | 2 +- .../models/galactica.py | 6 +- .../text_generation_server/models/gpt_neox.py | 6 +- server/text_generation_server/models/rw.py | 4 + .../models/seq2seq_lm.py | 3 + 17 files changed, 1159 insertions(+), 24 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2.json create mode 100644 integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_load.json create mode 100644 integration-tests/models/test_flash_qwen2.py create mode 100644 server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py create mode 100644 server/text_generation_server/models/flash_qwen2.py diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2.json b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2.json new file mode 100644 index 00000000..7219f9e6 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2271, + "logprob": null, + "text": "Test" + }, + { + "id": 1681, + "logprob": -8.8515625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" + }, + { + "id": 2, + "logprob": -2.9160156, + "special": false, + "text": "#" + }, + { + "id": 4230, + "logprob": -3.1035156, + "special": false, + "text": " Create" + }, + { + "id": 264, + "logprob": -1.1025391, + "special": false, + "text": " a" + }, + { + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" + }, + { + "id": 198, + "logprob": -1.1953125, + "special": false, + "text": "\n" + }, + { + "id": 2035, + "logprob": -1.3203125, + "special": false, + "text": "request" + }, + { + "id": 284, + "logprob": -0.13537598, + "special": false, + "text": " =" + }, + { + "id": 7388, + "logprob": -1.2402344, + "special": false, + "text": " requests" + }, + { + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" + } + ], + "top_tokens": null + }, + "generated_text": "\n# Create a request\nrequest = requests.get" +} diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_all_params.json b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_all_params.json new file mode 100644 index 00000000..4a2936af --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_all_params.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2271, + "logprob": null, + "text": "Test" + }, + { + "id": 1681, + "logprob": -8.8515625, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 311, + "logprob": -1.4277344, + "special": false, + "text": " to" + }, + { + "id": 279, + "logprob": -0.65478516, + "special": false, + "text": " the" + }, + { + "id": 2473, + "logprob": -1.8300781, + "special": false, + "text": " service" + }, + { + "id": 382, + "logprob": -0.75, + "special": false, + "text": ".\n\n" + }, + { + "id": 286, + "logprob": -0.11621094, + "special": false, + "text": " " + }, + { + "id": 549, + "logprob": 0.0, + "special": false, + "text": " :" + }, + { + "id": 689, + "logprob": -0.48608398, + "special": false, + "text": "return" + }, + { + "id": 25, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 5949, + "logprob": -0.5756836, + "special": false, + "text": " Response" + }, + { + "id": 504, + "logprob": -0.24499512, + "special": false, + "text": " from" + } + ], + "top_tokens": null + }, + "generated_text": "Test request to the service.\n\n :return: Response from" +} diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_load.json b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_load.json new file mode 100644 index 00000000..4786ff24 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_load.json @@ -0,0 +1,338 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2271, + "logprob": null, + "text": "Test" + }, + { + "id": 1681, + "logprob": -8.8515625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" + }, + { + "id": 2, + "logprob": -2.9140625, + "special": false, + "text": "#" + }, + { + "id": 4230, + "logprob": -3.1054688, + "special": false, + "text": " Create" + }, + { + "id": 264, + "logprob": -1.0966797, + "special": false, + "text": " a" + }, + { + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" + }, + { + "id": 198, + "logprob": -1.1923828, + "special": false, + "text": "\n" + }, + { + "id": 2035, + "logprob": -1.3193359, + "special": false, + "text": "request" + }, + { + "id": 284, + "logprob": -0.13586426, + "special": false, + "text": " =" + }, + { + "id": 7388, + "logprob": -1.2412109, + "special": false, + "text": " requests" + }, + { + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" + } + ], + "top_tokens": null + }, + "generated_text": "\n# Create a request\nrequest = requests.get" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2271, + "logprob": null, + "text": "Test" + }, + { + "id": 1681, + "logprob": -8.8515625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" + }, + { + "id": 2, + "logprob": -2.9140625, + "special": false, + "text": "#" + }, + { + "id": 4230, + "logprob": -3.1054688, + "special": false, + "text": " Create" + }, + { + "id": 264, + "logprob": -1.0966797, + "special": false, + "text": " a" + }, + { + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" + }, + { + "id": 198, + "logprob": -1.1923828, + "special": false, + "text": "\n" + }, + { + "id": 2035, + "logprob": -1.3193359, + "special": false, + "text": "request" + }, + { + "id": 284, + "logprob": -0.13586426, + "special": false, + "text": " =" + }, + { + "id": 7388, + "logprob": -1.2412109, + "special": false, + "text": " requests" + }, + { + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" + } + ], + "top_tokens": null + }, + "generated_text": "\n# Create a request\nrequest = requests.get" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2271, + "logprob": null, + "text": "Test" + }, + { + "id": 1681, + "logprob": -8.8515625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" + }, + { + "id": 2, + "logprob": -2.9140625, + "special": false, + "text": "#" + }, + { + "id": 4230, + "logprob": -3.1054688, + "special": false, + "text": " Create" + }, + { + "id": 264, + "logprob": -1.0966797, + "special": false, + "text": " a" + }, + { + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" + }, + { + "id": 198, + "logprob": -1.1923828, + "special": false, + "text": "\n" + }, + { + "id": 2035, + "logprob": -1.3193359, + "special": false, + "text": "request" + }, + { + "id": 284, + "logprob": -0.13586426, + "special": false, + "text": " =" + }, + { + "id": 7388, + "logprob": -1.2412109, + "special": false, + "text": " requests" + }, + { + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" + } + ], + "top_tokens": null + }, + "generated_text": "\n# Create a request\nrequest = requests.get" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2271, + "logprob": null, + "text": "Test" + }, + { + "id": 1681, + "logprob": -8.8515625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" + }, + { + "id": 2, + "logprob": -2.9140625, + "special": false, + "text": "#" + }, + { + "id": 4230, + "logprob": -3.1054688, + "special": false, + "text": " Create" + }, + { + "id": 264, + "logprob": -1.0966797, + "special": false, + "text": " a" + }, + { + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" + }, + { + "id": 198, + "logprob": -1.1923828, + "special": false, + "text": "\n" + }, + { + "id": 2035, + "logprob": -1.3193359, + "special": false, + "text": "request" + }, + { + "id": 284, + "logprob": -0.13586426, + "special": false, + "text": " =" + }, + { + "id": 7388, + "logprob": -1.2412109, + "special": false, + "text": " requests" + }, + { + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" + } + ], + "top_tokens": null + }, + "generated_text": "\n# Create a request\nrequest = requests.get" + } +] diff --git a/integration-tests/models/test_flash_qwen2.py b/integration-tests/models/test_flash_qwen2.py new file mode 100644 index 00000000..2963aeb4 --- /dev/null +++ b/integration-tests/models/test_flash_qwen2.py @@ -0,0 +1,59 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_qwen2_handle(launcher): + with launcher("Qwen/Qwen1.5-0.5B") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_qwen2(flash_qwen2_handle): + await flash_qwen2_handle.health(300) + return flash_qwen2_handle.client + + +@pytest.mark.asyncio +async def test_flash_qwen2(flash_qwen2, response_snapshot): + response = await flash_qwen2.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == "\n# Create a request\nrequest = requests.get" + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot): + response = await flash_qwen2.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot): + responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all( + [r.generated_text == responses[0].generated_text for r in responses] + ), f"{[r.generated_text for r in responses]}" + assert responses[0].generated_text == "\n# Create a request\nrequest = requests.get" + + assert responses == response_snapshot diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e2edbfa9..e7b0b9e2 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -54,6 +54,9 @@ try: from text_generation_server.models.flash_llama import ( FlashLlama, ) + from text_generation_server.models.flash_qwen2 import ( + FlashQwen2, + ) from text_generation_server.models.flash_gemma import ( FlashGemma, ) @@ -81,6 +84,7 @@ if FLASH_ATTENTION: __all__.append(FlashMistral) __all__.append(FlashMixtral) __all__.append(FlashPhi) + __all__.append(FlashQwen2) __all__.append(FlashStarcoder2) MAMBA_AVAILABLE = True @@ -339,9 +343,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif sharded: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate") - ) + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) else: return CausalLM( model_id, @@ -399,6 +401,17 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type == "mixtral": sliding_window = config_dict.get("sliding_window", -1) @@ -413,6 +426,18 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + if model_type == "starcoder2": sliding_window = config_dict.get("sliding_window", -1) if ( @@ -425,6 +450,43 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") + ) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type == "qwen2": + sliding_window = config_dict.get("sliding_window", -1) + if ( + (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION + ) or HAS_FLASH_ATTN_V2_CUDA: + return FlashQwen2( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type == "opt": return OPTSharded( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index bbcef210..93ec6ba4 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -486,6 +486,9 @@ class CausalLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if use_medusa: + raise RuntimeError("Medusa decoding is not enabled for AutoModel") + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 10b40483..c8f02bca 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -870,7 +870,7 @@ class BloomForCausalLM(BloomPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **deprecated_arguments, - ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py new file mode 100644 index 00000000..94023b33 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -0,0 +1,400 @@ +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from typing import Optional, List, Tuple + +from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + SpeculativeHead, + get_linear, + FastRMSNorm, +) + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + w = [ + weights.get_sharded(f"{p}.bias", dim=0) + for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + ] + bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device) + + return TensorParallelColumnLinear( + get_linear(weight, bias=bias, quantize=config.quantize) + ) + + +class Qwen2Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + + paged_attention.reshape_and_cache( + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + window_size_left=self.max_past, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class Qwen2MLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class Qwen2Layer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = Qwen2Attention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class Qwen2Model(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.layers = nn.ModuleList( + [ + Qwen2Layer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix="model.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + true_max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, true_max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class Qwen2ForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = Qwen2Model(config, weights) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head", + weights=weights, + ) + self.max_past = config.sliding_window + self.max_past_tensor = ( + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None + else None + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + true_max_s = max_s + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif self.max_past is not None: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + true_max_s, + prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index 2550d2d1..1b060060 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -721,7 +721,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): ) hidden_states = outputs[0] - lm_logits = self.embed_out(hidden_states) + lm_logits, speculative_logits = self.embed_out(hidden_states) lm_loss = None if labels is not None: @@ -739,12 +739,15 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): output = (lm_logits,) + outputs[1:] return ((lm_loss,) + output) if lm_loss is not None else output - return CausalLMOutputWithPast( - loss=lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + return ( + CausalLMOutputWithPast( + loss=lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ), + speculative_logits, ) def prepare_inputs_for_generation( diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index de5e95af..7a5cf917 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -792,16 +792,19 @@ class OPTForCausalLM(OPTPreTrainedModel): return_dict=return_dict, ) - logits = self.lm_head(outputs[0]).contiguous() + logits, speculative_logits = self.lm_head(outputs) loss = None - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + return ( + CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ), + speculative_logits, ) def prepare_inputs_for_generation( diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index fd5c18e0..8149c1b0 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -315,7 +315,7 @@ class BaseFlashMistral(FlashCausalLM): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype else: - raise NotImplementedError("FlashLlama is only available on GPU") + raise NotImplementedError("FlashMistral is only available on GPU") tokenizer = LlamaTokenizerFast.from_pretrained( model_id, diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py new file mode 100644 index 00000000..c3c63516 --- /dev/null +++ b/server/text_generation_server/models/flash_qwen2.py @@ -0,0 +1,88 @@ +import math + +import torch +import torch.distributed + +from opentelemetry import trace +from transformers.models.qwen2 import Qwen2Tokenizer +from typing import Optional + +from text_generation_server.models.cache_manager import BLOCK_SIZE +from text_generation_server.models.flash_mistral import ( + BaseFlashMistral, + set_sliding_window, +) +from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2ForCausalLM, +) +from transformers.models.qwen2 import Qwen2Config +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + +tracer = trace.get_tracer(__name__) + + +class FlashQwen2(BaseFlashMistral): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + use_medusa: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + raise NotImplementedError("FlashQwen2 is only available on GPU") + + tokenizer = Qwen2Tokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = Qwen2Config.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + config.use_medusa = use_medusa + + # Set context windows + if config.sliding_window is not None: + set_sliding_window( + config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) + ) + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["gptq", "awq"]: + weights._set_gptq_params(model_id, revision) + + model = Qwen2ForCausalLM(config, weights) + + self.cuda_graphs = {} + + torch.distributed.barrier(group=self.process_group) + super(BaseFlashMistral, self).__init__( + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + sliding_window=config.sliding_window, + ) diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 2f6ae757..68e726d8 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -38,7 +38,7 @@ class FlashStarcoder2(BaseFlashMistral): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype else: - raise NotImplementedError("FlashLlama is only available on GPU") + raise NotImplementedError("FlashStarcoder2 is only available on GPU") tokenizer = GPT2TokenizerFast.from_pretrained( model_id, diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 3607c285..a46f86be 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -167,6 +167,7 @@ class GalacticaSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -194,6 +195,7 @@ class GalacticaSharded(CausalLM): ) config.quantize = quantize tokenizer.pad_token_id = config.pad_token_id + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") @@ -229,10 +231,10 @@ class GalacticaSharded(CausalLM): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, ) - return outputs.logits, outputs.past_key_values + return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 45df4839..1c4cfe7d 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -24,6 +24,7 @@ class GPTNeoxSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -50,6 +51,7 @@ class GPTNeoxSharded(CausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") @@ -75,7 +77,7 @@ class GPTNeoxSharded(CausalLM): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -84,4 +86,4 @@ class GPTNeoxSharded(CausalLM): ) logits = outputs.logits - return logits, outputs.past_key_values + return logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 22ab093e..92c93542 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -12,9 +12,13 @@ class RW(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if use_medusa: + raise RuntimeError("Medusa decoding is not enabled for AutoModel") + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index fae9a2df..e55a661c 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -536,6 +536,9 @@ class Seq2SeqLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if use_medusa: + raise RuntimeError("Medusa decoding is not enabled for AutoModel") + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype