166 lines
6.4 KiB
Python
166 lines
6.4 KiB
Python
import sqlite3
|
|
import time
|
|
|
|
from flask import jsonify, request
|
|
|
|
from llm_server.routes.stats import SemaphoreCheckerThread
|
|
from . import bp
|
|
from ..cache import redis
|
|
from ..helpers.client import format_sillytavern_err
|
|
from ..helpers.http import validate_json
|
|
from ..queue import priority_queue
|
|
from ... import opts
|
|
from ...database import log_prompt
|
|
from ...helpers import safe_list_get, indefinite_article
|
|
|
|
DEFAULT_PRIORITY = 9999
|
|
|
|
|
|
# TODO: clean this up and make the ooba vs hf-textgen more object-oriented
|
|
|
|
@bp.route('/generate', methods=['POST'])
|
|
def generate():
|
|
start_time = time.time()
|
|
request_valid_json, request_json_body = validate_json(request.data)
|
|
if not request_valid_json:
|
|
return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400
|
|
|
|
if request.headers.get('cf-connecting-ip'):
|
|
client_ip = request.headers.get('cf-connecting-ip')
|
|
elif request.headers.get('x-forwarded-for'):
|
|
client_ip = request.headers.get('x-forwarded-for').split(',')[0]
|
|
else:
|
|
client_ip = request.remote_addr
|
|
|
|
SemaphoreCheckerThread.recent_prompters[client_ip] = time.time()
|
|
|
|
parameters = request_json_body.copy()
|
|
del parameters['prompt']
|
|
|
|
token = request.headers.get('X-Api-Key')
|
|
priority = None
|
|
if token:
|
|
conn = sqlite3.connect(opts.database_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,))
|
|
result = cursor.fetchone()
|
|
if result:
|
|
priority = result[0]
|
|
conn.close()
|
|
|
|
if priority is None:
|
|
priority = DEFAULT_PRIORITY
|
|
else:
|
|
print(f'Token {token} was given priority {priority}.')
|
|
|
|
queued_ip_count = redis.get_dict('queued_ip_count').get(client_ip, 0) + redis.get_dict('processing_ips').get(client_ip, 0)
|
|
if queued_ip_count < opts.ip_in_queue_max or priority == 0:
|
|
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
|
|
else:
|
|
event = None
|
|
if not event:
|
|
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
|
response_json_body = {
|
|
'results': [
|
|
{
|
|
'text': backend_response,
|
|
}
|
|
],
|
|
}
|
|
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, None, parameters, dict(request.headers), 429, is_error=True)
|
|
return jsonify({
|
|
**response_json_body
|
|
}), 200
|
|
|
|
event.wait()
|
|
success, response, error_msg = event.data
|
|
|
|
end_time = time.time()
|
|
elapsed_time = end_time - start_time
|
|
|
|
# Be extra careful when getting attributes from the response object
|
|
try:
|
|
response_status_code = response.status_code
|
|
except:
|
|
response_status_code = 0
|
|
|
|
# TODO: why is this if block sitting here
|
|
if (not success or not response) and opts.mode == 'oobabooga':
|
|
# Ooba doesn't return any error messages
|
|
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
|
|
response_json_body = {
|
|
'results': [
|
|
{
|
|
'text': backend_response,
|
|
}
|
|
],
|
|
}
|
|
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, None, parameters, dict(request.headers), response if response else 0, is_error=True)
|
|
return jsonify({
|
|
'code': 500,
|
|
'error': 'failed to reach backend',
|
|
**response_json_body
|
|
}), 200
|
|
response_valid_json, response_json_body = validate_json(response)
|
|
backend_err = False
|
|
|
|
# Return the result to the client
|
|
if response_valid_json:
|
|
if opts.mode == 'oobabooga':
|
|
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
|
|
if not backend_response:
|
|
backend_err = True
|
|
backend_response = format_sillytavern_err(
|
|
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
|
|
'error')
|
|
response_json_body['results'][0]['text'] = backend_response
|
|
elif opts.mode == 'hf-textgen':
|
|
backend_response = response_json_body.get('generated_text', '')
|
|
if response_json_body.get('error'):
|
|
backend_err = True
|
|
error_type = response_json_body.get('error_type')
|
|
error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error'
|
|
backend_response = format_sillytavern_err(
|
|
f'Backend ({opts.mode}) {error_type_string}: {response_json_body.get("error")}',
|
|
f'HTTP CODE {response_status_code}')
|
|
response_json_body = {
|
|
'results': [
|
|
{
|
|
'text': backend_response
|
|
}
|
|
]
|
|
}
|
|
else:
|
|
response_json_body = {
|
|
'results': [
|
|
{
|
|
'text': backend_response
|
|
}
|
|
]
|
|
}
|
|
else:
|
|
raise Exception
|
|
if not backend_err:
|
|
redis.incr('proompts')
|
|
|
|
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, elapsed_time if not backend_err else None, parameters, dict(request.headers), response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
|
return jsonify({
|
|
**response_json_body
|
|
}), 200
|
|
|
|
else:
|
|
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
|
response_json_body = {
|
|
'results': [
|
|
{
|
|
'text': backend_response,
|
|
}
|
|
],
|
|
}
|
|
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, elapsed_time, parameters, dict(request.headers), response.status_code, is_error=True)
|
|
return jsonify({
|
|
'code': 500,
|
|
'error': 'the backend did not return valid JSON',
|
|
**response_json_body
|
|
}), 200
|