From 93a7042d7e7b40ef204b8752dc817c6fe192a825 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 30 Sep 2024 11:15:09 +0200 Subject: [PATCH] feat: support phi3.5 moe (#2479) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: support phi3.5 moe model loading * fix: prefer llama base model and improve rotary logic * feat: return reasonable generation and add integration test * fix: run lint and update docs * fix: rerun lint for openapi docs * fix: prefer do_sample false unless temp is set by user, and update chat tests * fix: small typo adjustments * fix: consolidate long rope paths * fix: revert greedy by default and test changes * Vendor configuration so that we don't have to `trust_remote_code` * Use SparseMoELayer * Add support for dense MoE * Some type annotations * Add the usual model tests * Ruff. --------- Co-authored-by: Daniƫl de Kok Co-authored-by: Nicolas Patry --- docs/source/supported_models.md | 1 + .../test_flash_phi35_moe.json | 109 +++++ .../test_flash_phi35_moe_all_params.json | 99 ++++ .../test_flash_phi35_moe_load.json | 438 ++++++++++++++++++ .../models/test_flash_phi35_moe.py | 75 +++ integration-tests/models/test_tools_llama.py | 6 +- router/src/config.rs | 1 + .../text_generation_server/layers/rotary.py | 73 ++- .../text_generation_server/models/__init__.py | 31 ++ .../custom_modeling/flash_llama_modeling.py | 94 +++- .../custom_modeling/flash_phi_moe_modeling.py | 254 ++++++++++ 11 files changed, 1164 insertions(+), 17 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe.json create mode 100644 integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe_load.json create mode 100644 integration-tests/models/test_flash_phi35_moe.py create mode 100644 server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 832f88ef..3fa78ee9 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -20,6 +20,7 @@ Text Generation Inference enables serving optimized models on specific hardware - [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1) - [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder) - [Phi](https://huggingface.co/microsoft/phi-1_5) +- [PhiMoe](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct) - [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) - [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct) - [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) diff --git a/integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe.json b/integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe.json new file mode 100644 index 00000000..0d6dca31 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe.json @@ -0,0 +1,109 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1724, + "logprob": null, + "text": "What" + }, + { + "id": 338, + "logprob": -0.7133789, + "text": "is" + }, + { + "id": 16030, + "logprob": -13.9296875, + "text": "gradient" + }, + { + "id": 26815, + "logprob": -0.048919678, + "text": "descent" + }, + { + "id": 29973, + "logprob": -3.0078125, + "text": "?" + }, + { + "id": 13, + "logprob": -2.8105469, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.84521484, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 25584, + "logprob": -0.017028809, + "special": false, + "text": "Grad" + }, + { + "id": 993, + "logprob": -0.0027313232, + "special": false, + "text": "ient" + }, + { + "id": 26815, + "logprob": -0.023254395, + "special": false, + "text": " descent" + }, + { + "id": 338, + "logprob": -2.0623207e-05, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.5361328, + "special": false, + "text": " a" + }, + { + "id": 937, + "logprob": -0.17578125, + "special": false, + "text": " first" + }, + { + "id": 29899, + "logprob": 0.0, + "special": false, + "text": "-" + }, + { + "id": 2098, + "logprob": -0.00011539459, + "special": false, + "text": "order" + }, + { + "id": 13883, + "logprob": -0.47436523, + "special": false, + "text": " optimization" + }, + { + "id": 5687, + "logprob": -0.00027680397, + "special": false, + "text": " algorithm" + } + ], + "top_tokens": null + }, + "generated_text": "Gradient descent is a first-order optimization algorithm" +} diff --git a/integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe_all_params.json b/integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe_all_params.json new file mode 100644 index 00000000..38b80335 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe_all_params.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 16030, + "logprob": null, + "text": "gradient" + }, + { + "id": 26815, + "logprob": -6.4960938, + "text": "descent" + }, + { + "id": 29973, + "logprob": -5.1484375, + "text": "?" + }, + { + "id": 13, + "logprob": -4.0351562, + "text": "\n" + }, + { + "id": 13, + "logprob": -5.2265625, + "text": "\n" + } + ], + "seed": 0, + "tokens": [ + { + "id": 10994, + "logprob": -1.1542969, + "special": false, + "text": "Hello" + }, + { + "id": 29991, + "logprob": 0.0, + "special": false, + "text": "!" + }, + { + "id": 739, + "logprob": 0.0, + "special": false, + "text": " It" + }, + { + "id": 2444, + "logprob": -0.42260742, + "special": false, + "text": " seems" + }, + { + "id": 366, + "logprob": 0.0, + "special": false, + "text": " you" + }, + { + "id": 29915, + "logprob": 0.0, + "special": false, + "text": "'" + }, + { + "id": 276, + "logprob": -0.9838867, + "special": false, + "text": "re" + }, + { + "id": 3211, + "logprob": 0.0, + "special": false, + "text": " address" + }, + { + "id": 292, + "logprob": 0.0, + "special": false, + "text": "ing" + }, + { + "id": 263, + "logprob": -0.15124512, + "special": false, + "text": " a" + } + ], + "top_tokens": null + }, + "generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a" +} diff --git a/integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe_load.json b/integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe_load.json new file mode 100644 index 00000000..f1f81152 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe_load.json @@ -0,0 +1,438 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1724, + "logprob": null, + "text": "What" + }, + { + "id": 338, + "logprob": -0.7133789, + "text": "is" + }, + { + "id": 16030, + "logprob": -13.9296875, + "text": "gradient" + }, + { + "id": 26815, + "logprob": -0.048919678, + "text": "descent" + }, + { + "id": 29973, + "logprob": -3.0078125, + "text": "?" + }, + { + "id": 13, + "logprob": -2.8105469, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.84521484, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 25584, + "logprob": -0.017028809, + "special": false, + "text": "Grad" + }, + { + "id": 993, + "logprob": -0.0028476715, + "special": false, + "text": "ient" + }, + { + "id": 26815, + "logprob": -0.023971558, + "special": false, + "text": " descent" + }, + { + "id": 338, + "logprob": -2.0384789e-05, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.5229492, + "special": false, + "text": " a" + }, + { + "id": 937, + "logprob": -0.17602539, + "special": false, + "text": " first" + }, + { + "id": 29899, + "logprob": 0.0, + "special": false, + "text": "-" + }, + { + "id": 2098, + "logprob": -0.000116467476, + "special": false, + "text": "order" + }, + { + "id": 13883, + "logprob": -0.47436523, + "special": false, + "text": " optimization" + }, + { + "id": 5687, + "logprob": -0.00027871132, + "special": false, + "text": " algorithm" + } + ], + "top_tokens": null + }, + "generated_text": "Gradient descent is a first-order optimization algorithm" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1724, + "logprob": null, + "text": "What" + }, + { + "id": 338, + "logprob": -0.7128906, + "text": "is" + }, + { + "id": 16030, + "logprob": -13.9375, + "text": "gradient" + }, + { + "id": 26815, + "logprob": -0.05053711, + "text": "descent" + }, + { + "id": 29973, + "logprob": -3.0058594, + "text": "?" + }, + { + "id": 13, + "logprob": -2.8242188, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.84521484, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 25584, + "logprob": -0.018859863, + "special": false, + "text": "Grad" + }, + { + "id": 993, + "logprob": -0.002822876, + "special": false, + "text": "ient" + }, + { + "id": 26815, + "logprob": -0.023254395, + "special": false, + "text": " descent" + }, + { + "id": 338, + "logprob": -2.0384789e-05, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.5229492, + "special": false, + "text": " a" + }, + { + "id": 937, + "logprob": -0.17126465, + "special": false, + "text": " first" + }, + { + "id": 29899, + "logprob": 0.0, + "special": false, + "text": "-" + }, + { + "id": 2098, + "logprob": -0.0001155138, + "special": false, + "text": "order" + }, + { + "id": 13883, + "logprob": -0.47436523, + "special": false, + "text": " optimization" + }, + { + "id": 5687, + "logprob": -0.00027036667, + "special": false, + "text": " algorithm" + } + ], + "top_tokens": null + }, + "generated_text": "Gradient descent is a first-order optimization algorithm" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1724, + "logprob": null, + "text": "What" + }, + { + "id": 338, + "logprob": -0.71484375, + "text": "is" + }, + { + "id": 16030, + "logprob": -13.9375, + "text": "gradient" + }, + { + "id": 26815, + "logprob": -0.049346924, + "text": "descent" + }, + { + "id": 29973, + "logprob": -3.0078125, + "text": "?" + }, + { + "id": 13, + "logprob": -2.8242188, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.86328125, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 25584, + "logprob": -0.017196655, + "special": false, + "text": "Grad" + }, + { + "id": 993, + "logprob": -0.0028438568, + "special": false, + "text": "ient" + }, + { + "id": 26815, + "logprob": -0.023254395, + "special": false, + "text": " descent" + }, + { + "id": 338, + "logprob": -2.026558e-05, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.5229492, + "special": false, + "text": " a" + }, + { + "id": 937, + "logprob": -0.17602539, + "special": false, + "text": " first" + }, + { + "id": 29899, + "logprob": 0.0, + "special": false, + "text": "-" + }, + { + "id": 2098, + "logprob": -0.00011622906, + "special": false, + "text": "order" + }, + { + "id": 13883, + "logprob": -0.48608398, + "special": false, + "text": " optimization" + }, + { + "id": 5687, + "logprob": -0.00027894974, + "special": false, + "text": " algorithm" + } + ], + "top_tokens": null + }, + "generated_text": "Gradient descent is a first-order optimization algorithm" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1724, + "logprob": null, + "text": "What" + }, + { + "id": 338, + "logprob": -0.7192383, + "text": "is" + }, + { + "id": 16030, + "logprob": -13.9375, + "text": "gradient" + }, + { + "id": 26815, + "logprob": -0.050445557, + "text": "descent" + }, + { + "id": 29973, + "logprob": -3.0078125, + "text": "?" + }, + { + "id": 13, + "logprob": -2.8242188, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.8276367, + "text": "\n" + } + ], + "seed": null, + "tokens": [ + { + "id": 25584, + "logprob": -0.01727295, + "special": false, + "text": "Grad" + }, + { + "id": 993, + "logprob": -0.0027542114, + "special": false, + "text": "ient" + }, + { + "id": 26815, + "logprob": -0.023254395, + "special": false, + "text": " descent" + }, + { + "id": 338, + "logprob": -2.0384789e-05, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.5229492, + "special": false, + "text": " a" + }, + { + "id": 937, + "logprob": -0.17126465, + "special": false, + "text": " first" + }, + { + "id": 29899, + "logprob": 0.0, + "special": false, + "text": "-" + }, + { + "id": 2098, + "logprob": -0.00011301041, + "special": false, + "text": "order" + }, + { + "id": 13883, + "logprob": -0.48608398, + "special": false, + "text": " optimization" + }, + { + "id": 5687, + "logprob": -0.00027894974, + "special": false, + "text": " algorithm" + } + ], + "top_tokens": null + }, + "generated_text": "Gradient descent is a first-order optimization algorithm" + } +] diff --git a/integration-tests/models/test_flash_phi35_moe.py b/integration-tests/models/test_flash_phi35_moe.py new file mode 100644 index 00000000..2173740a --- /dev/null +++ b/integration-tests/models/test_flash_phi35_moe.py @@ -0,0 +1,75 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_phi35_moe_handle(launcher): + with launcher( + "microsoft/Phi-3.5-MoE-instruct", + num_shard=4, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_phi35_moe(flash_phi35_moe_handle): + await flash_phi35_moe_handle.health(300) + return flash_phi35_moe_handle.client + + +@pytest.mark.asyncio +async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot): + response = await flash_phi35_moe.generate( + "What is gradient descent?\n\n", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert ( + response.generated_text + == "Gradient descent is a first-order optimization algorithm" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot): + response = await flash_phi35_moe.generate( + "What is gradient descent?\n\n", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert ( + response.generated_text + == "What is gradient descent?\n\nHello! It seems you're addressing a" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_snapshot): + responses = await generate_load( + flash_phi35_moe, "What is gradient descent?\n\n", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert responses[0].details.generated_tokens == 10 + assert ( + responses[0].generated_text + == "Gradient descent is a first-order optimization algorithm" + ) + assert all( + [r.generated_text == responses[0].generated_text for r in responses] + ), f"{[r.generated_text for r in responses]}" + + assert responses == response_snapshot diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 9855cfda..c337afa1 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -4,7 +4,9 @@ import pytest @pytest.fixture(scope="module") def flash_llama_grammar_tools_handle(launcher): with launcher( - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False + "meta-llama/Meta-Llama-3.1-8B-Instruct", + num_shard=2, + disable_grammar_support=False, ) as handle: yield handle @@ -208,7 +210,7 @@ async def test_flash_llama_grammar_tools_stream( async for response in responses: count += 1 - assert count == 48 + assert count == 28 assert response == response_snapshot diff --git a/router/src/config.rs b/router/src/config.rs index 5d0be9c8..9b770d06 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -159,6 +159,7 @@ pub enum Config { #[serde(rename = "phi-msft")] PhiMsft, Phi3, + PhiMoe, Llama, Baichuan, Paligemma(Paligemma), diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index fc4a59b9..a2076bb2 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -166,6 +166,20 @@ class PositionRotaryEmbedding(nn.Module): 1 + math.log(scale) / math.log(original_max_position_embeddings) ) + # if short_mscale and long_mscale are provided we need to scale the freqs + # using the Phi3LongRoPEScaledRotaryEmbedding + if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling): + short_mscale = rope_scaling["short_mscale"] + long_mscale = rope_scaling["long_mscale"] + return Phi3LongRoPEScaledRotaryEmbedding( + short_inv_freq=short_inv_freq, + long_inv_freq=long_inv_freq, + max_position_embeddings=config.max_position_embeddings, + short_mscale=short_mscale, + long_mscale=long_mscale, + original_max_position_embeddings=original_max_position_embeddings, + ) + return SuRotaryEmbedding( short_inv_freq=short_inv_freq, long_inv_freq=long_inv_freq, @@ -287,6 +301,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): # or if we're on a new device (possibly due to tracing for instance) if ( seqlen > self._seq_len_cached + or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): @@ -308,6 +323,63 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype) +class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): + def __init__( + self, + short_inv_freq: torch.Tensor, + long_inv_freq: torch.Tensor, + max_position_embeddings: int, + short_mscale: float, + long_mscale: float, + original_max_position_embeddings: int, + ): + super(PositionRotaryEmbedding, self).__init__() + self.short_inv_freq = short_inv_freq + self.long_inv_freq = long_inv_freq + self.max_position_embeddings = max_position_embeddings + self.short_mscale = short_mscale + self.long_mscale = long_mscale + self.original_max_position_embeddings = original_max_position_embeddings + + # cache + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.dynamic_args = None + + def _update_cos_sin_cache(self, dtype, device, seqlen): + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype) + + short_freqs = torch.outer( + t[: self.original_max_position_embeddings], + self.short_inv_freq.to(device=t.device), + ) + + long_freqs = torch.outer( + t[self.original_max_position_embeddings :], + self.long_inv_freq.to(device=t.device), + ) + + short_freqs = short_freqs * self.short_mscale + long_freqs = long_freqs * self.long_mscale + + freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device) + freqs[: self.original_max_position_embeddings] = short_freqs + freqs[self.original_max_position_embeddings :] = long_freqs + + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): inv_freq = _create_inv_freq(dim, base, device) @@ -467,7 +539,6 @@ def apply_llama3_scaling( elif wavelen > low_freq_wavelen: new_freqs.append(freq / scaling_factor) else: - assert low_freq_wavelen != high_freq_wavelen smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / ( high_freq_factor - low_freq_factor diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e5e5aabb..99a6ba76 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -32,6 +32,9 @@ from text_generation_server.models.custom_modeling.phi_modeling import ( PhiConfig, PhiForCausalLM, ) +from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import ( + PhiMoEConfig, +) from text_generation_server.models.custom_modeling.t5_modeling import ( T5ForConditionalGeneration, ) @@ -237,6 +240,11 @@ class ModelType(enum.Enum): "name": "Phi", "url": "https://huggingface.co/microsoft/phi-1_5", } + PHI_MOE = { + "type": "phimoe", + "name": "PhiMoe", + "url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct", + } BAICHUAN = { "type": "baichuan", "name": "Baichuan", @@ -768,6 +776,29 @@ def get_model( trust_remote_code=trust_remote_code, ) + elif model_type == PHI_MOE: + if FLASH_ATTENTION: + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + config_class=PhiMoEConfig, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + else: + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif model_type == "phi-msft": if FLASH_ATTENTION: raise NotImplementedError( diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index df48c6f7..358fbbaa 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -19,7 +19,7 @@ # limitations under the License. from contextlib import contextmanager -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type import torch import torch.distributed @@ -28,6 +28,7 @@ from torch import nn from transformers.activations import ACT2FN from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, @@ -46,12 +47,19 @@ from text_generation_server.layers import ( from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, + FastLayerNorm, +) +from text_generation_server.layers import ( + FastLinear, ) from text_generation_server.utils.weights import ( Weights, ) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader +if SYSTEM != "ipex": + pass + if SYSTEM == "rocm": try: from vllm import _custom_C @@ -245,6 +253,42 @@ class FlashLlamaAttention(torch.nn.Module): ) +class Phi3MoE(nn.Module): + def __init__( + self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights + ): + super().__init__() + + # gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + self.moe = moe_layer_cls( + prefix=f"{prefix}.experts", + n_experts=config.num_local_experts, + n_expert_group=None, + renormalize=True, + topk=config.num_experts_per_tok, + topk_group=None, + weights=weights, + gate_proj_name="w1", + up_proj_name="w3", + down_proj_name="w2", + ) + + self.process_group = weights.process_group + + def forward(self, x, adapter_data) -> torch.Tensor: + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(x) + out = self.moe(x, gating_output=router_logits) + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + class LlamaMLP(nn.Module): def __init__(self, prefix, config, weights, index): super().__init__() @@ -358,18 +402,40 @@ class FlashLlamaLayer(nn.Module): weights=weights, ) - self.mlp = LlamaMLP( - prefix=f"{prefix}.mlp", config=config, weights=weights, index=index - ) - - self.input_layernorm = FastRMSNorm.load( - prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = FastRMSNorm.load( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=config.rms_norm_eps, - ) + if config.model_type == "phimoe": + moe_layer_cls = ( + SparseMoELayer + if SparseMoELayer.is_supported(weights) + else DenseMoELayer + ) + self.dense = Phi3MoE( + f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights + ) + # with moe the layernorms are are not rmsnorms and they have bias + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + else: + self.dense = LlamaMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, index=index + ) + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) def forward( self, @@ -406,7 +472,7 @@ class FlashLlamaLayer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output, adapter_data) + mlp_output = self.dense(normed_attn_res_output, adapter_data) return mlp_output, attn_res diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py new file mode 100644 index 00000000..bb585cc4 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py @@ -0,0 +1,254 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Phi-MoE model.""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +PHIMOE_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/Phi-3.5-MoE-instruct": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct/resolve/main/config.json", +} + + +class PhiMoEConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PhiMoEModel`]. It is used to instantiate a Phi-MoE + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3.5-MoE-instruct](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the PhiMoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PhiMoEModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 6400): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor`, `long_factor`, `short_mscale`, `long_mscale` and + `original_max_position_embeddings`. The `type` must be `longrope`, the `short_mscale` and `long_scale` must + be numbers, the `short_factor` and `long_factor` must be lists of numbers with the same length as half of + the attention head size and the `original_max_position_embeddings` must be an integer. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `262144`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 16): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.0): + The aux loss factor for the total loss. + router_jitter_noise (`float`, *optional*, defaults to 0.01): + Amount of noise to add to the router. + + ```python + >>> from transformers import PhiMoEModel, PhiMoEConfig + + >>> # Initializing a Phi-3 style configuration + >>> configuration = PhiMoEConfig.from_pretrained("microsoft/Phi-3.5-MoE-instruct") + + >>> # Initializing a model from the configuration + >>> model = PhiMoEModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phimoe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32064, + hidden_size=4096, + intermediate_size=6400, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + rope_scaling=None, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=16, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.01, + input_jitter_noise=0.0, + attention_bias=False, + lm_head_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + self.input_jitter_noise = input_jitter_noise + + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 6: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor`, `long_factor`, " + f"`short_mscale`, `long_mscale` and `original_max_position_embeddings`, got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + rope_scaling_short_mscale = self.rope_scaling.get("short_mscale", None) + rope_scaling_long_mscale = self.rope_scaling.get("long_mscale", None) + original_max_position_embeddings = self.rope_scaling.get( + "original_max_position_embeddings", None + ) + if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}" + ) + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + if ( + not len(rope_scaling_short_factor) + == self.hidden_size // self.num_attention_heads // 2 + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if ( + not len(rope_scaling_long_factor) + == self.hidden_size // self.num_attention_heads // 2 + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" + ) + if not isinstance(rope_scaling_short_mscale, (int, float)): + raise ValueError( + f"`rope_scaling`'s short_mscale field must be a number, got {rope_scaling_short_mscale}" + ) + if not isinstance(rope_scaling_long_mscale, (int, float)): + raise ValueError( + f"`rope_scaling`'s long_mscale field must be a number, got {rope_scaling_long_mscale}" + ) + if not isinstance(original_max_position_embeddings, int): + raise ValueError( + f"`rope_scaling`'s original_max_position_embeddings field must be an integer, got {original_max_position_embeddings}" + )