diff --git a/README.md b/README.md index 485cb91..5a712cd 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,8 @@ _A HTTP API to serve local LLM Models._ - - The purpose of this server is to abstract your LLM backend from your frontend API. This enables you to make changes to (or even switch) your backend without affecting your clients. - - ### Install 1. `sudo apt install redis` @@ -16,25 +12,18 @@ The purpose of this server is to abstract your LLM backend from your frontend AP 4. `pip install -r requirements.txt` 5. `python3 server.py` - - An example systemctl service file is provided in `other/local-llm.service`. - - ### Configure -First, set up your LLM backend. Currently, only [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) is supported, but eventually [huggingface/text-generation-inference](https://github.com/huggingface/text-generation-inference) will be the default. +First, set up your LLM backend. Currently, only [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) is supported, but +eventually [huggingface/text-generation-inference](https://github.com/huggingface/text-generation-inference) will be the default. Then, configure this server. The config file is located at `config/config.yml.sample` so copy it to `config/config.yml`. - - 1. Set `backend_url` to the base API URL of your backend. 2. Set `token_limit` to the configured token limit of the backend. This number is shown to clients and on the home page. - - To set up token auth, add rows to the `token_auth` table in the SQLite database. `token`: the token/password. @@ -51,6 +40,11 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database. `disabled`: mark the token as disabled. +### Use + +**DO NOT** lose your database. It's used for calculating the estimated wait time based on average TPS and response tokens and if you lose those stats your numbers will be inaccurate until the database fills back up again. If you change graphics +cards, you should probably clear the `generation_time` time column in the `prompts` table. + ### To Do - Implement streaming diff --git a/llm_server/config.py b/llm_server/config.py index 7d2bd4e..b08ceee 100644 --- a/llm_server/config.py +++ b/llm_server/config.py @@ -10,12 +10,13 @@ config_default_vars = { 'show_num_prompts': True, 'show_uptime': True, 'analytics_tracking_code': '', + 'average_generation_time_mode': 'database', } config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] mode_ui_names = { - 'oobabooga': 'Text Gen WebUI (ooba)', - 'hf-textgen': 'UNDEFINED', + 'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url'), + 'hf-textgen': ('UNDEFINED', 'UNDEFINED'), } diff --git a/llm_server/database.py b/llm_server/database.py index f1b0acb..e3bdafd 100644 --- a/llm_server/database.py +++ b/llm_server/database.py @@ -92,3 +92,12 @@ def get_number_of_rows(table_name): result = cur.fetchone() conn.close() return result[0] + + +def average_column(table_name, column_name): + conn = sqlite3.connect(opts.database_path) + cursor = conn.cursor() + cursor.execute(f"SELECT AVG({column_name}) FROM {table_name}") + result = cursor.fetchone() + conn.close() + return result[0] diff --git a/llm_server/opts.py b/llm_server/opts.py index e550f2b..9a3cb28 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -14,3 +14,4 @@ http_host = None verify_ssl = True show_num_prompts = True show_uptime = True +average_generation_time_mode = 'database' diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 8104397..15361a0 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -3,6 +3,7 @@ from datetime import datetime from llm_server import opts from llm_server.llm.info import get_running_model +from llm_server.routes.cache import redis from llm_server.routes.queue import priority_queue from llm_server.routes.stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time @@ -21,8 +22,18 @@ def generate_stats(): # waits = [elapsed for end, elapsed in t] # estimated_wait = int(sum(waits) / len(waits)) - average_generation_time = int(calculate_avg_gen_time()) proompters_in_queue = len(priority_queue) + get_active_gen_workers() + average_tps = float(redis.get('average_tps')) + + if opts.average_generation_time_mode == 'database': + average_generation_time = int(float(redis.get('average_generation_elapsed_sec'))) + average_output_tokens = int(float(redis.get('average_output_tokens'))) + estimated_wait_sec = int(((average_output_tokens / average_tps) * proompters_in_queue) / opts.concurrent_gens) + elif opts.average_generation_time_mode == 'minute': + average_generation_time = int(calculate_avg_gen_time()) + estimated_wait_sec = int((average_generation_time * proompters_in_queue) / opts.concurrent_gens) + else: + raise Exception # TODO: https://stackoverflow.com/questions/22721579/sorting-a-nested-ordereddict-by-key-recursively return { @@ -32,6 +43,7 @@ def generate_stats(): 'total_proompts': get_total_proompts() if opts.show_num_prompts else None, 'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None, 'average_generation_elapsed_sec': average_generation_time, + 'average_tps': average_tps, }, 'online': online, 'mode': opts.mode, @@ -39,7 +51,7 @@ def generate_stats(): 'endpoints': { 'blocking': opts.full_client_api, }, - 'estimated_wait_sec': int((average_generation_time * proompters_in_queue) / opts.concurrent_gens), + 'estimated_wait_sec': estimated_wait_sec, 'timestamp': int(time.time()), 'openaiKeys': '∞', 'anthropicKeys': '∞', diff --git a/llm_server/routes/v1/proxy.py b/llm_server/routes/v1/proxy.py index 585ce55..acd0797 100644 --- a/llm_server/routes/v1/proxy.py +++ b/llm_server/routes/v1/proxy.py @@ -1,16 +1,8 @@ -import time -from datetime import datetime +from flask import jsonify -from flask import jsonify, request - -from llm_server import opts from . import bp from .generate_stats import generate_stats -from .. import stats from ..cache import cache -from ..queue import priority_queue -from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers -from ...llm.info import get_running_model @bp.route('/stats', methods=['GET']) diff --git a/llm_server/threads.py b/llm_server/threads.py index d344391..1a89238 100644 --- a/llm_server/threads.py +++ b/llm_server/threads.py @@ -4,6 +4,7 @@ from threading import Thread import requests from llm_server import opts +from llm_server.database import average_column from llm_server.routes.cache import redis @@ -13,9 +14,21 @@ class BackendHealthCheck(Thread): def __init__(self): Thread.__init__(self) self.daemon = True + redis.set('average_generation_elapsed_sec', 0) + redis.set('average_tps', 0) + redis.set('average_output_tokens', 0) + redis.set('backend_online', 0) def run(self): while True: + average_generation_elapsed_sec = average_column('prompts', 'generation_time') if not None else 0 + redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) + + average_output_tokens = average_column('prompts', 'response_tokens') if not None else 0 + redis.set('average_output_tokens', average_output_tokens) + average_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) + redis.set('average_tps', average_tps) + if opts.mode == 'oobabooga': try: r = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl) diff --git a/server.py b/server.py index 2c630da..77e1928 100644 --- a/server.py +++ b/server.py @@ -60,11 +60,15 @@ if not opts.verify_ssl: flushed_keys = redis.flush() print('Flushed', len(flushed_keys), 'keys from Redis.') -redis.set('backend_online', 0) if config['load_num_prompts']: redis.set('proompts', get_number_of_rows('prompts')) +if config['average_generation_time_mode'] not in ['database', 'minute']: + print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode']) + sys.exit(1) +opts.average_generation_time_mode = config['average_generation_time_mode'] + start_workers(opts.concurrent_gens) # cleanup_thread = Thread(target=elapsed_times_cleanup) @@ -113,7 +117,8 @@ def home(): current_model=running_model, client_api=opts.full_client_api, estimated_wait=estimated_wait_sec, - mode_name=mode_ui_names[opts.mode], + mode_name=mode_ui_names[opts.mode][0], + api_input_textbox=mode_ui_names[opts.mode][1], context_size=opts.context_size, stats_json=json.dumps(stats, indent=4, ensure_ascii=False) ) diff --git a/templates/home.html b/templates/home.html index 02afcee..52c8579 100644 --- a/templates/home.html +++ b/templates/home.html @@ -14,6 +14,8 @@ background-color: #ffb6c16e; padding: 1em; display: inline-block; + margin: auto; + max-width: 95%; } a, a:visited { @@ -25,6 +27,12 @@ text-align: center; } + pre { + white-space: pre-wrap; + word-wrap: break-word; + text-align: justify; + } + @media only screen and (max-width: 600px) { .container { padding: 1em; @@ -51,7 +59,7 @@ Instructions:
  1. Set your API type to {{ mode_name }}
  2. -
  3. Enter {{ client_api }} in the Blocking API url textbox.
  4. +
  5. Enter {{ client_api }} in the {{ api_input_textbox }} textbox.
  6. Click Connect to test the connection.
  7. Open your preset config and set Context Size to {{ context_size }}.
  8. Follow this guide to get set up: rentry.org/freellamas