Merge cluster to master #3
|
@ -5,20 +5,20 @@ class DatabaseConnection:
|
|||
host: str = None
|
||||
username: str = None
|
||||
password: str = None
|
||||
database: str = None
|
||||
database_name: str = None
|
||||
|
||||
def init_db(self, host, username, password, database):
|
||||
def init_db(self, host, username, password, database_name):
|
||||
self.host = host
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.database = database
|
||||
self.database_name = database_name
|
||||
|
||||
def cursor(self):
|
||||
db = pymysql.connect(
|
||||
host=self.host,
|
||||
user=self.username,
|
||||
password=self.password,
|
||||
database=self.database,
|
||||
database=self.database_name,
|
||||
charset='utf8mb4',
|
||||
autocommit=True,
|
||||
)
|
||||
|
|
|
@ -4,7 +4,6 @@ import flask
|
|||
from flask import jsonify, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import do_db_log
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
|
|
@ -35,6 +35,7 @@ def worker(backend_url):
|
|||
redis.publish(event_id, 'begin')
|
||||
for item in pubsub.listen():
|
||||
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
|
||||
# The streaming endpoint has said that it has finished
|
||||
break
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
|
@ -47,6 +48,7 @@ def worker(backend_url):
|
|||
finally:
|
||||
decrement_ip_count(client_ip, 'processing_ips')
|
||||
decr_active_workers(selected_model, backend_url)
|
||||
print('Worker finished processing for', client_ip)
|
||||
|
||||
|
||||
def start_workers(cluster: dict):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import pickle
|
||||
import traceback
|
||||
|
||||
import redis
|
||||
|
||||
|
@ -17,11 +18,14 @@ def db_logger():
|
|||
p.subscribe('database-logger')
|
||||
|
||||
for message in p.listen():
|
||||
if message['type'] == 'message':
|
||||
data = pickle.loads(message['data'])
|
||||
function_name = data['function']
|
||||
args = data['args']
|
||||
kwargs = data['kwargs']
|
||||
try:
|
||||
if message['type'] == 'message':
|
||||
data = pickle.loads(message['data'])
|
||||
function_name = data['function']
|
||||
args = data['args']
|
||||
kwargs = data['kwargs']
|
||||
|
||||
if function_name == 'log_prompt':
|
||||
do_db_log(*args, **kwargs)
|
||||
if function_name == 'log_prompt':
|
||||
do_db_log(*args, **kwargs)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
|
|
@ -25,4 +25,4 @@ def console_printer():
|
|||
processing_count += redis.get(k, default=0, dtype=int)
|
||||
backends = [k for k, v in cluster_config.all().items() if v['online']]
|
||||
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
|
||||
time.sleep(1)
|
||||
time.sleep(10)
|
||||
|
|
|
@ -75,7 +75,8 @@ def stream_response(prompt, history):
|
|||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=300,
|
||||
stream=True
|
||||
stream=True,
|
||||
headers={'LLM-Source': 'huggingface-demo'}
|
||||
)
|
||||
except Exception:
|
||||
raise gr.Error("Failed to reach inference endpoint.")
|
||||
|
|
Reference in New Issue