feat: support phi3.5 moe (#2479)

* feat: support phi3.5 moe model loading

* fix: prefer llama base model and improve rotary logic

* feat: return reasonable generation and add integration test

* fix: run lint and update docs

* fix: rerun lint for openapi docs

* fix: prefer do_sample false unless temp is set by user, and update chat tests

* fix: small typo adjustments

* fix: consolidate long rope paths

* fix: revert greedy by default and test changes

* Vendor configuration so that we don't have to `trust_remote_code`

* Use SparseMoELayer

* Add support for dense MoE

* Some type annotations

* Add the usual model tests

* Ruff.

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
drbh 2024-09-30 11:15:09 +02:00 committed by GitHub
parent 90a1d04a2f
commit 93a7042d7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1164 additions and 17 deletions

View File

@ -20,6 +20,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1) - [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1)
- [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder) - [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)
- [Phi](https://huggingface.co/microsoft/phi-1_5) - [Phi](https://huggingface.co/microsoft/phi-1_5)
- [PhiMoe](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct)
- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) - [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct) - [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) - [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)

View File

@ -0,0 +1,109 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7133789,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027313232,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0623207e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5361328,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17578125,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011539459,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.47436523,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027680397,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
}

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 16030,
"logprob": null,
"text": "gradient"
},
{
"id": 26815,
"logprob": -6.4960938,
"text": "descent"
},
{
"id": 29973,
"logprob": -5.1484375,
"text": "?"
},
{
"id": 13,
"logprob": -4.0351562,
"text": "\n"
},
{
"id": 13,
"logprob": -5.2265625,
"text": "\n"
}
],
"seed": 0,
"tokens": [
{
"id": 10994,
"logprob": -1.1542969,
"special": false,
"text": "Hello"
},
{
"id": 29991,
"logprob": 0.0,
"special": false,
"text": "!"
},
{
"id": 739,
"logprob": 0.0,
"special": false,
"text": " It"
},
{
"id": 2444,
"logprob": -0.42260742,
"special": false,
"text": " seems"
},
{
"id": 366,
"logprob": 0.0,
"special": false,
"text": " you"
},
{
"id": 29915,
"logprob": 0.0,
"special": false,
"text": "'"
},
{
"id": 276,
"logprob": -0.9838867,
"special": false,
"text": "re"
},
{
"id": 3211,
"logprob": 0.0,
"special": false,
"text": " address"
},
{
"id": 292,
"logprob": 0.0,
"special": false,
"text": "ing"
},
{
"id": 263,
"logprob": -0.15124512,
"special": false,
"text": " a"
}
],
"top_tokens": null
},
"generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a"
}

View File

@ -0,0 +1,438 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7133789,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028476715,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023971558,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.000116467476,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.47436523,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027871132,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7128906,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.05053711,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0058594,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.018859863,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.002822876,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.0001155138,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.47436523,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027036667,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.71484375,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.049346924,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"text": "\n"
},
{
"id": 13,
"logprob": -0.86328125,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.017196655,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028438568,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.026558e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011622906,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.48608398,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7192383,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.050445557,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"text": "\n"
},
{
"id": 13,
"logprob": -0.8276367,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.01727295,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027542114,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011301041,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.48608398,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
}
]

View File

