diff --git a/llm_server/database/database.py b/llm_server/database/database.py index f4e6c9c..fc800a2 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -1,7 +1,6 @@ import json import time import traceback -from threading import Thread from typing import Union from llm_server import opts @@ -10,65 +9,60 @@ from llm_server.database.conn import database from llm_server.llm import get_token_count -def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False): +def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False): assert isinstance(prompt, str) assert isinstance(backend_url, str) - def background_task(): - nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens, is_error - # Try not to shove JSON into the database. - if isinstance(response, dict) and response.get('results'): - response = response['results'][0]['text'] - try: - j = json.loads(response) - if j.get('results'): - response = j['results'][0]['text'] - except: - pass + # Try not to shove JSON into the database. + if isinstance(response, dict) and response.get('results'): + response = response['results'][0]['text'] + try: + j = json.loads(response) + if j.get('results'): + response = j['results'][0]['text'] + except: + pass - prompt_tokens = get_token_count(prompt, backend_url) - if not is_error: - if not response_tokens: - response_tokens = get_token_count(response, backend_url) - else: - response_tokens = None + prompt_tokens = get_token_count(prompt, backend_url) + print('starting') - # Sometimes we may want to insert null into the DB, but - # usually we want to insert a float. - if gen_time: - gen_time = round(gen_time, 3) - if is_error: - gen_time = None + if not is_error: + if not response_tokens: + response_tokens = get_token_count(response, backend_url) + else: + response_tokens = None - if not opts.log_prompts: - prompt = None + # Sometimes we may want to insert null into the DB, but + # usually we want to insert a float. + if gen_time: + gen_time = round(gen_time, 3) + if is_error: + gen_time = None - if not opts.log_prompts and not is_error: - # TODO: test and verify this works as expected - response = None + if not opts.log_prompts: + prompt = None - if token: - increment_token_uses(token) + if not opts.log_prompts and not is_error: + # TODO: test and verify this works as expected + response = None - backend_info = cluster_config.get_backend(backend_url) - running_model = backend_info.get('model') - backend_mode = backend_info['mode'] - timestamp = int(time.time()) - cursor = database.cursor() - try: - cursor.execute(""" - INSERT INTO prompts - (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """, - (ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) - finally: - cursor.close() + if token: + increment_token_uses(token) - # TODO: use async/await instead of threads - thread = Thread(target=background_task) - thread.start() - thread.join() + backend_info = cluster_config.get_backend(backend_url) + running_model = backend_info.get('model') + backend_mode = backend_info['mode'] + timestamp = int(time.time()) + cursor = database.cursor() + try: + cursor.execute(""" + INSERT INTO prompts + (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, + (ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) + finally: + cursor.close() def is_valid_api_key(api_key): diff --git a/llm_server/database/log_to_db.py b/llm_server/database/log_to_db.py new file mode 100644 index 0000000..fa97ad7 --- /dev/null +++ b/llm_server/database/log_to_db.py @@ -0,0 +1,27 @@ +import pickle +from typing import Union + +from redis import Redis + + +def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False): + r = Redis(host='localhost', port=6379, db=3) + data = { + 'function': 'log_prompt', + 'args': [], + 'kwargs': { + 'ip': ip, + 'token': token, + 'prompt': prompt, + 'response': response, + 'gen_time': gen_time, + 'parameters': parameters, + 'headers': headers, + 'backend_response_code': backend_response_code, + 'request_url': request_url, + 'backend_url': backend_url, + 'response_tokens': response_tokens, + 'is_error': is_error + } + } + r.publish('database-logger', pickle.dumps(data)) diff --git a/llm_server/llm/oobabooga/ooba_backend.py b/llm_server/llm/oobabooga/ooba_backend.py index fe450bf..0e2b2d8 100644 --- a/llm_server/llm/oobabooga/ooba_backend.py +++ b/llm_server/llm/oobabooga/ooba_backend.py @@ -2,7 +2,7 @@ from flask import jsonify from llm_server.custom_redis import redis from ..llm_backend import LLMBackend -from ...database.database import log_prompt +from ...database.database import do_db_log from ...helpers import safe_list_get from ...routes.helpers.client import format_sillytavern_err from ...routes.helpers.http import validate_json @@ -34,7 +34,7 @@ class OobaboogaBackend(LLMBackend): else: error_msg = error_msg.strip('.') + '.' backend_response = format_sillytavern_err(error_msg, error_type='error', backend_url=self.backend_url) - log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True) + log_to_db(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True) return jsonify({ 'code': 500, 'msg': error_msg, @@ -57,13 +57,13 @@ class OobaboogaBackend(LLMBackend): if not backend_err: redis.incr('proompts') - log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err) + log_to_db(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err) return jsonify({ **response_json_body }), 200 else: backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', error_type='error', backend_url=self.backend_url) - log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True) + log_to_db(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True) return jsonify({ 'code': 500, 'msg': 'the backend did not return valid JSON', diff --git a/llm_server/llm/openai/oai_to_vllm.py b/llm_server/llm/openai/oai_to_vllm.py index 7d26467..35c9f30 100644 --- a/llm_server/llm/openai/oai_to_vllm.py +++ b/llm_server/llm/openai/oai_to_vllm.py @@ -3,17 +3,19 @@ from flask import jsonify from llm_server import opts -def oai_to_vllm(request_json_body, hashes: bool, mode): +def oai_to_vllm(request_json_body, stop_hashes: bool, mode): if not request_json_body.get('stop'): request_json_body['stop'] = [] if not isinstance(request_json_body['stop'], list): # It is a string, so create a list with the existing element. request_json_body['stop'] = [request_json_body['stop']] - if hashes: - request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT', '### RESPONSE']) + if stop_hashes: if opts.openai_force_no_hashes: - request_json_body['stop'].append('### ') + request_json_body['stop'].append('###') + else: + # TODO: make stopping strings a configurable + request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT']) else: request_json_body['stop'].extend(['user:', 'assistant:']) @@ -41,6 +43,11 @@ def format_oai_err(err_msg): def validate_oai(parameters): + if parameters.get('messages'): + for m in parameters['messages']: + if m['role'].lower() not in ['assistant', 'user', 'system']: + return format_oai_err('messages role must be assistant, user, or system') + if parameters.get('temperature', 0) > 2: return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'") if parameters.get('temperature', 0) < 0: diff --git a/llm_server/llm/openai/transform.py b/llm_server/llm/openai/transform.py index 39f942a..0c2946b 100644 --- a/llm_server/llm/openai/transform.py +++ b/llm_server/llm/openai/transform.py @@ -96,7 +96,7 @@ def transform_messages_to_prompt(oai_messages): elif msg['role'] == 'assistant': prompt += f'### ASSISTANT: {msg["content"]}\n\n' else: - return False + raise Exception(f'Unknown role: {msg["role"]}') except Exception as e: # TODO: use logging traceback.print_exc() diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 72b0243..31cd511 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -1,24 +1,16 @@ """ This file is used by the worker that processes requests. """ -import json -import time -from uuid import uuid4 import requests -import llm_server from llm_server import opts -from llm_server.custom_redis import redis # TODO: make the VLMM backend return TPS and time elapsed # https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py def prepare_json(json_data: dict): - # logit_bias is not currently supported - # del json_data['logit_bias'] - # Convert back to VLLM. json_data['max_tokens'] = json_data.pop('max_new_tokens') return json_data diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index abc1cbb..a9ec821 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -3,7 +3,7 @@ from typing import Tuple from flask import jsonify from vllm import SamplingParams -from llm_server.database.database import log_prompt +from llm_server.database.log_to_db import log_to_db from llm_server.llm.llm_backend import LLMBackend @@ -18,8 +18,8 @@ class VLLMBackend(LLMBackend): # Failsafe backend_response = '' - log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, - response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url) + log_to_db(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, + response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url) return jsonify({'results': [{'text': backend_response}]}), 200 @@ -29,14 +29,15 @@ class VLLMBackend(LLMBackend): top_k = parameters.get('top_k', self._default_params['top_k']) if top_k <= 0: top_k = -1 + sampling_params = SamplingParams( temperature=parameters.get('temperature', self._default_params['temperature']), top_p=parameters.get('top_p', self._default_params['top_p']), top_k=top_k, use_beam_search=True if parameters.get('num_beams', 0) > 1 else False, - stop=list(set(parameters.get('stopping_strings', self._default_params['stop']) or parameters.get('stop', self._default_params['stop']))), + stop=list(set(parameters.get('stopping_strings') or parameters.get('stop', self._default_params['stop']))), ignore_eos=parameters.get('ban_eos_token', False), - max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) or parameters.get('max_tokens', self._default_params['max_tokens']), + max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']), presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']), frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']) ) diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index 350621f..c01bfed 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -4,7 +4,8 @@ import flask from flask import jsonify, request from llm_server import opts -from llm_server.database.database import log_prompt +from llm_server.database.database import do_db_log +from llm_server.database.log_to_db import log_to_db from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.request_handler import RequestHandler @@ -40,7 +41,7 @@ class OobaRequestHandler(RequestHandler): msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' backend_response = self.handle_error(msg) if do_log: - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) + log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) return backend_response[0], 429 def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index c46e89f..d10bdf6 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -3,6 +3,7 @@ import time import traceback from flask import Response, jsonify, request +from redis import Redis from llm_server.custom_redis import redis from . import openai_bp @@ -10,7 +11,7 @@ from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts -from ...database.database import log_prompt +from ...database.log_to_db import log_to_db from ...llm.generator import generator from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit @@ -18,6 +19,7 @@ from ...llm.openai.transform import generate_oai_string, transform_messages_to_p # TODO: add rate-limit headers? + @openai_bp.route('/chat/completions', methods=['POST']) def openai_chat_completions(): request_valid_json, request_json_body = validate_json(request) @@ -36,12 +38,20 @@ def openai_chat_completions(): return 'Internal server error', 500 else: if not opts.enable_streaming: - return 'DISABLED', 401 + return + + handler.parameters, _ = handler.get_parameters() + handler.request_json_body = { + 'messages': handler.request_json_body['messages'], + 'model': handler.request_json_body['model'], + **handler.parameters + } invalid_oai_err_msg = validate_oai(handler.request_json_body) if invalid_oai_err_msg: return invalid_oai_err_msg - handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode']) + + handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode']) if opts.openai_silent_trim: handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)) @@ -64,7 +74,7 @@ def openai_chat_completions(): # Add a dummy event to the queue and wait for it to reach a worker event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) if not event: - log_prompt( + log_to_db( handler.client_ip, handler.token, handler.prompt, @@ -82,7 +92,6 @@ def openai_chat_completions(): _, _, _ = event.wait() try: - response = generator(msg_to_backend, handler.backend_url) r_headers = dict(request.headers) r_url = request.url model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') @@ -90,6 +99,7 @@ def openai_chat_completions(): def generate(): try: + response = generator(msg_to_backend, handler.backend_url) generated_text = '' partial_response = b'' for chunk in response.iter_content(chunk_size=1): @@ -125,8 +135,7 @@ def openai_chat_completions(): yield 'data: [DONE]\n\n' end_time = time.time() elapsed_time = end_time - start_time - - log_prompt( + log_to_db( handler.client_ip, handler.token, handler.prompt, diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 6904348..1843226 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -10,7 +10,8 @@ from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts -from ...database.database import log_prompt +from ...database.database import do_db_log +from ...database.log_to_db import log_to_db from ...llm import get_token_count from ...llm.generator import generator from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai @@ -34,7 +35,7 @@ def openai_completions(): invalid_oai_err_msg = validate_oai(handler.request_json_body) if invalid_oai_err_msg: return invalid_oai_err_msg - handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode']) + handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode']) if opts.openai_silent_trim: handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) @@ -102,7 +103,7 @@ def openai_completions(): # Add a dummy event to the queue and wait for it to reach a worker event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) if not event: - log_prompt( + log_to_db( handler.client_ip, handler.token, handler.prompt, @@ -164,7 +165,7 @@ def openai_completions(): end_time = time.time() elapsed_time = end_time - start_time - log_prompt( + log_to_db( handler.client_ip, handler.token, handler.prompt, diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 84b2c76..037de27 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -11,7 +11,8 @@ from flask import Response, jsonify, make_response from llm_server import opts from llm_server.cluster.backend import get_model_choices from llm_server.custom_redis import redis -from llm_server.database.database import is_api_key_moderated, log_prompt +from llm_server.database.database import is_api_key_moderated, do_db_log +from llm_server.database.log_to_db import log_to_db from llm_server.llm import get_token_count from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit @@ -58,7 +59,7 @@ class OpenAIRequestHandler(RequestHandler): traceback.print_exc() # TODO: support Ooba - self.parameters = oai_to_vllm(self.parameters, hashes=True, mode=self.cluster_backend_info['mode']) + self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) llm_request = {**self.parameters, 'prompt': self.prompt} (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) @@ -88,7 +89,7 @@ class OpenAIRequestHandler(RequestHandler): response.headers['x-ratelimit-reset-requests'] = f"{w}s" if do_log: - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) + log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) return response, 429 @@ -146,6 +147,6 @@ class OpenAIRequestHandler(RequestHandler): invalid_oai_err_msg = validate_oai(self.request_json_body) if invalid_oai_err_msg: return False, invalid_oai_err_msg - self.request_json_body = oai_to_vllm(self.request_json_body, hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) + self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) # If the parameters were invalid, let the superclass deal with it. return super().validate_request(prompt, do_log) diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index d981be8..90da0b1 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -7,7 +7,8 @@ from flask import Response, request from llm_server import opts from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend from llm_server.custom_redis import redis -from llm_server.database.database import get_token_ratelimit, log_prompt +from llm_server.database.database import get_token_ratelimit, do_db_log +from llm_server.database.log_to_db import log_to_db from llm_server.helpers import auto_set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend @@ -41,9 +42,11 @@ class RequestHandler: if not self.cluster_backend_info.get('mode'): print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info) if not self.cluster_backend_info.get('model'): - print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info) + print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info) + if not self.cluster_backend_info.get('model_config'): + print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info) - if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model'): + if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'): self.offline = True else: self.offline = False @@ -74,8 +77,6 @@ class RequestHandler: return self.request.remote_addr def get_parameters(self): - if self.request_json_body.get('max_tokens'): - self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens') parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) return parameters, parameters_invalid_msg @@ -117,7 +118,7 @@ class RequestHandler: backend_response = self.handle_error(combined_error_message, 'Validation Error') if do_log: - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True) + log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True) return False, backend_response return True, (None, 0) @@ -160,17 +161,17 @@ class RequestHandler: else: error_msg = error_msg.strip('.') + '.' backend_response = self.handle_error(error_msg) - log_prompt(ip=self.client_ip, - token=self.token, - prompt=prompt, - response=backend_response[0].data.decode('utf-8'), - gen_time=None, - parameters=self.parameters, - headers=dict(self.request.headers), - backend_response_code=response_status_code, - request_url=self.request.url, - backend_url=self.backend_url, - is_error=True) + log_to_db(ip=self.client_ip, + token=self.token, + prompt=prompt, + response=backend_response[0].data.decode('utf-8'), + gen_time=None, + parameters=self.parameters, + headers=dict(self.request.headers), + backend_response_code=response_status_code, + request_url=self.request.url, + backend_url=self.backend_url, + is_error=True) return (False, None, None, 0), backend_response # =============================================== @@ -190,7 +191,7 @@ class RequestHandler: if return_json_err: error_msg = 'The backend did not return valid JSON.' backend_response = self.handle_error(error_msg) - log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True) + log_to_db(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True) return (False, None, None, 0), backend_response # =============================================== diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 6cd98c0..9962ff8 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -9,7 +9,8 @@ from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts -from ...database.database import log_prompt +from ...database.database import do_db_log +from ...database.log_to_db import log_to_db from ...llm.generator import generator from ...sock import sock @@ -34,38 +35,38 @@ def stream_with_model(ws, model_name=None): def do_stream(ws, model_name): - def send_err_and_quit(quitting_err_msg): - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': 0, - 'text': quitting_err_msg - })) - ws.send(json.dumps({ - 'event': 'stream_end', - 'message_num': 1 - })) - log_prompt(ip=handler.client_ip, - token=handler.token, - prompt=input_prompt, - response=quitting_err_msg, - gen_time=None, - parameters=handler.parameters, - headers=r_headers, - backend_response_code=response_status_code, - request_url=r_url, - backend_url=handler.cluster_backend_info, - response_tokens=None, - is_error=True - ) - - if not opts.enable_streaming: - return 'Streaming is disabled', 500 - - r_headers = dict(request.headers) - r_url = request.url - message_num = 0 - try: + def send_err_and_quit(quitting_err_msg): + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': 0, + 'text': quitting_err_msg + })) + ws.send(json.dumps({ + 'event': 'stream_end', + 'message_num': 1 + })) + log_to_db(ip=handler.client_ip, + token=handler.token, + prompt=input_prompt, + response=quitting_err_msg, + gen_time=None, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.cluster_backend_info, + response_tokens=None, + is_error=True + ) + + if not opts.enable_streaming: + return 'Streaming is disabled', 500 + + r_headers = dict(request.headers) + r_url = request.url + message_num = 0 + while ws.connected: message = ws.receive() request_valid_json, request_json_body = validate_json(message) @@ -197,7 +198,7 @@ def do_stream(ws, model_name): pass end_time = time.time() elapsed_time = end_time - start_time - log_prompt(ip=handler.client_ip, + log_to_db(ip=handler.client_ip, token=handler.token, prompt=input_prompt, response=generated_text, diff --git a/llm_server/workers/logger.py b/llm_server/workers/logger.py new file mode 100644 index 0000000..2707615 --- /dev/null +++ b/llm_server/workers/logger.py @@ -0,0 +1,28 @@ +import pickle + +import redis + +from llm_server.database.database import do_db_log + + +def db_logger(): + """ + We don't want the logging operation to be blocking, so we will use a background worker + to do the logging. + :return: + """ + + r = redis.Redis(host='localhost', port=6379, db=3) + p = r.pubsub() + p.subscribe('database-logger') + + for message in p.listen(): + if message['type'] == 'message': + data = pickle.loads(message['data']) + function_name = data['function'] + args = data['args'] + kwargs = data['kwargs'] + + if function_name == 'log_prompt': + do_db_log(*args, **kwargs) + print('finished log') diff --git a/llm_server/workers/threader.py b/llm_server/workers/threader.py index fa6c252..bf14d60 100644 --- a/llm_server/workers/threader.py +++ b/llm_server/workers/threader.py @@ -2,11 +2,11 @@ import time from threading import Thread from llm_server import opts -from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.stores import redis_running_models from llm_server.cluster.worker import cluster_worker from llm_server.routes.v1.generate_stats import generate_stats from llm_server.workers.inferencer import start_workers +from llm_server.workers.logger import db_logger from llm_server.workers.mainer import main_background_thread from llm_server.workers.moderator import start_moderation_workers from llm_server.workers.printer import console_printer @@ -49,3 +49,8 @@ def start_background(): t.daemon = True t.start() print('Started the cluster worker.') + + t = Thread(target=db_logger) + t.daemon = True + t.start() + print('Started background logger') diff --git a/other/gradio_chat.py b/other/gradio_chat.py new file mode 100644 index 0000000..eb10d26 --- /dev/null +++ b/other/gradio_chat.py @@ -0,0 +1,33 @@ +import warnings + +import gradio as gr +import openai + +warnings.filterwarnings("ignore") + +openai.api_key = 'null' +openai.api_base = 'http://localhost:5000/api/openai/v1' + + +def stream_response(prompt, history): + messages = [] + for x in history: + messages.append({'role': 'user', 'content': x[0]}) + messages.append({'role': 'assistant', 'content': x[1]}) + messages.append({'role': 'user', 'content': prompt}) + + response = openai.ChatCompletion.create( + model='0', + messages=messages, + temperature=0, + max_tokens=300, + stream=True + ) + + message = '' + for chunk in response: + message += chunk['choices'][0]['delta']['content'] + yield message + + +gr.ChatInterface(stream_response, examples=["hello", "hola", "merhaba"], title="Chatbot Demo", analytics_enabled=False, cache_examples=False, css='#component-0{height:100%!important}').queue().launch()