This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.
local-llm-server/llm_server/routes/v1/generate.py

119 lines
4.4 KiB
Python

import sqlite3
import time
from flask import jsonify, request
from llm_server.routes.stats import SemaphoreCheckerThread
from . import bp
from ..cache import redis
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json
from ..queue import priority_queue
from ... import opts
from ...database import log_prompt
from ...helpers import safe_list_get
DEFAULT_PRIORITY = 9999
@bp.route('/generate', methods=['POST'])
def generate():
start_time = time.time()
request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json:
return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400
if request.headers.get('cf-connecting-ip'):
client_ip = request.headers.get('cf-connecting-ip')
elif request.headers.get('x-forwarded-for'):
client_ip = request.headers.get('x-forwarded-for')
else:
client_ip = request.remote_addr
SemaphoreCheckerThread.recent_prompters[client_ip] = time.time()
parameters = request_json_body.copy()
del parameters['prompt']
token = request.headers.get('X-Api-Key')
priority = None
if token:
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,))
result = cursor.fetchone()
if result:
priority = result[0]
conn.close()
if priority is None:
priority = DEFAULT_PRIORITY
else:
print(f'Token {token} was given priority {priority}.')
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
event.wait()
success, response, error_msg = event.data
# Add the elapsed time to a global list
end_time = time.time()
elapsed_time = end_time - start_time
# print('elapsed:', elapsed_time)
# with wait_in_queue_elapsed_lock:
# wait_in_queue_elapsed.append((end_time, elapsed_time))
if not success:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
response_json_body = {
'results': [
{
'text': backend_response,
}
],
}
else:
raise Exception
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, elapsed_time, parameters, dict(request.headers), response.status_code)
return jsonify({
'code': 500,
'error': 'failed to reach backend',
**response_json_body
}), 200
response_valid_json, response_json_body = validate_json(response)
if response_valid_json:
redis.incr('proompts')
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
if not backend_response:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(
f'Backend (oobabooga) returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than {opts.context_size}. Furthermore, oobabooga does not support concurrent requests so all users have to wait in line and the backend server may have glitched for a moment. Please try again.',
'error')
response_json_body['results'][0]['text'] = backend_response
else:
raise Exception
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, elapsed_time, parameters, dict(request.headers), response.status_code)
return jsonify({
**response_json_body
}), 200
else:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
response_json_body = {
'results': [
{
'text': backend_response,
}
],
}
else:
raise Exception
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, elapsed_time, parameters, dict(request.headers), response.status_code)
return jsonify({
'code': 500,
'error': 'the backend did not return valid JSON',
**response_json_body
}), 200