This commit is contained in:
Cyberes 2023-10-11 18:04:15 -06:00
parent 169e216a38
commit 74cf8f309b
6 changed files with 20 additions and 14 deletions

View File

@ -5,20 +5,20 @@ class DatabaseConnection:
host: str = None
username: str = None
password: str = None
database: str = None
database_name: str = None
def init_db(self, host, username, password, database):
def init_db(self, host, username, password, database_name):
self.host = host
self.username = username
self.password = password
self.database = database
self.database_name = database_name
def cursor(self):
db = pymysql.connect(
host=self.host,
user=self.username,
password=self.password,
database=self.database,
database=self.database_name,
charset='utf8mb4',
autocommit=True,
)

View File

@ -4,7 +4,6 @@ import flask
from flask import jsonify, request
from llm_server import opts
from llm_server.database.database import do_db_log
from llm_server.database.log_to_db import log_to_db
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.request_handler import RequestHandler

View File

@ -35,6 +35,7 @@ def worker(backend_url):
redis.publish(event_id, 'begin')
for item in pubsub.listen():
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
# The streaming endpoint has said that it has finished
break
time.sleep(0.1)
else:
@ -47,6 +48,7 @@ def worker(backend_url):
finally:
decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers(selected_model, backend_url)
print('Worker finished processing for', client_ip)
def start_workers(cluster: dict):

View File

@ -1,4 +1,5 @@
import pickle
import traceback
import redis
@ -17,11 +18,14 @@ def db_logger():
p.subscribe('database-logger')
for message in p.listen():
if message['type'] == 'message':
data = pickle.loads(message['data'])
function_name = data['function']
args = data['args']
kwargs = data['kwargs']
try:
if message['type'] == 'message':
data = pickle.loads(message['data'])
function_name = data['function']
args = data['args']
kwargs = data['kwargs']
if function_name == 'log_prompt':
do_db_log(*args, **kwargs)
if function_name == 'log_prompt':
do_db_log(*args, **kwargs)
except:
traceback.print_exc()

View File

@ -25,4 +25,4 @@ def console_printer():
processing_count += redis.get(k, default=0, dtype=int)
backends = [k for k, v in cluster_config.all().items() if v['online']]
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
time.sleep(1)
time.sleep(10)

View File

@ -75,7 +75,8 @@ def stream_response(prompt, history):
messages=messages,
temperature=0,
max_tokens=300,
stream=True
stream=True,
headers={'LLM-Source': 'huggingface-demo'}
)
except Exception:
raise gr.Error("Failed to reach inference endpoint.")