implement streaming on openai, improve streaming, run DB logging in background thread
This commit is contained in:
parent
bbe5d5a8fe
commit
1646a00987
|
@ -1,9 +1,9 @@
|
|||
from typing import Tuple, Union
|
||||
import threading
|
||||
from typing import Tuple
|
||||
|
||||
from flask import jsonify
|
||||
from vllm import SamplingParams
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.llm.llm_backend import LLMBackend
|
||||
|
||||
|
@ -18,8 +18,16 @@ class VLLMBackend(LLMBackend):
|
|||
else:
|
||||
# Failsafe
|
||||
backend_response = ''
|
||||
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
|
||||
|
||||
r_url = request.url
|
||||
|
||||
def background_task():
|
||||
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url,
|
||||
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
threading.Thread(target=background_task).start()
|
||||
|
||||
return jsonify({'results': [{'text': backend_response}]}), 200
|
||||
|
||||
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import json
|
||||
import traceback
|
||||
from functools import wraps
|
||||
from typing import Union
|
||||
|
||||
import flask
|
||||
import requests
|
||||
from flask import make_response, Request
|
||||
from flask import request, jsonify
|
||||
from flask import Request, make_response
|
||||
from flask import jsonify, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import is_valid_api_key
|
||||
|
@ -36,6 +37,16 @@ def require_api_key():
|
|||
else:
|
||||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||
else:
|
||||
try:
|
||||
# Handle websockets
|
||||
if request.json.get('X-API-KEY'):
|
||||
if is_valid_api_key(request.json.get('X-API-KEY')):
|
||||
return
|
||||
else:
|
||||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||
except:
|
||||
# TODO: remove this one we're sure this works as expected
|
||||
traceback.print_exc()
|
||||
return jsonify({'code': 401, 'message': 'API key required'}), 401
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,18 @@
|
|||
import json
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from flask import jsonify, request
|
||||
from flask import Response, jsonify, request
|
||||
|
||||
from . import openai_bp
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import validate_json
|
||||
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response
|
||||
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string
|
||||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...llm.generator import generator
|
||||
from ...llm.vllm import tokenize
|
||||
|
||||
|
||||
# TODO: add rate-limit headers?
|
||||
|
@ -16,8 +23,86 @@ def openai_chat_completions():
|
|||
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||
else:
|
||||
handler = OpenAIRequestHandler(request, request_json_body)
|
||||
if request_json_body.get('stream'):
|
||||
if not opts.enable_streaming:
|
||||
# TODO: return a proper OAI error message
|
||||
return 'disabled', 401
|
||||
|
||||
if opts.mode != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
# TODO: simulate OAI here
|
||||
raise Exception
|
||||
else:
|
||||
handler.prompt = handler.transform_messages_to_prompt()
|
||||
msg_to_backend = {
|
||||
**handler.parameters,
|
||||
'prompt': handler.prompt,
|
||||
'stream': True,
|
||||
}
|
||||
try:
|
||||
return OpenAIRequestHandler(request).handle_request()
|
||||
response = generator(msg_to_backend)
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = opts.running_model if opts.openai_epose_our_model else request_json_body.get('model')
|
||||
|
||||
def generate():
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
print(new)
|
||||
generated_text = generated_text + new
|
||||
data = {
|
||||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
except IndexError:
|
||||
continue
|
||||
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
def background_task():
|
||||
generated_tokens = tokenize(generated_text)
|
||||
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
threading.Thread(target=background_task).start()
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except:
|
||||
# TODO: simulate OAI here
|
||||
raise Exception
|
||||
else:
|
||||
try:
|
||||
return handler.handle_request()
|
||||
except Exception as e:
|
||||
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
|
||||
traceback.print_exc()
|
||||
|
|
|
@ -2,6 +2,7 @@ from flask import jsonify
|
|||
|
||||
from . import openai_bp
|
||||
from ..cache import ONE_MONTH_SECONDS, cache
|
||||
from ..openai_request_handler import generate_oai_string
|
||||
from ..stats import server_start_time
|
||||
|
||||
|
||||
|
@ -13,7 +14,7 @@ def openai_organizations():
|
|||
"data": [
|
||||
{
|
||||
"object": "organization",
|
||||
"id": "org-abCDEFGHiJklmNOPqrSTUVWX",
|
||||
"id": f"org-{generate_oai_string(24)}",
|
||||
"created": int(server_start_time.timestamp()),
|
||||
"title": "Personal",
|
||||
"name": "user-abcdefghijklmnopqrstuvwx",
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import json
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
import traceback
|
||||
from typing import Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import flask
|
||||
import requests
|
||||
|
@ -72,7 +73,7 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
model = self.request_json_body.get('model')
|
||||
|
||||
if success:
|
||||
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model), backend_response_status_code
|
||||
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
|
||||
else:
|
||||
return backend_response, backend_response_status_code
|
||||
|
||||
|
@ -131,7 +132,7 @@ def check_moderation_endpoint(prompt: str):
|
|||
return response['results'][0]['flagged'], offending_categories
|
||||
|
||||
|
||||
def build_openai_response(prompt, response, model):
|
||||
def build_openai_response(prompt, response, model=None):
|
||||
# Seperate the user's prompt from the context
|
||||
x = prompt.split('### USER:')
|
||||
if len(x) > 1:
|
||||
|
@ -142,10 +143,11 @@ def build_openai_response(prompt, response, model):
|
|||
if len(x) > 1:
|
||||
response = re.sub(r'\n$', '', y[0].strip(' '))
|
||||
|
||||
# TODO: async/await
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||
response_tokens = llm_server.llm.get_token_count(response)
|
||||
return jsonify({
|
||||
"id": f"chatcmpl-{uuid4()}",
|
||||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": opts.running_model if opts.openai_epose_our_model else model,
|
||||
|
@ -163,3 +165,8 @@ def build_openai_response(prompt, response, model):
|
|||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
def generate_oai_string(length=24):
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for i in range(length))
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import json
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from flask import request
|
||||
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import validate_json
|
||||
from ..helpers.http import require_api_key, validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ... import opts
|
||||
from ...database.database import increment_token_uses, log_prompt
|
||||
|
@ -23,6 +24,10 @@ def stream(ws):
|
|||
# TODO: return a formatted ST error message
|
||||
return 'disabled', 401
|
||||
|
||||
auth_failure = require_api_key()
|
||||
if auth_failure:
|
||||
return auth_failure
|
||||
|
||||
message_num = 0
|
||||
while ws.connected:
|
||||
message = ws.receive()
|
||||
|
@ -40,7 +45,6 @@ def stream(ws):
|
|||
raise NotImplementedError
|
||||
|
||||
handler = OobaRequestHandler(request, request_json_body)
|
||||
token = request_json_body.get('X-API-KEY')
|
||||
generated_text = ''
|
||||
input_prompt = None
|
||||
response_status_code = 0
|
||||
|
@ -59,7 +63,6 @@ def stream(ws):
|
|||
'prompt': input_prompt,
|
||||
'stream': True,
|
||||
}
|
||||
|
||||
try:
|
||||
response = generator(msg_to_backend)
|
||||
|
||||
|
@ -74,8 +77,9 @@ def stream(ws):
|
|||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_str = partial_response[:-1].decode() # Remove the null character and decode the byte string to a string
|
||||
json_obj = json.loads(json_str)
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
|
||||
except IndexError:
|
||||
|
@ -100,19 +104,29 @@ def stream(ws):
|
|||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
def background_task():
|
||||
generated_tokens = tokenize(generated_text)
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
threading.Thread(target=background_task).start()
|
||||
except:
|
||||
generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error')
|
||||
generated_tokens = tokenize(generated_text)
|
||||
traceback.print_exc()
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': generated_text
|
||||
}))
|
||||
|
||||
def background_task():
|
||||
generated_tokens = tokenize(generated_text)
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
threading.Thread(target=background_task).start()
|
||||
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
|
|
Reference in New Issue