feat(server): remove trust_remote_code requirement for falcon models (#396)

This commit is contained in:
OlivierDehaene 2023-06-01 12:07:41 +02:00 committed by GitHub
parent d69a0633be
commit c0928e6f26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 11 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
from loguru import logger from loguru import logger
from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from typing import Optional from typing import Optional
@ -138,10 +138,8 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config = AutoConfig.from_pretrained( config_dict, _ = PretrainedConfig.get_config_dict(model_id, revision=revision, trust_remote_code=trust_remote_code)
model_id, revision=revision, trust_remote_code=trust_remote_code model_type = config_dict["model_type"]
)
model_type = config.model_type
if model_type == "gpt_bigcode": if model_type == "gpt_bigcode":
if sharded: if sharded:
@ -201,9 +199,9 @@ def get_model(
if model_type in ["RefinedWeb", "RefinedWebModel"]: if model_type in ["RefinedWeb", "RefinedWebModel"]:
if sharded: if sharded:
if FLASH_ATTENTION: if FLASH_ATTENTION:
if config.alibi or ( if config_dict.get("alibi", False) or (
config.model_type == "RefinedWebModel" model_type == "RefinedWebModel"
and config.n_head_kv != config.n_head and config_dict.get("multi_query", True)
): ):
raise NotImplementedError("sharded is not supported for this model") raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded( return FlashRWSharded(
@ -216,7 +214,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb") FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
) )
else: else:
if FLASH_ATTENTION and not config.alibi: if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRW( return FlashRW(
model_id, model_id,
revision, revision,
@ -250,7 +248,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if config.model_type == "opt": if model_type == "opt":
if sharded: if sharded:
return OPTSharded( return OPTSharded(
model_id, model_id,
@ -294,7 +292,7 @@ def get_model(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
) )
auto_map = getattr(config, "auto_map", None) auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None: if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys(): if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM( return CausalLM(