feat(server): remove trust_remote_code requirement for falcon models (#396)
This commit is contained in:
parent
d69a0633be
commit
c0928e6f26
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
|
||||
from loguru import logger
|
||||
from transformers import AutoConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import modeling_auto
|
||||
from typing import Optional
|
||||
|
||||
|
@ -138,10 +138,8 @@ def get_model(
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
model_type = config.model_type
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(model_id, revision=revision, trust_remote_code=trust_remote_code)
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
if model_type == "gpt_bigcode":
|
||||
if sharded:
|
||||
|
@ -201,9 +199,9 @@ def get_model(
|
|||
if model_type in ["RefinedWeb", "RefinedWebModel"]:
|
||||
if sharded:
|
||||
if FLASH_ATTENTION:
|
||||
if config.alibi or (
|
||||
config.model_type == "RefinedWebModel"
|
||||
and config.n_head_kv != config.n_head
|
||||
if config_dict.get("alibi", False) or (
|
||||
model_type == "RefinedWebModel"
|
||||
and config_dict.get("multi_query", True)
|
||||
):
|
||||
raise NotImplementedError("sharded is not supported for this model")
|
||||
return FlashRWSharded(
|
||||
|
@ -216,7 +214,7 @@ def get_model(
|
|||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
|
||||
)
|
||||
else:
|
||||
if FLASH_ATTENTION and not config.alibi:
|
||||
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
||||
return FlashRW(
|
||||
model_id,
|
||||
revision,
|
||||
|
@ -250,7 +248,7 @@ def get_model(
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if config.model_type == "opt":
|
||||
if model_type == "opt":
|
||||
if sharded:
|
||||
return OPTSharded(
|
||||
model_id,
|
||||
|
@ -294,7 +292,7 @@ def get_model(
|
|||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
auto_map = getattr(config, "auto_map", None)
|
||||
auto_map = config_dict.get("auto_map", None)
|
||||
if trust_remote_code and auto_map is not None:
|
||||
if "AutoModelForCausalLM" in auto_map.keys():
|
||||
return CausalLM(
|
||||
|
|
Loading…
Reference in New Issue