feat: support phi3.5 moe (#2479)

* feat: support phi3.5 moe model loading

* fix: prefer llama base model and improve rotary logic

* feat: return reasonable generation and add integration test

* fix: run lint and update docs

* fix: rerun lint for openapi docs

* fix: prefer do_sample false unless temp is set by user, and update chat tests

* fix: small typo adjustments

* fix: consolidate long rope paths

* fix: revert greedy by default and test changes

* Vendor configuration so that we don't have to `trust_remote_code`

* Use SparseMoELayer

* Add support for dense MoE

* Some type annotations

* Add the usual model tests

* Ruff.

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
drbh 2024-09-30 11:15:09 +02:00 committed by GitHub
parent 90a1d04a2f
commit 93a7042d7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1164 additions and 17 deletions

View File

@ -20,6 +20,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1)
- [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)
- [Phi](https://huggingface.co/microsoft/phi-1_5)
- [PhiMoe](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct)
- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)

View File

@ -0,0 +1,109 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7133789,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027313232,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0623207e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5361328,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17578125,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011539459,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.47436523,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027680397,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
}

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 16030,
"logprob": null,
"text": "gradient"
},
{
"id": 26815,
"logprob": -6.4960938,
"text": "descent"
},
{
"id": 29973,
"logprob": -5.1484375,
"text": "?"
},
{
"id": 13,
"logprob": -4.0351562,
"text": "\n"
},
{
"id": 13,
"logprob": -5.2265625,
"text": "\n"
}
],
"seed": 0,
"tokens": [
{
"id": 10994,
"logprob": -1.1542969,
"special": false,
"text": "Hello"
},
{
"id": 29991,
"logprob": 0.0,
"special": false,
"text": "!"
},
{
"id": 739,
"logprob": 0.0,
"special": false,
"text": " It"
},
{
"id": 2444,
"logprob": -0.42260742,
"special": false,
"text": " seems"
},
{
"id": 366,
"logprob": 0.0,
"special": false,
"text": " you"
},
{
"id": 29915,
"logprob": 0.0,
"special": false,
"text": "'"
},
{
"id": 276,
"logprob": -0.9838867,
"special": false,
"text": "re"
},
{
"id": 3211,
"logprob": 0.0,
"special": false,
"text": " address"
},
{
"id": 292,
"logprob": 0.0,
"special": false,
"text": "ing"
},
{
"id": 263,
"logprob": -0.15124512,
"special": false,
"text": " a"
}
],
"top_tokens": null
},
"generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a"
}

View File

@ -0,0 +1,438 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7133789,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028476715,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023971558,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.000116467476,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.47436523,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027871132,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7128906,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.05053711,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0058594,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.018859863,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.002822876,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.0001155138,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.47436523,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027036667,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.71484375,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.049346924,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"text": "\n"
},
{
"id": 13,
"logprob": -0.86328125,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.017196655,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028438568,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.026558e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011622906,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.48608398,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7192383,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.050445557,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"text": "\n"
},
{
"id": 13,
"logprob": -0.8276367,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.01727295,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027542114,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011301041,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.48608398,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
}
]

View File

