fix the queue??

This commit is contained in:
Cyberes 2023-10-05 21:37:18 -06:00
parent ea61766838
commit e8964fcfd2
10 changed files with 195 additions and 157 deletions

View File

@ -3,13 +3,12 @@ import sys
import time import time
from pathlib import Path from pathlib import Path
from redis import Redis
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.redis_cycle import redis_cycler_db
from llm_server.cluster.stores import redis_running_models
from llm_server.config.load import load_config, parse_backends from llm_server.config.load import load_config, parse_backends
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.create import create_db from llm_server.database.create import create_db
from llm_server.routes.queue import priority_queue
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.threader import start_background from llm_server.workers.threader import start_background
@ -21,11 +20,8 @@ else:
config_path = Path(script_path, 'config', 'config.yml') config_path = Path(script_path, 'config', 'config.yml')
if __name__ == "__main__": if __name__ == "__main__":
flushed_keys = redis.flush() Redis().flushall()
print('Flushed', len(flushed_keys), 'keys from Redis.') print('Flushed Redis.')
redis_cycler_db.flushall()
redis_running_models.flush()
success, config, msg = load_config(config_path) success, config, msg = load_config(config_path)
if not success: if not success:
@ -34,7 +30,6 @@ if __name__ == "__main__":
create_db() create_db()
priority_queue.flush()
cluster_config.clear() cluster_config.clear()
cluster_config.load(parse_backends(config)) cluster_config.load(parse_backends(config))

View File

@ -3,11 +3,13 @@ import sys
import openai import openai
import llm_server
from llm_server import opts from llm_server import opts
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.conn import database from llm_server.database.conn import database
from llm_server.database.database import get_number_of_rows from llm_server.database.database import get_number_of_rows
from llm_server.routes.queue import PriorityQueue
def load_config(config_path): def load_config(config_path):
@ -54,6 +56,8 @@ def load_config(config_path):
for item in config['cluster']: for item in config['cluster']:
opts.cluster_workers += item['concurrent_gens'] opts.cluster_workers += item['concurrent_gens']
llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']])
if opts.openai_expose_our_model and not opts.openai_api_key: if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1) sys.exit(1)

View File

