feat(server): support new falcon config (#712)
This commit is contained in:
parent
2efd46ef95
commit
ab96b9aec3
|
@ -200,13 +200,10 @@ def get_model(
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type in ["RefinedWeb", "RefinedWebModel"]:
|
||||
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
|
||||
if sharded:
|
||||
if FLASH_ATTENTION:
|
||||
if config_dict.get("alibi", False) or (
|
||||
model_type == "RefinedWebModel"
|
||||
and config_dict.get("multi_query", True)
|
||||
):
|
||||
if config_dict.get("alibi", False):
|
||||
raise NotImplementedError("sharded is not supported for this model")
|
||||
return FlashRWSharded(
|
||||
model_id,
|
||||
|
@ -215,9 +212,7 @@ def get_model(
|
|||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
|
||||
)
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
||||
else:
|
||||
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
||||
return FlashRWSharded(
|
||||
|
|
|
@ -49,8 +49,8 @@ class RWConfig(PretrainedConfig):
|
|||
model_type="RefinedWeb",
|
||||
vocab_size=250880,
|
||||
hidden_size=64,
|
||||
n_layer=2,
|
||||
n_head=8,
|
||||
num_hidden_layers=None,
|
||||
num_attention_heads=None,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
|
@ -58,9 +58,10 @@ class RWConfig(PretrainedConfig):
|
|||
eos_token_id=2,
|
||||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
n_head_kv=None,
|
||||
num_kv_heads=None,
|
||||
multi_query=False,
|
||||
alibi=False,
|
||||
new_decoder_architecture=None,
|
||||
bias=False,
|
||||
parallel_attn=False,
|
||||
**kwargs,
|
||||
|
@ -78,8 +79,16 @@ class RWConfig(PretrainedConfig):
|
|||
# Backward compatibility with n_embed kwarg
|
||||
n_embed = kwargs.pop("n_embed", None)
|
||||
self.hidden_size = hidden_size if n_embed is None else n_embed
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.n_layer = (
|
||||
num_hidden_layers
|
||||
if num_hidden_layers is not None
|
||||
else kwargs.pop("n_layer", 2)
|
||||
)
|
||||
self.n_head = (
|
||||
num_attention_heads
|
||||
if num_attention_heads is not None
|
||||
else kwargs.pop("n_head", 8)
|
||||
)
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
|
@ -91,10 +100,21 @@ class RWConfig(PretrainedConfig):
|
|||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
if n_head_kv is not None:
|
||||
self.n_head_kv = n_head_kv
|
||||
if num_kv_heads is not None:
|
||||
self.n_head_kv = num_kv_heads
|
||||
else:
|
||||
self.n_head_kv = 1 if multi_query else n_head
|
||||
old_n_head_kv = kwargs.pop("n_head_kv", None)
|
||||
if old_n_head_kv is not None:
|
||||
self.n_head_kv = old_n_head_kv
|
||||
else:
|
||||
self.n_head_kv = 1 if multi_query else self.n_head
|
||||
|
||||
if new_decoder_architecture is not None:
|
||||
self.new_decoder_architecture = new_decoder_architecture
|
||||
elif model_type == "RefinedWeb":
|
||||
self.new_decoder_architecture = True
|
||||
else:
|
||||
self.new_decoder_architecture = False
|
||||
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
|
@ -530,15 +550,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||
self.word_embeddings = TensorParallelEmbedding(
|
||||
prefix="transformer.word_embeddings", weights=weights
|
||||
)
|
||||
if config.model_type == "RefinedWebModel":
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLayer(layer_id, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.cache_size = self.h[0].self_attention.num_heads_kv
|
||||
elif config.model_type == "RefinedWeb":
|
||||
|
||||
if config.new_decoder_architecture:
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLargeLayer(layer_id, config, weights)
|
||||
|
@ -547,9 +560,13 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||
)
|
||||
self.cache_size = self.h[0].self_attention.num_groups
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"model_type {config.model_type} is not supported."
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLayer(layer_id, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.cache_size = self.h[0].self_attention.num_heads_kv
|
||||
|
||||
self.ln_f = FastLayerNorm.load(
|
||||
prefix="transformer.ln_f",
|
||||
|
|
Loading…
Reference in New Issue