diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 4be8b98f..b2bde282 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -23,6 +23,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple # Flash attention imports @@ -43,6 +44,56 @@ from text_generation_server.utils.layers import ( ) +class LlamaConfig(PretrainedConfig): + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_scaling=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_scaling = rope_scaling + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + class LlamaRMSNorm(nn.Module): def __init__(self, prefix, weights, eps=1e-6): """ diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 088b50b9..b699799e 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -2,13 +2,13 @@ import torch import torch.distributed from opentelemetry import trace -from transformers import AutoConfig from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, + LlamaConfig, ) from text_generation_server.utils import ( initialize_torch_distributed, @@ -52,7 +52,7 @@ class FlashLlama(FlashCausalLM): trust_remote_code=trust_remote_code, ) - config = AutoConfig.from_pretrained( + config = LlamaConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code )