This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.
local-llm-server/llm_server/routes/request_handler.py

278 lines
12 KiB
Python

import time
from typing import Tuple, Union
import flask
from flask import Response, request
from llm_server.cluster.cluster_config import get_a_cluster_backend, cluster_config
from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import redis
from llm_server.database.database import get_token_ratelimit
from llm_server.database.log_to_db import log_to_db
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.logging import create_logger
from llm_server.routes.auth import parse_token
from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue
_logger = create_logger('RequestHandler')
class RequestHandler:
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
self.request = incoming_request
# self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
# Routes need to validate it, here we just load it
if incoming_json:
self.request_valid_json, self.request_json_body = validate_json(incoming_json)
else:
self.request_valid_json, self.request_json_body = validate_json(self.request)
if not self.request_valid_json:
raise Exception(f'Not valid JSON. Routes are supposed to reject invalid JSON.')
self.start_time = time.time()
self.client_ip = self.get_client_ip()
self.token = self.get_auth_token()
self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
self.parameters = None
self.used = False
# This is null by default since most handlers need to transform the prompt in a specific way.
self.prompt = None
self.selected_model = selected_model
self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
# Debug stuff
# if not self.cluster_backend_info.get('mode'):
# print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
# if not self.cluster_backend_info.get('model'):
# print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
# if not self.cluster_backend_info.get('model_config'):
# print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'):
self.offline = True
else:
self.offline = False
self.selected_model = self.cluster_backend_info['model']
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
if self.token and not self.token.startswith('SYSTEM__'):
# "recent_prompters" is only used for stats.
redis.zadd('recent_prompters', {self.client_ip: time.time()})
def check_online(self) -> bool:
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
return self.cluster_backend_info['online']
def get_auth_token(self):
if self.request_json_body.get('X-API-KEY'):
return self.request_json_body['X-API-KEY']
elif self.request.headers.get('X-Api-Key'):
return self.request.headers['X-Api-Key']
elif self.request.headers.get('Authorization'):
return parse_token(self.request.headers['Authorization'])
def get_client_ip(self):
if self.request.headers.get('Llm-Connecting-Ip'):
return self.request.headers['Llm-Connecting-Ip']
if self.request.headers.get('X-Connecting-IP'):
return self.request.headers.get('X-Connecting-IP')
elif self.request.headers.get('Cf-Connecting-Ip'):
return self.request.headers.get('Cf-Connecting-Ip')
elif self.request.headers.get('X-Forwarded-For'):
return self.request.headers.get('X-Forwarded-For').split(',')[0]
else:
return self.request.remote_addr
def get_parameters(self):
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
return parameters, parameters_invalid_msg
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
"""
This needs to be called at the start of the subclass handle_request() method.
:param prompt:
:param do_log:
:return:
"""
invalid_request_err_msgs = []
self.parameters, parameters_invalid_msg = self.get_parameters() # Parameters will be None if invalid.
if self.parameters and not parameters_invalid_msg:
# Backends shouldn't check max_new_tokens, but rather things specific to their backend.
# Let the RequestHandler do the generic checks.
if self.parameters.get('max_new_tokens', 0) > GlobalConfig.get().max_new_tokens:
invalid_request_err_msgs.append(f'`max_new_tokens` must be less than or equal to {GlobalConfig.get().max_new_tokens}')
if prompt:
prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt)
if not prompt_valid:
invalid_request_err_msgs.append(invalid_prompt_err_msg)
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters, prompt, self.request)
if not request_valid:
invalid_request_err_msgs.append(invalid_request_err_msg)
else:
invalid_request_err_msgs.append(parameters_invalid_msg)
if len(invalid_request_err_msgs):
if len(invalid_request_err_msgs) > 1:
# Format multiple error messages each on a new line.
e = [f'\n{x}.' for x in invalid_request_err_msgs]
combined_error_message = '\n'.join(e)
else:
# Otherwise, just grab the first and only one.
combined_error_message = invalid_request_err_msgs[0] + '.'
backend_response = self.handle_error(combined_error_message, 'Validation Error')
if do_log:
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True)
return False, backend_response
return True, (None, 0)
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
prompt = llm_request['prompt']
if not self.is_client_ratelimited():
# Validate again before submission since the backend handler may have changed something.
# Also, this is the first time we validate the prompt.
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
if not request_valid:
return (False, None, None, 0), invalid_response
event = priority_queue.put(self.backend_url, (llm_request, self.client_ip, self.token, self.parameters), self.token_priority, self.selected_model)
else:
event = None
if not event:
return (False, None, None, 0), self.handle_ratelimited()
# TODO: add wait timeout
success, response, error_msg = event.wait()
if error_msg == 'closed':
return (False, None, None, 0), (self.handle_error('Request Timeout')[0], 408)
end_time = time.time()
elapsed_time = end_time - self.start_time
if response:
try:
# Be extra careful when getting attributes from the response object
response_status_code = response.status_code
except:
response_status_code = 0
else:
response_status_code = None
# ===============================================
# We encountered an error
if not success or not response or error_msg:
if not error_msg or error_msg == '':
error_msg = 'Unknown error.'
else:
error_msg = error_msg.strip('.') + '.'
backend_response = self.handle_error(error_msg)
log_to_db(ip=self.client_ip,
token=self.token,
prompt=prompt,
response=backend_response[0].data.decode('utf-8'),
gen_time=None,
parameters=self.parameters,
headers=dict(self.request.headers),
backend_response_code=response_status_code,
request_url=self.request.url,
backend_url=self.backend_url,
is_error=True)
return (False, None, None, 0), backend_response
# ===============================================
response_valid_json, response_json_body = validate_json(response)
return_json_err = False
# The backend didn't send valid JSON
if not response_valid_json:
return_json_err = True
# Make sure the backend didn't crap out.
results = response_json_body.get('results', [])
if len(results) and not results[0].get('text'):
return_json_err = True
if return_json_err:
error_msg = 'The backend did not return valid JSON.'
backend_response = self.handle_error(error_msg)
log_to_db(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
return (False, None, None, 0), backend_response
# ===============================================
self.used = True
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def is_client_ratelimited(self) -> bool:
if self.token_priority == 0:
return False
queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
x = redis.hget('processing_ips', self.client_ip)
if x:
processing_ip = int(x)
else:
processing_ip = 0
if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
_logger.debug(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued')
return True
else:
return False
def handle_request(self) -> Tuple[flask.Response, int]:
# Must include this in your child.
# assert not self.used
# if self.offline:
# msg = f'{self.selected_model} is not a valid model choice.'
# print(msg)
# return format_sillytavern_err(msg)
raise NotImplementedError
def handle_ratelimited(self, do_log: bool = True) -> Tuple[flask.Response, int]:
raise NotImplementedError
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
raise NotImplementedError
def get_backend_handler(mode, backend_url: str):
if mode == 'oobabooga':
return OobaboogaBackend(backend_url)
elif mode == 'vllm':
return VLLMBackend(backend_url)
else:
raise Exception
def delete_dict_key(d: dict, k: Union[str, list]):
if isinstance(k, str):
if k in d.keys():
del d[k]
elif isinstance(k, list):
for item in k:
if item in d.keys():
del d[item]
else:
raise ValueError
return d
def before_request():
auto_set_base_client_api(request)
if request.endpoint != 'v1.get_stats':
response = require_api_key()
if response is not None:
return response