diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ffc224cc..e9260eed 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -200,13 +200,10 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type in ["RefinedWeb", "RefinedWebModel"]: + if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]: if sharded: if FLASH_ATTENTION: - if config_dict.get("alibi", False) or ( - model_type == "RefinedWebModel" - and config_dict.get("multi_query", True) - ): + if config_dict.get("alibi", False): raise NotImplementedError("sharded is not supported for this model") return FlashRWSharded( model_id, @@ -215,9 +212,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb") - ) + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) else: if FLASH_ATTENTION and not config_dict.get("alibi", False): return FlashRWSharded( diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 1e9539c4..3570b283 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -49,8 +49,8 @@ class RWConfig(PretrainedConfig): model_type="RefinedWeb", vocab_size=250880, hidden_size=64, - n_layer=2, - n_head=8, + num_hidden_layers=None, + num_attention_heads=None, layer_norm_epsilon=1e-5, initializer_range=0.02, use_cache=True, @@ -58,9 +58,10 @@ class RWConfig(PretrainedConfig): eos_token_id=2, hidden_dropout=0.0, attention_dropout=0.0, - n_head_kv=None, + num_kv_heads=None, multi_query=False, alibi=False, + new_decoder_architecture=None, bias=False, parallel_attn=False, **kwargs, @@ -78,8 +79,16 @@ class RWConfig(PretrainedConfig): # Backward compatibility with n_embed kwarg n_embed = kwargs.pop("n_embed", None) self.hidden_size = hidden_size if n_embed is None else n_embed - self.n_layer = n_layer - self.n_head = n_head + self.n_layer = ( + num_hidden_layers + if num_hidden_layers is not None + else kwargs.pop("n_layer", 2) + ) + self.n_head = ( + num_attention_heads + if num_attention_heads is not None + else kwargs.pop("n_head", 8) + ) self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range self.use_cache = use_cache @@ -91,10 +100,21 @@ class RWConfig(PretrainedConfig): self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id - if n_head_kv is not None: - self.n_head_kv = n_head_kv + if num_kv_heads is not None: + self.n_head_kv = num_kv_heads else: - self.n_head_kv = 1 if multi_query else n_head + old_n_head_kv = kwargs.pop("n_head_kv", None) + if old_n_head_kv is not None: + self.n_head_kv = old_n_head_kv + else: + self.n_head_kv = 1 if multi_query else self.n_head + + if new_decoder_architecture is not None: + self.new_decoder_architecture = new_decoder_architecture + elif model_type == "RefinedWeb": + self.new_decoder_architecture = True + else: + self.new_decoder_architecture = False super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -530,15 +550,8 @@ class FlashRWModel(FlashRWPreTrainedModel): self.word_embeddings = TensorParallelEmbedding( prefix="transformer.word_embeddings", weights=weights ) - if config.model_type == "RefinedWebModel": - self.h = nn.ModuleList( - [ - FlashRWLayer(layer_id, config, weights) - for layer_id in range(config.num_hidden_layers) - ] - ) - self.cache_size = self.h[0].self_attention.num_heads_kv - elif config.model_type == "RefinedWeb": + + if config.new_decoder_architecture: self.h = nn.ModuleList( [ FlashRWLargeLayer(layer_id, config, weights) @@ -547,9 +560,13 @@ class FlashRWModel(FlashRWPreTrainedModel): ) self.cache_size = self.h[0].self_attention.num_groups else: - raise NotImplementedError( - f"model_type {config.model_type} is not supported." + self.h = nn.ModuleList( + [ + FlashRWLayer(layer_id, config, weights) + for layer_id in range(config.num_hidden_layers) + ] ) + self.cache_size = self.h[0].self_attention.num_heads_kv self.ln_f = FastLayerNorm.load( prefix="transformer.ln_f",