fix(server): fix llamav2 config (#635)

This commit is contained in:
OlivierDehaene 2023-07-18 18:49:42 +02:00 committed by GitHub
parent cf83f9b66f
commit 5e6ddfd6a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 2 deletions

View File

@ -23,6 +23,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
@ -43,6 +44,56 @@ from text_generation_server.utils.layers import (
) )
class LlamaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class LlamaRMSNorm(nn.Module): class LlamaRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6): def __init__(self, prefix, weights, eps=1e-6):
""" """

View File

@ -2,13 +2,13 @@ import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoConfig
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
from typing import Optional from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
LlamaConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
@ -52,7 +52,7 @@ class FlashLlama(FlashCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config = AutoConfig.from_pretrained( config = LlamaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )