local-llm-server/llm_server/threads.py

85 lines
3.7 KiB
Python

import time
from threading import Thread
import requests
from llm_server import opts
from llm_server.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.routes.cache import redis
class MainBackgroundThread(Thread):
backend_online = False
# TODO: do I really need to put everything in Redis?
# TODO: call generate_stats() every minute, cache the results, put results in a DB table, then have other parts of code call this cache
def __init__(self):
Thread.__init__(self)
self.daemon = True
redis.set('average_generation_elapsed_sec', 0)
redis.set('average_tps', 0)
redis.set('average_output_tokens', 0)
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
def run(self):
while True:
if opts.mode == 'oobabooga':
# try:
# r = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
# opts.running_model = r.json()['result']
# redis.set('backend_online', 1)
# except Exception as e:
# redis.set('backend_online', 0)
# # TODO: handle error
# print(e)
model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
opts.running_model = model
redis.set('backend_online', 1)
elif opts.mode == 'hf-textgen':
try:
r = requests.get(f'{opts.backend_url}/info', timeout=3, verify=opts.verify_ssl)
j = r.json()
opts.running_model = j['model_id'].replace('/', '_')
redis.set('backend_online', 1)
redis.set_dict('backend_info', j)
except Exception as e:
redis.set('backend_online', 0)
# TODO: handle error
print(e)
elif opts.mode == 'vllm':
model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
opts.running_model = model
redis.set('backend_online', 1)
else:
raise Exception
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, exclude_zeros=True) or 0
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, opts.mode, exclude_zeros=True) or 0
redis.set('average_output_tokens', average_output_tokens)
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
# Avoid division by zero
average_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0
redis.set('average_tps', average_tps)
time.sleep(60)