add ratelimiting to websocket streaming endpoint, fix queue not decrementing IP requests, add console printer
This commit is contained in:
parent
43299b32ad
commit
e5fbc9545d
|
@ -3,8 +3,9 @@ from llm_server import opts
|
||||||
|
|
||||||
def generator(request_json_body):
|
def generator(request_json_body):
|
||||||
if opts.mode == 'oobabooga':
|
if opts.mode == 'oobabooga':
|
||||||
from .oobabooga.generate import generate
|
# from .oobabooga.generate import generate
|
||||||
return generate(request_json_body)
|
# return generate(request_json_body)
|
||||||
|
raise NotImplementedError
|
||||||
elif opts.mode == 'vllm':
|
elif opts.mode == 'vllm':
|
||||||
from .vllm.generate import generate
|
from .vllm.generate import generate
|
||||||
r = generate(request_json_body)
|
r = generate(request_json_body)
|
||||||
|
|
|
@ -1,13 +1,11 @@
|
||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.llm.generator import generator
|
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,6 +37,8 @@ class RedisPriorityQueue:
|
||||||
|
|
||||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||||
ip_count = self.redis.hget('queued_ip_count', item[1])
|
ip_count = self.redis.hget('queued_ip_count', item[1])
|
||||||
|
if ip_count:
|
||||||
|
ip_count = int(ip_count)
|
||||||
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
|
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
|
||||||
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
|
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
|
||||||
return None # reject the request
|
return None # reject the request
|
||||||
|
@ -52,11 +52,10 @@ class RedisPriorityQueue:
|
||||||
data = self.redis.zpopmin('queue')
|
data = self.redis.zpopmin('queue')
|
||||||
if data:
|
if data:
|
||||||
item = json.loads(data[0][0])
|
item = json.loads(data[0][0])
|
||||||
client_ip = item[1][1]
|
client_ip = item[0][1]
|
||||||
# Decrement the count for this IP
|
|
||||||
self.decrement_ip_count(client_ip, 'queued_ip_count')
|
self.decrement_ip_count(client_ip, 'queued_ip_count')
|
||||||
return item
|
return item
|
||||||
time.sleep(1) # wait for an item to be added to the queue
|
time.sleep(0.5) # wait for something to be added to the queue
|
||||||
|
|
||||||
def increment_ip_count(self, ip, key):
|
def increment_ip_count(self, ip, key):
|
||||||
self.redis.hincrby(key, ip, 1)
|
self.redis.hincrby(key, ip, 1)
|
||||||
|
@ -67,6 +66,13 @@ class RedisPriorityQueue:
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.redis.zcard('queue')
|
return self.redis.zcard('queue')
|
||||||
|
|
||||||
|
def get_ip_count(self, client_ip: str):
|
||||||
|
x = self.redis.hget('queued_ip_count', client_ip)
|
||||||
|
if x:
|
||||||
|
return x.decode('utf-8')
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class DataEvent:
|
class DataEvent:
|
||||||
def __init__(self, event_id=None):
|
def __init__(self, event_id=None):
|
||||||
|
@ -85,32 +91,3 @@ class DataEvent:
|
||||||
|
|
||||||
|
|
||||||
priority_queue = RedisPriorityQueue()
|
priority_queue = RedisPriorityQueue()
|
||||||
|
|
||||||
|
|
||||||
def worker():
|
|
||||||
while True:
|
|
||||||
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
|
|
||||||
|
|
||||||
increment_ip_count(client_ip, 'processing_ips')
|
|
||||||
redis.incr('active_gen_workers')
|
|
||||||
|
|
||||||
try:
|
|
||||||
start_time = time.time()
|
|
||||||
success, response, error_msg = generator(request_json_body)
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
elapsed_time = end_time - start_time
|
|
||||||
redis.rpush('generation_elapsed', json.dumps((end_time, elapsed_time)))
|
|
||||||
|
|
||||||
event = DataEvent(event_id)
|
|
||||||
event.set((success, response, error_msg))
|
|
||||||
finally:
|
|
||||||
decrement_ip_count(client_ip, 'processing_ips')
|
|
||||||
redis.decr('active_gen_workers')
|
|
||||||
|
|
||||||
|
|
||||||
def start_workers(num_workers: int):
|
|
||||||
for _ in range(num_workers):
|
|
||||||
t = threading.Thread(target=worker)
|
|
||||||
t.daemon = True
|
|
||||||
t.start()
|
|
||||||
|
|
|
@ -12,7 +12,6 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||||
from llm_server.routes.auth import parse_token
|
from llm_server.routes.auth import parse_token
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
|
||||||
from llm_server.routes.helpers.http import require_api_key, validate_json
|
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||||
from llm_server.routes.queue import priority_queue
|
from llm_server.routes.queue import priority_queue
|
||||||
|
|
||||||
|
@ -134,7 +133,6 @@ class RequestHandler:
|
||||||
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
|
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
return (False, None, None, 0), invalid_response
|
return (False, None, None, 0), invalid_response
|
||||||
|
|
||||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority)
|
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority)
|
||||||
else:
|
else:
|
||||||
event = None
|
event = None
|
||||||
|
@ -193,14 +191,11 @@ class RequestHandler:
|
||||||
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||||
|
|
||||||
def is_client_ratelimited(self) -> bool:
|
def is_client_ratelimited(self) -> bool:
|
||||||
print('queued_ip_count', redis.get_dict('queued_ip_count'))
|
|
||||||
print('processing_ips', redis.get_dict('processing_ips'))
|
|
||||||
|
|
||||||
|
|
||||||
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
||||||
if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0:
|
if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
print(f'Rejecting request from {self.client_ip} - {queued_ip_count} queued + processing.')
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||||
|
|
|
@ -2,12 +2,14 @@ import json
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
|
||||||
from ..helpers.client import format_sillytavern_err
|
from ..cache import redis
|
||||||
from ..helpers.http import require_api_key, validate_json
|
from ..helpers.http import require_api_key, validate_json
|
||||||
from ..ooba_request_handler import OobaRequestHandler
|
from ..ooba_request_handler import OobaRequestHandler
|
||||||
|
from ..queue import decrement_ip_count, priority_queue
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...database.database import log_prompt
|
from ...database.database import log_prompt
|
||||||
from ...llm.generator import generator
|
from ...llm.generator import generator
|
||||||
|
@ -20,8 +22,31 @@ from ...stream import sock
|
||||||
|
|
||||||
@sock.route('/api/v1/stream')
|
@sock.route('/api/v1/stream')
|
||||||
def stream(ws):
|
def stream(ws):
|
||||||
|
def send_err_and_quit(quitting_err_msg):
|
||||||
|
ws.send(json.dumps({
|
||||||
|
'event': 'text_stream',
|
||||||
|
'message_num': 0,
|
||||||
|
'text': quitting_err_msg
|
||||||
|
}))
|
||||||
|
ws.send(json.dumps({
|
||||||
|
'event': 'stream_end',
|
||||||
|
'message_num': 1
|
||||||
|
}))
|
||||||
|
ws.close()
|
||||||
|
log_in_bg(quitting_err_msg, is_error=True)
|
||||||
|
|
||||||
|
def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
|
||||||
|
def background_task_exception():
|
||||||
|
generated_tokens = tokenize(generated_text_bg)
|
||||||
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error)
|
||||||
|
|
||||||
|
# TODO: use async/await instead of threads
|
||||||
|
thread = threading.Thread(target=background_task_exception)
|
||||||
|
thread.start()
|
||||||
|
thread.join()
|
||||||
|
|
||||||
if not opts.enable_streaming:
|
if not opts.enable_streaming:
|
||||||
return 'disabled', 401
|
return 'Streaming is disabled', 401
|
||||||
|
|
||||||
r_headers = dict(request.headers)
|
r_headers = dict(request.headers)
|
||||||
r_url = request.url
|
r_url = request.url
|
||||||
|
@ -30,12 +55,7 @@ def stream(ws):
|
||||||
message = ws.receive()
|
message = ws.receive()
|
||||||
request_valid_json, request_json_body = validate_json(message)
|
request_valid_json, request_json_body = validate_json(message)
|
||||||
if not request_valid_json or not request_json_body.get('prompt'):
|
if not request_valid_json or not request_json_body.get('prompt'):
|
||||||
ws.send(json.dumps({
|
return 'Invalid JSON', 400
|
||||||
'event': 'text_stream',
|
|
||||||
'message_num': message_num,
|
|
||||||
'text': 'Invalid JSON'
|
|
||||||
}))
|
|
||||||
message_num += 1
|
|
||||||
else:
|
else:
|
||||||
if opts.mode != 'vllm':
|
if opts.mode != 'vllm':
|
||||||
# TODO: implement other backends
|
# TODO: implement other backends
|
||||||
|
@ -50,36 +70,44 @@ def stream(ws):
|
||||||
input_prompt = request_json_body['prompt']
|
input_prompt = request_json_body['prompt']
|
||||||
response_status_code = 0
|
response_status_code = 0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
err_msg = None
|
||||||
|
if handler.is_client_ratelimited():
|
||||||
|
r, _ = handler.handle_ratelimited()
|
||||||
|
err_msg = r.json['results'][0]['text']
|
||||||
|
else:
|
||||||
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
|
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
err_msg = invalid_response[0].json['results'][0]['text']
|
err_msg = invalid_response[0].json['results'][0]['text']
|
||||||
ws.send(json.dumps({
|
if err_msg:
|
||||||
'event': 'text_stream',
|
send_err_and_quit(err_msg)
|
||||||
'message_num': 0,
|
return
|
||||||
'text': err_msg
|
|
||||||
}))
|
|
||||||
ws.send(json.dumps({
|
|
||||||
'event': 'stream_end',
|
|
||||||
'message_num': 1
|
|
||||||
}))
|
|
||||||
ws.close() # this is important if we encountered and error and exited early.
|
|
||||||
|
|
||||||
def background_task():
|
llm_request = {
|
||||||
log_prompt(handler.client_ip, handler.token, input_prompt, err_msg, None, handler.parameters, r_headers, response_status_code, r_url, is_error=True)
|
|
||||||
|
|
||||||
# TODO: use async/await instead of threads
|
|
||||||
thread = threading.Thread(target=background_task)
|
|
||||||
thread.start()
|
|
||||||
thread.join()
|
|
||||||
else:
|
|
||||||
msg_to_backend = {
|
|
||||||
**handler.parameters,
|
**handler.parameters,
|
||||||
'prompt': input_prompt,
|
'prompt': input_prompt,
|
||||||
'stream': True,
|
'stream': True,
|
||||||
}
|
}
|
||||||
try:
|
|
||||||
response = generator(msg_to_backend)
|
|
||||||
|
|
||||||
|
# Add a dummy event to the queue and wait for it to reach a worker
|
||||||
|
event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority)
|
||||||
|
if not event:
|
||||||
|
r, _ = handler.handle_ratelimited()
|
||||||
|
err_msg = r.json['results'][0]['text']
|
||||||
|
send_err_and_quit(err_msg)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
response = generator(llm_request)
|
||||||
|
if not response:
|
||||||
|
error_msg = 'Failed to reach backend while streaming.'
|
||||||
|
print('Streaming failed:', error_msg)
|
||||||
|
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
|
||||||
|
ws.send(json.dumps({
|
||||||
|
'event': 'text_stream',
|
||||||
|
'message_num': message_num,
|
||||||
|
'text': msg
|
||||||
|
}))
|
||||||
|
else:
|
||||||
# Be extra careful when getting attributes from the response object
|
# Be extra careful when getting attributes from the response object
|
||||||
try:
|
try:
|
||||||
response_status_code = response.status_code
|
response_status_code = response.status_code
|
||||||
|
@ -88,14 +116,6 @@ def stream(ws):
|
||||||
|
|
||||||
partial_response = b''
|
partial_response = b''
|
||||||
|
|
||||||
# TODO: handle when the backend is offline
|
|
||||||
# Traceback (most recent call last):
|
|
||||||
# File "/srv/server/local-llm-server/llm_server/routes/v1/generate_stream.py", line 91, in stream
|
|
||||||
# for chunk in response.iter_content(chunk_size=1):
|
|
||||||
# ^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
# AttributeError: 'NoneType' object has no attribute 'iter_content'
|
|
||||||
|
|
||||||
|
|
||||||
for chunk in response.iter_content(chunk_size=1):
|
for chunk in response.iter_content(chunk_size=1):
|
||||||
partial_response += chunk
|
partial_response += chunk
|
||||||
if partial_response.endswith(b'\x00'):
|
if partial_response.endswith(b'\x00'):
|
||||||
|
@ -116,6 +136,7 @@ def stream(ws):
|
||||||
'text': new
|
'text': new
|
||||||
}))
|
}))
|
||||||
except:
|
except:
|
||||||
|
# The client closed the stream.
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||||
|
@ -128,41 +149,32 @@ def stream(ws):
|
||||||
if not chunk:
|
if not chunk:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if response:
|
||||||
response.close()
|
response.close()
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
|
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)
|
||||||
def background_task_success():
|
|
||||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
|
||||||
|
|
||||||
# TODO: use async/await instead of threads
|
|
||||||
thread = threading.Thread(target=background_task_success)
|
|
||||||
thread.start()
|
|
||||||
thread.join()
|
|
||||||
except:
|
except:
|
||||||
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].data.decode('utf-8')
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'text_stream',
|
'event': 'text_stream',
|
||||||
'message_num': message_num,
|
'message_num': message_num,
|
||||||
'text': generated_text
|
'text': generated_text
|
||||||
}))
|
}))
|
||||||
|
log_in_bg(generated_text, is_error=True, status_code=response_status_code)
|
||||||
def background_task_exception():
|
finally:
|
||||||
generated_tokens = tokenize(generated_text)
|
# The worker incremented it, we'll decrement it.
|
||||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
|
decrement_ip_count(handler.client_ip, 'processing_ips')
|
||||||
|
redis.decr('active_gen_workers')
|
||||||
# TODO: use async/await instead of threads
|
|
||||||
thread = threading.Thread(target=background_task_exception)
|
|
||||||
thread.start()
|
|
||||||
thread.join()
|
|
||||||
try:
|
try:
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'stream_end',
|
'event': 'stream_end',
|
||||||
'message_num': message_num
|
'message_num': message_num
|
||||||
}))
|
}))
|
||||||
except:
|
except:
|
||||||
|
# The client closed the stream.
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
from llm_server import opts
|
||||||
|
from llm_server.llm.generator import generator
|
||||||
|
from llm_server.routes.cache import redis
|
||||||
|
from llm_server.routes.queue import DataEvent, decrement_ip_count, increment_ip_count, priority_queue
|
||||||
|
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
while True:
|
||||||
|
need_to_wait()
|
||||||
|
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
|
||||||
|
need_to_wait()
|
||||||
|
|
||||||
|
increment_ip_count(client_ip, 'processing_ips')
|
||||||
|
redis.incr('active_gen_workers')
|
||||||
|
|
||||||
|
if not request_json_body:
|
||||||
|
# This was a dummy request from the websocket handler.
|
||||||
|
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
success, response, error_msg = generator(request_json_body)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
# redis.rpush('generation_elapsed', json.dumps((end_time, elapsed_time)))
|
||||||
|
|
||||||
|
event = DataEvent(event_id)
|
||||||
|
event.set((success, response, error_msg))
|
||||||
|
finally:
|
||||||
|
decrement_ip_count(client_ip, 'processing_ips')
|
||||||
|
redis.decr('active_gen_workers')
|
||||||
|
|
||||||
|
|
||||||
|
def start_workers(num_workers: int):
|
||||||
|
i = 0
|
||||||
|
for _ in range(num_workers):
|
||||||
|
t = threading.Thread(target=worker)
|
||||||
|
t.daemon = True
|
||||||
|
t.start()
|
||||||
|
i += 1
|
||||||
|
print(f'Started {i} inference workers.')
|
||||||
|
|
||||||
|
|
||||||
|
def need_to_wait():
|
||||||
|
# We need to check the number of active workers since the streaming endpoint may be doing something.
|
||||||
|
active_workers = redis.get('active_gen_workers', int, 0)
|
||||||
|
s = time.time()
|
||||||
|
while active_workers >= opts.concurrent_gens:
|
||||||
|
time.sleep(0.01)
|
||||||
|
e = time.time()
|
||||||
|
if e - s > 0.5:
|
||||||
|
print(f'Worker was delayed {e - s} seconds.')
|
|
@ -0,0 +1,29 @@
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
|
||||||
|
def console_printer():
|
||||||
|
logger = logging.getLogger('console_printer')
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setLevel(logging.INFO)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter("%(asctime)s: %(levelname)s:%(name)s - %(message)s")
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
while True:
|
||||||
|
queued_ip_count = redis.get_dict('queued_ip_count')
|
||||||
|
queued_ip_count = sum([v for k, v in queued_ip_count.items()])
|
||||||
|
processing_ips = redis.get_dict('processing_ips')
|
||||||
|
processing_count = sum([v for k, v in processing_ips.items()])
|
||||||
|
|
||||||
|
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {queued_ip_count}')
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
|
def start_console_printer():
|
||||||
|
t = threading.Thread(target=console_printer)
|
||||||
|
t.daemon = True
|
||||||
|
t.start()
|
11
server.py
11
server.py
|
@ -1,4 +1,4 @@
|
||||||
from redis import Redis
|
from llm_server.workers.printer import start_console_printer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import gevent.monkey
|
import gevent.monkey
|
||||||
|
@ -16,6 +16,7 @@ from threading import Thread
|
||||||
import openai
|
import openai
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from flask import Flask, jsonify, render_template, request
|
from flask import Flask, jsonify, render_template, request
|
||||||
|
from redis import Redis
|
||||||
|
|
||||||
import llm_server
|
import llm_server
|
||||||
from llm_server.database.conn import database
|
from llm_server.database.conn import database
|
||||||
|
@ -26,6 +27,7 @@ from llm_server.routes.openai import openai_bp
|
||||||
from llm_server.routes.server_error import handle_server_error
|
from llm_server.routes.server_error import handle_server_error
|
||||||
from llm_server.routes.v1 import bp
|
from llm_server.routes.v1 import bp
|
||||||
from llm_server.stream import init_socketio
|
from llm_server.stream import init_socketio
|
||||||
|
from llm_server.workers.blocking import start_workers
|
||||||
|
|
||||||
# TODO: have the workers handle streaming too
|
# TODO: have the workers handle streaming too
|
||||||
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
|
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
|
||||||
|
@ -36,6 +38,7 @@ from llm_server.stream import init_socketio
|
||||||
# TODO: implement RRD backend loadbalancer option
|
# TODO: implement RRD backend loadbalancer option
|
||||||
# TODO: have VLLM reject a request if it already has n == concurrent_gens running
|
# TODO: have VLLM reject a request if it already has n == concurrent_gens running
|
||||||
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
|
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
|
||||||
|
# TODO: use coloredlogs
|
||||||
|
|
||||||
# Lower priority
|
# Lower priority
|
||||||
# TODO: the processing stat showed -1 and I had to restart the server
|
# TODO: the processing stat showed -1 and I had to restart the server
|
||||||
|
@ -65,7 +68,6 @@ from llm_server.helpers import resolve_path, auto_set_base_client_api
|
||||||
from llm_server.llm.vllm.info import vllm_info
|
from llm_server.llm.vllm.info import vllm_info
|
||||||
from llm_server.routes.cache import RedisWrapper, flask_cache
|
from llm_server.routes.cache import RedisWrapper, flask_cache
|
||||||
from llm_server.llm import redis
|
from llm_server.llm import redis
|
||||||
from llm_server.routes.queue import start_workers
|
|
||||||
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers
|
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
|
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
|
||||||
|
@ -166,6 +168,8 @@ def pre_fork(server):
|
||||||
|
|
||||||
redis.set_dict('processing_ips', {})
|
redis.set_dict('processing_ips', {})
|
||||||
redis.set_dict('queued_ip_count', {})
|
redis.set_dict('queued_ip_count', {})
|
||||||
|
|
||||||
|
# Flush the RedisPriorityQueue database.
|
||||||
queue_redis = Redis(host='localhost', port=6379, db=15)
|
queue_redis = Redis(host='localhost', port=6379, db=15)
|
||||||
for key in queue_redis.scan_iter('*'):
|
for key in queue_redis.scan_iter('*'):
|
||||||
queue_redis.delete(key)
|
queue_redis.delete(key)
|
||||||
|
@ -181,8 +185,7 @@ def pre_fork(server):
|
||||||
|
|
||||||
# Start background processes
|
# Start background processes
|
||||||
start_workers(opts.concurrent_gens)
|
start_workers(opts.concurrent_gens)
|
||||||
print(f'Started {opts.concurrent_gens} inference workers.')
|
start_console_printer()
|
||||||
|
|
||||||
start_moderation_workers(opts.openai_moderation_workers)
|
start_moderation_workers(opts.openai_moderation_workers)
|
||||||
MainBackgroundThread().start()
|
MainBackgroundThread().start()
|
||||||
SemaphoreCheckerThread().start()
|
SemaphoreCheckerThread().start()
|
||||||
|
|
Reference in New Issue