From 462530c2b05013e0822fd3d9e38fb08adc0a84b6 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 27 Mar 2023 00:23:22 -0700 Subject: [PATCH] fix(server): Avoid using try/except to determine kind of AutoModel (#142) --- server/text_generation_server/models/__init__.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f3a92ad2..90c70cb5 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -3,6 +3,7 @@ import torch from loguru import logger from transformers import AutoConfig +from transformers.models.auto import modeling_auto from typing import Optional from text_generation_server.models.model import Model @@ -65,14 +66,15 @@ def get_model( return SantaCoder(model_id, revision, quantize) config = AutoConfig.from_pretrained(model_id, revision=revision) + model_type = config.model_type - if config.model_type == "bloom": + if model_type == "bloom": if sharded: return BLOOMSharded(model_id, revision, quantize=quantize) else: return BLOOM(model_id, revision, quantize=quantize) - if config.model_type == "gpt_neox": + if model_type == "gpt_neox": if sharded: neox_cls = FlashNeoXSharded if FLASH_NEOX else GPTNeoxSharded return neox_cls(model_id, revision, quantize=quantize) @@ -80,7 +82,7 @@ def get_model( neox_cls = FlashNeoX if FLASH_NEOX else CausalLM return neox_cls(model_id, revision, quantize=quantize) - if config.model_type == "t5": + if model_type == "t5": if sharded: return T5Sharded(model_id, revision, quantize=quantize) else: @@ -88,7 +90,10 @@ def get_model( if sharded: raise ValueError("sharded is not supported for AutoModel") - try: + + if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM(model_id, revision, quantize=quantize) - except Exception: + if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: return Seq2SeqLM(model_id, revision, quantize=quantize) + + raise ValueError(f"Unsupported model type {model_type}")