2023-09-26 22:09:11 -06:00
|
|
|
import time
|
|
|
|
import traceback
|
|
|
|
|
2023-09-26 23:59:22 -06:00
|
|
|
from flask import jsonify, make_response, request
|
2023-09-26 22:09:11 -06:00
|
|
|
|
|
|
|
from . import openai_bp
|
|
|
|
from ..cache import redis
|
|
|
|
from ..helpers.client import format_sillytavern_err
|
|
|
|
from ..helpers.http import validate_json
|
|
|
|
from ..ooba_request_handler import OobaRequestHandler
|
|
|
|
from ... import opts
|
|
|
|
from ...llm import get_token_count
|
2023-09-26 23:59:22 -06:00
|
|
|
from ...llm.openai.transform import build_openai_response, generate_oai_string
|
2023-09-26 22:09:11 -06:00
|
|
|
|
|
|
|
|
|
|
|
# TODO: add rate-limit headers?
|
|
|
|
|
|
|
|
@openai_bp.route('/completions', methods=['POST'])
|
|
|
|
def openai_completions():
|
|
|
|
request_valid_json, request_json_body = validate_json(request)
|
|
|
|
if not request_valid_json or not request_json_body.get('prompt'):
|
|
|
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
response, status_code = OobaRequestHandler(request).handle_request()
|
|
|
|
if status_code != 200:
|
|
|
|
return status_code
|
|
|
|
output = response.json['results'][0]['text']
|
|
|
|
|
|
|
|
# TODO: async/await
|
|
|
|
prompt_tokens = get_token_count(request_json_body['prompt'])
|
|
|
|
response_tokens = get_token_count(output)
|
|
|
|
running_model = redis.get('running_model', str, 'ERROR')
|
|
|
|
|
2023-09-26 23:59:22 -06:00
|
|
|
response = make_response(jsonify({
|
2023-09-26 22:09:11 -06:00
|
|
|
"id": f"cmpl-{generate_oai_string(30)}",
|
|
|
|
"object": "text_completion",
|
|
|
|
"created": int(time.time()),
|
|
|
|
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
|
|
|
|
"choices": [
|
|
|
|
{
|
|
|
|
"text": output,
|
|
|
|
"index": 0,
|
|
|
|
"logprobs": None,
|
|
|
|
"finish_reason": None
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"usage": {
|
|
|
|
"prompt_tokens": prompt_tokens,
|
|
|
|
"completion_tokens": response_tokens,
|
|
|
|
"total_tokens": prompt_tokens + response_tokens
|
|
|
|
}
|
2023-09-26 23:59:22 -06:00
|
|
|
}), 200)
|
|
|
|
|
|
|
|
stats = redis.get('proxy_stats', dict)
|
|
|
|
if stats:
|
|
|
|
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
|
|
|
return response
|
2023-09-27 14:48:47 -06:00
|
|
|
except Exception:
|
|
|
|
traceback.print_exc()
|
|
|
|
return 'Internal Server Error', 500
|