Hotfix: various GPT-based model fixes (#2256)

This commit is contained in:
Daniël de Kok 2024-07-19 14:42:19 +02:00 committed by GitHub
parent 80adb5be16
commit 18db78f295
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 8 deletions

View File

@ -573,6 +573,10 @@ def get_model(
) )
elif model_type == GPT_NEOX: elif model_type == GPT_NEOX:
if FLASH_ATTENTION: if FLASH_ATTENTION:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig,
)
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashGPTNeoXForCausalLM, model_class=FlashGPTNeoXForCausalLM,
@ -582,6 +586,7 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig,
) )
elif sharded: elif sharded:
return CausalLM( return CausalLM(

View File

@ -24,7 +24,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -45,6 +45,13 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight
class GPTNeoXConfig(TransformersGPTNeoXConfig):
attribute_map = {
"num_key_value_heads": "num_attention_heads",
}
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
@ -65,10 +72,10 @@ def load_row(config, prefix: str, weights, bias: bool):
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_multi_weights_col([prefix], dim=0) weight = weights.get_multi_weights_col([prefix], dim=0)
if isinstance(weight, torch.Tensor): if isinstance(weight, UnquantizedWeight):
# Only on non quantized versions # Only on non quantized versions
weight = ( weight.weight = (
weight.view( weight.weight.view(
num_heads, num_heads,
3, 3,
head_size, head_size,

View File

@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight
class Starcoder2Config(PretrainedConfig): class Starcoder2Config(PretrainedConfig):
@ -129,16 +130,16 @@ def _load_gqa(config, prefix: str, weights):
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq", "marlin"]: if isinstance(weight, UnquantizedWeight):
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size() num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [ assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size, (num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
if config.use_bias: if config.use_bias:
w = [ w = [