local-llm-server/llm_server/routes/request_handler.py

142 lines
5.8 KiB
Python

import sqlite3
import time
from typing import Union
from flask import jsonify
from llm_server import opts
from llm_server.database import log_prompt
from llm_server.llm.hf_textgen.hf_textgen_backend import HfTextgenLLMBackend
from llm_server.llm.oobabooga.ooba_backend import OobaboogaLLMBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.cache import redis
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import SemaphoreCheckerThread
DEFAULT_PRIORITY = 9999
def delete_dict_key(d: dict, k: Union[str, list]):
if isinstance(k, str):
if k in d.keys():
del d[k]
elif isinstance(k, list):
for item in k:
if item in d.keys():
del d[item]
else:
raise ValueError
return d
class OobaRequestHandler:
def __init__(self, incoming_request):
self.request_json_body = None
self.request = incoming_request
self.start_time = time.time()
self.client_ip = self.get_client_ip()
self.token = self.request.headers.get('X-Api-Key')
self.priority = self.get_priority()
self.backend = self.get_backend()
self.parameters = self.parameters_invalid_msg = None
def validate_request(self) -> (bool, Union[str, None]):
# TODO: move this to LLMBackend
if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens or self.parameters.get('max_tokens', 0) > opts.max_new_tokens:
return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
return True, None
def get_client_ip(self):
if self.request.headers.get('cf-connecting-ip'):
return self.request.headers.get('cf-connecting-ip')
elif self.request.headers.get('x-forwarded-for'):
return self.request.headers.get('x-forwarded-for').split(',')[0]
else:
return self.request.remote_addr
# def get_parameters(self):
# # TODO: make this a LLMBackend method
# return self.backend.get_parameters()
def get_priority(self):
if self.token:
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (self.token,))
result = cursor.fetchone()
conn.close()
if result:
return result[0]
return DEFAULT_PRIORITY
def get_backend(self):
if opts.mode == 'oobabooga':
return OobaboogaLLMBackend()
elif opts.mode == 'hf-textgen':
return HfTextgenLLMBackend()
elif opts.mode == 'vllm':
return VLLMBackend()
else:
raise Exception
def get_parameters(self):
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
def handle_request(self):
request_valid_json, self.request_json_body = validate_json(self.request.data)
if not request_valid_json:
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
self.get_parameters()
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
request_valid, invalid_request_err_msg = self.validate_request()
if not self.parameters:
params_valid = False
else:
params_valid = True
if not request_valid or not params_valid:
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid]
combined_error_message = ', '.join(error_messages)
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
return jsonify({
'code': 400,
'msg': 'parameter validation error',
'results': [{'text': err}]
}), 200
# Reconstruct the request JSON with the validated parameters and prompt.
prompt = self.request_json_body.get('prompt', '')
llm_request = {**self.parameters, 'prompt': prompt}
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
else:
# Client was rate limited
event = None
if not event:
return self.handle_ratelimited()
event.wait()
success, response, error_msg = event.data
end_time = time.time()
elapsed_time = end_time - self.start_time
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True)
return jsonify({
'results': [{'text': backend_response}]
}), 200