Merge cluster to master #3

Merged
cyberes merged 163 commits from cluster into master 2023-10-27 19:19:22 -06:00
16 changed files with 242 additions and 141 deletions
Showing only changes of commit acf409abfc - Show all commits

View File

@ -1,7 +1,6 @@
import json
import time
import traceback
from threading import Thread
from typing import Union
from llm_server import opts
@ -10,12 +9,10 @@ from llm_server.database.conn import database
from llm_server.llm import get_token_count
def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
assert isinstance(prompt, str)
assert isinstance(backend_url, str)
def background_task():
nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens, is_error
# Try not to shove JSON into the database.
if isinstance(response, dict) and response.get('results'):
response = response['results'][0]['text']
@ -27,6 +24,8 @@ def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen
pass
prompt_tokens = get_token_count(prompt, backend_url)
print('starting')
if not is_error:
if not response_tokens:
response_tokens = get_token_count(response, backend_url)
@ -65,11 +64,6 @@ def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen
finally:
cursor.close()
# TODO: use async/await instead of threads
thread = Thread(target=background_task)
thread.start()
thread.join()
def is_valid_api_key(api_key):
cursor = database.cursor()

View File

@ -0,0 +1,27 @@
import pickle
from typing import Union
from redis import Redis
def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
r = Redis(host='localhost', port=6379, db=3)
data = {
'function': 'log_prompt',
'args': [],
'kwargs': {
'ip': ip,
'token': token,
'prompt': prompt,
'response': response,
'gen_time': gen_time,
'parameters': parameters,
'headers': headers,
'backend_response_code': backend_response_code,
'request_url': request_url,
'backend_url': backend_url,
'response_tokens': response_tokens,
'is_error': is_error
}
}
r.publish('database-logger', pickle.dumps(data))

View File

