Hotfix: various GPT-based model fixes (#2256)
This commit is contained in:
parent
80adb5be16
commit
18db78f295
|
@ -573,6 +573,10 @@ def get_model(
|
||||||
)
|
)
|
||||||
elif model_type == GPT_NEOX:
|
elif model_type == GPT_NEOX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
|
GPTNeoXConfig,
|
||||||
|
)
|
||||||
|
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashGPTNeoXForCausalLM,
|
model_class=FlashGPTNeoXForCausalLM,
|
||||||
|
@ -582,6 +586,7 @@ def get_model(
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=GPTNeoXConfig,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
|
|
|
@ -24,7 +24,7 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
|
@ -45,6 +45,13 @@ from text_generation_server.layers.layernorm import (
|
||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
|
class GPTNeoXConfig(TransformersGPTNeoXConfig):
|
||||||
|
attribute_map = {
|
||||||
|
"num_key_value_heads": "num_attention_heads",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
|
@ -65,10 +72,10 @@ def load_row(config, prefix: str, weights, bias: bool):
|
||||||
|
|
||||||
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||||
weight = weights.get_multi_weights_col([prefix], dim=0)
|
weight = weights.get_multi_weights_col([prefix], dim=0)
|
||||||
if isinstance(weight, torch.Tensor):
|
if isinstance(weight, UnquantizedWeight):
|
||||||
# Only on non quantized versions
|
# Only on non quantized versions
|
||||||
weight = (
|
weight.weight = (
|
||||||
weight.view(
|
weight.weight.view(
|
||||||
num_heads,
|
num_heads,
|
||||||
3,
|
3,
|
||||||
head_size,
|
head_size,
|
||||||
|
|
|
@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import (
|
||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2Config(PretrainedConfig):
|
class Starcoder2Config(PretrainedConfig):
|
||||||
|
@ -129,16 +130,16 @@ def _load_gqa(config, prefix: str, weights):
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
if isinstance(weight, UnquantizedWeight):
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
assert list(weight.shape) == [
|
assert list(weight.weight.shape) == [
|
||||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
if config.use_bias:
|
if config.use_bias:
|
||||||
w = [
|
w = [
|
||||||
|
|
Loading…
Reference in New Issue