feat(server): support new falcon config (#712)
This commit is contained in:
parent
2efd46ef95
commit
ab96b9aec3
|
@ -200,13 +200,10 @@ def get_model(
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type in ["RefinedWeb", "RefinedWebModel"]:
|
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
|
||||||
if sharded:
|
if sharded:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
if config_dict.get("alibi", False) or (
|
if config_dict.get("alibi", False):
|
||||||
model_type == "RefinedWebModel"
|
|
||||||
and config_dict.get("multi_query", True)
|
|
||||||
):
|
|
||||||
raise NotImplementedError("sharded is not supported for this model")
|
raise NotImplementedError("sharded is not supported for this model")
|
||||||
return FlashRWSharded(
|
return FlashRWSharded(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -215,9 +212,7 @@ def get_model(
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
||||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
||||||
return FlashRWSharded(
|
return FlashRWSharded(
|
||||||
|
|
|
@ -49,8 +49,8 @@ class RWConfig(PretrainedConfig):
|
||||||
model_type="RefinedWeb",
|
model_type="RefinedWeb",
|
||||||
vocab_size=250880,
|
vocab_size=250880,
|
||||||
hidden_size=64,
|
hidden_size=64,
|
||||||
n_layer=2,
|
num_hidden_layers=None,
|
||||||
n_head=8,
|
num_attention_heads=None,
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
|
@ -58,9 +58,10 @@ class RWConfig(PretrainedConfig):
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
hidden_dropout=0.0,
|
hidden_dropout=0.0,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
n_head_kv=None,
|
num_kv_heads=None,
|
||||||
multi_query=False,
|
multi_query=False,
|
||||||
alibi=False,
|
alibi=False,
|
||||||
|
new_decoder_architecture=None,
|
||||||
bias=False,
|
bias=False,
|
||||||
parallel_attn=False,
|
parallel_attn=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -78,8 +79,16 @@ class RWConfig(PretrainedConfig):
|
||||||
# Backward compatibility with n_embed kwarg
|
# Backward compatibility with n_embed kwarg
|
||||||
n_embed = kwargs.pop("n_embed", None)
|
n_embed = kwargs.pop("n_embed", None)
|
||||||
self.hidden_size = hidden_size if n_embed is None else n_embed
|
self.hidden_size = hidden_size if n_embed is None else n_embed
|
||||||
self.n_layer = n_layer
|
self.n_layer = (
|
||||||
self.n_head = n_head
|
num_hidden_layers
|
||||||
|
if num_hidden_layers is not None
|
||||||
|
else kwargs.pop("n_layer", 2)
|
||||||
|
)
|
||||||
|
self.n_head = (
|
||||||
|
num_attention_heads
|
||||||
|
if num_attention_heads is not None
|
||||||
|
else kwargs.pop("n_head", 8)
|
||||||
|
)
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
|
@ -91,10 +100,21 @@ class RWConfig(PretrainedConfig):
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
if n_head_kv is not None:
|
if num_kv_heads is not None:
|
||||||
self.n_head_kv = n_head_kv
|
self.n_head_kv = num_kv_heads
|
||||||
else:
|
else:
|
||||||
self.n_head_kv = 1 if multi_query else n_head
|
old_n_head_kv = kwargs.pop("n_head_kv", None)
|
||||||
|
if old_n_head_kv is not None:
|
||||||
|
self.n_head_kv = old_n_head_kv
|
||||||
|
else:
|
||||||
|
self.n_head_kv = 1 if multi_query else self.n_head
|
||||||
|
|
||||||
|
if new_decoder_architecture is not None:
|
||||||
|
self.new_decoder_architecture = new_decoder_architecture
|
||||||
|
elif model_type == "RefinedWeb":
|
||||||
|
self.new_decoder_architecture = True
|
||||||
|
else:
|
||||||
|
self.new_decoder_architecture = False
|
||||||
|
|
||||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
|
|
||||||
|
@ -530,15 +550,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
self.word_embeddings = TensorParallelEmbedding(
|
self.word_embeddings = TensorParallelEmbedding(
|
||||||
prefix="transformer.word_embeddings", weights=weights
|
prefix="transformer.word_embeddings", weights=weights
|
||||||
)
|
)
|
||||||
if config.model_type == "RefinedWebModel":
|
|
||||||
self.h = nn.ModuleList(
|
if config.new_decoder_architecture:
|
||||||
[
|
|
||||||
FlashRWLayer(layer_id, config, weights)
|
|
||||||
for layer_id in range(config.num_hidden_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.cache_size = self.h[0].self_attention.num_heads_kv
|
|
||||||
elif config.model_type == "RefinedWeb":
|
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashRWLargeLayer(layer_id, config, weights)
|
FlashRWLargeLayer(layer_id, config, weights)
|
||||||
|
@ -547,9 +560,13 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
)
|
)
|
||||||
self.cache_size = self.h[0].self_attention.num_groups
|
self.cache_size = self.h[0].self_attention.num_groups
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
self.h = nn.ModuleList(
|
||||||
f"model_type {config.model_type} is not supported."
|
[
|
||||||
|
FlashRWLayer(layer_id, config, weights)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
self.cache_size = self.h[0].self_attention.num_heads_kv
|
||||||
|
|
||||||
self.ln_f = FastLayerNorm.load(
|
self.ln_f = FastLayerNorm.load(
|
||||||
prefix="transformer.ln_f",
|
prefix="transformer.ln_f",
|
||||||
|
|
Loading…
Reference in New Issue