48 lines
2.1 KiB
Python
48 lines
2.1 KiB
Python
from typing import Tuple
|
|
|
|
import flask
|
|
from flask import jsonify, request
|
|
|
|
from llm_server import opts
|
|
from llm_server.database.database import log_prompt
|
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
|
from llm_server.routes.request_handler import RequestHandler
|
|
|
|
|
|
class OobaRequestHandler(RequestHandler):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def handle_request(self):
|
|
assert not self.used
|
|
|
|
request_valid, invalid_response = self.validate_request()
|
|
if not request_valid:
|
|
return invalid_response
|
|
|
|
# Reconstruct the request JSON with the validated parameters and prompt.
|
|
prompt = self.request_json_body.get('prompt', '')
|
|
llm_request = {**self.parameters, 'prompt': prompt}
|
|
|
|
_, backend_response = self.generate_response(llm_request)
|
|
return backend_response
|
|
|
|
def handle_ratelimited(self, do_log: bool = True):
|
|
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
|
backend_response = self.handle_error(msg)
|
|
if do_log:
|
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
|
|
return backend_response[0], 200 # We only return the response from handle_error(), not the error code
|
|
|
|
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
|
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
|
if disable_st_error_formatting:
|
|
# TODO: how to format this
|
|
response_msg = error_msg
|
|
else:
|
|
response_msg = format_sillytavern_err(error_msg, error_type)
|
|
|
|
return jsonify({
|
|
'results': [{'text': response_msg}]
|
|
}), 200 # return 200 so we don't trigger an error message in the client's ST
|