feat(server): remove trust_remote_code requirement for falcon models (#396)
This commit is contained in:
parent
d69a0633be
commit
c0928e6f26
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers import AutoConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -138,10 +138,8 @@ def get_model(
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config_dict, _ = PretrainedConfig.get_config_dict(model_id, revision=revision, trust_remote_code=trust_remote_code)
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_type = config_dict["model_type"]
|
||||||
)
|
|
||||||
model_type = config.model_type
|
|
||||||
|
|
||||||
if model_type == "gpt_bigcode":
|
if model_type == "gpt_bigcode":
|
||||||
if sharded:
|
if sharded:
|
||||||
|
@ -201,9 +199,9 @@ def get_model(
|
||||||
if model_type in ["RefinedWeb", "RefinedWebModel"]:
|
if model_type in ["RefinedWeb", "RefinedWebModel"]:
|
||||||
if sharded:
|
if sharded:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
if config.alibi or (
|
if config_dict.get("alibi", False) or (
|
||||||
config.model_type == "RefinedWebModel"
|
model_type == "RefinedWebModel"
|
||||||
and config.n_head_kv != config.n_head
|
and config_dict.get("multi_query", True)
|
||||||
):
|
):
|
||||||
raise NotImplementedError("sharded is not supported for this model")
|
raise NotImplementedError("sharded is not supported for this model")
|
||||||
return FlashRWSharded(
|
return FlashRWSharded(
|
||||||
|
@ -216,7 +214,7 @@ def get_model(
|
||||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if FLASH_ATTENTION and not config.alibi:
|
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
||||||
return FlashRW(
|
return FlashRW(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -250,7 +248,7 @@ def get_model(
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.model_type == "opt":
|
if model_type == "opt":
|
||||||
if sharded:
|
if sharded:
|
||||||
return OPTSharded(
|
return OPTSharded(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -294,7 +292,7 @@ def get_model(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
|
||||||
auto_map = getattr(config, "auto_map", None)
|
auto_map = config_dict.get("auto_map", None)
|
||||||
if trust_remote_code and auto_map is not None:
|
if trust_remote_code and auto_map is not None:
|
||||||
if "AutoModelForCausalLM" in auto_map.keys():
|
if "AutoModelForCausalLM" in auto_map.keys():
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
|
|
Loading…
Reference in New Issue