implement streaming on openai, improve streaming, run DB logging in background thread

This commit is contained in:
Cyberes 2023-09-25 12:30:40 -06:00
parent bbe5d5a8fe
commit 1646a00987
6 changed files with 169 additions and 43 deletions

View File

@ -1,9 +1,9 @@
from typing import Tuple, Union import threading
from typing import Tuple
from flask import jsonify from flask import jsonify
from vllm import SamplingParams from vllm import SamplingParams
from llm_server import opts
from llm_server.database.database import log_prompt from llm_server.database.database import log_prompt
from llm_server.llm.llm_backend import LLMBackend from llm_server.llm.llm_backend import LLMBackend
@ -18,8 +18,16 @@ class VLLMBackend(LLMBackend):
else: else:
# Failsafe # Failsafe
backend_response = '' backend_response = ''
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens')) r_url = request.url
def background_task():
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
# TODO: use async/await instead of threads
threading.Thread(target=background_task).start()
return jsonify({'results': [{'text': backend_response}]}), 200 return jsonify({'results': [{'text': backend_response}]}), 200
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]: def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:

View File

@ -1,11 +1,12 @@
import json import json
import traceback
from functools import wraps from functools import wraps
from typing import Union from typing import Union
import flask import flask
import requests import requests
from flask import make_response, Request from flask import Request, make_response
from flask import request, jsonify from flask import jsonify, request
from llm_server import opts from llm_server import opts
from llm_server.database.database import is_valid_api_key from llm_server.database.database import is_valid_api_key
@ -36,7 +37,17 @@ def require_api_key():
else: else:
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403 return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
else: else:
return jsonify({'code': 401, 'message': 'API key required'}), 401 try:
# Handle websockets
if request.json.get('X-API-KEY'):
if is_valid_api_key(request.json.get('X-API-KEY')):
return
else:
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
except:
# TODO: remove this one we're sure this works as expected
traceback.print_exc()
return jsonify({'code': 401, 'message': 'API key required'}), 401
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]): def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]):

View File

@ -1,11 +1,18 @@
import json
import threading
import time
import traceback import traceback
from flask import jsonify, request from flask import Response, jsonify, request
from . import openai_bp from . import openai_bp
from ..helpers.client import format_sillytavern_err from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string
from ... import opts
from ...database.database import log_prompt
from ...llm.generator import generator
from ...llm.vllm import tokenize
# TODO: add rate-limit headers? # TODO: add rate-limit headers?
@ -16,10 +23,88 @@ def openai_chat_completions():
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else: else:
try: handler = OpenAIRequestHandler(request, request_json_body)
return OpenAIRequestHandler(request).handle_request() if request_json_body.get('stream'):
except Exception as e: if not opts.enable_streaming:
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}') # TODO: return a proper OAI error message
traceback.print_exc() return 'disabled', 401
print(request.data)
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500 if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
# TODO: simulate OAI here
raise Exception
else:
handler.prompt = handler.transform_messages_to_prompt()
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
try:
response = generator(msg_to_backend)
r_headers = dict(request.headers)
r_url = request.url
model = opts.running_model if opts.openai_epose_our_model else request_json_body.get('model')
def generate():
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
print(new)
generated_text = generated_text + new
data = {
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
except IndexError:
continue
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
def background_task():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
threading.Thread(target=background_task).start()
return Response(generate(), mimetype='text/event-stream')
except:
# TODO: simulate OAI here
raise Exception
else:
try:
return handler.handle_request()
except Exception as e:
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
traceback.print_exc()
print(request.data)
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500

View File

@ -2,6 +2,7 @@ from flask import jsonify
from . import openai_bp from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, cache from ..cache import ONE_MONTH_SECONDS, cache
from ..openai_request_handler import generate_oai_string
from ..stats import server_start_time from ..stats import server_start_time
@ -13,7 +14,7 @@ def openai_organizations():
"data": [ "data": [
{ {
"object": "organization", "object": "organization",
"id": "org-abCDEFGHiJklmNOPqrSTUVWX", "id": f"org-{generate_oai_string(24)}",
"created": int(server_start_time.timestamp()), "created": int(server_start_time.timestamp()),
"title": "Personal", "title": "Personal",
"name": "user-abcdefghijklmnopqrstuvwx", "name": "user-abcdefghijklmnopqrstuvwx",

View File

@ -1,9 +1,10 @@
import json import json
import re import re
import secrets
import string
import time import time
import traceback import traceback
from typing import Tuple from typing import Tuple
from uuid import uuid4
import flask import flask
import requests import requests
@ -72,7 +73,7 @@ class OpenAIRequestHandler(RequestHandler):
model = self.request_json_body.get('model') model = self.request_json_body.get('model')
if success: if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model), backend_response_status_code return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
else: else:
return backend_response, backend_response_status_code return backend_response, backend_response_status_code
@ -131,7 +132,7 @@ def check_moderation_endpoint(prompt: str):
return response['results'][0]['flagged'], offending_categories return response['results'][0]['flagged'], offending_categories
def build_openai_response(prompt, response, model): def build_openai_response(prompt, response, model=None):
# Seperate the user's prompt from the context # Seperate the user's prompt from the context
x = prompt.split('### USER:') x = prompt.split('### USER:')
if len(x) > 1: if len(x) > 1:
@ -142,10 +143,11 @@ def build_openai_response(prompt, response, model):
if len(x) > 1: if len(x) > 1:
response = re.sub(r'\n$', '', y[0].strip(' ')) response = re.sub(r'\n$', '', y[0].strip(' '))
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt) prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response) response_tokens = llm_server.llm.get_token_count(response)
return jsonify({ return jsonify({
"id": f"chatcmpl-{uuid4()}", "id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion", "object": "chat.completion",
"created": int(time.time()), "created": int(time.time()),
"model": opts.running_model if opts.openai_epose_our_model else model, "model": opts.running_model if opts.openai_epose_our_model else model,
@ -163,3 +165,8 @@ def build_openai_response(prompt, response, model):
"total_tokens": prompt_tokens + response_tokens "total_tokens": prompt_tokens + response_tokens
} }
}) })
def generate_oai_string(length=24):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for i in range(length))

