import time import traceback from flask import jsonify, make_response, request from . import openai_bp from ..cache import redis from ..helpers.client import format_sillytavern_err from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler from ... import opts from ...llm import get_token_count from ...llm.openai.transform import build_openai_response, generate_oai_string # TODO: add rate-limit headers? @openai_bp.route('/completions', methods=['POST']) def openai_completions(): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: try: response, status_code = OobaRequestHandler(request).handle_request() if status_code != 200: return status_code output = response.json['results'][0]['text'] # TODO: async/await prompt_tokens = get_token_count(request_json_body['prompt']) response_tokens = get_token_count(output) running_model = redis.get('running_model', str, 'ERROR') response = make_response(jsonify({ "id": f"cmpl-{generate_oai_string(30)}", "object": "text_completion", "created": int(time.time()), "model": running_model if opts.openai_expose_our_model else request_json_body.get('model'), "choices": [ { "text": output, "index": 0, "logprobs": None, "finish_reason": None } ], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": response_tokens, "total_tokens": prompt_tokens + response_tokens } }), 200) stats = redis.get('proxy_stats', dict) if stats: response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] return response except Exception: traceback.print_exc() return 'Internal Server Error', 500