From 1646a00987a64c56f3d75bcab713348b799438da Mon Sep 17 00:00:00 2001 From: Cyberes Date: Mon, 25 Sep 2023 12:30:40 -0600 Subject: [PATCH] implement streaming on openai, improve streaming, run DB logging in background thread --- llm_server/llm/vllm/vllm_backend.py | 16 ++- llm_server/routes/helpers/http.py | 17 ++- llm_server/routes/openai/chat_completions.py | 103 +++++++++++++++++-- llm_server/routes/openai/simulated.py | 3 +- llm_server/routes/openai_request_handler.py | 15 ++- llm_server/routes/v1/generate_stream.py | 58 +++++++---- 6 files changed, 169 insertions(+), 43 deletions(-) diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index 5bb9c46..2c48bf0 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -1,9 +1,9 @@ -from typing import Tuple, Union +import threading +from typing import Tuple from flask import jsonify from vllm import SamplingParams -from llm_server import opts from llm_server.database.database import log_prompt from llm_server.llm.llm_backend import LLMBackend @@ -18,8 +18,16 @@ class VLLMBackend(LLMBackend): else: # Failsafe backend_response = '' - log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, - response_tokens=response_json_body.get('details', {}).get('generated_tokens')) + + r_url = request.url + + def background_task(): + log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url, + response_tokens=response_json_body.get('details', {}).get('generated_tokens')) + + # TODO: use async/await instead of threads + threading.Thread(target=background_task).start() + return jsonify({'results': [{'text': backend_response}]}), 200 def get_parameters(self, parameters) -> Tuple[dict | None, str | None]: diff --git a/llm_server/routes/helpers/http.py b/llm_server/routes/helpers/http.py index a756eb3..26cf34b 100644 --- a/llm_server/routes/helpers/http.py +++ b/llm_server/routes/helpers/http.py @@ -1,11 +1,12 @@ import json +import traceback from functools import wraps from typing import Union import flask import requests -from flask import make_response, Request -from flask import request, jsonify +from flask import Request, make_response +from flask import jsonify, request from llm_server import opts from llm_server.database.database import is_valid_api_key @@ -36,7 +37,17 @@ def require_api_key(): else: return jsonify({'code': 403, 'message': 'Invalid API key'}), 403 else: - return jsonify({'code': 401, 'message': 'API key required'}), 401 + try: + # Handle websockets + if request.json.get('X-API-KEY'): + if is_valid_api_key(request.json.get('X-API-KEY')): + return + else: + return jsonify({'code': 403, 'message': 'Invalid API key'}), 403 + except: + # TODO: remove this one we're sure this works as expected + traceback.print_exc() + return jsonify({'code': 401, 'message': 'API key required'}), 401 def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]): diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index db007a3..186d4fe 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -1,11 +1,18 @@ +import json +import threading +import time import traceback -from flask import jsonify, request +from flask import Response, jsonify, request from . import openai_bp from ..helpers.client import format_sillytavern_err from ..helpers.http import validate_json -from ..openai_request_handler import OpenAIRequestHandler, build_openai_response +from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string +from ... import opts +from ...database.database import log_prompt +from ...llm.generator import generator +from ...llm.vllm import tokenize # TODO: add rate-limit headers? @@ -16,10 +23,88 @@ def openai_chat_completions(): if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 else: - try: - return OpenAIRequestHandler(request).handle_request() - except Exception as e: - print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}') - traceback.print_exc() - print(request.data) - return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500 + handler = OpenAIRequestHandler(request, request_json_body) + if request_json_body.get('stream'): + if not opts.enable_streaming: + # TODO: return a proper OAI error message + return 'disabled', 401 + + if opts.mode != 'vllm': + # TODO: implement other backends + raise NotImplementedError + + response_status_code = 0 + start_time = time.time() + request_valid, invalid_response = handler.validate_request() + if not request_valid: + # TODO: simulate OAI here + raise Exception + else: + handler.prompt = handler.transform_messages_to_prompt() + msg_to_backend = { + **handler.parameters, + 'prompt': handler.prompt, + 'stream': True, + } + try: + response = generator(msg_to_backend) + r_headers = dict(request.headers) + r_url = request.url + model = opts.running_model if opts.openai_epose_our_model else request_json_body.get('model') + + def generate(): + generated_text = '' + partial_response = b'' + for chunk in response.iter_content(chunk_size=1): + partial_response += chunk + if partial_response.endswith(b'\x00'): + json_strs = partial_response.split(b'\x00') + for json_str in json_strs: + if json_str: + try: + json_obj = json.loads(json_str.decode()) + new = json_obj['text'][0].split(handler.prompt + generated_text)[1] + print(new) + generated_text = generated_text + new + data = { + "id": f"chatcmpl-{generate_oai_string(30)}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": new + }, + "finish_reason": None + } + ] + } + yield f'data: {json.dumps(data)}\n\n' + except IndexError: + continue + + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time + + def background_task(): + generated_tokens = tokenize(generated_text) + log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) + + # TODO: use async/await instead of threads + threading.Thread(target=background_task).start() + + return Response(generate(), mimetype='text/event-stream') + except: + # TODO: simulate OAI here + raise Exception + else: + try: + return handler.handle_request() + except Exception as e: + print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}') + traceback.print_exc() + print(request.data) + return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500 diff --git a/llm_server/routes/openai/simulated.py b/llm_server/routes/openai/simulated.py index 7e80f25..58bfbdf 100644 --- a/llm_server/routes/openai/simulated.py +++ b/llm_server/routes/openai/simulated.py @@ -2,6 +2,7 @@ from flask import jsonify from . import openai_bp from ..cache import ONE_MONTH_SECONDS, cache +from ..openai_request_handler import generate_oai_string from ..stats import server_start_time @@ -13,7 +14,7 @@ def openai_organizations(): "data": [ { "object": "organization", - "id": "org-abCDEFGHiJklmNOPqrSTUVWX", + "id": f"org-{generate_oai_string(24)}", "created": int(server_start_time.timestamp()), "title": "Personal", "name": "user-abcdefghijklmnopqrstuvwx", diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 600de95..354c261 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -1,9 +1,10 @@ import json import re +import secrets +import string import time import traceback from typing import Tuple -from uuid import uuid4 import flask import requests @@ -72,7 +73,7 @@ class OpenAIRequestHandler(RequestHandler): model = self.request_json_body.get('model') if success: - return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model), backend_response_status_code + return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code else: return backend_response, backend_response_status_code @@ -131,7 +132,7 @@ def check_moderation_endpoint(prompt: str): return response['results'][0]['flagged'], offending_categories -def build_openai_response(prompt, response, model): +def build_openai_response(prompt, response, model=None): # Seperate the user's prompt from the context x = prompt.split('### USER:') if len(x) > 1: @@ -142,10 +143,11 @@ def build_openai_response(prompt, response, model): if len(x) > 1: response = re.sub(r'\n$', '', y[0].strip(' ')) + # TODO: async/await prompt_tokens = llm_server.llm.get_token_count(prompt) response_tokens = llm_server.llm.get_token_count(response) return jsonify({ - "id": f"chatcmpl-{uuid4()}", + "id": f"chatcmpl-{generate_oai_string(30)}", "object": "chat.completion", "created": int(time.time()), "model": opts.running_model if opts.openai_epose_our_model else model, @@ -163,3 +165,8 @@ def build_openai_response(prompt, response, model): "total_tokens": prompt_tokens + response_tokens } }) + + +def generate_oai_string(length=24): + alphabet = string.ascii_letters + string.digits + return ''.join(secrets.choice(alphabet) for i in range(length)) diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 0c92bb9..ed112f6 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -1,11 +1,12 @@ import json +import threading import time import traceback from flask import request from ..helpers.client import format_sillytavern_err -from ..helpers.http import validate_json +from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ... import opts from ...database.database import increment_token_uses, log_prompt @@ -23,6 +24,10 @@ def stream(ws): # TODO: return a formatted ST error message return 'disabled', 401 + auth_failure = require_api_key() + if auth_failure: + return auth_failure + message_num = 0 while ws.connected: message = ws.receive() @@ -40,7 +45,6 @@ def stream(ws): raise NotImplementedError handler = OobaRequestHandler(request, request_json_body) - token = request_json_body.get('X-API-KEY') generated_text = '' input_prompt = None response_status_code = 0 @@ -59,7 +63,6 @@ def stream(ws): 'prompt': input_prompt, 'stream': True, } - try: response = generator(msg_to_backend) @@ -74,23 +77,24 @@ def stream(ws): for chunk in response.iter_content(chunk_size=1): partial_response += chunk if partial_response.endswith(b'\x00'): - json_str = partial_response[:-1].decode() # Remove the null character and decode the byte string to a string - json_obj = json.loads(json_str) - try: - new = json_obj['text'][0].split(input_prompt + generated_text)[1] - except IndexError: - # ???? - continue + json_strs = partial_response.split(b'\x00') + for json_str in json_strs: + if json_str: + try: + new = json_obj['text'][0].split(input_prompt + generated_text)[1] + except IndexError: + # ???? + continue - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': new - })) - message_num += 1 + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': new + })) + message_num += 1 - generated_text = generated_text + new - partial_response = b'' # Reset the partial response + generated_text = generated_text + new + partial_response = b'' # Reset the partial response # If there is no more data, break the loop if not chunk: @@ -100,18 +104,28 @@ def stream(ws): end_time = time.time() elapsed_time = end_time - start_time - generated_tokens = tokenize(generated_text) - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens) + + def background_task(): + generated_tokens = tokenize(generated_text) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens) + + # TODO: use async/await instead of threads + threading.Thread(target=background_task).start() except: generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error') - generated_tokens = tokenize(generated_text) traceback.print_exc() ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': generated_text })) - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens) + + def background_task(): + generated_tokens = tokenize(generated_text) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens) + + # TODO: use async/await instead of threads + threading.Thread(target=background_task).start() ws.send(json.dumps({ 'event': 'stream_end',