fix background logger, add gradio chat example

This commit is contained in:
Cyberes 2023-10-04 19:24:47 -06:00
parent 1670594908
commit acf409abfc
16 changed files with 242 additions and 141 deletions

View File

@ -1,7 +1,6 @@
import json import json
import time import time
import traceback import traceback
from threading import Thread
from typing import Union from typing import Union
from llm_server import opts from llm_server import opts
@ -10,65 +9,60 @@ from llm_server.database.conn import database
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False): def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
assert isinstance(prompt, str) assert isinstance(prompt, str)
assert isinstance(backend_url, str) assert isinstance(backend_url, str)
def background_task(): # Try not to shove JSON into the database.
nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens, is_error if isinstance(response, dict) and response.get('results'):
# Try not to shove JSON into the database. response = response['results'][0]['text']
if isinstance(response, dict) and response.get('results'): try:
response = response['results'][0]['text'] j = json.loads(response)
try: if j.get('results'):
j = json.loads(response) response = j['results'][0]['text']
if j.get('results'): except:
response = j['results'][0]['text'] pass
except:
pass
prompt_tokens = get_token_count(prompt, backend_url) prompt_tokens = get_token_count(prompt, backend_url)
if not is_error: print('starting')
if not response_tokens:
response_tokens = get_token_count(response, backend_url)
else:
response_tokens = None
# Sometimes we may want to insert null into the DB, but if not is_error:
# usually we want to insert a float. if not response_tokens:
if gen_time: response_tokens = get_token_count(response, backend_url)
gen_time = round(gen_time, 3) else:
if is_error: response_tokens = None
gen_time = None
if not opts.log_prompts: # Sometimes we may want to insert null into the DB, but
prompt = None # usually we want to insert a float.
if gen_time:
gen_time = round(gen_time, 3)
if is_error:
gen_time = None
if not opts.log_prompts and not is_error: if not opts.log_prompts:
# TODO: test and verify this works as expected prompt = None
response = None
if token: if not opts.log_prompts and not is_error:
increment_token_uses(token) # TODO: test and verify this works as expected
response = None
backend_info = cluster_config.get_backend(backend_url) if token:
running_model = backend_info.get('model') increment_token_uses(token)
backend_mode = backend_info['mode']
timestamp = int(time.time())
cursor = database.cursor()
try:
cursor.execute("""
INSERT INTO prompts
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally:
cursor.close()
# TODO: use async/await instead of threads backend_info = cluster_config.get_backend(backend_url)
thread = Thread(target=background_task) running_model = backend_info.get('model')
thread.start() backend_mode = backend_info['mode']
thread.join() timestamp = int(time.time())
cursor = database.cursor()
try:
cursor.execute("""
INSERT INTO prompts
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally:
cursor.close()
def is_valid_api_key(api_key): def is_valid_api_key(api_key):

View File

@ -0,0 +1,27 @@
import pickle
from typing import Union
from redis import Redis
def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
r = Redis(host='localhost', port=6379, db=3)
data = {
'function': 'log_prompt',
'args': [],
'kwargs': {
'ip': ip,
'token': token,
'prompt': prompt,
'response': response,
'gen_time': gen_time,
'parameters': parameters,
'headers': headers,
'backend_response_code': backend_response_code,
'request_url': request_url,
'backend_url': backend_url,
'response_tokens': response_tokens,
'is_error': is_error
}
}
r.publish('database-logger', pickle.dumps(data))

View File

@ -2,7 +2,7 @@ from flask import jsonify
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from ..llm_backend import LLMBackend from ..llm_backend import LLMBackend
from ...database.database import log_prompt from ...database.database import do_db_log
from ...helpers import safe_list_get from ...helpers import safe_list_get
from ...routes.helpers.client import format_sillytavern_err from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json from ...routes.helpers.http import validate_json
@ -34,7 +34,7 @@ class OobaboogaBackend(LLMBackend):
else: else:
error_msg = error_msg.strip('.') + '.' error_msg = error_msg.strip('.') + '.'
backend_response = format_sillytavern_err(error_msg, error_type='error', backend_url=self.backend_url) backend_response = format_sillytavern_err(error_msg, error_type='error', backend_url=self.backend_url)
log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True) log_to_db(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
return jsonify({ return jsonify({
'code': 500, 'code': 500,
'msg': error_msg, 'msg': error_msg,
@ -57,13 +57,13 @@ class OobaboogaBackend(LLMBackend):
if not backend_err: if not backend_err:
redis.incr('proompts') redis.incr('proompts')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err) log_to_db(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify({ return jsonify({
**response_json_body **response_json_body
}), 200 }), 200
else: else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', error_type='error', backend_url=self.backend_url) backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', error_type='error', backend_url=self.backend_url)
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True) log_to_db(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
return jsonify({ return jsonify({
'code': 500, 'code': 500,
'msg': 'the backend did not return valid JSON', 'msg': 'the backend did not return valid JSON',

View File

@ -3,17 +3,19 @@ from flask import jsonify
from llm_server import opts from llm_server import opts
def oai_to_vllm(request_json_body, hashes: bool, mode): def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
if not request_json_body.get('stop'): if not request_json_body.get('stop'):
request_json_body['stop'] = [] request_json_body['stop'] = []
if not isinstance(request_json_body['stop'], list): if not isinstance(request_json_body['stop'], list):
# It is a string, so create a list with the existing element. # It is a string, so create a list with the existing element.
request_json_body['stop'] = [request_json_body['stop']] request_json_body['stop'] = [request_json_body['stop']]
if hashes: if stop_hashes:
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT', '### RESPONSE'])
if opts.openai_force_no_hashes: if opts.openai_force_no_hashes:
request_json_body['stop'].append('### ') request_json_body['stop'].append('###')
else:
# TODO: make stopping strings a configurable
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT'])
else: else:
request_json_body['stop'].extend(['user:', 'assistant:']) request_json_body['stop'].extend(['user:', 'assistant:'])
@ -41,6 +43,11 @@ def format_oai_err(err_msg):
def validate_oai(parameters): def validate_oai(parameters):
if parameters.get('messages'):
for m in parameters['messages']:
if m['role'].lower() not in ['assistant', 'user', 'system']:
return format_oai_err('messages role must be assistant, user, or system')
if parameters.get('temperature', 0) > 2: if parameters.get('temperature', 0) > 2:
return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'") return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
if parameters.get('temperature', 0) < 0: if parameters.get('temperature', 0) < 0:

View File

@ -96,7 +96,7 @@ def transform_messages_to_prompt(oai_messages):
elif msg['role'] == 'assistant': elif msg['role'] == 'assistant':
prompt += f'### ASSISTANT: {msg["content"]}\n\n' prompt += f'### ASSISTANT: {msg["content"]}\n\n'
else: else:
return False raise Exception(f'Unknown role: {msg["role"]}')
except Exception as e: except Exception as e:
# TODO: use logging # TODO: use logging
traceback.print_exc() traceback.print_exc()

View File

@ -1,24 +1,16 @@
""" """
This file is used by the worker that processes requests. This file is used by the worker that processes requests.
""" """
import json
import time
from uuid import uuid4
import requests import requests
import llm_server
from llm_server import opts from llm_server import opts
from llm_server.custom_redis import redis
# TODO: make the VLMM backend return TPS and time elapsed # TODO: make the VLMM backend return TPS and time elapsed
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py # https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
def prepare_json(json_data: dict): def prepare_json(json_data: dict):
# logit_bias is not currently supported
# del json_data['logit_bias']
# Convert back to VLLM. # Convert back to VLLM.
json_data['max_tokens'] = json_data.pop('max_new_tokens') json_data['max_tokens'] = json_data.pop('max_new_tokens')
return json_data return json_data

View File

@ -3,7 +3,7 @@ from typing import Tuple
from flask import jsonify from flask import jsonify
from vllm import SamplingParams from vllm import SamplingParams
from llm_server.database.database import log_prompt from llm_server.database.log_to_db import log_to_db
from llm_server.llm.llm_backend import LLMBackend from llm_server.llm.llm_backend import LLMBackend
@ -18,8 +18,8 @@ class VLLMBackend(LLMBackend):
# Failsafe # Failsafe
backend_response = '' backend_response = ''
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, log_to_db(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url) response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url)
return jsonify({'results': [{'text': backend_response}]}), 200 return jsonify({'results': [{'text': backend_response}]}), 200
@ -29,14 +29,15 @@ class VLLMBackend(LLMBackend):
top_k = parameters.get('top_k', self._default_params['top_k']) top_k = parameters.get('top_k', self._default_params['top_k'])
if top_k <= 0: if top_k <= 0:
top_k = -1 top_k = -1
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=parameters.get('temperature', self._default_params['temperature']), temperature=parameters.get('temperature', self._default_params['temperature']),
top_p=parameters.get('top_p', self._default_params['top_p']), top_p=parameters.get('top_p', self._default_params['top_p']),
top_k=top_k, top_k=top_k,
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False, use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
stop=list(set(parameters.get('stopping_strings', self._default_params['stop']) or parameters.get('stop', self._default_params['stop']))), stop=list(set(parameters.get('stopping_strings') or parameters.get('stop', self._default_params['stop']))),
ignore_eos=parameters.get('ban_eos_token', False), ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) or parameters.get('max_tokens', self._default_params['max_tokens']), max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']),
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']), presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']) frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty'])
) )

View File

@ -4,7 +4,8 @@ import flask
from flask import jsonify, request from flask import jsonify, request
from llm_server import opts from llm_server import opts
from llm_server.database.database import log_prompt from llm_server.database.database import do_db_log
from llm_server.database.log_to_db import log_to_db
from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.request_handler import RequestHandler from llm_server.routes.request_handler import RequestHandler
@ -40,7 +41,7 @@ class OobaRequestHandler(RequestHandler):
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
backend_response = self.handle_error(msg) backend_response = self.handle_error(msg)
if do_log: if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return backend_response[0], 429 return backend_response[0], 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:

View File

@ -3,6 +3,7 @@ import time
import traceback import traceback
from flask import Response, jsonify, request from flask import Response, jsonify, request
from redis import Redis
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from . import openai_bp from . import openai_bp
@ -10,7 +11,7 @@ from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler from ..openai_request_handler import OpenAIRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts from ... import opts
from ...database.database import log_prompt from ...database.log_to_db import log_to_db
from ...llm.generator import generator from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
@ -18,6 +19,7 @@ from ...llm.openai.transform import generate_oai_string, transform_messages_to_p
# TODO: add rate-limit headers? # TODO: add rate-limit headers?
@openai_bp.route('/chat/completions', methods=['POST']) @openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions(): def openai_chat_completions():
request_valid_json, request_json_body = validate_json(request) request_valid_json, request_json_body = validate_json(request)
@ -36,12 +38,20 @@ def openai_chat_completions():
return 'Internal server error', 500 return 'Internal server error', 500
else: else:
if not opts.enable_streaming: if not opts.enable_streaming:
return 'DISABLED', 401 return
handler.parameters, _ = handler.get_parameters()
handler.request_json_body = {
'messages': handler.request_json_body['messages'],
'model': handler.request_json_body['model'],
**handler.parameters
}
invalid_oai_err_msg = validate_oai(handler.request_json_body) invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg: if invalid_oai_err_msg:
return invalid_oai_err_msg return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
if opts.openai_silent_trim: if opts.openai_silent_trim:
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)) handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
@ -64,7 +74,7 @@ def openai_chat_completions():
# Add a dummy event to the queue and wait for it to reach a worker # Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
if not event: if not event:
log_prompt( log_to_db(
handler.client_ip, handler.client_ip,
handler.token, handler.token,
handler.prompt, handler.prompt,
@ -82,7 +92,6 @@ def openai_chat_completions():
_, _, _ = event.wait() _, _, _ = event.wait()
try: try:
response = generator(msg_to_backend, handler.backend_url)
r_headers = dict(request.headers) r_headers = dict(request.headers)
r_url = request.url r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
@ -90,6 +99,7 @@ def openai_chat_completions():
def generate(): def generate():
try: try:
response = generator(msg_to_backend, handler.backend_url)
generated_text = '' generated_text = ''
partial_response = b'' partial_response = b''
for chunk in response.iter_content(chunk_size=1): for chunk in response.iter_content(chunk_size=1):
@ -125,8 +135,7 @@ def openai_chat_completions():
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
log_to_db(
log_prompt(
handler.client_ip, handler.client_ip,
handler.token, handler.token,
handler.prompt, handler.prompt,

View File

@ -10,7 +10,8 @@ from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts from ... import opts
from ...database.database import log_prompt from ...database.database import do_db_log
from ...database.log_to_db import log_to_db
from ...llm import get_token_count from ...llm import get_token_count
from ...llm.generator import generator from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
@ -34,7 +35,7 @@ def openai_completions():
invalid_oai_err_msg = validate_oai(handler.request_json_body) invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg: if invalid_oai_err_msg:
return invalid_oai_err_msg return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode']) handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
if opts.openai_silent_trim: if opts.openai_silent_trim:
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
@ -102,7 +103,7 @@ def openai_completions():
# Add a dummy event to the queue and wait for it to reach a worker # Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
if not event: if not event:
log_prompt( log_to_db(
handler.client_ip, handler.client_ip,
handler.token, handler.token,
handler.prompt, handler.prompt,
@ -164,7 +165,7 @@ def openai_completions():
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
log_prompt( log_to_db(
handler.client_ip, handler.client_ip,
handler.token, handler.token,
handler.prompt, handler.prompt,

View File

@ -11,7 +11,8 @@ from flask import Response, jsonify, make_response
from llm_server import opts from llm_server import opts
from llm_server.cluster.backend import get_model_choices from llm_server.cluster.backend import get_model_choices
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated, log_prompt from llm_server.database.database import is_api_key_moderated, do_db_log
from llm_server.database.log_to_db import log_to_db
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
@ -58,7 +59,7 @@ class OpenAIRequestHandler(RequestHandler):
traceback.print_exc() traceback.print_exc()
# TODO: support Ooba # TODO: support Ooba
self.parameters = oai_to_vllm(self.parameters, hashes=True, mode=self.cluster_backend_info['mode']) self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
llm_request = {**self.parameters, 'prompt': self.prompt} llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
@ -88,7 +89,7 @@ class OpenAIRequestHandler(RequestHandler):
response.headers['x-ratelimit-reset-requests'] = f"{w}s" response.headers['x-ratelimit-reset-requests'] = f"{w}s"
if do_log: if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return response, 429 return response, 429
@ -146,6 +147,6 @@ class OpenAIRequestHandler(RequestHandler):
invalid_oai_err_msg = validate_oai(self.request_json_body) invalid_oai_err_msg = validate_oai(self.request_json_body)
if invalid_oai_err_msg: if invalid_oai_err_msg:
return False, invalid_oai_err_msg return False, invalid_oai_err_msg
self.request_json_body = oai_to_vllm(self.request_json_body, hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
# If the parameters were invalid, let the superclass deal with it. # If the parameters were invalid, let the superclass deal with it.
return super().validate_request(prompt, do_log) return super().validate_request(prompt, do_log)

View File

@ -7,7 +7,8 @@ from flask import Response, request
from llm_server import opts from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.database import get_token_ratelimit, log_prompt from llm_server.database.database import get_token_ratelimit, do_db_log
from llm_server.database.log_to_db import log_to_db
from llm_server.helpers import auto_set_base_client_api from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend
@ -41,9 +42,11 @@ class RequestHandler:
if not self.cluster_backend_info.get('mode'): if not self.cluster_backend_info.get('mode'):
print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info) print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('model'): if not self.cluster_backend_info.get('model'):
print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info) print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('model_config'):
print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model'): if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'):
self.offline = True self.offline = True
else: else:
self.offline = False self.offline = False
@ -74,8 +77,6 @@ class RequestHandler:
return self.request.remote_addr return self.request.remote_addr
def get_parameters(self): def get_parameters(self):
if self.request_json_body.get('max_tokens'):
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
return parameters, parameters_invalid_msg return parameters, parameters_invalid_msg
@ -117,7 +118,7 @@ class RequestHandler:
backend_response = self.handle_error(combined_error_message, 'Validation Error') backend_response = self.handle_error(combined_error_message, 'Validation Error')
if do_log: if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True) log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True)
return False, backend_response return False, backend_response
return True, (None, 0) return True, (None, 0)
@ -160,17 +161,17 @@ class RequestHandler:
else: else:
error_msg = error_msg.strip('.') + '.' error_msg = error_msg.strip('.') + '.'
backend_response = self.handle_error(error_msg) backend_response = self.handle_error(error_msg)
log_prompt(ip=self.client_ip, log_to_db(ip=self.client_ip,
token=self.token, token=self.token,
prompt=prompt, prompt=prompt,
response=backend_response[0].data.decode('utf-8'), response=backend_response[0].data.decode('utf-8'),
gen_time=None, gen_time=None,
parameters=self.parameters, parameters=self.parameters,
headers=dict(self.request.headers), headers=dict(self.request.headers),
backend_response_code=response_status_code, backend_response_code=response_status_code,
request_url=self.request.url, request_url=self.request.url,
backend_url=self.backend_url, backend_url=self.backend_url,
is_error=True) is_error=True)
return (False, None, None, 0), backend_response return (False, None, None, 0), backend_response
# =============================================== # ===============================================
@ -190,7 +191,7 @@ class RequestHandler:
if return_json_err: if return_json_err:
error_msg = 'The backend did not return valid JSON.' error_msg = 'The backend did not return valid JSON.'
backend_response = self.handle_error(error_msg) backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True) log_to_db(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
return (False, None, None, 0), backend_response return (False, None, None, 0), backend_response
# =============================================== # ===============================================

View File

@ -9,7 +9,8 @@ from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts from ... import opts
from ...database.database import log_prompt from ...database.database import do_db_log
from ...database.log_to_db import log_to_db
from ...llm.generator import generator from ...llm.generator import generator
from ...sock import sock from ...sock import sock
@ -34,38 +35,38 @@ def stream_with_model(ws, model_name=None):
def do_stream(ws, model_name): def do_stream(ws, model_name):
def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': quitting_err_msg
}))
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
log_prompt(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=quitting_err_msg,
gen_time=None,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.cluster_backend_info,
response_tokens=None,
is_error=True
)
if not opts.enable_streaming:
return 'Streaming is disabled', 500
r_headers = dict(request.headers)
r_url = request.url
message_num = 0
try: try:
def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': quitting_err_msg
}))
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
log_to_db(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=quitting_err_msg,
gen_time=None,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.cluster_backend_info,
response_tokens=None,
is_error=True
)
if not opts.enable_streaming:
return 'Streaming is disabled', 500
r_headers = dict(request.headers)
r_url = request.url
message_num = 0
while ws.connected: while ws.connected:
message = ws.receive() message = ws.receive()
request_valid_json, request_json_body = validate_json(message) request_valid_json, request_json_body = validate_json(message)
@ -197,7 +198,7 @@ def do_stream(ws, model_name):
pass pass
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
log_prompt(ip=handler.client_ip, log_to_db(ip=handler.client_ip,
token=handler.token, token=handler.token,
prompt=input_prompt, prompt=input_prompt,
response=generated_text, response=generated_text,

View File

@ -0,0 +1,28 @@
import pickle
import redis
from llm_server.database.database import do_db_log
def db_logger():
"""
We don't want the logging operation to be blocking, so we will use a background worker
to do the logging.
:return:
"""
r = redis.Redis(host='localhost', port=6379, db=3)
p = r.pubsub()
p.subscribe('database-logger')
for message in p.listen():
if message['type'] == 'message':
data = pickle.loads(message['data'])
function_name = data['function']
args = data['args']
kwargs = data['kwargs']
if function_name == 'log_prompt':
do_db_log(*args, **kwargs)
print('finished log')

View File

@ -2,11 +2,11 @@ import time
from threading import Thread from threading import Thread
from llm_server import opts from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.stores import redis_running_models from llm_server.cluster.stores import redis_running_models
from llm_server.cluster.worker import cluster_worker from llm_server.cluster.worker import cluster_worker
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.inferencer import start_workers from llm_server.workers.inferencer import start_workers
from llm_server.workers.logger import db_logger
from llm_server.workers.mainer import main_background_thread from llm_server.workers.mainer import main_background_thread
from llm_server.workers.moderator import start_moderation_workers from llm_server.workers.moderator import start_moderation_workers
from llm_server.workers.printer import console_printer from llm_server.workers.printer import console_printer
@ -49,3 +49,8 @@ def start_background():
t.daemon = True t.daemon = True
t.start() t.start()
print('Started the cluster worker.') print('Started the cluster worker.')
t = Thread(target=db_logger)
t.daemon = True
t.start()
print('Started background logger')

33
other/gradio_chat.py Normal file
View File

@ -0,0 +1,33 @@
import warnings
import gradio as gr
import openai
warnings.filterwarnings("ignore")
openai.api_key = 'null'
openai.api_base = 'http://localhost:5000/api/openai/v1'
def stream_response(prompt, history):
messages = []
for x in history:
messages.append({'role': 'user', 'content': x[0]})
messages.append({'role': 'assistant', 'content': x[1]})
messages.append({'role': 'user', 'content': prompt})
response = openai.ChatCompletion.create(
model='0',
messages=messages,
temperature=0,
max_tokens=300,
stream=True
)
message = ''
for chunk in response:
message += chunk['choices'][0]['delta']['content']
yield message
gr.ChatInterface(stream_response, examples=["hello", "hola", "merhaba"], title="Chatbot Demo", analytics_enabled=False, cache_examples=False, css='#component-0{height:100%!important}').queue().launch()