further align openai endpoint with expected responses
This commit is contained in:
parent
84ea2f8891
commit
320f51e01c
|
@ -10,7 +10,6 @@ from llm_server.llm.vllm import tokenize
|
|||
|
||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||
|
||||
if not is_error:
|
||||
if not response_tokens:
|
||||
response_tokens = llm_server.llm.get_token_count(response)
|
||||
|
|
|
@ -39,7 +39,7 @@ class OobaboogaBackend(LLMBackend):
|
|||
'code': 500,
|
||||
'msg': error_msg,
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200
|
||||
}), 400
|
||||
|
||||
# ===============================================
|
||||
|
||||
|
@ -67,7 +67,7 @@ class OobaboogaBackend(LLMBackend):
|
|||
'code': 500,
|
||||
'msg': 'the backend did not return valid JSON',
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200
|
||||
}), 400
|
||||
|
||||
def validate_params(self, params_dict: dict):
|
||||
# No validation required
|
||||
|
|
|
@ -9,8 +9,8 @@ from redis.typing import FieldT
|
|||
|
||||
cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'})
|
||||
|
||||
ONE_MONTH_SECONDS = 2678000
|
||||
|
||||
# redis = Redis()
|
||||
|
||||
class RedisWrapper:
|
||||
"""
|
||||
|
|
|
@ -33,9 +33,9 @@ class OobaRequestHandler(RequestHandler):
|
|||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
|
||||
return jsonify({
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200
|
||||
}), 429
|
||||
|
||||
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||
return jsonify({
|
||||
'results': [{'text': msg}]
|
||||
}), 200
|
||||
}), 400
|
||||
|
|
|
@ -31,3 +31,4 @@ def handle_error(e):
|
|||
from .models import openai_list_models
|
||||
from .chat_completions import openai_chat_completions
|
||||
from .info import get_openai_info
|
||||
from .simulated import *
|
|
@ -8,6 +8,8 @@ from ..helpers.http import validate_json
|
|||
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response
|
||||
|
||||
|
||||
# TODO: add rate-limit headers?
|
||||
|
||||
@openai_bp.route('/chat/completions', methods=['POST'])
|
||||
def openai_chat_completions():
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
|
@ -20,4 +22,4 @@ def openai_chat_completions():
|
|||
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
|
||||
traceback.print_exc()
|
||||
print(request.data)
|
||||
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 200
|
||||
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from flask import Response
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import cache
|
||||
from ... import opts
|
||||
|
||||
|
||||
@openai_bp.route('/prompt', methods=['GET'])
|
||||
@cache.cached(timeout=2678000, query_string=True)
|
||||
def get_openai_info():
|
||||
if opts.expose_openai_system_prompt:
|
||||
resp = Response(opts.openai_system_prompt)
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
from flask import jsonify, request
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import cache, redis
|
||||
from ..cache import ONE_MONTH_SECONDS, cache, redis
|
||||
from ..stats import server_start_time
|
||||
from ... import opts
|
||||
from ...llm.info import get_running_model
|
||||
import openai
|
||||
|
||||
|
||||
@openai_bp.route('/models', methods=['GET'])
|
||||
@cache.cached(timeout=60, query_string=True)
|
||||
def openai_list_models():
|
||||
cache_key = 'openai_model_cache::' + request.url
|
||||
cached_response = cache.get(cache_key)
|
||||
|
@ -23,7 +25,8 @@ def openai_list_models():
|
|||
'type': error.__class__.__name__
|
||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||
else:
|
||||
response = jsonify({
|
||||
oai = fetch_openai_models()
|
||||
r = {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
|
@ -51,7 +54,13 @@ def openai_list_models():
|
|||
"parent": None
|
||||
}
|
||||
]
|
||||
}), 200
|
||||
}
|
||||
response = jsonify({**r, **oai}), 200
|
||||
cache.set(cache_key, response, timeout=60)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@cache.memoize(timeout=ONE_MONTH_SECONDS)
|
||||
def fetch_openai_models():
|
||||
return openai.Model.list()
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
from flask import jsonify
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import ONE_MONTH_SECONDS, cache
|
||||
from ..stats import server_start_time
|
||||
|
||||
|
||||
@openai_bp.route('/organizations', methods=['GET'])
|
||||
@cache.cached(timeout=ONE_MONTH_SECONDS, query_string=True)
|
||||
def openai_organizations():
|
||||
return jsonify({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "organization",
|
||||
"id": "org-abCDEFGHiJklmNOPqrSTUVWX",
|
||||
"created": int(server_start_time.timestamp()),
|
||||
"title": "Personal",
|
||||
"name": "user-abcdefghijklmnopqrstuvwx",
|
||||
"description": "Personal org for bobjoe@0.0.0.0",
|
||||
"personal": True,
|
||||
"is_default": True,
|
||||
"role": "owner"
|
||||
}
|
||||
]
|
||||
})
|
|
@ -25,8 +25,7 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
self.prompt = None
|
||||
|
||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||
if self.used:
|
||||
raise Exception
|
||||
assert not self.used
|
||||
|
||||
request_valid, invalid_response = self.validate_request()
|
||||
if not request_valid:
|
||||
|
@ -69,6 +68,7 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
|
||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||
|
||||
if success:
|
||||
return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code
|
||||
else:
|
||||
|
@ -77,9 +77,10 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
def handle_ratelimited(self):
|
||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||
log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True)
|
||||
return build_openai_response(self.prompt, backend_response), 200
|
||||
return build_openai_response(self.prompt, backend_response), 429
|
||||
|
||||
def transform_messages_to_prompt(self):
|
||||
# TODO: add some way of cutting the user's prompt down so that we can fit the system prompt and moderation endpoint response
|
||||
try:
|
||||
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
|
||||
for msg in self.request.json['messages']:
|
||||
|
@ -104,7 +105,15 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
return prompt
|
||||
|
||||
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||
return build_openai_response('', msg), 200
|
||||
# return build_openai_response('', msg), 400
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": "Invalid request, check your parameters and try again.",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None
|
||||
}
|
||||
}), 400
|
||||
|
||||
|
||||
def check_moderation_endpoint(prompt: str):
|
||||
|
|
|
@ -20,4 +20,4 @@ def generate():
|
|||
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
|
||||
print(traceback.format_exc())
|
||||
print(request.data)
|
||||
return format_sillytavern_err(f'Server encountered exception.', 'error'), 200
|
||||
return format_sillytavern_err(f'Server encountered exception.', 'error'), 500
|
||||
|
|
|
@ -3,6 +3,7 @@ import sys
|
|||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import openai
|
||||
import simplejson as json
|
||||
from flask import Flask, jsonify, render_template, request
|
||||
|
||||
|
@ -84,6 +85,7 @@ opts.openai_system_prompt = config['openai_system_prompt']
|
|||
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
|
||||
opts.enable_streaming = config['enable_streaming']
|
||||
opts.openai_api_key = config['openai_api_key']
|
||||
openai.api_key = opts.openai_api_key
|
||||
opts.admin_token = config['admin_token']
|
||||
|
||||
if config['http_host']:
|
||||
|
@ -183,8 +185,8 @@ def home():
|
|||
analytics_tracking_code=analytics_tracking_code,
|
||||
info_html=info_html,
|
||||
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
|
||||
client_api=stats['endpoints']['blocking'],
|
||||
ws_client_api=stats['endpoints']['streaming'],
|
||||
client_api=f'https://{base_client_api}',
|
||||
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
|
||||
estimated_wait=estimated_wait_sec,
|
||||
mode_name=mode_ui_names[opts.mode][0],
|
||||
api_input_textbox=mode_ui_names[opts.mode][1],
|
||||
|
|
Reference in New Issue