Hotfix: various GPT-based model fixes (#2256)
This commit is contained in:
parent
80adb5be16
commit
18db78f295
|
@ -573,6 +573,10 @@ def get_model(
|
|||
)
|
||||
elif model_type == GPT_NEOX:
|
||||
if FLASH_ATTENTION:
|
||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||
GPTNeoXConfig,
|
||||
)
|
||||
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGPTNeoXForCausalLM,
|
||||
|
@ -582,6 +586,7 @@ def get_model(
|
|||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=GPTNeoXConfig,
|
||||
)
|
||||
elif sharded:
|
||||
return CausalLM(
|
||||
|
|
|
@ -24,7 +24,7 @@ import torch.distributed
|
|||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.layers.attention import (
|
||||
|
@ -45,6 +45,13 @@ from text_generation_server.layers.layernorm import (
|
|||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
|
||||
|
||||
class GPTNeoXConfig(TransformersGPTNeoXConfig):
|
||||
attribute_map = {
|
||||
"num_key_value_heads": "num_attention_heads",
|
||||
}
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
|
@ -65,10 +72,10 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||
|
||||
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||
weight = weights.get_multi_weights_col([prefix], dim=0)
|
||||
if isinstance(weight, torch.Tensor):
|
||||
if isinstance(weight, UnquantizedWeight):
|
||||
# Only on non quantized versions
|
||||
weight = (
|
||||
weight.view(
|
||||
weight.weight = (
|
||||
weight.weight.view(
|
||||
num_heads,
|
||||
3,
|
||||
head_size,
|
||||
|
|
|
@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import (
|
|||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
|
||||
|
||||
class Starcoder2Config(PretrainedConfig):
|
||||
|
@ -129,16 +130,16 @@ def _load_gqa(config, prefix: str, weights):
|
|||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
if isinstance(weight, UnquantizedWeight):
|
||||
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
assert list(weight.weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
if config.use_bias:
|
||||
w = [
|
||||
|
|
Loading…
Reference in New Issue