implement streaming on openai, improve streaming, run DB logging in background thread

This commit is contained in:
Cyberes 2023-09-25 12:30:40 -06:00
parent bbe5d5a8fe
commit 1646a00987
6 changed files with 169 additions and 43 deletions

View File

@ -1,9 +1,9 @@
from typing import Tuple, Union
import threading
from typing import Tuple
from flask import jsonify
from vllm import SamplingParams
from llm_server import opts
from llm_server.database.database import log_prompt
from llm_server.llm.llm_backend import LLMBackend
@ -18,8 +18,16 @@ class VLLMBackend(LLMBackend):
else:
# Failsafe
backend_response = ''
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
r_url = request.url
def background_task():
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
# TODO: use async/await instead of threads
threading.Thread(target=background_task).start()
return jsonify({'results': [{'text': backend_response}]}), 200
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:

View File

@ -1,11 +1,12 @@
import json
import traceback
from functools import wraps
from typing import Union
import flask
import requests
from flask import make_response, Request
from flask import request, jsonify
from flask import Request, make_response
from flask import jsonify, request
from llm_server import opts
from llm_server.database.database import is_valid_api_key
@ -36,7 +37,17 @@ def require_api_key():
else:
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
else:
return jsonify({'code': 401, 'message': 'API key required'}), 401
try:
# Handle websockets
if request.json.get('X-API-KEY'):
if is_valid_api_key(request.json.get('X-API-KEY')):
return
else:
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
except:
# TODO: remove this one we're sure this works as expected
traceback.print_exc()
return jsonify({'code': 401, 'message': 'API key required'}), 401
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]):

View File

@ -1,11 +1,18 @@
import json
import threading
import time
import traceback
from flask import jsonify, request
from flask import Response, jsonify, request
from . import openai_bp
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string
from ... import opts
from ...database.database import log_prompt
from ...llm.generator import generator
from ...llm.vllm import tokenize
# TODO: add rate-limit headers?
@ -16,10 +23,88 @@ def openai_chat_completions():
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else:
try:
return OpenAIRequestHandler(request).handle_request()
except Exception as e:
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
traceback.print_exc()
print(request.data)
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500
handler = OpenAIRequestHandler(request, request_json_body)
if request_json_body.get('stream'):
if not opts.enable_streaming:
# TODO: return a proper OAI error message
return 'disabled', 401
if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
# TODO: simulate OAI here
raise Exception
else:
handler.prompt = handler.transform_messages_to_prompt()
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
try:
response = generator(msg_to_backend)
r_headers = dict(request.headers)
r_url = request.url
model = opts.running_model if opts.openai_epose_our_model else request_json_body.get('model')
def generate():
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
print(new)
generated_text = generated_text + new
data = {
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
except IndexError:
continue
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
def background_task():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
threading.Thread(target=background_task).start()
return Response(generate(), mimetype='text/event-stream')
except:
# TODO: simulate OAI here
raise Exception
else:
try:
return handler.handle_request()
except Exception as e:
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
traceback.print_exc()
print(request.data)
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500

View File

@ -2,6 +2,7 @@ from flask import jsonify
from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, cache
from ..openai_request_handler import generate_oai_string
from ..stats import server_start_time
@ -13,7 +14,7 @@ def openai_organizations():
"data": [
{
"object": "organization",
"id": "org-abCDEFGHiJklmNOPqrSTUVWX",
"id": f"org-{generate_oai_string(24)}",
"created": int(server_start_time.timestamp()),
"title": "Personal",
"name": "user-abcdefghijklmnopqrstuvwx",

View File

@ -1,9 +1,10 @@
import json
import re
import secrets
import string
import time
import traceback
from typing import Tuple
from uuid import uuid4
import flask
import requests
@ -72,7 +73,7 @@ class OpenAIRequestHandler(RequestHandler):
model = self.request_json_body.get('model')
if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model), backend_response_status_code
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
else:
return backend_response, backend_response_status_code
@ -131,7 +132,7 @@ def check_moderation_endpoint(prompt: str):
return response['results'][0]['flagged'], offending_categories
def build_openai_response(prompt, response, model):
def build_openai_response(prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
@ -142,10 +143,11 @@ def build_openai_response(prompt, response, model):
if len(x) > 1:
response = re.sub(r'\n$', '', y[0].strip(' '))
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response)
return jsonify({
"id": f"chatcmpl-{uuid4()}",
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": opts.running_model if opts.openai_epose_our_model else model,
@ -163,3 +165,8 @@ def build_openai_response(prompt, response, model):
"total_tokens": prompt_tokens + response_tokens
}
})
def generate_oai_string(length=24):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for i in range(length))

View File

@ -1,11 +1,12 @@
import json
import threading
import time
import traceback
from flask import request
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json
from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler
from ... import opts
from ...database.database import increment_token_uses, log_prompt
@ -23,6 +24,10 @@ def stream(ws):
# TODO: return a formatted ST error message
return 'disabled', 401
auth_failure = require_api_key()
if auth_failure:
return auth_failure
message_num = 0
while ws.connected:
message = ws.receive()
@ -40,7 +45,6 @@ def stream(ws):
raise NotImplementedError
handler = OobaRequestHandler(request, request_json_body)
token = request_json_body.get('X-API-KEY')
generated_text = ''
input_prompt = None
response_status_code = 0
@ -59,7 +63,6 @@ def stream(ws):
'prompt': input_prompt,
'stream': True,
}
try:
response = generator(msg_to_backend)
@ -74,23 +77,24 @@ def stream(ws):
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_str = partial_response[:-1].decode() # Remove the null character and decode the byte string to a string
json_obj = json.loads(json_str)
try:
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
except IndexError:
# ????
continue
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
except IndexError:
# ????
continue
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': new
}))
message_num += 1
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': new
}))
message_num += 1
generated_text = generated_text + new
partial_response = b'' # Reset the partial response
generated_text = generated_text + new
partial_response = b'' # Reset the partial response
# If there is no more data, break the loop
if not chunk:
@ -100,18 +104,28 @@ def stream(ws):
end_time = time.time()
elapsed_time = end_time - start_time
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
def background_task():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
threading.Thread(target=background_task).start()
except:
generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error')
generated_tokens = tokenize(generated_text)
traceback.print_exc()
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': generated_text
}))
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
def background_task():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
threading.Thread(target=background_task).start()
ws.send(json.dumps({
'event': 'stream_end',