local-llm-server/llm_server/routes/v1/generate.py

102 lines
4.0 KiB
Python
Raw Normal View History

2023-08-21 21:28:52 -06:00
from flask import jsonify, request
from llm_server.routes.stats import concurrent_semaphore, proompts
2023-08-22 20:28:41 -06:00
from . import bp
from ..helpers.client import format_sillytavern_err
2023-08-21 22:49:44 -06:00
from ..helpers.http import cache_control, validate_json
2023-08-21 21:28:52 -06:00
from ... import opts
from ...database import log_prompt
2023-08-22 20:28:41 -06:00
from ...helpers import safe_list_get
2023-08-21 21:28:52 -06:00
def generator(request_json_body):
if opts.mode == 'oobabooga':
from ...llm.oobabooga.generate import generate
return generate(request_json_body)
elif opts.mode == 'hf-textgen':
from ...llm.hf_textgen.generate import generate
return generate(request_json_body)
else:
raise Exception
2023-08-21 21:28:52 -06:00
@bp.route('/generate', methods=['POST'])
2023-08-21 22:49:44 -06:00
@cache_control(-1)
2023-08-21 21:28:52 -06:00
def generate():
request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json:
return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400
with concurrent_semaphore:
2023-08-22 20:28:41 -06:00
if request.headers.get('cf-connecting-ip'):
client_ip = request.headers.get('cf-connecting-ip')
elif request.headers.get('x-forwarded-for'):
client_ip = request.headers.get('x-forwarded-for')
else:
client_ip = request.remote_addr
token = request.headers.get('X-Api-Key')
parameters = request_json_body.copy()
del parameters['prompt']
2023-08-21 21:28:52 -06:00
success, response, error_msg = generator(request_json_body)
if not success:
2023-08-22 20:28:41 -06:00
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'Failed to reach the backend: {error_msg}', 'error')
response_json_body = {
'results': [
{
'text': backend_response,
}
],
}
else:
raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
2023-08-21 21:28:52 -06:00
return jsonify({
'code': 500,
2023-08-22 20:28:41 -06:00
'error': 'failed to reach backend',
**response_json_body
}), 200
2023-08-21 21:28:52 -06:00
response_valid_json, response_json_body = validate_json(response)
if response_valid_json:
proompts.increment()
2023-08-22 20:28:41 -06:00
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
if not backend_response:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'Backend returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than {opts.token_limit}.', 'error')
response_json_body['results'][0]['text'] = backend_response
else:
raise Exception
2023-08-21 21:28:52 -06:00
2023-08-22 20:28:41 -06:00
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
2023-08-21 21:28:52 -06:00
2023-08-22 20:28:41 -06:00
print(response_json_body)
2023-08-21 21:28:52 -06:00
return jsonify({
**response_json_body
}), 200
else:
2023-08-22 20:28:41 -06:00
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
response_json_body = {
'results': [
{
'text': backend_response,
}
],
}
else:
raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
2023-08-21 21:28:52 -06:00
return jsonify({
'code': 500,
2023-08-22 20:28:41 -06:00
'error': 'the backend did not return valid JSON',
**response_json_body
}), 200
2023-08-21 21:28:52 -06:00
# @openai_bp.route('/chat/completions', methods=['POST'])
# def generate_openai():
# print(request.data)
# return '', 200