feat(server): remove trust_remote_code requirement for falcon models (#396)

This commit is contained in:
OlivierDehaene 2023-06-01 12:07:41 +02:00 committed by GitHub
parent d69a0633be
commit c0928e6f26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 11 deletions

View File

@ -1,7 +1,7 @@
import torch
from loguru import logger
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from typing import Optional
@ -138,10 +138,8 @@ def get_model(
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config.model_type
config_dict, _ = PretrainedConfig.get_config_dict(model_id, revision=revision, trust_remote_code=trust_remote_code)
model_type = config_dict["model_type"]
if model_type == "gpt_bigcode":
if sharded:
@ -201,9 +199,9 @@ def get_model(
if model_type in ["RefinedWeb", "RefinedWebModel"]:
if sharded:
if FLASH_ATTENTION:
if config.alibi or (
config.model_type == "RefinedWebModel"
and config.n_head_kv != config.n_head
if config_dict.get("alibi", False) or (
model_type == "RefinedWebModel"
and config_dict.get("multi_query", True)
):
raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded(
@ -216,7 +214,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
)
else:
if FLASH_ATTENTION and not config.alibi:
if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRW(
model_id,
revision,
@ -250,7 +248,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if config.model_type == "opt":
if model_type == "opt":
if sharded:
return OPTSharded(
model_id,
@ -294,7 +292,7 @@ def get_model(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
auto_map = getattr(config, "auto_map", None)
auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM(