2023-08-22 23:37:39 -06:00
import time
2023-08-21 21:28:52 -06:00
from flask import jsonify , request
2023-08-23 01:14:19 -06:00
from llm_server . routes . stats import SemaphoreCheckerThread , concurrent_semaphore
2023-08-22 20:28:41 -06:00
from . import bp
2023-08-23 01:14:19 -06:00
from . . cache import redis
2023-08-22 20:28:41 -06:00
from . . helpers . client import format_sillytavern_err
2023-08-21 22:49:44 -06:00
from . . helpers . http import cache_control , validate_json
2023-08-21 21:28:52 -06:00
from . . . import opts
from . . . database import log_prompt
2023-08-22 20:28:41 -06:00
from . . . helpers import safe_list_get
2023-08-21 21:28:52 -06:00
2023-08-22 19:58:31 -06:00
def generator ( request_json_body ) :
if opts . mode == ' oobabooga ' :
from . . . llm . oobabooga . generate import generate
return generate ( request_json_body )
elif opts . mode == ' hf-textgen ' :
from . . . llm . hf_textgen . generate import generate
return generate ( request_json_body )
else :
raise Exception
2023-08-21 21:28:52 -06:00
@bp.route ( ' /generate ' , methods = [ ' POST ' ] )
2023-08-21 22:49:44 -06:00
@cache_control ( - 1 )
2023-08-21 21:28:52 -06:00
def generate ( ) :
request_valid_json , request_json_body = validate_json ( request . data )
if not request_valid_json :
return jsonify ( { ' code ' : 400 , ' error ' : ' Invalid JSON ' } ) , 400
with concurrent_semaphore :
2023-08-22 20:28:41 -06:00
if request . headers . get ( ' cf-connecting-ip ' ) :
client_ip = request . headers . get ( ' cf-connecting-ip ' )
elif request . headers . get ( ' x-forwarded-for ' ) :
client_ip = request . headers . get ( ' x-forwarded-for ' )
else :
client_ip = request . remote_addr
2023-08-22 23:37:39 -06:00
SemaphoreCheckerThread . recent_prompters [ client_ip ] = time . time ( )
2023-08-22 20:28:41 -06:00
token = request . headers . get ( ' X-Api-Key ' )
parameters = request_json_body . copy ( )
del parameters [ ' prompt ' ]
2023-08-21 21:28:52 -06:00
success , response , error_msg = generator ( request_json_body )
if not success :
2023-08-22 20:28:41 -06:00
if opts . mode == ' oobabooga ' :
2023-08-22 21:14:12 -06:00
backend_response = format_sillytavern_err ( f ' Failed to reach the backend ( { opts . mode } ): { error_msg } ' , ' error ' )
2023-08-22 20:28:41 -06:00
response_json_body = {
' results ' : [
{
' text ' : backend_response ,
}
] ,
}
else :
raise Exception
log_prompt ( opts . database_path , client_ip , token , request_json_body [ ' prompt ' ] , backend_response , parameters , dict ( request . headers ) , response . status_code )
2023-08-21 21:28:52 -06:00
return jsonify ( {
' code ' : 500 ,
2023-08-22 20:28:41 -06:00
' error ' : ' failed to reach backend ' ,
* * response_json_body
} ) , 200
2023-08-21 21:28:52 -06:00
response_valid_json , response_json_body = validate_json ( response )
if response_valid_json :
2023-08-23 01:14:19 -06:00
redis . incr ( ' proompts ' )
2023-08-22 20:28:41 -06:00
backend_response = safe_list_get ( response_json_body . get ( ' results ' , [ ] ) , 0 , { } ) . get ( ' text ' )
if not backend_response :
if opts . mode == ' oobabooga ' :
2023-08-22 22:32:29 -06:00
backend_response = format_sillytavern_err (
f ' Backend (oobabooga) returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than { opts . context_size } . Furthermore, oobabooga does not support concurrent requests so all users have to wait in line and the backend server may have glitched for a moment. Please try again. ' ,
' error ' )
2023-08-22 20:28:41 -06:00
response_json_body [ ' results ' ] [ 0 ] [ ' text ' ] = backend_response
else :
raise Exception
2023-08-21 21:28:52 -06:00
2023-08-22 20:28:41 -06:00
log_prompt ( opts . database_path , client_ip , token , request_json_body [ ' prompt ' ] , backend_response , parameters , dict ( request . headers ) , response . status_code )
2023-08-21 21:28:52 -06:00
return jsonify ( {
* * response_json_body
} ) , 200
else :
2023-08-22 20:28:41 -06:00
if opts . mode == ' oobabooga ' :
backend_response = format_sillytavern_err ( f ' The backend did not return valid JSON. ' , ' error ' )
response_json_body = {
' results ' : [
{
' text ' : backend_response ,
}
] ,
}
else :
raise Exception
log_prompt ( opts . database_path , client_ip , token , request_json_body [ ' prompt ' ] , backend_response , parameters , dict ( request . headers ) , response . status_code )
2023-08-21 21:28:52 -06:00
return jsonify ( {
' code ' : 500 ,
2023-08-22 20:28:41 -06:00
' error ' : ' the backend did not return valid JSON ' ,
* * response_json_body
} ) , 200
2023-08-21 21:28:52 -06:00
# @openai_bp.route('/chat/completions', methods=['POST'])
# def generate_openai():
# print(request.data)
# return '', 200