@ -0,0 +1,75 @@
import pytest
@pytest.fixture(scope="module")
def flash_phi35_moe_handle(launcher):
with launcher(
"microsoft/Phi-3.5-MoE-instruct",
num_shard=4,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_phi35_moe(flash_phi35_moe_handle):
await flash_phi35_moe_handle.health(300)
return flash_phi35_moe_handle.client
@pytest.mark.asyncio
async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
response = await flash_phi35_moe.generate(
"What is gradient descent?\n\n", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "Gradient descent is a first-order optimization algorithm"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
response = await flash_phi35_moe.generate(
"What is gradient descent?\n\n",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is gradient descent?\n\nHello! It seems you're addressing a"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_snapshot):
responses = await generate_load(
flash_phi35_moe, "What is gradient descent?\n\n", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert responses[0].details.generated_tokens == 10
assert (
responses[0].generated_text
== "Gradient descent is a first-order optimization algorithm"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses == response_snapshot

View File

@ -4,7 +4,9 @@ import pytest
@pytest.fixture(scope="module")
def flash_llama_grammar_tools_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
"meta-llama/Meta-Llama-3.1-8B-Instruct",
num_shard=2,
disable_grammar_support=False,
) as handle:
yield handle
@ -208,7 +210,7 @@ async def test_flash_llama_grammar_tools_stream(
async for response in responses:
count += 1
assert count == 48
assert count == 28
assert response == response_snapshot

View File

@ -159,6 +159,7 @@ pub enum Config {
#[serde(rename = "phi-msft")]
PhiMsft,
Phi3,
PhiMoe,
Llama,
Baichuan,
Paligemma(Paligemma),

View File

@ -166,6 +166,20 @@ class PositionRotaryEmbedding(nn.Module):
1 + math.log(scale) / math.log(original_max_position_embeddings)
)
# if short_mscale and long_mscale are provided we need to scale the freqs
# using the Phi3LongRoPEScaledRotaryEmbedding
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
short_mscale = rope_scaling["short_mscale"]
long_mscale = rope_scaling["long_mscale"]
return Phi3LongRoPEScaledRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
max_position_embeddings=config.max_position_embeddings,
short_mscale=short_mscale,
long_mscale=long_mscale,
original_max_position_embeddings=original_max_position_embeddings,
)
return SuRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
@ -287,6 +301,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
@ -308,6 +323,63 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
def __init__(
self,
short_inv_freq: torch.Tensor,
long_inv_freq: torch.Tensor,
max_position_embeddings: int,
short_mscale: float,
long_mscale: float,
original_max_position_embeddings: int,
):
super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq
self.long_inv_freq = long_inv_freq
self.max_position_embeddings = max_position_embeddings
self.short_mscale = short_mscale
self.long_mscale = long_mscale
self.original_max_position_embeddings = original_max_position_embeddings
# cache
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
def _update_cos_sin_cache(self, dtype, device, seqlen):
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
short_freqs = torch.outer(
t[: self.original_max_position_embeddings],
self.short_inv_freq.to(device=t.device),
)
long_freqs = torch.outer(
t[self.original_max_position_embeddings :],
self.long_inv_freq.to(device=t.device),
)
short_freqs = short_freqs * self.short_mscale
long_freqs = long_freqs * self.long_mscale
freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
freqs[: self.original_max_position_embeddings] = short_freqs
freqs[self.original_max_position_embeddings :] = long_freqs
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device)
@ -467,7 +539,6 @@ def apply_llama3_scaling(
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scaling_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor

View File

@ -32,6 +32,9 @@ from text_generation_server.models.custom_modeling.phi_modeling import (
PhiConfig,
PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
PhiMoEConfig,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
@ -237,6 +240,11 @@ class ModelType(enum.Enum):
"name": "Phi",
"url": "https://huggingface.co/microsoft/phi-1_5",
}
PHI_MOE = {
"type": "phimoe",
"name": "PhiMoe",
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
}
BAICHUAN = {
"type": "baichuan",
"name": "Baichuan",
@ -768,6 +776,29 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == PHI_MOE:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
config_class=PhiMoEConfig,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "phi-msft":
if FLASH_ATTENTION:
raise NotImplementedError(

View File

@ -19,7 +19,7 @@
# limitations under the License.
from contextlib import contextmanager
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Type
import torch
import torch.distributed
@ -28,6 +28,7 @@ from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
@ -46,12 +47,19 @@ from text_generation_server.layers import (
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
FastLayerNorm,
)
from text_generation_server.layers import (
FastLinear,
)
from text_generation_server.utils.weights import (
Weights,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
if SYSTEM != "ipex":
pass
if SYSTEM == "rocm":
try:
from vllm import _custom_C
@ -245,6 +253,42 @@ class FlashLlamaAttention(torch.nn.Module):
)
class Phi3MoE(nn.Module):
def __init__(
self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights
):
super().__init__()
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.moe = moe_layer_cls(
prefix=f"{prefix}.experts",
n_experts=config.num_local_experts,
n_expert_group=None,
renormalize=True,
topk=config.num_experts_per_tok,
topk_group=None,
weights=weights,
gate_proj_name="w1",
up_proj_name="w3",
down_proj_name="w2",
)
self.process_group = weights.process_group
def forward(self, x, adapter_data) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = self.moe(x, gating_output=router_logits)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights, index):
super().__init__()
@ -358,12 +402,34 @@ class FlashLlamaLayer(nn.Module):
weights=weights,
)
self.mlp = LlamaMLP(
if config.model_type == "phimoe":
moe_layer_cls = (
SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.dense = Phi3MoE(
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
)
# with moe the layernorms are are not rmsnorms and they have bias
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
else:
self.dense = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
@ -406,7 +472,7 @@ class FlashLlamaLayer(nn.Module):
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
mlp_output = self.dense(normed_attn_res_output, adapter_data)
return mlp_output, attn_res

View File

@ -0,0 +1,254 @@
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Phi-MoE model."""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
PHIMOE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"microsoft/Phi-3.5-MoE-instruct": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct/resolve/main/config.json",
}
class PhiMoEConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PhiMoEModel`]. It is used to instantiate a Phi-MoE
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the
[microsoft/Phi-3.5-MoE-instruct](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32064):
Vocabulary size of the PhiMoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`PhiMoEModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 6400):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*):
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
contain the following keys: `type`, `short_factor`, `long_factor`, `short_mscale`, `long_mscale` and
`original_max_position_embeddings`. The `type` must be `longrope`, the `short_mscale` and `long_scale` must
be numbers, the `short_factor` and `long_factor` must be lists of numbers with the same length as half of
the attention head size and the `original_max_position_embeddings` must be an integer.
sliding_window (`int`, *optional*):
Sliding window attention window size. If not specified, will default to `262144`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
num_experts_per_tok (`int`, *optional*, defaults to 2):
The number of experts to root per-token, can be also interpreted as the `top-p` routing
parameter
num_local_experts (`int`, *optional*, defaults to 16):
Number of experts per Sparse MLP layer.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.0):
The aux loss factor for the total loss.
router_jitter_noise (`float`, *optional*, defaults to 0.01):
Amount of noise to add to the router.
```python
>>> from transformers import PhiMoEModel, PhiMoEConfig
>>> # Initializing a Phi-3 style configuration
>>> configuration = PhiMoEConfig.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
>>> # Initializing a model from the configuration
>>> model = PhiMoEModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "phimoe"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32064,
hidden_size=4096,
intermediate_size=6400,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=1e6,
rope_scaling=None,
sliding_window=None,
attention_dropout=0.0,
num_experts_per_tok=2,
num_local_experts=16,
output_router_logits=False,
router_aux_loss_coef=0.001,
router_jitter_noise=0.01,
input_jitter_noise=0.0,
attention_bias=False,
lm_head_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.attention_bias = attention_bias
self.lm_head_bias = lm_head_bias
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
self.input_jitter_noise = input_jitter_noise
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 6:
raise ValueError(
"`rope_scaling` must be a dictionary with three fields, `type`, `short_factor`, `long_factor`, "
f"`short_mscale`, `long_mscale` and `original_max_position_embeddings`, got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
rope_scaling_short_mscale = self.rope_scaling.get("short_mscale", None)
rope_scaling_long_mscale = self.rope_scaling.get("long_mscale", None)
original_max_position_embeddings = self.rope_scaling.get(
"original_max_position_embeddings", None
)
if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}"
)
if not (
isinstance(rope_scaling_short_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
):
raise ValueError(
f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
)
if (
not len(rope_scaling_short_factor)
== self.hidden_size // self.num_attention_heads // 2
):
raise ValueError(
f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
)
if not (
isinstance(rope_scaling_long_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
):
raise ValueError(
f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
)
if (
not len(rope_scaling_long_factor)
== self.hidden_size // self.num_attention_heads // 2
):
raise ValueError(
f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
)
if not isinstance(rope_scaling_short_mscale, (int, float)):
raise ValueError(
f"`rope_scaling`'s short_mscale field must be a number, got {rope_scaling_short_mscale}"
)
if not isinstance(rope_scaling_long_mscale, (int, float)):
raise ValueError(
f"`rope_scaling`'s long_mscale field must be a number, got {rope_scaling_long_mscale}"
)
if not isinstance(original_max_position_embeddings, int):
raise ValueError(
f"`rope_scaling`'s original_max_position_embeddings field must be an integer, got {original_max_position_embeddings}"
)