diff --git a/config/config.yml b/config/config.yml index c0c3083..3f2d42e 100644 --- a/config/config.yml +++ b/config/config.yml @@ -5,7 +5,7 @@ log_prompts: true mode: oobabooga auth_required: false concurrent_gens: 3 -token_limit: 5555 +token_limit: 7777 backend_url: http://172.0.0.2:9104 diff --git a/llm_server/routes/cache.py b/llm_server/routes/cache.py index 72c3ee7..58e74f3 100644 --- a/llm_server/routes/cache.py +++ b/llm_server/routes/cache.py @@ -1,3 +1,5 @@ from flask_caching import Cache +from redis import Redis cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local-llm'}) +redis = Redis() diff --git a/llm_server/routes/stats.py b/llm_server/routes/stats.py index 47077e0..975b0ce 100644 --- a/llm_server/routes/stats.py +++ b/llm_server/routes/stats.py @@ -6,13 +6,22 @@ from threading import Semaphore, Thread from llm_server import opts from llm_server.integer import ThreadSafeInteger from llm_server.opts import concurrent_gens +from llm_server.routes.cache import redis # proompters_1_min = 0 concurrent_semaphore = Semaphore(concurrent_gens) -proompts = ThreadSafeInteger(0) start_time = datetime.now() +def get_count(): + count = redis.get('proompts') + if count is None: + count = 0 + else: + count = int(count) + return count + + class SemaphoreCheckerThread(Thread): proompters_1_min = 0 recent_prompters = {} diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index dc50c04..b284d46 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -2,8 +2,9 @@ import time from flask import jsonify, request -from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore, proompts +from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore from . import bp +from ..cache import redis from ..helpers.client import format_sillytavern_err from ..helpers.http import cache_control, validate_json from ... import opts @@ -65,7 +66,7 @@ def generate(): }), 200 response_valid_json, response_json_body = validate_json(response) if response_valid_json: - proompts.increment() + redis.incr('proompts') backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text') if not backend_response: if opts.mode == 'oobabooga': diff --git a/llm_server/routes/v1/proxy.py b/llm_server/routes/v1/proxy.py index 76945d9..207ba92 100644 --- a/llm_server/routes/v1/proxy.py +++ b/llm_server/routes/v1/proxy.py @@ -15,7 +15,6 @@ from ...llm.info import get_running_model @bp.route('/stats', methods=['GET']) @cache.cached(timeout=5, query_string=True) -@cache_control(5) def get_stats(): model_list = get_running_model() # will return False when the fetch fails if isinstance(model_list, bool): @@ -27,7 +26,7 @@ def get_stats(): 'stats': { 'proompters_now': opts.concurrent_gens - concurrent_semaphore._value, 'proompters_1_min': SemaphoreCheckerThread.proompters_1_min, - 'total_proompts': stats.proompts.value, + 'total_proompts': stats.get_count(), 'uptime': int((datetime.now() - stats.start_time).total_seconds()), }, 'online': online, diff --git a/other/local-llm.service b/other/local-llm.service index 2f9fa9b..b20e19b 100644 --- a/other/local-llm.service +++ b/other/local-llm.service @@ -7,7 +7,9 @@ After=basic.target network.target User=server Group=server WorkingDirectory=/srv/server/local-llm-server -ExecStart=/srv/server/local-llm-server/venv/bin/gunicorn --workers 3 --bind 0.0.0.0:5000 server:app +# Need a lot of workers since we have long-running requests +# Takes about 3.5G memory +ExecStart=/srv/server/local-llm-server/venv/bin/gunicorn --workers 20 --bind 0.0.0.0:5000 server:app --timeout 60 --worker-class gevent Restart=always RestartSec=2 diff --git a/requirements.txt b/requirements.txt index 695028f..ca4dfdb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ flask_caching requests tiktoken gunicorn -redis \ No newline at end of file +redis +gevent \ No newline at end of file