implement streaming on openai, improve streaming, run DB logging in background thread
This commit is contained in:
parent
bbe5d5a8fe
commit
1646a00987
|
@ -1,9 +1,9 @@
|
||||||
from typing import Tuple, Union
|
import threading
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
from llm_server import opts
|
|
||||||
from llm_server.database.database import log_prompt
|
from llm_server.database.database import log_prompt
|
||||||
from llm_server.llm.llm_backend import LLMBackend
|
from llm_server.llm.llm_backend import LLMBackend
|
||||||
|
|
||||||
|
@ -18,8 +18,16 @@ class VLLMBackend(LLMBackend):
|
||||||
else:
|
else:
|
||||||
# Failsafe
|
# Failsafe
|
||||||
backend_response = ''
|
backend_response = ''
|
||||||
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
|
|
||||||
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
|
r_url = request.url
|
||||||
|
|
||||||
|
def background_task():
|
||||||
|
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url,
|
||||||
|
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
|
||||||
|
|
||||||
|
# TODO: use async/await instead of threads
|
||||||
|
threading.Thread(target=background_task).start()
|
||||||
|
|
||||||
return jsonify({'results': [{'text': backend_response}]}), 200
|
return jsonify({'results': [{'text': backend_response}]}), 200
|
||||||
|
|
||||||
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import json
|
import json
|
||||||
|
import traceback
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import flask
|
import flask
|
||||||
import requests
|
import requests
|
||||||
from flask import make_response, Request
|
from flask import Request, make_response
|
||||||
from flask import request, jsonify
|
from flask import jsonify, request
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.database.database import is_valid_api_key
|
from llm_server.database.database import is_valid_api_key
|
||||||
|
@ -36,7 +37,17 @@ def require_api_key():
|
||||||
else:
|
else:
|
||||||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||||
else:
|
else:
|
||||||
return jsonify({'code': 401, 'message': 'API key required'}), 401
|
try:
|
||||||
|
# Handle websockets
|
||||||
|
if request.json.get('X-API-KEY'):
|
||||||
|
if is_valid_api_key(request.json.get('X-API-KEY')):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||||
|
except:
|
||||||
|
# TODO: remove this one we're sure this works as expected
|
||||||
|
traceback.print_exc()
|
||||||
|
return jsonify({'code': 401, 'message': 'API key required'}), 401
|
||||||
|
|
||||||
|
|
||||||
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]):
|
def validate_json(data: Union[str, flask.Request, requests.models.Response, flask.Response, dict]):
|
||||||
|
|
|
@ -1,11 +1,18 @@
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from flask import jsonify, request
|
from flask import Response, jsonify, request
|
||||||
|
|
||||||
from . import openai_bp
|
from . import openai_bp
|
||||||
from ..helpers.client import format_sillytavern_err
|
from ..helpers.client import format_sillytavern_err
|
||||||
from ..helpers.http import validate_json
|
from ..helpers.http import validate_json
|
||||||
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response
|
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string
|
||||||
|
from ... import opts
|
||||||
|
from ...database.database import log_prompt
|
||||||
|
from ...llm.generator import generator
|
||||||
|
from ...llm.vllm import tokenize
|
||||||
|
|
||||||
|
|
||||||
# TODO: add rate-limit headers?
|
# TODO: add rate-limit headers?
|
||||||
|
@ -16,10 +23,88 @@ def openai_chat_completions():
|
||||||
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
||||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||||
else:
|
else:
|
||||||
try:
|
handler = OpenAIRequestHandler(request, request_json_body)
|
||||||
return OpenAIRequestHandler(request).handle_request()
|
if request_json_body.get('stream'):
|
||||||
except Exception as e:
|
if not opts.enable_streaming:
|
||||||
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
|
# TODO: return a proper OAI error message
|
||||||
traceback.print_exc()
|
return 'disabled', 401
|
||||||
print(request.data)
|
|
||||||
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500
|
if opts.mode != 'vllm':
|
||||||
|
# TODO: implement other backends
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
response_status_code = 0
|
||||||
|
start_time = time.time()
|
||||||
|
request_valid, invalid_response = handler.validate_request()
|
||||||
|
if not request_valid:
|
||||||
|
# TODO: simulate OAI here
|
||||||
|
raise Exception
|
||||||
|
else:
|
||||||
|
handler.prompt = handler.transform_messages_to_prompt()
|
||||||
|
msg_to_backend = {
|
||||||
|
**handler.parameters,
|
||||||
|
'prompt': handler.prompt,
|
||||||
|
'stream': True,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
response = generator(msg_to_backend)
|
||||||
|
r_headers = dict(request.headers)
|
||||||
|
r_url = request.url
|
||||||
|
model = opts.running_model if opts.openai_epose_our_model else request_json_body.get('model')
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
generated_text = ''
|
||||||
|
partial_response = b''
|
||||||
|
for chunk in response.iter_content(chunk_size=1):
|
||||||
|
partial_response += chunk
|
||||||
|
if partial_response.endswith(b'\x00'):
|
||||||
|
json_strs = partial_response.split(b'\x00')
|
||||||
|
for json_str in json_strs:
|
||||||
|
if json_str:
|
||||||
|
try:
|
||||||
|
json_obj = json.loads(json_str.decode())
|
||||||
|
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||||
|
print(new)
|
||||||
|
generated_text = generated_text + new
|
||||||
|
data = {
|
||||||
|
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": new
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
yield f'data: {json.dumps(data)}\n\n'
|
||||||
|
except IndexError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield 'data: [DONE]\n\n'
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
|
||||||
|
def background_task():
|
||||||
|
generated_tokens = tokenize(generated_text)
|
||||||
|
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
|
||||||
|
|
||||||
|
# TODO: use async/await instead of threads
|
||||||
|
threading.Thread(target=background_task).start()
|
||||||
|
|
||||||
|
return Response(generate(), mimetype='text/event-stream')
|
||||||
|
except:
|
||||||
|
# TODO: simulate OAI here
|
||||||
|
raise Exception
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return handler.handle_request()
|
||||||
|
except Exception as e:
|
||||||
|
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
|
||||||
|
traceback.print_exc()
|
||||||
|
print(request.data)
|
||||||
|
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500
|
||||||
|
|
|
@ -2,6 +2,7 @@ from flask import jsonify
|
||||||
|
|
||||||
from . import openai_bp
|
from . import openai_bp
|
||||||
from ..cache import ONE_MONTH_SECONDS, cache
|
from ..cache import ONE_MONTH_SECONDS, cache
|
||||||
|
from ..openai_request_handler import generate_oai_string
|
||||||
from ..stats import server_start_time
|
from ..stats import server_start_time
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,7 +14,7 @@ def openai_organizations():
|
||||||
"data": [
|
"data": [
|
||||||
{
|
{
|
||||||
"object": "organization",
|
"object": "organization",
|
||||||
"id": "org-abCDEFGHiJklmNOPqrSTUVWX",
|
"id": f"org-{generate_oai_string(24)}",
|
||||||
"created": int(server_start_time.timestamp()),
|
"created": int(server_start_time.timestamp()),
|
||||||
"title": "Personal",
|
"title": "Personal",
|
||||||
"name": "user-abcdefghijklmnopqrstuvwx",
|
"name": "user-abcdefghijklmnopqrstuvwx",
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import flask
|
import flask
|
||||||
import requests
|
import requests
|
||||||
|
@ -72,7 +73,7 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
model = self.request_json_body.get('model')
|
model = self.request_json_body.get('model')
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model), backend_response_status_code
|
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
|
||||||
else:
|
else:
|
||||||
return backend_response, backend_response_status_code
|
return backend_response, backend_response_status_code
|
||||||
|
|
||||||
|
@ -131,7 +132,7 @@ def check_moderation_endpoint(prompt: str):
|
||||||
return response['results'][0]['flagged'], offending_categories
|
return response['results'][0]['flagged'], offending_categories
|
||||||
|
|
||||||
|
|
||||||
def build_openai_response(prompt, response, model):
|
def build_openai_response(prompt, response, model=None):
|
||||||
# Seperate the user's prompt from the context
|
# Seperate the user's prompt from the context
|
||||||
x = prompt.split('### USER:')
|
x = prompt.split('### USER:')
|
||||||
if len(x) > 1:
|
if len(x) > 1:
|
||||||
|
@ -142,10 +143,11 @@ def build_openai_response(prompt, response, model):
|
||||||
if len(x) > 1:
|
if len(x) > 1:
|
||||||
response = re.sub(r'\n$', '', y[0].strip(' '))
|
response = re.sub(r'\n$', '', y[0].strip(' '))
|
||||||
|
|
||||||
|
# TODO: async/await
|
||||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||||
response_tokens = llm_server.llm.get_token_count(response)
|
response_tokens = llm_server.llm.get_token_count(response)
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"id": f"chatcmpl-{uuid4()}",
|
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"created": int(time.time()),
|
"created": int(time.time()),
|
||||||
"model": opts.running_model if opts.openai_epose_our_model else model,
|
"model": opts.running_model if opts.openai_epose_our_model else model,
|
||||||
|
@ -163,3 +165,8 @@ def build_openai_response(prompt, response, model):
|
||||||
"total_tokens": prompt_tokens + response_tokens
|
"total_tokens": prompt_tokens + response_tokens
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def generate_oai_string(length=24):
|
||||||
|
alphabet = string.ascii_letters + string.digits
|
||||||
|
return ''.join(secrets.choice(alphabet) for i in range(length))
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import json
|
import json
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
|
||||||
from ..helpers.client import format_sillytavern_err
|
from ..helpers.client import format_sillytavern_err
|
||||||
from ..helpers.http import validate_json
|
from ..helpers.http import require_api_key, validate_json
|
||||||
from ..ooba_request_handler import OobaRequestHandler
|
from ..ooba_request_handler import OobaRequestHandler
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...database.database import increment_token_uses, log_prompt
|
from ...database.database import increment_token_uses, log_prompt
|
||||||
|
@ -23,6 +24,10 @@ def stream(ws):
|
||||||
# TODO: return a formatted ST error message
|
# TODO: return a formatted ST error message
|
||||||
return 'disabled', 401
|
return 'disabled', 401
|
||||||
|
|
||||||
|
auth_failure = require_api_key()
|
||||||
|
if auth_failure:
|
||||||
|
return auth_failure
|
||||||
|
|
||||||
message_num = 0
|
message_num = 0
|
||||||
while ws.connected:
|
while ws.connected:
|
||||||
message = ws.receive()
|
message = ws.receive()
|
||||||
|
@ -40,7 +45,6 @@ def stream(ws):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
handler = OobaRequestHandler(request, request_json_body)
|
handler = OobaRequestHandler(request, request_json_body)
|
||||||
token = request_json_body.get('X-API-KEY')
|
|
||||||
generated_text = ''
|
generated_text = ''
|
||||||
input_prompt = None
|
input_prompt = None
|
||||||
response_status_code = 0
|
response_status_code = 0
|
||||||
|
@ -59,7 +63,6 @@ def stream(ws):
|
||||||
'prompt': input_prompt,
|
'prompt': input_prompt,
|
||||||
'stream': True,
|
'stream': True,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = generator(msg_to_backend)
|
response = generator(msg_to_backend)
|
||||||
|
|
||||||
|
@ -74,23 +77,24 @@ def stream(ws):
|
||||||
for chunk in response.iter_content(chunk_size=1):
|
for chunk in response.iter_content(chunk_size=1):
|
||||||
partial_response += chunk
|
partial_response += chunk
|
||||||
if partial_response.endswith(b'\x00'):
|
if partial_response.endswith(b'\x00'):
|
||||||
json_str = partial_response[:-1].decode() # Remove the null character and decode the byte string to a string
|
json_strs = partial_response.split(b'\x00')
|
||||||
json_obj = json.loads(json_str)
|
for json_str in json_strs:
|
||||||
try:
|
if json_str:
|
||||||
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
|
try:
|
||||||
except IndexError:
|
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
|
||||||
# ????
|
except IndexError:
|
||||||
continue
|
# ????
|
||||||
|
continue
|
||||||
|
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'text_stream',
|
'event': 'text_stream',
|
||||||
'message_num': message_num,
|
'message_num': message_num,
|
||||||
'text': new
|
'text': new
|
||||||
}))
|
}))
|
||||||
message_num += 1
|
message_num += 1
|
||||||
|
|
||||||
generated_text = generated_text + new
|
generated_text = generated_text + new
|
||||||
partial_response = b'' # Reset the partial response
|
partial_response = b'' # Reset the partial response
|
||||||
|
|
||||||
# If there is no more data, break the loop
|
# If there is no more data, break the loop
|
||||||
if not chunk:
|
if not chunk:
|
||||||
|
@ -100,18 +104,28 @@ def stream(ws):
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
generated_tokens = tokenize(generated_text)
|
|
||||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
|
def background_task():
|
||||||
|
generated_tokens = tokenize(generated_text)
|
||||||
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
|
||||||
|
|
||||||
|
# TODO: use async/await instead of threads
|
||||||
|
threading.Thread(target=background_task).start()
|
||||||
except:
|
except:
|
||||||
generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error')
|
generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error')
|
||||||
generated_tokens = tokenize(generated_text)
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'text_stream',
|
'event': 'text_stream',
|
||||||
'message_num': message_num,
|
'message_num': message_num,
|
||||||
'text': generated_text
|
'text': generated_text
|
||||||
}))
|
}))
|
||||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
|
|
||||||
|
def background_task():
|
||||||
|
generated_tokens = tokenize(generated_text)
|
||||||
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
|
||||||
|
|
||||||
|
# TODO: use async/await instead of threads
|
||||||
|
threading.Thread(target=background_task).start()
|
||||||
|
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'stream_end',
|
'event': 'stream_end',
|
||||||
|
|
Reference in New Issue