local-llm-server/llm_server/routes/request_handler.py

181 lines
7.8 KiB
Python

import sqlite3
import time
from typing import Tuple, Union
import flask
from flask import Response
from llm_server import opts
from llm_server.database import log_prompt
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.cache import redis
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import SemaphoreCheckerThread
DEFAULT_PRIORITY = 9999
class RequestHandler:
def __init__(self, incoming_request: flask.Request):
self.request = incoming_request
_, self.request_json_body = validate_json(self.request) # routes need to validate it, here we just load it
self.start_time = time.time()
self.client_ip = self.get_client_ip()
self.token = self.request.headers.get('X-Api-Key')
self.priority = self.get_priority()
self.backend = get_backend()
self.parameters = None
self.used = False
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
def get_client_ip(self):
if self.request.headers.get('cf-connecting-ip'):
return self.request.headers.get('cf-connecting-ip')
elif self.request.headers.get('x-forwarded-for'):
return self.request.headers.get('x-forwarded-for').split(',')[0]
else:
return self.request.remote_addr
def get_priority(self):
if self.token:
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (self.token,))
result = cursor.fetchone()
conn.close()
if result:
return result[0]
return DEFAULT_PRIORITY
def get_parameters(self):
if self.request_json_body.get('max_tokens'):
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
return parameters, parameters_invalid_msg
def validate_request(self) -> Tuple[bool, Tuple[Response | None, int]]:
self.parameters, parameters_invalid_msg = self.get_parameters()
request_valid = False
invalid_request_err_msg = None
if self.parameters:
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters)
if not request_valid:
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (not bool(parameters_invalid_msg), parameters_invalid_msg)] if not valid and msg]
combined_error_message = ', '.join(error_messages)
backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
return False, self.handle_error(backend_response)
return True, (None, 0)
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
prompt = llm_request['prompt']
if not self.is_client_ratelimited():
# Validate the prompt right before submission since the backend handler may have changed something.
prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt)
if not prompt_valid:
backend_response = format_sillytavern_err(f'Validation Error: {invalid_prompt_err_msg}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
return (False, None, None, 0), self.handle_error(backend_response)
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
else:
event = None
if not event:
return (False, None, None, 0), self.handle_ratelimited()
success, response, error_msg = event.wait()
end_time = time.time()
elapsed_time = end_time - self.start_time
if response:
try:
# Be extra careful when getting attributes from the response object
response_status_code = response.status_code
except:
response_status_code = 0
else:
response_status_code = None
# ===============================================
# We encountered an error
if not success or not response or error_msg:
if not error_msg or error_msg == '':
error_msg = 'Unknown error.'
else:
error_msg = error_msg.strip('.') + '.'
backend_response = format_sillytavern_err(error_msg, 'error')
log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
return (False, None, None, 0), self.handle_error(backend_response)
# ===============================================
response_valid_json, response_json_body = validate_json(response)
return_json_err = False
# The backend didn't send valid JSON
if not response_valid_json:
return_json_err = True
# Make sure the backend didn't crap out.
results = response_json_body.get('results', [])
if len(results) and not results[0].get('text'):
return_json_err = True
if return_json_err:
error_msg = 'The backend did not return valid JSON.'
backend_response = format_sillytavern_err(error_msg, 'error')
log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
return (False, None, None, 0), self.handle_error(backend_response)
# ===============================================
self.used = True
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def is_client_ratelimited(self) -> bool:
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
return False
else:
return True
def handle_request(self) -> Tuple[flask.Response, int]:
# Must include this in your child.
# if self.used:
# raise Exception('Can only use a RequestHandler object once.')
raise NotImplementedError
def handle_ratelimited(self) -> Tuple[flask.Response, int]:
raise NotImplementedError
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
raise NotImplementedError
def get_backend():
if opts.mode == 'oobabooga':
return OobaboogaBackend()
elif opts.mode == 'vllm':
return VLLMBackend()
else:
raise Exception
def delete_dict_key(d: dict, k: Union[str, list]):
if isinstance(k, str):
if k in d.keys():
del d[k]
elif isinstance(k, list):
for item in k:
if item in d.keys():
del d[item]
else:
raise ValueError
return d