diff --git a/llm_server/database/database.py b/llm_server/database/database.py index 784939e..1c8f0ae 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -10,7 +10,6 @@ from llm_server.llm.vllm import tokenize def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): prompt_tokens = llm_server.llm.get_token_count(prompt) - if not is_error: if not response_tokens: response_tokens = llm_server.llm.get_token_count(response) diff --git a/llm_server/llm/oobabooga/ooba_backend.py b/llm_server/llm/oobabooga/ooba_backend.py index 578f663..4336756 100644 --- a/llm_server/llm/oobabooga/ooba_backend.py +++ b/llm_server/llm/oobabooga/ooba_backend.py @@ -39,7 +39,7 @@ class OobaboogaBackend(LLMBackend): 'code': 500, 'msg': error_msg, 'results': [{'text': backend_response}] - }), 200 + }), 400 # =============================================== @@ -67,7 +67,7 @@ class OobaboogaBackend(LLMBackend): 'code': 500, 'msg': 'the backend did not return valid JSON', 'results': [{'text': backend_response}] - }), 200 + }), 400 def validate_params(self, params_dict: dict): # No validation required diff --git a/llm_server/routes/cache.py b/llm_server/routes/cache.py index daaf2da..ae4e3cf 100644 --- a/llm_server/routes/cache.py +++ b/llm_server/routes/cache.py @@ -9,8 +9,8 @@ from redis.typing import FieldT cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'}) +ONE_MONTH_SECONDS = 2678000 -# redis = Redis() class RedisWrapper: """ diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index d3ca482..3f021b5 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -33,9 +33,9 @@ class OobaRequestHandler(RequestHandler): log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True) return jsonify({ 'results': [{'text': backend_response}] - }), 200 + }), 429 def handle_error(self, msg: str) -> Tuple[flask.Response, int]: return jsonify({ 'results': [{'text': msg}] - }), 200 + }), 400 diff --git a/llm_server/routes/openai/__init__.py b/llm_server/routes/openai/__init__.py index c4f521a..dc9f979 100644 --- a/llm_server/routes/openai/__init__.py +++ b/llm_server/routes/openai/__init__.py @@ -31,3 +31,4 @@ def handle_error(e): from .models import openai_list_models from .chat_completions import openai_chat_completions from .info import get_openai_info +from .simulated import * \ No newline at end of file diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index d0be27a..442412c 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -8,6 +8,8 @@ from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler, build_openai_response +# TODO: add rate-limit headers? + @openai_bp.route('/chat/completions', methods=['POST']) def openai_chat_completions(): request_valid_json, request_json_body = validate_json(request) @@ -20,4 +22,4 @@ def openai_chat_completions(): print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}') traceback.print_exc() print(request.data) - return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 200 + return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500 diff --git a/llm_server/routes/openai/info.py b/llm_server/routes/openai/info.py index 3e98e2f..f0f112e 100644 --- a/llm_server/routes/openai/info.py +++ b/llm_server/routes/openai/info.py @@ -1,10 +1,12 @@ from flask import Response from . import openai_bp +from ..cache import cache from ... import opts @openai_bp.route('/prompt', methods=['GET']) +@cache.cached(timeout=2678000, query_string=True) def get_openai_info(): if opts.expose_openai_system_prompt: resp = Response(opts.openai_system_prompt) diff --git a/llm_server/routes/openai/models.py b/llm_server/routes/openai/models.py index 9d6a223..1aa9266 100644 --- a/llm_server/routes/openai/models.py +++ b/llm_server/routes/openai/models.py @@ -1,13 +1,15 @@ from flask import jsonify, request from . import openai_bp -from ..cache import cache, redis +from ..cache import ONE_MONTH_SECONDS, cache, redis from ..stats import server_start_time from ... import opts from ...llm.info import get_running_model +import openai @openai_bp.route('/models', methods=['GET']) +@cache.cached(timeout=60, query_string=True) def openai_list_models(): cache_key = 'openai_model_cache::' + request.url cached_response = cache.get(cache_key) @@ -23,7 +25,8 @@ def openai_list_models(): 'type': error.__class__.__name__ }), 500 # return 500 so Cloudflare doesn't intercept us else: - response = jsonify({ + oai = fetch_openai_models() + r = { "object": "list", "data": [ { @@ -51,7 +54,13 @@ def openai_list_models(): "parent": None } ] - }), 200 + } + response = jsonify({**r, **oai}), 200 cache.set(cache_key, response, timeout=60) return response + + +@cache.memoize(timeout=ONE_MONTH_SECONDS) +def fetch_openai_models(): + return openai.Model.list() diff --git a/llm_server/routes/openai/simulated.py b/llm_server/routes/openai/simulated.py new file mode 100644 index 0000000..7e80f25 --- /dev/null +++ b/llm_server/routes/openai/simulated.py @@ -0,0 +1,26 @@ +from flask import jsonify + +from . import openai_bp +from ..cache import ONE_MONTH_SECONDS, cache +from ..stats import server_start_time + + +@openai_bp.route('/organizations', methods=['GET']) +@cache.cached(timeout=ONE_MONTH_SECONDS, query_string=True) +def openai_organizations(): + return jsonify({ + "object": "list", + "data": [ + { + "object": "organization", + "id": "org-abCDEFGHiJklmNOPqrSTUVWX", + "created": int(server_start_time.timestamp()), + "title": "Personal", + "name": "user-abcdefghijklmnopqrstuvwx", + "description": "Personal org for bobjoe@0.0.0.0", + "personal": True, + "is_default": True, + "role": "owner" + } + ] + }) diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 6a09b87..98df308 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -25,8 +25,7 @@ class OpenAIRequestHandler(RequestHandler): self.prompt = None def handle_request(self) -> Tuple[flask.Response, int]: - if self.used: - raise Exception + assert not self.used request_valid, invalid_response = self.validate_request() if not request_valid: @@ -69,6 +68,7 @@ class OpenAIRequestHandler(RequestHandler): self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE']) llm_request = {**self.parameters, 'prompt': self.prompt} (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) + if success: return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code else: @@ -77,9 +77,10 @@ class OpenAIRequestHandler(RequestHandler): def handle_ratelimited(self): backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True) - return build_openai_response(self.prompt, backend_response), 200 + return build_openai_response(self.prompt, backend_response), 429 def transform_messages_to_prompt(self): + # TODO: add some way of cutting the user's prompt down so that we can fit the system prompt and moderation endpoint response try: prompt = f'### INSTRUCTION: {opts.openai_system_prompt}' for msg in self.request.json['messages']: @@ -104,7 +105,15 @@ class OpenAIRequestHandler(RequestHandler): return prompt def handle_error(self, msg: str) -> Tuple[flask.Response, int]: - return build_openai_response('', msg), 200 + # return build_openai_response('', msg), 400 + return jsonify({ + "error": { + "message": "Invalid request, check your parameters and try again.", + "type": "invalid_request_error", + "param": None, + "code": None + } + }), 400 def check_moderation_endpoint(prompt: str): diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index 7456acb..b2d52a1 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -20,4 +20,4 @@ def generate(): print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}') print(traceback.format_exc()) print(request.data) - return format_sillytavern_err(f'Server encountered exception.', 'error'), 200 + return format_sillytavern_err(f'Server encountered exception.', 'error'), 500 diff --git a/server.py b/server.py index 69abb5f..73f6f35 100644 --- a/server.py +++ b/server.py @@ -3,6 +3,7 @@ import sys from pathlib import Path from threading import Thread +import openai import simplejson as json from flask import Flask, jsonify, render_template, request @@ -84,6 +85,7 @@ opts.openai_system_prompt = config['openai_system_prompt'] opts.expose_openai_system_prompt = config['expose_openai_system_prompt'] opts.enable_streaming = config['enable_streaming'] opts.openai_api_key = config['openai_api_key'] +openai.api_key = opts.openai_api_key opts.admin_token = config['admin_token'] if config['http_host']: @@ -183,8 +185,8 @@ def home(): analytics_tracking_code=analytics_tracking_code, info_html=info_html, current_model=opts.manual_model_name if opts.manual_model_name else running_model, - client_api=stats['endpoints']['blocking'], - ws_client_api=stats['endpoints']['streaming'], + client_api=f'https://{base_client_api}', + ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None, estimated_wait=estimated_wait_sec, mode_name=mode_ui_names[opts.mode][0], api_input_textbox=mode_ui_names[opts.mode][1],