further align openai endpoint with expected responses

This commit is contained in:
Cyberes 2023-09-24 21:45:30 -06:00
parent 84ea2f8891
commit 320f51e01c
12 changed files with 67 additions and 17 deletions

View File

@ -10,7 +10,6 @@ from llm_server.llm.vllm import tokenize
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
prompt_tokens = llm_server.llm.get_token_count(prompt)
if not is_error:
if not response_tokens:
response_tokens = llm_server.llm.get_token_count(response)

View File

@ -39,7 +39,7 @@ class OobaboogaBackend(LLMBackend):
'code': 500,
'msg': error_msg,
'results': [{'text': backend_response}]
}), 200
}), 400
# ===============================================
@ -67,7 +67,7 @@ class OobaboogaBackend(LLMBackend):
'code': 500,
'msg': 'the backend did not return valid JSON',
'results': [{'text': backend_response}]
}), 200
}), 400
def validate_params(self, params_dict: dict):
# No validation required

View File

@ -9,8 +9,8 @@ from redis.typing import FieldT
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'})
ONE_MONTH_SECONDS = 2678000
# redis = Redis()
class RedisWrapper:
"""

View File

@ -33,9 +33,9 @@ class OobaRequestHandler(RequestHandler):
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
return jsonify({
'results': [{'text': backend_response}]
}), 200
}), 429
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
return jsonify({
'results': [{'text': msg}]
}), 200
}), 400

View File

@ -31,3 +31,4 @@ def handle_error(e):
from .models import openai_list_models
from .chat_completions import openai_chat_completions
from .info import get_openai_info
from .simulated import *

View File

@ -8,6 +8,8 @@ from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response
# TODO: add rate-limit headers?
@openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions():
request_valid_json, request_json_body = validate_json(request)
@ -20,4 +22,4 @@ def openai_chat_completions():
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
traceback.print_exc()
print(request.data)
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 200
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500

View File

@ -1,10 +1,12 @@
from flask import Response
from . import openai_bp
from ..cache import cache
from ... import opts
@openai_bp.route('/prompt', methods=['GET'])
@cache.cached(timeout=2678000, query_string=True)
def get_openai_info():
if opts.expose_openai_system_prompt:
resp = Response(opts.openai_system_prompt)

View File

@ -1,13 +1,15 @@
from flask import jsonify, request
from . import openai_bp
from ..cache import cache, redis
from ..cache import ONE_MONTH_SECONDS, cache, redis
from ..stats import server_start_time
from ... import opts
from ...llm.info import get_running_model
import openai
@openai_bp.route('/models', methods=['GET'])
@cache.cached(timeout=60, query_string=True)
def openai_list_models():
cache_key = 'openai_model_cache::' + request.url
cached_response = cache.get(cache_key)
@ -23,7 +25,8 @@ def openai_list_models():
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
response = jsonify({
oai = fetch_openai_models()
r = {
"object": "list",
"data": [
{
@ -51,7 +54,13 @@ def openai_list_models():
"parent": None
}
]
}), 200
}
response = jsonify({**r, **oai}), 200
cache.set(cache_key, response, timeout=60)
return response
@cache.memoize(timeout=ONE_MONTH_SECONDS)
def fetch_openai_models():
return openai.Model.list()

View File

@ -0,0 +1,26 @@
from flask import jsonify
from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, cache
from ..stats import server_start_time
@openai_bp.route('/organizations', methods=['GET'])
@cache.cached(timeout=ONE_MONTH_SECONDS, query_string=True)
def openai_organizations():
return jsonify({
"object": "list",
"data": [
{
"object": "organization",
"id": "org-abCDEFGHiJklmNOPqrSTUVWX",
"created": int(server_start_time.timestamp()),
"title": "Personal",
"name": "user-abcdefghijklmnopqrstuvwx",
"description": "Personal org for bobjoe@0.0.0.0",
"personal": True,
"is_default": True,
"role": "owner"
}
]
})

View File

@ -25,8 +25,7 @@ class OpenAIRequestHandler(RequestHandler):
self.prompt = None
def handle_request(self) -> Tuple[flask.Response, int]:
if self.used:
raise Exception
assert not self.used
request_valid, invalid_response = self.validate_request()
if not request_valid:
@ -69,6 +68,7 @@ class OpenAIRequestHandler(RequestHandler):
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code
else:
@ -77,9 +77,10 @@ class OpenAIRequestHandler(RequestHandler):
def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True)
return build_openai_response(self.prompt, backend_response), 200
return build_openai_response(self.prompt, backend_response), 429
def transform_messages_to_prompt(self):
# TODO: add some way of cutting the user's prompt down so that we can fit the system prompt and moderation endpoint response
try:
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
for msg in self.request.json['messages']:
@ -104,7 +105,15 @@ class OpenAIRequestHandler(RequestHandler):
return prompt
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
return build_openai_response('', msg), 200
# return build_openai_response('', msg), 400
return jsonify({
"error": {
"message": "Invalid request, check your parameters and try again.",
"type": "invalid_request_error",
"param": None,
"code": None
}
}), 400
def check_moderation_endpoint(prompt: str):

View File

@ -20,4 +20,4 @@ def generate():
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
print(traceback.format_exc())
print(request.data)
return format_sillytavern_err(f'Server encountered exception.', 'error'), 200
return format_sillytavern_err(f'Server encountered exception.', 'error'), 500

View File

@ -3,6 +3,7 @@ import sys
from pathlib import Path
from threading import Thread
import openai
import simplejson as json
from flask import Flask, jsonify, render_template, request
@ -84,6 +85,7 @@ opts.openai_system_prompt = config['openai_system_prompt']
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
opts.enable_streaming = config['enable_streaming']
opts.openai_api_key = config['openai_api_key']
openai.api_key = opts.openai_api_key
opts.admin_token = config['admin_token']
if config['http_host']:
@ -183,8 +185,8 @@ def home():
analytics_tracking_code=analytics_tracking_code,
info_html=info_html,
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
client_api=stats['endpoints']['blocking'],
ws_client_api=stats['endpoints']['streaming'],
client_api=f'https://{base_client_api}',
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
estimated_wait=estimated_wait_sec,
mode_name=mode_ui_names[opts.mode][0],
api_input_textbox=mode_ui_names[opts.mode][1],