@ -74,7 +74,7 @@ def openai_chat_completions():
event = None event = None
if not handler.is_client_ratelimited(): if not handler.is_client_ratelimited():
# Add a dummy event to the queue and wait for it to reach a worker # Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
if not event: if not event:
log_to_db( log_to_db(
handler.client_ip, handler.client_ip,
@ -107,63 +107,64 @@ def openai_chat_completions():
oai_string = generate_oai_string(30) oai_string = generate_oai_string(30)
def generate(): def generate():
response = generator(msg_to_backend, handler.backend_url) try:
generated_text = '' response = generator(msg_to_backend, handler.backend_url)
partial_response = b'' generated_text = ''
for chunk in response.iter_content(chunk_size=1): partial_response = b''
partial_response += chunk for chunk in response.iter_content(chunk_size=1):
if partial_response.endswith(b'\x00'): partial_response += chunk
json_strs = partial_response.split(b'\x00') if partial_response.endswith(b'\x00'):
for json_str in json_strs: json_strs = partial_response.split(b'\x00')
if json_str: for json_str in json_strs:
try: if json_str:
json_obj = json.loads(json_str.decode()) try:
new = json_obj['text'][0].split(handler.prompt + generated_text)[1] json_obj = json.loads(json_str.decode())
generated_text = generated_text + new new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
except IndexError: generated_text = generated_text + new
# ???? except IndexError:
continue # ????
continue
data = { data = {
"id": f"chatcmpl-{oai_string}", "id": f"chatcmpl-{oai_string}",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"created": int(time.time()), "created": int(time.time()),
"model": model, "model": model,
"choices": [ "choices": [
{ {
"index": 0, "index": 0,
"delta": { "delta": {
"content": new "content": new
}, },
"finish_reason": None "finish_reason": None
} }
] ]
} }
yield f'data: {json.dumps(data)}\n\n' yield f'data: {json.dumps(data)}\n\n'
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
log_to_db( log_to_db(
handler.client_ip, handler.client_ip,
handler.token, handler.token,
handler.prompt, handler.prompt,
generated_text, generated_text,
elapsed_time, elapsed_time,
handler.parameters, handler.parameters,
r_headers, r_headers,
response_status_code, response_status_code,
r_url, r_url,
handler.backend_url, handler.backend_url,
) )
finally:
# After completing inference, we need to tell the worker we
# are finished.
if event_id: # may be None if ratelimited.
redis.publish(event_id, 'finished')
else:
print('event_id was None!')
return Response(generate(), mimetype='text/event-stream') return Response(generate(), mimetype='text/event-stream')
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
return 'INTERNAL SERVER', 500 return 'INTERNAL SERVER', 500
finally:
# After completing inference, we need to tell the worker we
# are finished.
if event_id: # may be None if ratelimited.
redis.publish(event_id, 'finished')
else:
print('event_id was None!')

View File

@ -102,7 +102,7 @@ def openai_completions():
event = None event = None
if not handler.is_client_ratelimited(): if not handler.is_client_ratelimited():
# Add a dummy event to the queue and wait for it to reach a worker # Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
if not event: if not event:
log_to_db( log_to_db(
handler.client_ip, handler.client_ip,
@ -135,61 +135,62 @@ def openai_completions():
oai_string = generate_oai_string(30) oai_string = generate_oai_string(30)
def generate(): def generate():
generated_text = '' try:
partial_response = b'' generated_text = ''
for chunk in response.iter_content(chunk_size=1): partial_response = b''
partial_response += chunk for chunk in response.iter_content(chunk_size=1):
if partial_response.endswith(b'\x00'): partial_response += chunk
json_strs = partial_response.split(b'\x00') if partial_response.endswith(b'\x00'):
for json_str in json_strs: json_strs = partial_response.split(b'\x00')
if json_str: for json_str in json_strs:
try: if json_str:
json_obj = json.loads(json_str.decode()) try:
new = json_obj['text'][0].split(handler.prompt + generated_text)[1] json_obj = json.loads(json_str.decode())
generated_text = generated_text + new new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
except IndexError: generated_text = generated_text + new
# ???? except IndexError:
continue # ????
continue
data = { data = {
"id": f"cmpl-{oai_string}", "id": f"cmpl-{oai_string}",
"object": "text_completion", "object": "text_completion",
"created": int(time.time()), "created": int(time.time()),
"model": model, "model": model,
"choices": [ "choices": [
{ {
"index": 0, "index": 0,
"delta": { "delta": {
"content": new "content": new
}, },
"finish_reason": None "finish_reason": None
} }
] ]
} }
yield f'data: {json.dumps(data)}\n\n' yield f'data: {json.dumps(data)}\n\n'
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
log_to_db( log_to_db(
handler.client_ip, handler.client_ip,
handler.token, handler.token,
handler.prompt, handler.prompt,
generated_text, generated_text,
elapsed_time, elapsed_time,
handler.parameters, handler.parameters,
r_headers, r_headers,
response_status_code, response_status_code,
r_url, r_url,
handler.backend_url, handler.backend_url,
) )
finally:
if event_id:
redis.publish(event_id, 'finished')
else:
print('event_id was None!')
return Response(generate(), mimetype='text/event-stream') return Response(generate(), mimetype='text/event-stream')
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
return 'INTERNAL SERVER', 500 return 'INTERNAL SERVER', 500
finally:
if event_id:
redis.publish(event_id, 'finished')
else:
print('event_id was None!')

View File

@ -1,10 +1,12 @@
import json import json
import pickle import pickle
import time import time
from typing import Tuple
from uuid import uuid4 from uuid import uuid4
from redis import Redis from redis import Redis
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis from llm_server.custom_redis import RedisCustom, redis
from llm_server.database.database import get_token_ratelimit from llm_server.database.database import get_token_ratelimit
@ -20,7 +22,7 @@ def decrement_ip_count(client_ip: str, redis_key):
class RedisPriorityQueue: class RedisPriorityQueue:
def __init__(self, name: str = 'priority_queue', db: int = 12): def __init__(self, name, db: int = 12):
self.redis = RedisCustom(name, db=db) self.redis = RedisCustom(name, db=db)
def put(self, item, priority, selected_model): def put(self, item, priority, selected_model):
@ -98,9 +100,6 @@ class DataEvent:
return pickle.loads(item['data']) return pickle.loads(item['data'])
priority_queue = RedisPriorityQueue()
def update_active_workers(key: str, operation: str): def update_active_workers(key: str, operation: str):
if operation == 'incr': if operation == 'incr':
redis.incr(f'active_gen_workers:{key}') redis.incr(f'active_gen_workers:{key}')
@ -118,3 +117,60 @@ def incr_active_workers(selected_model: str, backend_url: str):
def decr_active_workers(selected_model: str, backend_url: str): def decr_active_workers(selected_model: str, backend_url: str):
update_active_workers(selected_model, 'decr') update_active_workers(selected_model, 'decr')
update_active_workers(backend_url, 'decr') update_active_workers(backend_url, 'decr')
class PriorityQueue:
def __init__(self, backends: list = None):
"""
Only have to load the backends once.
:param backends:
"""
self.redis = Redis(host='localhost', port=6379, db=9)
if backends:
for item in backends:
self.redis.lpush('backends', item)
def get_backends(self):
return [x.decode('utf-8') for x in self.redis.lrange('backends', 0, -1)]
def get_queued_ip_count(self, client_ip: str):
count = 0
for backend_url in self.get_backends():
queue = RedisPriorityQueue(backend_url)
count += queue.get_queued_ip_count(client_ip)
return count
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str):
queue = RedisPriorityQueue(backend_url)
return queue.put(item, priority, selected_model)
def len(self, model_name):
count = 0
backends_with_models = []
for k in self.get_backends():
info = cluster_config.get_backend(k)
if info.get('model') == model_name:
backends_with_models.append(k)
for backend_url in backends_with_models:
queue = RedisPriorityQueue(backend_url)
count += queue.len(model_name)
return count
def __len__(self):
count = 0
for backend_url in self.get_backends():
queue = RedisPriorityQueue(backend_url)
count += len(queue)
return count
def flush(self):
for k in self.redis.keys():
q = json.loads(self.redis.get(k))
q.flush()
self.redis.set(k, json.dumps(q))
def flush_db(self):
self.redis.flushdb()
priority_queue = PriorityQueue()

View File

@ -7,7 +7,7 @@ from flask import Response, request
from llm_server import opts from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.database import get_token_ratelimit, do_db_log from llm_server.database.database import get_token_ratelimit
from llm_server.database.log_to_db import log_to_db from llm_server.database.log_to_db import log_to_db
from llm_server.helpers import auto_set_base_client_api from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
@ -131,7 +131,7 @@ class RequestHandler:
request_valid, invalid_response = self.validate_request(prompt, do_log=True) request_valid, invalid_response = self.validate_request(prompt, do_log=True)
if not request_valid: if not request_valid:
return (False, None, None, 0), invalid_response return (False, None, None, 0), invalid_response
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.backend_url), self.token_priority, self.selected_model) event = priority_queue.put(self.backend_url, (llm_request, self.client_ip, self.token, self.parameters), self.token_priority, self.selected_model)
else: else:
event = None event = None

View File

@ -122,7 +122,7 @@ def do_stream(ws, model_name):
event = None event = None
if not handler.is_client_ratelimited(): if not handler.is_client_ratelimited():
# Add a dummy event to the queue and wait for it to reach a worker # Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
if not event: if not event:
log_to_db( log_to_db(
handler.client_ip, handler.client_ip,

View File

@ -1,29 +1,24 @@
import threading import threading
import time import time
from uuid import uuid4
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis from llm_server.custom_redis import redis, RedisCustom
from llm_server.llm.generator import generator from llm_server.llm.generator import generator
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, RedisPriorityQueue, PriorityQueue, priority_queue
def worker(): def worker(backend_url):
queue = RedisPriorityQueue(backend_url)
while True: while True:
(request_json_body, client_ip, token, parameters, backend_url), event_id, selected_model = priority_queue.get() (request_json_body, client_ip, token, parameters), event_id, selected_model = queue.get()
if not backend_url:
backend_url = get_a_cluster_backend(selected_model)
else:
backend_url = cluster_config.validate_backend(backend_url)
backend_info = cluster_config.get_backend(backend_url) backend_info = cluster_config.get_backend(backend_url)
if not selected_model: if not selected_model:
selected_model = backend_info['model'] selected_model = backend_info['model']
increment_ip_count(client_ip, 'processing_ips') increment_ip_count(client_ip, 'processing_ips')
incr_active_workers(selected_model, backend_url) incr_active_workers(selected_model, backend_url)
need_to_wait(backend_url)
try: try:
if not request_json_body: if not request_json_body:
# This was a dummy request from the streaming handlers. # This was a dummy request from the streaming handlers.
@ -37,7 +32,6 @@ def worker():
redis.publish(event_id, 'begin') redis.publish(event_id, 'begin')
for item in pubsub.listen(): for item in pubsub.listen():
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished': if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
# Once the handler is complete, move on.
break break
time.sleep(0.1) time.sleep(0.1)
else: else:
@ -50,23 +44,12 @@ def worker():
decr_active_workers(selected_model, backend_url) decr_active_workers(selected_model, backend_url)
def start_workers(num_workers: int): def start_workers(cluster: dict):
i = 0 i = 0
for _ in range(num_workers): for item in cluster:
t = threading.Thread(target=worker) for _ in range(item['concurrent_gens']):
t.daemon = True t = threading.Thread(target=worker, args=(item['backend_url'],))
t.start() t.daemon = True
i += 1 t.start()
i += 1
print(f'Started {i} inference workers.') print(f'Started {i} inference workers.')
def need_to_wait(backend_url: str):
# We need to check the number of active workers since the streaming endpoint may be doing something.
active_workers = redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int)
concurrent_gens = cluster_config.get_backend(backend_url).get('concurrent_gens', 1)
s = time.time()
while active_workers >= concurrent_gens:
time.sleep(0.01)
e = time.time()
if e - s > 0.1:
print(f'Worker was delayed {e - s} seconds.')

View File

@ -25,6 +25,4 @@ def console_printer():
processing_count += redis.get(k, default=0, dtype=int) processing_count += redis.get(k, default=0, dtype=int)
backends = [k for k, v in cluster_config.all().items() if v['online']] backends = [k for k, v in cluster_config.all().items() if v['online']]
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}') logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
priority_queue.print_all_items()
print('============================')
time.sleep(1) time.sleep(1)

View File

@ -20,7 +20,7 @@ def cache_stats():
def start_background(): def start_background():
start_workers(opts.cluster_workers) start_workers(opts.cluster)
t = Thread(target=main_background_thread) t = Thread(target=main_background_thread)
t.daemon = True t.daemon = True