2023-08-23 20:12:38 -06:00
import sqlite3
2023-08-22 23:37:39 -06:00
import time
2023-08-21 21:28:52 -06:00
from flask import jsonify , request
2023-08-23 22:22:38 -06:00
from llm_server . routes . stats import SemaphoreCheckerThread
2023-08-22 20:28:41 -06:00
from . import bp
2023-08-23 01:14:19 -06:00
from . . cache import redis
2023-08-22 20:28:41 -06:00
from . . helpers . client import format_sillytavern_err
2023-08-23 20:12:38 -06:00
from . . helpers . http import validate_json
from . . queue import priority_queue
2023-08-21 21:28:52 -06:00
from . . . import opts
from . . . database import log_prompt
2023-08-29 13:46:41 -06:00
from . . . helpers import safe_list_get , indefinite_article
2023-08-21 21:28:52 -06:00
2023-08-23 20:12:38 -06:00
DEFAULT_PRIORITY = 9999
2023-08-21 21:28:52 -06:00
2023-08-29 13:46:41 -06:00
# TODO: clean this up and make the ooba vs hf-textgen more object-oriented
2023-08-21 21:28:52 -06:00
@bp.route ( ' /generate ' , methods = [ ' POST ' ] )
def generate ( ) :
2023-08-23 21:33:52 -06:00
start_time = time . time ( )
2023-08-21 21:28:52 -06:00
request_valid_json , request_json_body = validate_json ( request . data )
if not request_valid_json :
return jsonify ( { ' code ' : 400 , ' error ' : ' Invalid JSON ' } ) , 400
2023-08-23 21:33:52 -06:00
if request . headers . get ( ' cf-connecting-ip ' ) :
client_ip = request . headers . get ( ' cf-connecting-ip ' )
elif request . headers . get ( ' x-forwarded-for ' ) :
client_ip = request . headers . get ( ' x-forwarded-for ' )
else :
client_ip = request . remote_addr
2023-08-22 23:37:39 -06:00
2023-08-23 21:33:52 -06:00
SemaphoreCheckerThread . recent_prompters [ client_ip ] = time . time ( )
2023-08-22 23:37:39 -06:00
2023-08-23 21:33:52 -06:00
parameters = request_json_body . copy ( )
del parameters [ ' prompt ' ]
2023-08-22 20:28:41 -06:00
2023-08-23 21:33:52 -06:00
token = request . headers . get ( ' X-Api-Key ' )
priority = None
if token :
conn = sqlite3 . connect ( opts . database_path )
cursor = conn . cursor ( )
cursor . execute ( " SELECT priority FROM token_auth WHERE token = ? " , ( token , ) )
result = cursor . fetchone ( )
if result :
priority = result [ 0 ]
conn . close ( )
2023-08-23 20:12:38 -06:00
2023-08-23 21:33:52 -06:00
if priority is None :
priority = DEFAULT_PRIORITY
else :
print ( f ' Token { token } was given priority { priority } . ' )
2023-08-23 20:12:38 -06:00
2023-08-29 13:46:41 -06:00
queued_ip_count = redis . get_dict ( ' queued_ip_count ' ) . get ( client_ip , 0 ) + redis . get_dict ( ' processing_ips ' ) . get ( client_ip , 0 )
if queued_ip_count < opts . ip_in_queue_max or priority == 0 :
2023-08-27 23:48:10 -06:00
event = priority_queue . put ( ( request_json_body , client_ip , token , parameters ) , priority )
else :
event = None
if not event :
2023-08-29 13:46:41 -06:00
backend_response = format_sillytavern_err ( f ' Ratelimited: you are only allowed to have { opts . ip_in_queue_max } simultaneous requests at a time. Please complete your other requests before sending another. ' , ' error ' )
response_json_body = {
' results ' : [
{
' text ' : backend_response ,
}
] ,
}
2023-08-29 14:48:33 -06:00
log_prompt ( client_ip , token , request_json_body [ ' prompt ' ] , backend_response , None , parameters , dict ( request . headers ) , 429 , is_error = True )
2023-08-27 23:48:10 -06:00
return jsonify ( {
* * response_json_body
} ) , 200
2023-08-23 21:33:52 -06:00
event . wait ( )
success , response , error_msg = event . data
2023-08-23 20:12:38 -06:00
2023-08-23 21:33:52 -06:00
end_time = time . time ( )
elapsed_time = end_time - start_time
2023-08-29 13:46:41 -06:00
if ( not success or not response ) and opts . mode == ' oobabooga ' :
# Ooba doesn't return any error messages
backend_response = format_sillytavern_err ( f ' Failed to reach the backend ( { opts . mode } ): { error_msg } ' , ' error ' )
response_json_body = {
' results ' : [
{
' text ' : backend_response ,
}
] ,
}
2023-08-29 14:48:33 -06:00
log_prompt ( client_ip , token , request_json_body [ ' prompt ' ] , backend_response , None , parameters , dict ( request . headers ) , response if response else 0 , is_error = True )
2023-08-23 21:33:52 -06:00
return jsonify ( {
' code ' : 500 ,
' error ' : ' failed to reach backend ' ,
* * response_json_body
} ) , 200
response_valid_json , response_json_body = validate_json ( response )
2023-08-24 22:53:06 -06:00
backend_err = False
2023-08-29 13:46:41 -06:00
# Return the result to the client
2023-08-23 21:33:52 -06:00
if response_valid_json :
2023-08-29 13:46:41 -06:00
if opts . mode == ' oobabooga ' :
backend_response = safe_list_get ( response_json_body . get ( ' results ' , [ ] ) , 0 , { } ) . get ( ' text ' )
if not backend_response :
2023-08-24 22:53:06 -06:00
backend_err = True
2023-08-23 21:33:52 -06:00
backend_response = format_sillytavern_err (
2023-08-25 12:20:16 -06:00
f ' Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again. ' ,
2023-08-23 21:33:52 -06:00
' error ' )
response_json_body [ ' results ' ] [ 0 ] [ ' text ' ] = backend_response
2023-08-29 13:46:41 -06:00
elif opts . mode == ' hf-textgen ' :
backend_response = response_json_body . get ( ' generated_text ' , ' ' )
if response_json_body . get ( ' error ' ) :
2023-08-29 14:48:33 -06:00
backend_err = True
2023-08-29 13:46:41 -06:00
error_type = response_json_body . get ( ' error_type ' )
error_type_string = ' returned an error ' if opts . mode == ' oobabooga ' else f ' returned { indefinite_article ( error_type ) } { error_type } error '
response_json_body = {
' results ' : [
{
' text ' : format_sillytavern_err (
f ' Backend ( { opts . mode } ) { error_type_string } : { response_json_body . get ( " error " ) } ' ,
' error ' )
}
]
}
2023-08-22 20:28:41 -06:00
else :
2023-08-29 13:46:41 -06:00
response_json_body = {
' results ' : [
{
' text ' : backend_response
}
]
}
else :
raise Exception
redis . incr ( ' proompts ' )
2023-08-29 14:48:33 -06:00
log_prompt ( client_ip , token , request_json_body [ ' prompt ' ] , backend_response if not backend_err else ' ' , elapsed_time if not backend_err else None , parameters , dict ( request . headers ) , response . status_code if response else 0 , response_json_body . get ( ' details ' , { } ) . get ( ' generated_tokens ' ) , is_error = backend_err )
2023-08-23 21:33:52 -06:00
return jsonify ( {
* * response_json_body
} ) , 200
2023-08-29 13:46:41 -06:00
2023-08-23 21:33:52 -06:00
else :
if opts . mode == ' oobabooga ' :
backend_response = format_sillytavern_err ( f ' The backend did not return valid JSON. ' , ' error ' )
response_json_body = {
' results ' : [
{
' text ' : backend_response ,
}
] ,
}
else :
raise Exception
2023-08-29 14:48:33 -06:00
log_prompt ( client_ip , token , request_json_body [ ' prompt ' ] , backend_response , elapsed_time , parameters , dict ( request . headers ) , response . status_code , is_error = True )
2023-08-23 21:33:52 -06:00
return jsonify ( {
' code ' : 500 ,
' error ' : ' the backend did not return valid JSON ' ,
* * response_json_body
} ) , 200