51 lines
2.0 KiB
Python
51 lines
2.0 KiB
Python
from typing import Tuple, Union
|
|
|
|
import flask
|
|
|
|
from llm_server import opts
|
|
from llm_server.cluster.cluster_config import cluster_config
|
|
from llm_server.custom_redis import redis
|
|
from llm_server.llm import get_token_count
|
|
|
|
|
|
class LLMBackend:
|
|
_default_params: dict
|
|
|
|
def __init__(self, backend_url: str):
|
|
self.backend_url = backend_url
|
|
|
|
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
|
|
raise NotImplementedError
|
|
|
|
def validate_params(self, params_dict: dict) -> Tuple[bool, str | None]:
|
|
raise NotImplementedError
|
|
|
|
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
|
# raise NotImplementedError
|
|
|
|
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
|
"""
|
|
Validate and return the parameters for this backend.
|
|
Lets you set defaults for specific backends.
|
|
:param parameters:
|
|
:return:
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def validate_request(self, parameters: dict, prompt: str, request: flask.Request) -> Tuple[bool, Union[str, None]]:
|
|
"""
|
|
If a backend needs to do other checks not related to the prompt or parameters.
|
|
Default is no extra checks preformed.
|
|
:param parameters:
|
|
:return:
|
|
"""
|
|
return True, None
|
|
|
|
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
|
|
prompt_len = get_token_count(prompt, self.backend_url)
|
|
token_limit = cluster_config.get_backend(self.backend_url)['model_config']['max_position_embeddings']
|
|
if prompt_len > token_limit - 10:
|
|
model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str)
|
|
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {model_name}). Please lower your context size'
|
|
return True, None
|