2023-08-21 21:28:52 -06:00
|
|
|
from flask import jsonify, request
|
|
|
|
|
|
|
|
from . import bp
|
|
|
|
from llm_server.routes.stats import concurrent_semaphore, proompts
|
2023-08-21 22:49:44 -06:00
|
|
|
from ..helpers.http import cache_control, validate_json
|
2023-08-21 21:28:52 -06:00
|
|
|
from ... import opts
|
|
|
|
from ...database import log_prompt
|
|
|
|
|
|
|
|
|
2023-08-22 19:58:31 -06:00
|
|
|
def generator(request_json_body):
|
|
|
|
if opts.mode == 'oobabooga':
|
|
|
|
from ...llm.oobabooga.generate import generate
|
|
|
|
return generate(request_json_body)
|
|
|
|
elif opts.mode == 'hf-textgen':
|
|
|
|
from ...llm.hf_textgen.generate import generate
|
|
|
|
return generate(request_json_body)
|
|
|
|
else:
|
|
|
|
raise Exception
|
2023-08-21 21:28:52 -06:00
|
|
|
|
|
|
|
|
|
|
|
@bp.route('/generate', methods=['POST'])
|
2023-08-21 22:49:44 -06:00
|
|
|
@cache_control(-1)
|
2023-08-21 21:28:52 -06:00
|
|
|
def generate():
|
|
|
|
request_valid_json, request_json_body = validate_json(request.data)
|
|
|
|
if not request_valid_json:
|
|
|
|
return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400
|
|
|
|
|
|
|
|
with concurrent_semaphore:
|
|
|
|
success, response, error_msg = generator(request_json_body)
|
|
|
|
if not success:
|
|
|
|
return jsonify({
|
|
|
|
'code': 500,
|
|
|
|
'error': 'failed to reach backend'
|
|
|
|
}), 500
|
|
|
|
response_valid_json, response_json_body = validate_json(response)
|
|
|
|
if response_valid_json:
|
|
|
|
proompts.increment()
|
|
|
|
|
|
|
|
# request.headers = {"host": "proxy.chub-archive.evulid.cc", "x-forwarded-proto": "https", "user-agent": "node-fetch/1.0 (+https://github.com/bitinn/node-fetch)", "cf-visitor": {"scheme": "https"}, "cf-ipcountry": "CH", "accept": "*/*", "accept-encoding": "gzip",
|
|
|
|
# "x-forwarded-for": "193.32.127.228", "cf-ray": "7fa72c6a6d5cbba7-FRA", "cf-connecting-ip": "193.32.127.228", "cdn-loop": "cloudflare", "content-type": "application/json", "content-length": "9039"}
|
|
|
|
|
|
|
|
if request.headers.get('cf-connecting-ip'):
|
|
|
|
client_ip = request.headers.get('cf-connecting-ip')
|
|
|
|
elif request.headers.get('x-forwarded-for'):
|
|
|
|
client_ip = request.headers.get('x-forwarded-for')
|
|
|
|
else:
|
|
|
|
client_ip = request.remote_addr
|
|
|
|
|
|
|
|
parameters = request_json_body.copy()
|
|
|
|
del parameters['prompt']
|
|
|
|
|
|
|
|
token = request.headers.get('X-Api-Key')
|
|
|
|
|
2023-08-22 19:58:31 -06:00
|
|
|
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers), response.status_code)
|
2023-08-21 21:28:52 -06:00
|
|
|
return jsonify({
|
|
|
|
**response_json_body
|
|
|
|
}), 200
|
|
|
|
else:
|
|
|
|
return jsonify({
|
|
|
|
'code': 500,
|
|
|
|
'error': 'failed to reach backend'
|
|
|
|
}), 500
|
|
|
|
|
|
|
|
# @openai_bp.route('/chat/completions', methods=['POST'])
|
|
|
|
# def generate_openai():
|
|
|
|
# print(request.data)
|
|
|
|
# return '', 200
|