This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.
local-llm-server/llm_server/routes/openai/completions.py

62 lines
2.2 KiB
Python
Raw Normal View History

2023-09-26 22:09:11 -06:00
import time
import traceback
from flask import jsonify, make_response, request
2023-09-26 22:09:11 -06:00
from . import openai_bp
from ..cache import redis
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
from ... import opts
from ...llm import get_token_count
from ...llm.openai.transform import build_openai_response, generate_oai_string
2023-09-26 22:09:11 -06:00
# TODO: add rate-limit headers?
@openai_bp.route('/completions', methods=['POST'])
def openai_completions():
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
try:
response, status_code = OobaRequestHandler(request).handle_request()
if status_code != 200:
return status_code
output = response.json['results'][0]['text']
# TODO: async/await
prompt_tokens = get_token_count(request_json_body['prompt'])
response_tokens = get_token_count(output)
running_model = redis.get('running_model', str, 'ERROR')
response = make_response(jsonify({
2023-09-26 22:09:11 -06:00
"id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
"choices": [
{
"text": output,
"index": 0,
"logprobs": None,
"finish_reason": None
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
stats = redis.get('proxy_stats', dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
2023-09-27 14:48:47 -06:00
except Exception:
traceback.print_exc()
return 'Internal Server Error', 500