2023-09-14 17:38:20 -06:00
from typing import Tuple
2023-09-12 16:40:09 -06:00
2023-09-14 17:38:20 -06:00
import flask
2023-09-12 16:40:09 -06:00
from flask import jsonify
from llm_server import opts
2023-09-20 20:30:31 -06:00
from llm_server . database . database import log_prompt
2023-09-12 16:40:09 -06:00
from llm_server . routes . helpers . client import format_sillytavern_err
from llm_server . routes . request_handler import RequestHandler
class OobaRequestHandler ( RequestHandler ) :
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
def handle_request ( self ) :
2023-09-25 22:32:48 -06:00
assert not self . used
2023-09-12 16:40:09 -06:00
2023-09-14 14:05:50 -06:00
request_valid , invalid_response = self . validate_request ( )
if not request_valid :
return invalid_response
2023-09-12 16:40:09 -06:00
# Reconstruct the request JSON with the validated parameters and prompt.
prompt = self . request_json_body . get ( ' prompt ' , ' ' )
llm_request = { * * self . parameters , ' prompt ' : prompt }
2023-09-14 14:05:50 -06:00
_ , backend_response = self . generate_response ( llm_request )
return backend_response
2023-09-12 16:40:09 -06:00
def handle_ratelimited ( self ) :
backend_response = format_sillytavern_err ( f ' Ratelimited: you are only allowed to have { opts . simultaneous_requests_per_ip } simultaneous requests at a time. Please complete your other requests before sending another. ' , ' error ' )
2023-09-13 11:22:33 -06:00
log_prompt ( self . client_ip , self . token , self . request_json_body . get ( ' prompt ' , ' ' ) , backend_response , None , self . parameters , dict ( self . request . headers ) , 429 , self . request . url , is_error = True )
2023-09-12 16:40:09 -06:00
return jsonify ( {
' results ' : [ { ' text ' : backend_response } ]
2023-09-24 21:45:30 -06:00
} ) , 429
2023-09-14 17:38:20 -06:00
def handle_error ( self , msg : str ) - > Tuple [ flask . Response , int ] :
return jsonify ( {
' results ' : [ { ' text ' : msg } ]
2023-09-24 21:45:30 -06:00
} ) , 400