feat(server): support new falcon config (#712)

This commit is contained in:
OlivierDehaene 2023-07-27 18:38:57 +02:00 committed by GitHub
parent 2efd46ef95
commit ab96b9aec3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 27 deletions

View File

@ -200,13 +200,10 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type in ["RefinedWeb", "RefinedWebModel"]: if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
if sharded: if sharded:
if FLASH_ATTENTION: if FLASH_ATTENTION:
if config_dict.get("alibi", False) or ( if config_dict.get("alibi", False):
model_type == "RefinedWebModel"
and config_dict.get("multi_query", True)
):
raise NotImplementedError("sharded is not supported for this model") raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded( return FlashRWSharded(
model_id, model_id,
@ -215,9 +212,7 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
raise NotImplementedError( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
)
else: else:
if FLASH_ATTENTION and not config_dict.get("alibi", False): if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRWSharded( return FlashRWSharded(

View File

@ -49,8 +49,8 @@ class RWConfig(PretrainedConfig):
model_type="RefinedWeb", model_type="RefinedWeb",
vocab_size=250880, vocab_size=250880,
hidden_size=64, hidden_size=64,
n_layer=2, num_hidden_layers=None,
n_head=8, num_attention_heads=None,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
use_cache=True, use_cache=True,
@ -58,9 +58,10 @@ class RWConfig(PretrainedConfig):
eos_token_id=2, eos_token_id=2,
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
n_head_kv=None, num_kv_heads=None,
multi_query=False, multi_query=False,
alibi=False, alibi=False,
new_decoder_architecture=None,
bias=False, bias=False,
parallel_attn=False, parallel_attn=False,
**kwargs, **kwargs,
@ -78,8 +79,16 @@ class RWConfig(PretrainedConfig):
# Backward compatibility with n_embed kwarg # Backward compatibility with n_embed kwarg
n_embed = kwargs.pop("n_embed", None) n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer self.n_layer = (
self.n_head = n_head num_hidden_layers
if num_hidden_layers is not None
else kwargs.pop("n_layer", 2)
)
self.n_head = (
num_attention_heads
if num_attention_heads is not None
else kwargs.pop("n_head", 8)
)
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.use_cache = use_cache self.use_cache = use_cache
@ -91,10 +100,21 @@ class RWConfig(PretrainedConfig):
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
if n_head_kv is not None: if num_kv_heads is not None:
self.n_head_kv = n_head_kv self.n_head_kv = num_kv_heads
else: else:
self.n_head_kv = 1 if multi_query else n_head old_n_head_kv = kwargs.pop("n_head_kv", None)
if old_n_head_kv is not None:
self.n_head_kv = old_n_head_kv
else:
self.n_head_kv = 1 if multi_query else self.n_head
if new_decoder_architecture is not None:
self.new_decoder_architecture = new_decoder_architecture
elif model_type == "RefinedWeb":
self.new_decoder_architecture = True
else:
self.new_decoder_architecture = False
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@ -530,15 +550,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
self.word_embeddings = TensorParallelEmbedding( self.word_embeddings = TensorParallelEmbedding(
prefix="transformer.word_embeddings", weights=weights prefix="transformer.word_embeddings", weights=weights
) )
if config.model_type == "RefinedWebModel":
self.h = nn.ModuleList( if config.new_decoder_architecture:
[
FlashRWLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = self.h[0].self_attention.num_heads_kv
elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLargeLayer(layer_id, config, weights) FlashRWLargeLayer(layer_id, config, weights)
@ -547,9 +560,13 @@ class FlashRWModel(FlashRWPreTrainedModel):
) )
self.cache_size = self.h[0].self_attention.num_groups self.cache_size = self.h[0].self_attention.num_groups
else: else:
raise NotImplementedError( self.h = nn.ModuleList(
f"model_type {config.model_type} is not supported." [
FlashRWLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
) )
self.cache_size = self.h[0].self_attention.num_heads_kv
self.ln_f = FastLayerNorm.load( self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f", prefix="transformer.ln_f",