@ -2,7 +2,7 @@ from flask import jsonify
from llm_server.custom_redis import redis
from ..llm_backend import LLMBackend
from ...database.database import log_prompt
from ...database.database import do_db_log
from ...helpers import safe_list_get
from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json
@ -34,7 +34,7 @@ class OobaboogaBackend(LLMBackend):
else:
error_msg = error_msg.strip('.') + '.'
backend_response = format_sillytavern_err(error_msg, error_type='error', backend_url=self.backend_url)
log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
log_to_db(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
return jsonify({
'code': 500,
'msg': error_msg,
@ -57,13 +57,13 @@ class OobaboogaBackend(LLMBackend):
if not backend_err:
redis.incr('proompts')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
log_to_db(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify({
**response_json_body
}), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', error_type='error', backend_url=self.backend_url)
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
log_to_db(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
return jsonify({
'code': 500,
'msg': 'the backend did not return valid JSON',

View File

@ -3,17 +3,19 @@ from flask import jsonify
from llm_server import opts
def oai_to_vllm(request_json_body, hashes: bool, mode):
def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
if not request_json_body.get('stop'):
request_json_body['stop'] = []
if not isinstance(request_json_body['stop'], list):
# It is a string, so create a list with the existing element.
request_json_body['stop'] = [request_json_body['stop']]
if hashes:
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT', '### RESPONSE'])
if stop_hashes:
if opts.openai_force_no_hashes:
request_json_body['stop'].append('### ')
request_json_body['stop'].append('###')
else:
# TODO: make stopping strings a configurable
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT'])
else:
request_json_body['stop'].extend(['user:', 'assistant:'])
@ -41,6 +43,11 @@ def format_oai_err(err_msg):
def validate_oai(parameters):
if parameters.get('messages'):
for m in parameters['messages']:
if m['role'].lower() not in ['assistant', 'user', 'system']:
return format_oai_err('messages role must be assistant, user, or system')
if parameters.get('temperature', 0) > 2:
return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
if parameters.get('temperature', 0) < 0:

View File

@ -96,7 +96,7 @@ def transform_messages_to_prompt(oai_messages):
elif msg['role'] == 'assistant':
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
else:
return False
raise Exception(f'Unknown role: {msg["role"]}')
except Exception as e:
# TODO: use logging
traceback.print_exc()

View File

@ -1,24 +1,16 @@
"""
This file is used by the worker that processes requests.
"""
import json
import time
from uuid import uuid4
import requests
import llm_server
from llm_server import opts
from llm_server.custom_redis import redis
# TODO: make the VLMM backend return TPS and time elapsed
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
def prepare_json(json_data: dict):
# logit_bias is not currently supported
# del json_data['logit_bias']
# Convert back to VLLM.
json_data['max_tokens'] = json_data.pop('max_new_tokens')
return json_data

View File

@ -3,7 +3,7 @@ from typing import Tuple
from flask import jsonify
from vllm import SamplingParams
from llm_server.database.database import log_prompt
from llm_server.database.log_to_db import log_to_db
from llm_server.llm.llm_backend import LLMBackend
@ -18,7 +18,7 @@ class VLLMBackend(LLMBackend):
# Failsafe
backend_response = ''
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
log_to_db(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url)
return jsonify({'results': [{'text': backend_response}]}), 200
@ -29,14 +29,15 @@ class VLLMBackend(LLMBackend):
top_k = parameters.get('top_k', self._default_params['top_k'])
if top_k <= 0:
top_k = -1
sampling_params = SamplingParams(
temperature=parameters.get('temperature', self._default_params['temperature']),
top_p=parameters.get('top_p', self._default_params['top_p']),
top_k=top_k,
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
stop=list(set(parameters.get('stopping_strings', self._default_params['stop']) or parameters.get('stop', self._default_params['stop']))),
stop=list(set(parameters.get('stopping_strings') or parameters.get('stop', self._default_params['stop']))),
ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) or parameters.get('max_tokens', self._default_params['max_tokens']),
max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']),
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty'])
)

View File

@ -4,7 +4,8 @@ import flask
from flask import jsonify, request
from llm_server import opts
from llm_server.database.database import log_prompt
from llm_server.database.database import do_db_log
from llm_server.database.log_to_db import log_to_db
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.request_handler import RequestHandler
@ -40,7 +41,7 @@ class OobaRequestHandler(RequestHandler):
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
backend_response = self.handle_error(msg)
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return backend_response[0], 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:

View File

@ -3,6 +3,7 @@ import time
import traceback
from flask import Response, jsonify, request
from redis import Redis
from llm_server.custom_redis import redis
from . import openai_bp
@ -10,7 +11,7 @@ from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts
from ...database.database import log_prompt
from ...database.log_to_db import log_to_db
from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
@ -18,6 +19,7 @@ from ...llm.openai.transform import generate_oai_string, transform_messages_to_p
# TODO: add rate-limit headers?
@openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions():
request_valid_json, request_json_body = validate_json(request)
@ -36,12 +38,20 @@ def openai_chat_completions():
return 'Internal server error', 500
else:
if not opts.enable_streaming:
return 'DISABLED', 401
return
handler.parameters, _ = handler.get_parameters()
handler.request_json_body = {
'messages': handler.request_json_body['messages'],
'model': handler.request_json_body['model'],
**handler.parameters
}
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
if opts.openai_silent_trim:
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
@ -64,7 +74,7 @@ def openai_chat_completions():
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
if not event:
log_prompt(
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
@ -82,7 +92,6 @@ def openai_chat_completions():
_, _, _ = event.wait()
try:
response = generator(msg_to_backend, handler.backend_url)
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
@ -90,6 +99,7 @@ def openai_chat_completions():
def generate():
try:
response = generator(msg_to_backend, handler.backend_url)
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
@ -125,8 +135,7 @@ def openai_chat_completions():
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,

View File

@ -10,7 +10,8 @@ from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts
from ...database.database import log_prompt
from ...database.database import do_db_log
from ...database.log_to_db import log_to_db
from ...llm import get_token_count
from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
@ -34,7 +35,7 @@ def openai_completions():
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
if opts.openai_silent_trim:
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
@ -102,7 +103,7 @@ def openai_completions():
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
if not event:
log_prompt(
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
@ -164,7 +165,7 @@ def openai_completions():
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,

View File

@ -11,7 +11,8 @@ from flask import Response, jsonify, make_response
from llm_server import opts
from llm_server.cluster.backend import get_model_choices
from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated, log_prompt
from llm_server.database.database import is_api_key_moderated, do_db_log
from llm_server.database.log_to_db import log_to_db
from llm_server.llm import get_token_count
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
@ -58,7 +59,7 @@ class OpenAIRequestHandler(RequestHandler):
traceback.print_exc()
# TODO: support Ooba
self.parameters = oai_to_vllm(self.parameters, hashes=True, mode=self.cluster_backend_info['mode'])
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
@ -88,7 +89,7 @@ class OpenAIRequestHandler(RequestHandler):
response.headers['x-ratelimit-reset-requests'] = f"{w}s"
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return response, 429
@ -146,6 +147,6 @@ class OpenAIRequestHandler(RequestHandler):
invalid_oai_err_msg = validate_oai(self.request_json_body)
if invalid_oai_err_msg:
return False, invalid_oai_err_msg
self.request_json_body = oai_to_vllm(self.request_json_body, hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
# If the parameters were invalid, let the superclass deal with it.
return super().validate_request(prompt, do_log)

View File

@ -7,7 +7,8 @@ from flask import Response, request
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
from llm_server.custom_redis import redis
from llm_server.database.database import get_token_ratelimit, log_prompt
from llm_server.database.database import get_token_ratelimit, do_db_log
from llm_server.database.log_to_db import log_to_db
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
@ -41,9 +42,11 @@ class RequestHandler:
if not self.cluster_backend_info.get('mode'):
print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('model'):
print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('model_config'):
print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model'):
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'):
self.offline = True
else:
self.offline = False
@ -74,8 +77,6 @@ class RequestHandler:
return self.request.remote_addr
def get_parameters(self):
if self.request_json_body.get('max_tokens'):
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
return parameters, parameters_invalid_msg
@ -117,7 +118,7 @@ class RequestHandler:
backend_response = self.handle_error(combined_error_message, 'Validation Error')
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True)
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True)
return False, backend_response
return True, (None, 0)
@ -160,7 +161,7 @@ class RequestHandler:
else:
error_msg = error_msg.strip('.') + '.'
backend_response = self.handle_error(error_msg)
log_prompt(ip=self.client_ip,
log_to_db(ip=self.client_ip,
token=self.token,
prompt=prompt,
response=backend_response[0].data.decode('utf-8'),
@ -190,7 +191,7 @@ class RequestHandler:
if return_json_err:
error_msg = 'The backend did not return valid JSON.'
backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
log_to_db(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
return (False, None, None, 0), backend_response
# ===============================================

View File

@ -9,7 +9,8 @@ from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts
from ...database.database import log_prompt
from ...database.database import do_db_log
from ...database.log_to_db import log_to_db
from ...llm.generator import generator
from ...sock import sock
@ -34,6 +35,7 @@ def stream_with_model(ws, model_name=None):
def do_stream(ws, model_name):
try:
def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({
'event': 'text_stream',
@ -44,7 +46,7 @@ def do_stream(ws, model_name):
'event': 'stream_end',
'message_num': 1
}))
log_prompt(ip=handler.client_ip,
log_to_db(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=quitting_err_msg,
@ -65,7 +67,6 @@ def do_stream(ws, model_name):
r_url = request.url
message_num = 0
try:
while ws.connected:
message = ws.receive()
request_valid_json, request_json_body = validate_json(message)
@ -197,7 +198,7 @@ def do_stream(ws, model_name):
pass
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(ip=handler.client_ip,
log_to_db(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,

View File

@ -0,0 +1,28 @@
import pickle
import redis
from llm_server.database.database import do_db_log
def db_logger():
"""
We don't want the logging operation to be blocking, so we will use a background worker
to do the logging.
:return:
"""
r = redis.Redis(host='localhost', port=6379, db=3)
p = r.pubsub()
p.subscribe('database-logger')
for message in p.listen():
if message['type'] == 'message':
data = pickle.loads(message['data'])
function_name = data['function']
args = data['args']
kwargs = data['kwargs']
if function_name == 'log_prompt':
do_db_log(*args, **kwargs)
print('finished log')

View File

@ -2,11 +2,11 @@ import time
from threading import Thread
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.stores import redis_running_models
from llm_server.cluster.worker import cluster_worker
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.inferencer import start_workers
from llm_server.workers.logger import db_logger
from llm_server.workers.mainer import main_background_thread
from llm_server.workers.moderator import start_moderation_workers
from llm_server.workers.printer import console_printer
@ -49,3 +49,8 @@ def start_background():
t.daemon = True
t.start()
print('Started the cluster worker.')
t = Thread(target=db_logger)
t.daemon = True
t.start()
print('Started background logger')

33
other/gradio_chat.py Normal file
View File

@ -0,0 +1,33 @@
import warnings
import gradio as gr
import openai
warnings.filterwarnings("ignore")
openai.api_key = 'null'
openai.api_base = 'http://localhost:5000/api/openai/v1'
def stream_response(prompt, history):
messages = []
for x in history:
messages.append({'role': 'user', 'content': x[0]})
messages.append({'role': 'assistant', 'content': x[1]})
messages.append({'role': 'user', 'content': prompt})
response = openai.ChatCompletion.create(
model='0',
messages=messages,
temperature=0,
max_tokens=300,
stream=True
)
message = ''
for chunk in response:
message += chunk['choices'][0]['delta']['content']
yield message
gr.ChatInterface(stream_response, examples=["hello", "hola", "merhaba"], title="Chatbot Demo", analytics_enabled=False, cache_examples=False, css='#component-0{height:100%!important}').queue().launch()