2023-09-12 01:04:11 -06:00
from typing import Tuple , Union
2023-08-30 18:53:26 -06:00
2023-09-13 11:22:33 -06:00
import flask
2023-10-27 19:19:22 -06:00
from llm_server . cluster . cluster_config import cluster_config
2023-09-24 13:02:30 -06:00
from llm_server . llm import get_token_count
2023-09-14 18:31:13 -06:00
2023-08-30 18:53:26 -06:00
class LLMBackend :
2023-09-14 17:38:20 -06:00
_default_params : dict
2023-09-12 16:40:09 -06:00
2023-10-27 19:19:22 -06:00
def __init__ ( self , backend_url : str ) :
self . backend_url = backend_url
self . backend_info = cluster_config . get_backend ( self . backend_url )
2023-09-14 14:05:50 -06:00
def handle_response ( self , success , request : flask . Request , response_json_body : dict , response_status_code : int , client_ip , token , prompt , elapsed_time , parameters , headers ) :
2023-08-30 18:53:26 -06:00
raise NotImplementedError
2023-09-11 20:47:19 -06:00
def validate_params ( self , params_dict : dict ) - > Tuple [ bool , str | None ] :
2023-08-30 18:53:26 -06:00
raise NotImplementedError
2023-09-11 20:47:19 -06:00
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# raise NotImplementedError
2023-09-12 01:04:11 -06:00
2023-09-12 10:30:45 -06:00
def get_parameters ( self , parameters ) - > Tuple [ dict | None , str | None ] :
2023-09-12 01:04:11 -06:00
"""
Validate and return the parameters for this backend .
Lets you set defaults for specific backends .
: param parameters :
: return :
"""
raise NotImplementedError
2023-09-12 16:40:09 -06:00
2023-09-24 13:02:30 -06:00
def validate_request ( self , parameters : dict , prompt : str , request : flask . Request ) - > Tuple [ bool , Union [ str , None ] ] :
"""
If a backend needs to do other checks not related to the prompt or parameters .
Default is no extra checks preformed .
2023-10-27 19:19:22 -06:00
: param request :
: param prompt :
2023-09-24 13:02:30 -06:00
: param parameters :
: return :
"""
return True , None
2023-09-14 18:31:13 -06:00
2023-09-20 20:30:31 -06:00
def validate_prompt ( self , prompt : str ) - > Tuple [ bool , Union [ str , None ] ] :
2023-10-27 19:19:22 -06:00
prompt_len = get_token_count ( prompt , self . backend_url )
token_limit = self . backend_info [ ' model_config ' ] [ ' max_position_embeddings ' ]
if prompt_len > token_limit - 10 :
return False , f ' Token indices sequence length is longer than the specified maximum sequence length for this model ( { prompt_len } > { token_limit } , model: { self . backend_info [ " model " ] } ). Please lower your context size '
2023-09-14 18:31:13 -06:00
return True , None