2023-09-12 16:40:09 -06:00
import time
from flask import jsonify
from llm_server import opts
from llm_server . database import log_prompt
from llm_server . routes . helpers . client import format_sillytavern_err
from llm_server . routes . helpers . http import validate_json
from llm_server . routes . queue import priority_queue
from llm_server . routes . request_handler import RequestHandler
class OobaRequestHandler ( RequestHandler ) :
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
def handle_request ( self ) :
if self . used :
raise Exception
request_valid_json , self . request_json_body = validate_json ( self . request )
if not request_valid_json :
return jsonify ( { ' code ' : 400 , ' msg ' : ' Invalid JSON ' } ) , 400
params_valid , request_valid = self . validate_request ( )
2023-09-13 11:56:30 -06:00
if not request_valid [ 0 ] or not params_valid [ 0 ] :
2023-09-12 16:40:09 -06:00
error_messages = [ msg for valid , msg in [ request_valid , params_valid ] if not valid and msg ]
combined_error_message = ' , ' . join ( error_messages )
err = format_sillytavern_err ( f ' Validation Error: { combined_error_message } . ' , ' error ' )
2023-09-13 11:22:33 -06:00
log_prompt ( self . client_ip , self . token , self . request_json_body . get ( ' prompt ' , ' ' ) , err , 0 , self . parameters , dict ( self . request . headers ) , 0 , self . request . url , is_error = True )
2023-09-12 16:40:09 -06:00
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
return jsonify ( {
' code ' : 400 ,
' msg ' : ' parameter validation error ' ,
' results ' : [ { ' text ' : err } ]
} ) , 200
# Reconstruct the request JSON with the validated parameters and prompt.
prompt = self . request_json_body . get ( ' prompt ' , ' ' )
llm_request = { * * self . parameters , ' prompt ' : prompt }
if not self . is_client_ratelimited ( ) :
event = priority_queue . put ( ( llm_request , self . client_ip , self . token , self . parameters ) , self . priority )
else :
event = None
if not event :
return self . handle_ratelimited ( )
event . wait ( )
success , response , error_msg = event . data
end_time = time . time ( )
elapsed_time = end_time - self . start_time
self . used = True
2023-09-13 11:58:38 -06:00
return self . backend . handle_response ( success , self . request , response , error_msg , self . client_ip , self . token , prompt , elapsed_time , self . parameters , dict ( self . request . headers ) )
2023-09-12 16:40:09 -06:00
def handle_ratelimited ( self ) :
backend_response = format_sillytavern_err ( f ' Ratelimited: you are only allowed to have { opts . simultaneous_requests_per_ip } simultaneous requests at a time. Please complete your other requests before sending another. ' , ' error ' )
2023-09-13 11:22:33 -06:00
log_prompt ( self . client_ip , self . token , self . request_json_body . get ( ' prompt ' , ' ' ) , backend_response , None , self . parameters , dict ( self . request . headers ) , 429 , self . request . url , is_error = True )
2023-09-12 16:40:09 -06:00
return jsonify ( {
' results ' : [ { ' text ' : backend_response } ]
} ) , 200