65 lines
2.9 KiB
Python
65 lines
2.9 KiB
Python
import time
|
|
|
|
from flask import jsonify
|
|
|
|
from llm_server import opts
|
|
from llm_server.database import log_prompt
|
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
|
from llm_server.routes.helpers.http import validate_json
|
|
from llm_server.routes.queue import priority_queue
|
|
from llm_server.routes.request_handler import RequestHandler
|
|
|
|
|
|
class OobaRequestHandler(RequestHandler):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def handle_request(self):
|
|
if self.used:
|
|
raise Exception
|
|
|
|
request_valid_json, self.request_json_body = validate_json(self.request)
|
|
if not request_valid_json:
|
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
|
|
|
params_valid, request_valid = self.validate_request()
|
|
if not request_valid[0] or not params_valid[0]:
|
|
error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg]
|
|
combined_error_message = ', '.join(error_messages)
|
|
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
|
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
|
return jsonify({
|
|
'code': 400,
|
|
'msg': 'parameter validation error',
|
|
'results': [{'text': err}]
|
|
}), 200
|
|
|
|
# Reconstruct the request JSON with the validated parameters and prompt.
|
|
prompt = self.request_json_body.get('prompt', '')
|
|
llm_request = {**self.parameters, 'prompt': prompt}
|
|
|
|
if not self.is_client_ratelimited():
|
|
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
|
else:
|
|
event = None
|
|
|
|
if not event:
|
|
return self.handle_ratelimited()
|
|
|
|
event.wait()
|
|
success, response, error_msg = event.data
|
|
|
|
end_time = time.time()
|
|
elapsed_time = end_time - self.start_time
|
|
|
|
self.used = True
|
|
return self.backend.handle_response(success, self.request, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
|
|
|
def handle_ratelimited(self):
|
|
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
|
|
return jsonify({
|
|
'results': [{'text': backend_response}]
|
|
}), 200
|