@ -0,0 +1,75 @@
import pytest
@pytest.fixture(scope="module")
def flash_phi35_moe_handle(launcher):
with launcher(
"microsoft/Phi-3.5-MoE-instruct",
num_shard=4,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_phi35_moe(flash_phi35_moe_handle):
await flash_phi35_moe_handle.health(300)
return flash_phi35_moe_handle.client
@pytest.mark.asyncio
async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
response = await flash_phi35_moe.generate(
"What is gradient descent?\n\n", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "Gradient descent is a first-order optimization algorithm"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
response = await flash_phi35_moe.generate(
"What is gradient descent?\n\n",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is gradient descent?\n\nHello! It seems you're addressing a"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_snapshot):
responses = await generate_load(
flash_phi35_moe, "What is gradient descent?\n\n", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert responses[0].details.generated_tokens == 10
assert (
responses[0].generated_text
== "Gradient descent is a first-order optimization algorithm"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses == response_snapshot

View File

@ -4,7 +4,9 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_grammar_tools_handle(launcher): def flash_llama_grammar_tools_handle(launcher):
with launcher( with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False "meta-llama/Meta-Llama-3.1-8B-Instruct",
num_shard=2,
disable_grammar_support=False,
) as handle: ) as handle:
yield handle yield handle
@ -208,7 +210,7 @@ async def test_flash_llama_grammar_tools_stream(
async for response in responses: async for response in responses:
count += 1 count += 1
assert count == 48 assert count == 28
assert response == response_snapshot assert response == response_snapshot

View File

@ -159,6 +159,7 @@ pub enum Config {
#[serde(rename = "phi-msft")] #[serde(rename = "phi-msft")]
PhiMsft, PhiMsft,
Phi3, Phi3,
PhiMoe,
Llama, Llama,
Baichuan, Baichuan,
Paligemma(Paligemma), Paligemma(Paligemma),

View File

@ -166,6 +166,20 @@ class PositionRotaryEmbedding(nn.Module):
1 + math.log(scale) / math.log(original_max_position_embeddings) 1 + math.log(scale) / math.log(original_max_position_embeddings)
) )
# if short_mscale and long_mscale are provided we need to scale the freqs
# using the Phi3LongRoPEScaledRotaryEmbedding
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
short_mscale = rope_scaling["short_mscale"]
long_mscale = rope_scaling["long_mscale"]
return Phi3LongRoPEScaledRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
max_position_embeddings=config.max_position_embeddings,
short_mscale=short_mscale,
long_mscale=long_mscale,
original_max_position_embeddings=original_max_position_embeddings,
)
return SuRotaryEmbedding( return SuRotaryEmbedding(
short_inv_freq=short_inv_freq, short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq, long_inv_freq=long_inv_freq,
@ -287,6 +301,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if ( if (
seqlen > self._seq_len_cached seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device or self._cos_cached.device != device
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
@ -308,6 +323,63 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype) self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
def __init__(
self,
short_inv_freq: torch.Tensor,
long_inv_freq: torch.Tensor,
max_position_embeddings: int,
short_mscale: float,
long_mscale: float,
original_max_position_embeddings: int,
):
super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq
self.long_inv_freq = long_inv_freq
self.max_position_embeddings = max_position_embeddings
self.short_mscale = short_mscale
self.long_mscale = long_mscale
self.original_max_position_embeddings = original_max_position_embeddings
# cache
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
def _update_cos_sin_cache(self, dtype, device, seqlen):
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
short_freqs = torch.outer(
t[: self.original_max_position_embeddings],
self.short_inv_freq.to(device=t.device),
)
long_freqs = torch.outer(
t[self.original_max_position_embeddings :],
self.long_inv_freq.to(device=t.device),
)
short_freqs = short_freqs * self.short_mscale
long_freqs = long_freqs * self.long_mscale
freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
freqs[: self.original_max_position_embeddings] = short_freqs
freqs[self.original_max_position_embeddings :] = long_freqs
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
@ -467,7 +539,6 @@ def apply_llama3_scaling(
elif wavelen > low_freq_wavelen: elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scaling_factor) new_freqs.append(freq / scaling_factor)
else: else:
assert low_freq_wavelen != high_freq_wavelen assert low_freq_wavelen != high_freq_wavelen
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / ( smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor high_freq_factor - low_freq_factor

View File

@ -32,6 +32,9 @@ from text_generation_server.models.custom_modeling.phi_modeling import (
PhiConfig, PhiConfig,
PhiForCausalLM, PhiForCausalLM,
) )
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
PhiMoEConfig,
)
from text_generation_server.models.custom_modeling.t5_modeling import ( from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration, T5ForConditionalGeneration,
) )
@ -237,6 +240,11 @@ class ModelType(enum.Enum):
"name": "Phi", "name": "Phi",
"url": "https://huggingface.co/microsoft/phi-1_5", "url": "https://huggingface.co/microsoft/phi-1_5",
} }
PHI_MOE = {
"type": "phimoe",
"name": "PhiMoe",
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
}
BAICHUAN = { BAICHUAN = {
"type": "baichuan", "type": "baichuan",
"name": "Baichuan", "name": "Baichuan",
@ -768,6 +776,29 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == PHI_MOE:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
config_class=PhiMoEConfig,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "phi-msft": elif model_type == "phi-msft":
if FLASH_ATTENTION: if FLASH_ATTENTION:
raise NotImplementedError( raise NotImplementedError(

View File

@ -19,7 +19,7 @@
# limitations under the License. # limitations under the License.
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Type
import torch import torch
import torch.distributed import torch.distributed
@ -28,6 +28,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
@ -46,12 +47,19 @@ from text_generation_server.layers import (
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
FastLayerNorm,
)
from text_generation_server.layers import (
FastLinear,
) )
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
Weights, Weights,
) )
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
if SYSTEM != "ipex":
pass
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
from vllm import _custom_C from vllm import _custom_C
@ -245,6 +253,42 @@ class FlashLlamaAttention(torch.nn.Module):
) )
class Phi3MoE(nn.Module):
def __init__(
self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights
):
super().__init__()
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.moe = moe_layer_cls(
prefix=f"{prefix}.experts",
n_experts=config.num_local_experts,
n_expert_group=None,
renormalize=True,
topk=config.num_experts_per_tok,
topk_group=None,
weights=weights,
gate_proj_name="w1",
up_proj_name="w3",
down_proj_name="w2",
)
self.process_group = weights.process_group
def forward(self, x, adapter_data) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = self.moe(x, gating_output=router_logits)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights, index): def __init__(self, prefix, config, weights, index):
super().__init__() super().__init__()
@ -358,18 +402,40 @@ class FlashLlamaLayer(nn.Module):
weights=weights, weights=weights,
) )
self.mlp = LlamaMLP( if config.model_type == "phimoe":
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index moe_layer_cls = (
) SparseMoELayer
if SparseMoELayer.is_supported(weights)
self.input_layernorm = FastRMSNorm.load( else DenseMoELayer
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps )
) self.dense = Phi3MoE(
self.post_attention_layernorm = FastRMSNorm.load( f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
prefix=f"{prefix}.post_attention_layernorm", )
weights=weights, # with moe the layernorms are are not rmsnorms and they have bias
eps=config.rms_norm_eps, self.input_layernorm = FastLayerNorm.load(
) prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
else:
self.dense = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward( def forward(
self, self,
@ -406,7 +472,7 @@ class FlashLlamaLayer(nn.Module):
attn_output, res attn_output, res
) )
mlp_output = self.mlp(normed_attn_res_output, adapter_data) mlp_output = self.dense(normed_attn_res_output, adapter_data)
return mlp_output, attn_res return mlp_output, attn_res

View File

@ -0,0 +1,254 @@
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Phi-MoE model."""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
PHIMOE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"microsoft/Phi-3.5-MoE-instruct": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct/resolve/main/config.json",
}
class PhiMoEConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PhiMoEModel`]. It is used to instantiate a Phi-MoE
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the
[microsoft/Phi-3.5-MoE-instruct](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32064):
Vocabulary size of the PhiMoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`PhiMoEModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 6400):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*):
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
contain the following keys: `type`, `short_factor`, `long_factor`, `short_mscale`, `long_mscale` and
`original_max_position_embeddings`. The `type` must be `longrope`, the `short_mscale` and `long_scale` must
be numbers, the `short_factor` and `long_factor` must be lists of numbers with the same length as half of
the attention head size and the `original_max_position_embeddings` must be an integer.
sliding_window (`int`, *optional*):
Sliding window attention window size. If not specified, will default to `262144`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
num_experts_per_tok (`int`, *optional*, defaults to 2):
The number of experts to root per-token, can be also interpreted as the `top-p` routing
parameter
num_local_experts (`int`, *optional*, defaults to 16):
Number of experts per Sparse MLP layer.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.0):
The aux loss factor for the total loss.
router_jitter_noise (`float`, *optional*, defaults to 0.01):
Amount of noise to add to the router.
```python
>>> from transformers import PhiMoEModel, PhiMoEConfig
>>> # Initializing a Phi-3 style configuration
>>> configuration = PhiMoEConfig.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
>>> # Initializing a model from the configuration
>>> model = PhiMoEModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "phimoe"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32064,
hidden_size=4096,
intermediate_size=6400,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=1e6,
rope_scaling=None,
sliding_window=None,
attention_dropout=0.0,
num_experts_per_tok=2,
num_local_experts=16,
output_router_logits=False,
router_aux_loss_coef=0.001,
router_jitter_noise=0.01,
input_jitter_noise=0.0,
attention_bias=False,
lm_head_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.attention_bias = attention_bias
self.lm_head_bias = lm_head_bias
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
self.input_jitter_noise = input_jitter_noise
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 6:
raise ValueError(
"`rope_scaling` must be a dictionary with three fields, `type`, `short_factor`, `long_factor`, "
f"`short_mscale`, `long_mscale` and `original_max_position_embeddings`, got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
rope_scaling_short_mscale = self.rope_scaling.get("short_mscale", None)
rope_scaling_long_mscale = self.rope_scaling.get("long_mscale", None)
original_max_position_embeddings = self.rope_scaling.get(
"original_max_position_embeddings", None
)
if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}"
)
if not (
isinstance(rope_scaling_short_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
):
raise ValueError(
f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
)
if (
not len(rope_scaling_short_factor)
== self.hidden_size // self.num_attention_heads // 2
):
raise ValueError(
f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
)
if not (
isinstance(rope_scaling_long_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
):
raise ValueError(
f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
)
if (
not len(rope_scaling_long_factor)
== self.hidden_size // self.num_attention_heads // 2
):
raise ValueError(
f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
)
if not isinstance(rope_scaling_short_mscale, (int, float)):
raise ValueError(
f"`rope_scaling`'s short_mscale field must be a number, got {rope_scaling_short_mscale}"
)
if not isinstance(rope_scaling_long_mscale, (int, float)):
raise ValueError(
f"`rope_scaling`'s long_mscale field must be a number, got {rope_scaling_long_mscale}"
)
if not isinstance(original_max_position_embeddings, int):
raise ValueError(
f"`rope_scaling`'s original_max_position_embeddings field must be an integer, got {original_max_position_embeddings}"
)