2023-08-23 20:12:38 -06:00
import sqlite3
2023-08-22 23:37:39 -06:00
import time
2023-08-21 21:28:52 -06:00
from flask import jsonify , request
2023-08-23 22:22:38 -06:00
from llm_server . routes . stats import SemaphoreCheckerThread
2023-08-22 20:28:41 -06:00
from . import bp
2023-08-23 01:14:19 -06:00
from . . cache import redis
2023-08-22 20:28:41 -06:00
from . . helpers . client import format_sillytavern_err
2023-08-23 20:12:38 -06:00
from . . helpers . http import validate_json
from . . queue import priority_queue
2023-08-21 21:28:52 -06:00
from . . . import opts
from . . . database import log_prompt
2023-08-22 20:28:41 -06:00
from . . . helpers import safe_list_get
2023-08-21 21:28:52 -06:00
2023-08-23 20:12:38 -06:00
DEFAULT_PRIORITY = 9999
2023-08-21 21:28:52 -06:00
@bp.route ( ' /generate ' , methods = [ ' POST ' ] )
def generate ( ) :
2023-08-23 21:33:52 -06:00
start_time = time . time ( )
2023-08-21 21:28:52 -06:00
request_valid_json , request_json_body = validate_json ( request . data )
if not request_valid_json :
return jsonify ( { ' code ' : 400 , ' error ' : ' Invalid JSON ' } ) , 400
2023-08-23 21:33:52 -06:00
if request . headers . get ( ' cf-connecting-ip ' ) :
client_ip = request . headers . get ( ' cf-connecting-ip ' )
elif request . headers . get ( ' x-forwarded-for ' ) :
client_ip = request . headers . get ( ' x-forwarded-for ' )
else :
client_ip = request . remote_addr
2023-08-22 23:37:39 -06:00
2023-08-23 21:33:52 -06:00
SemaphoreCheckerThread . recent_prompters [ client_ip ] = time . time ( )
2023-08-22 23:37:39 -06:00
2023-08-23 21:33:52 -06:00
parameters = request_json_body . copy ( )
del parameters [ ' prompt ' ]
2023-08-22 20:28:41 -06:00
2023-08-23 21:33:52 -06:00
token = request . headers . get ( ' X-Api-Key ' )
priority = None
if token :
conn = sqlite3 . connect ( opts . database_path )
cursor = conn . cursor ( )
cursor . execute ( " SELECT priority FROM token_auth WHERE token = ? " , ( token , ) )
result = cursor . fetchone ( )
if result :
priority = result [ 0 ]
conn . close ( )
2023-08-23 20:12:38 -06:00
2023-08-23 21:33:52 -06:00
if priority is None :
priority = DEFAULT_PRIORITY
else :
print ( f ' Token { token } was given priority { priority } . ' )
2023-08-23 20:12:38 -06:00
2023-08-23 21:33:52 -06:00
event = priority_queue . put ( ( request_json_body , client_ip , token , parameters ) , priority )
event . wait ( )
success , response , error_msg = event . data
2023-08-23 20:12:38 -06:00
2023-08-23 21:33:52 -06:00
# Add the elapsed time to a global list
end_time = time . time ( )
elapsed_time = end_time - start_time
# print('elapsed:', elapsed_time)
2023-08-23 22:24:32 -06:00
# with wait_in_queue_elapsed_lock:
# wait_in_queue_elapsed.append((end_time, elapsed_time))
2023-08-21 21:28:52 -06:00
2023-08-23 21:33:52 -06:00
if not success :
if opts . mode == ' oobabooga ' :
backend_response = format_sillytavern_err ( f ' Failed to reach the backend ( { opts . mode } ): { error_msg } ' , ' error ' )
response_json_body = {
' results ' : [
{
' text ' : backend_response ,
}
] ,
}
2023-08-21 21:28:52 -06:00
else :
2023-08-23 21:33:52 -06:00
raise Exception
2023-08-23 22:20:39 -06:00
log_prompt ( client_ip , token , request_json_body [ ' prompt ' ] , backend_response , elapsed_time , parameters , dict ( request . headers ) , response . status_code )
2023-08-23 21:33:52 -06:00
return jsonify ( {
' code ' : 500 ,
' error ' : ' failed to reach backend ' ,
* * response_json_body
} ) , 200
response_valid_json , response_json_body = validate_json ( response )
if response_valid_json :
redis . incr ( ' proompts ' )
backend_response = safe_list_get ( response_json_body . get ( ' results ' , [ ] ) , 0 , { } ) . get ( ' text ' )
if not backend_response :
2023-08-22 20:28:41 -06:00
if opts . mode == ' oobabooga ' :
2023-08-23 21:33:52 -06:00
backend_response = format_sillytavern_err (
f ' Backend (oobabooga) returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than { opts . context_size } . Furthermore, oobabooga does not support concurrent requests so all users have to wait in line and the backend server may have glitched for a moment. Please try again. ' ,
' error ' )
response_json_body [ ' results ' ] [ 0 ] [ ' text ' ] = backend_response
2023-08-22 20:28:41 -06:00
else :
raise Exception
2023-08-23 21:33:52 -06:00
2023-08-23 22:20:39 -06:00
log_prompt ( client_ip , token , request_json_body [ ' prompt ' ] , backend_response , elapsed_time , parameters , dict ( request . headers ) , response . status_code )
2023-08-23 21:33:52 -06:00
return jsonify ( {
* * response_json_body
} ) , 200
else :
if opts . mode == ' oobabooga ' :
backend_response = format_sillytavern_err ( f ' The backend did not return valid JSON. ' , ' error ' )
response_json_body = {
' results ' : [
{
' text ' : backend_response ,
}
] ,
}
else :
raise Exception
2023-08-23 22:20:39 -06:00
log_prompt ( client_ip , token , request_json_body [ ' prompt ' ] , backend_response , elapsed_time , parameters , dict ( request . headers ) , response . status_code )
2023-08-23 21:33:52 -06:00
return jsonify ( {
' code ' : 500 ,
' error ' : ' the backend did not return valid JSON ' ,
* * response_json_body
} ) , 200