2023-09-12 01:04:11 -06:00
from typing import Tuple , Union
2023-08-30 18:53:26 -06:00
2023-09-13 11:22:33 -06:00
import flask
2023-09-14 18:31:13 -06:00
from llm_server import opts
2023-09-24 13:02:30 -06:00
from llm_server . llm import get_token_count
2023-09-28 18:40:24 -06:00
from llm_server . custom_redis import redis
2023-09-14 18:31:13 -06:00
2023-08-30 18:53:26 -06:00
class LLMBackend :
2023-09-14 17:38:20 -06:00
_default_params : dict
2023-09-12 16:40:09 -06:00
2023-09-14 14:05:50 -06:00
def handle_response ( self , success , request : flask . Request , response_json_body : dict , response_status_code : int , client_ip , token , prompt , elapsed_time , parameters , headers ) :
2023-08-30 18:53:26 -06:00
raise NotImplementedError
2023-09-11 20:47:19 -06:00
def validate_params ( self , params_dict : dict ) - > Tuple [ bool , str | None ] :
2023-08-30 18:53:26 -06:00
raise NotImplementedError
2023-09-11 20:47:19 -06:00
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# raise NotImplementedError
2023-09-12 01:04:11 -06:00
2023-09-12 10:30:45 -06:00
def get_parameters ( self , parameters ) - > Tuple [ dict | None , str | None ] :
2023-09-12 01:04:11 -06:00
"""
Validate and return the parameters for this backend .
Lets you set defaults for specific backends .
: param parameters :
: return :
"""
raise NotImplementedError
2023-09-12 16:40:09 -06:00
2023-09-24 13:02:30 -06:00
def validate_request ( self , parameters : dict , prompt : str , request : flask . Request ) - > Tuple [ bool , Union [ str , None ] ] :
"""
If a backend needs to do other checks not related to the prompt or parameters .
Default is no extra checks preformed .
: param parameters :
: return :
"""
return True , None
2023-09-14 18:31:13 -06:00
2023-09-20 20:30:31 -06:00
def validate_prompt ( self , prompt : str ) - > Tuple [ bool , Union [ str , None ] ] :
2023-09-24 13:02:30 -06:00
prompt_len = get_token_count ( prompt )
2023-09-20 20:30:31 -06:00
if prompt_len > opts . context_size - 10 :
2023-09-27 14:36:49 -06:00
model_name = redis . get ( ' running_model ' , str , ' NO MODEL ERROR ' )
return False , f ' Token indices sequence length is longer than the specified maximum sequence length for this model ( { prompt_len } > { opts . context_size } , model: { model_name } ). Please lower your context size '
2023-09-14 18:31:13 -06:00
return True , None