feat(server): support new falcon config (#712)

This commit is contained in:
OlivierDehaene 2023-07-27 18:38:57 +02:00 committed by GitHub
parent 2efd46ef95
commit ab96b9aec3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 27 deletions

View File

@ -200,13 +200,10 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type in ["RefinedWeb", "RefinedWebModel"]:
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
if sharded:
if FLASH_ATTENTION:
if config_dict.get("alibi", False) or (
model_type == "RefinedWebModel"
and config_dict.get("multi_query", True)
):
if config_dict.get("alibi", False):
raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded(
model_id,
@ -215,9 +212,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
else:
if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRWSharded(

View File

@ -49,8 +49,8 @@ class RWConfig(PretrainedConfig):
model_type="RefinedWeb",
vocab_size=250880,
hidden_size=64,
n_layer=2,
n_head=8,
num_hidden_layers=None,
num_attention_heads=None,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
@ -58,9 +58,10 @@ class RWConfig(PretrainedConfig):
eos_token_id=2,
hidden_dropout=0.0,
attention_dropout=0.0,
n_head_kv=None,
num_kv_heads=None,
multi_query=False,
alibi=False,
new_decoder_architecture=None,
bias=False,
parallel_attn=False,
**kwargs,
@ -78,8 +79,16 @@ class RWConfig(PretrainedConfig):
# Backward compatibility with n_embed kwarg
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer
self.n_head = n_head
self.n_layer = (
num_hidden_layers
if num_hidden_layers is not None
else kwargs.pop("n_layer", 2)
)
self.n_head = (
num_attention_heads
if num_attention_heads is not None
else kwargs.pop("n_head", 8)
)
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
@ -91,10 +100,21 @@ class RWConfig(PretrainedConfig):
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
if n_head_kv is not None:
self.n_head_kv = n_head_kv
if num_kv_heads is not None:
self.n_head_kv = num_kv_heads
else:
self.n_head_kv = 1 if multi_query else n_head
old_n_head_kv = kwargs.pop("n_head_kv", None)
if old_n_head_kv is not None:
self.n_head_kv = old_n_head_kv
else:
self.n_head_kv = 1 if multi_query else self.n_head
if new_decoder_architecture is not None:
self.new_decoder_architecture = new_decoder_architecture
elif model_type == "RefinedWeb":
self.new_decoder_architecture = True
else:
self.new_decoder_architecture = False
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@ -530,15 +550,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
self.word_embeddings = TensorParallelEmbedding(
prefix="transformer.word_embeddings", weights=weights
)
if config.model_type == "RefinedWebModel":
self.h = nn.ModuleList(
[
FlashRWLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = self.h[0].self_attention.num_heads_kv
elif config.model_type == "RefinedWeb":
if config.new_decoder_architecture:
self.h = nn.ModuleList(
[
FlashRWLargeLayer(layer_id, config, weights)
@ -547,9 +560,13 @@ class FlashRWModel(FlashRWPreTrainedModel):
)
self.cache_size = self.h[0].self_attention.num_groups
else:
raise NotImplementedError(
f"model_type {config.model_type} is not supported."
self.h = nn.ModuleList(
[
FlashRWLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = self.h[0].self_attention.num_heads_kv
self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f",