View File

@ -1,11 +1,12 @@
import json import json
import threading
import time import time
import traceback import traceback
from flask import request from flask import request
from ..helpers.client import format_sillytavern_err from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ... import opts from ... import opts
from ...database.database import increment_token_uses, log_prompt from ...database.database import increment_token_uses, log_prompt
@ -23,6 +24,10 @@ def stream(ws):
# TODO: return a formatted ST error message # TODO: return a formatted ST error message
return 'disabled', 401 return 'disabled', 401
auth_failure = require_api_key()
if auth_failure:
return auth_failure
message_num = 0 message_num = 0
while ws.connected: while ws.connected:
message = ws.receive() message = ws.receive()
@ -40,7 +45,6 @@ def stream(ws):
raise NotImplementedError raise NotImplementedError
handler = OobaRequestHandler(request, request_json_body) handler = OobaRequestHandler(request, request_json_body)
token = request_json_body.get('X-API-KEY')
generated_text = '' generated_text = ''
input_prompt = None input_prompt = None
response_status_code = 0 response_status_code = 0
@ -59,7 +63,6 @@ def stream(ws):
'prompt': input_prompt, 'prompt': input_prompt,
'stream': True, 'stream': True,
} }
try: try:
response = generator(msg_to_backend) response = generator(msg_to_backend)
@ -74,23 +77,24 @@ def stream(ws):
for chunk in response.iter_content(chunk_size=1): for chunk in response.iter_content(chunk_size=1):
partial_response += chunk partial_response += chunk
if partial_response.endswith(b'\x00'): if partial_response.endswith(b'\x00'):
json_str = partial_response[:-1].decode() # Remove the null character and decode the byte string to a string json_strs = partial_response.split(b'\x00')
json_obj = json.loads(json_str) for json_str in json_strs:
try: if json_str:
new = json_obj['text'][0].split(input_prompt + generated_text)[1] try:
except IndexError: new = json_obj['text'][0].split(input_prompt + generated_text)[1]
# ???? except IndexError:
continue # ????
continue
ws.send(json.dumps({ ws.send(json.dumps({
'event': 'text_stream', 'event': 'text_stream',
'message_num': message_num, 'message_num': message_num,
'text': new 'text': new
})) }))
message_num += 1 message_num += 1
generated_text = generated_text + new generated_text = generated_text + new
partial_response = b'' # Reset the partial response partial_response = b'' # Reset the partial response
# If there is no more data, break the loop # If there is no more data, break the loop
if not chunk: if not chunk:
@ -100,18 +104,28 @@ def stream(ws):
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens) def background_task():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
threading.Thread(target=background_task).start()
except: except:
generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error') generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error')
generated_tokens = tokenize(generated_text)
traceback.print_exc() traceback.print_exc()
ws.send(json.dumps({ ws.send(json.dumps({
'event': 'text_stream', 'event': 'text_stream',
'message_num': message_num, 'message_num': message_num,
'text': generated_text 'text': generated_text
})) }))
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
def background_task():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
threading.Thread(target=background_task).start()
ws.send(json.dumps({ ws.send(json.dumps({
'event': 'stream_end', 'event': 'stream_end',