Merge cluster to master #3
|
@ -3,7 +3,6 @@ from typing import Tuple, Union
|
||||||
import flask
|
import flask
|
||||||
|
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.custom_redis import redis
|
|
||||||
from llm_server.llm import get_token_count
|
from llm_server.llm import get_token_count
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'
|
|
@ -3,7 +3,7 @@ from typing import Tuple
|
||||||
import flask
|
import flask
|
||||||
from flask import jsonify, request
|
from flask import jsonify, request
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import messages, opts
|
||||||
from llm_server.database.log_to_db import log_to_db
|
from llm_server.database.log_to_db import log_to_db
|
||||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||||
from llm_server.routes.request_handler import RequestHandler
|
from llm_server.routes.request_handler import RequestHandler
|
||||||
|
@ -16,9 +16,8 @@ class OobaRequestHandler(RequestHandler):
|
||||||
def handle_request(self, return_ok: bool = True):
|
def handle_request(self, return_ok: bool = True):
|
||||||
assert not self.used
|
assert not self.used
|
||||||
if self.offline:
|
if self.offline:
|
||||||
msg = 'The model you requested is not a valid choice. Please retry your query.'
|
print(messages.BACKEND_OFFLINE)
|
||||||
print(msg)
|
self.handle_error(messages.BACKEND_OFFLINE)
|
||||||
self.handle_error(msg)
|
|
||||||
|
|
||||||
request_valid, invalid_response = self.validate_request()
|
request_valid, invalid_response = self.validate_request()
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
|
|
|
@ -12,7 +12,7 @@ from ..queue import priority_queue
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...database.log_to_db import log_to_db
|
from ...database.log_to_db import log_to_db
|
||||||
from ...llm.generator import generator
|
from ...llm.generator import generator
|
||||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
||||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,6 +106,10 @@ def openai_chat_completions(model_name=None):
|
||||||
break
|
break
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Double check the model is still online
|
||||||
|
if not handler.check_online():
|
||||||
|
return return_invalid_model_err(handler.request_json_body['model'])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r_headers = dict(request.headers)
|
r_headers = dict(request.headers)
|
||||||
r_url = request.url
|
r_url = request.url
|
||||||
|
|
|
@ -13,7 +13,7 @@ from ... import opts
|
||||||
from ...database.log_to_db import log_to_db
|
from ...database.log_to_db import log_to_db
|
||||||
from ...llm import get_token_count
|
from ...llm import get_token_count
|
||||||
from ...llm.generator import generator
|
from ...llm.generator import generator
|
||||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
||||||
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
||||||
|
|
||||||
|
|
||||||
|
@ -131,6 +131,10 @@ def openai_completions(model_name=None):
|
||||||
break
|
break
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Double check the model is still online
|
||||||
|
if not handler.check_online():
|
||||||
|
return return_invalid_model_err(handler.request_json_body['model'])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = generator(msg_to_backend, handler.backend_url)
|
response = generator(msg_to_backend, handler.backend_url)
|
||||||
r_headers = dict(request.headers)
|
r_headers = dict(request.headers)
|
||||||
|
|
|
@ -58,6 +58,10 @@ class RequestHandler:
|
||||||
# "recent_prompters" is only used for stats.
|
# "recent_prompters" is only used for stats.
|
||||||
redis.zadd('recent_prompters', {self.client_ip: time.time()})
|
redis.zadd('recent_prompters', {self.client_ip: time.time()})
|
||||||
|
|
||||||
|
def check_online(self) -> bool:
|
||||||
|
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||||
|
return self.cluster_backend_info['online']
|
||||||
|
|
||||||
def get_auth_token(self):
|
def get_auth_token(self):
|
||||||
if self.request_json_body.get('X-API-KEY'):
|
if self.request_json_body.get('X-API-KEY'):
|
||||||
return self.request_json_body['X-API-KEY']
|
return self.request_json_body['X-API-KEY']
|
||||||
|
|
|
@ -8,7 +8,7 @@ from . import bp
|
||||||
from ..helpers.http import require_api_key, validate_json
|
from ..helpers.http import require_api_key, validate_json
|
||||||
from ..ooba_request_handler import OobaRequestHandler
|
from ..ooba_request_handler import OobaRequestHandler
|
||||||
from ..queue import priority_queue
|
from ..queue import priority_queue
|
||||||
from ... import opts
|
from ... import messages, opts
|
||||||
from ...custom_redis import redis
|
from ...custom_redis import redis
|
||||||
from ...database.log_to_db import log_to_db
|
from ...database.log_to_db import log_to_db
|
||||||
from ...llm.generator import generator
|
from ...llm.generator import generator
|
||||||
|
@ -147,9 +147,12 @@ def do_stream(ws, model_name):
|
||||||
break
|
break
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Double check the model is still online
|
||||||
|
if not handler.check_online():
|
||||||
|
return messages.BACKEND_OFFLINE, 404 # TODO: format this error
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = generator(llm_request, handler.backend_url)
|
response = generator(llm_request, handler.backend_url)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
error_msg = 'Failed to reach backend while streaming.'
|
error_msg = 'Failed to reach backend while streaming.'
|
||||||
print('Streaming failed:', error_msg)
|
print('Streaming failed:', error_msg)
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
|
from llm_server import messages
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.custom_redis import redis, RedisCustom
|
from llm_server.custom_redis import redis
|
||||||
from llm_server.llm.generator import generator
|
from llm_server.llm.generator import generator
|
||||||
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, RedisPriorityQueue, PriorityQueue, priority_queue
|
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
|
||||||
|
|
||||||
|
|
||||||
def worker(backend_url):
|
def worker(backend_url):
|
||||||
|
@ -14,14 +14,18 @@ def worker(backend_url):
|
||||||
while True:
|
while True:
|
||||||
(request_json_body, client_ip, token, parameters), event_id, selected_model = queue.get()
|
(request_json_body, client_ip, token, parameters), event_id, selected_model = queue.get()
|
||||||
backend_info = cluster_config.get_backend(backend_url)
|
backend_info = cluster_config.get_backend(backend_url)
|
||||||
|
|
||||||
|
if not backend_info['online']:
|
||||||
|
event = DataEvent(event_id)
|
||||||
|
event.set((False, None, messages.BACKEND_OFFLINE))
|
||||||
|
return
|
||||||
|
|
||||||
if not selected_model:
|
if not selected_model:
|
||||||
selected_model = backend_info['model']
|
selected_model = backend_info['model']
|
||||||
|
|
||||||
increment_ip_count(client_ip, 'processing_ips')
|
increment_ip_count(client_ip, 'processing_ips')
|
||||||
incr_active_workers(selected_model, backend_url)
|
incr_active_workers(selected_model, backend_url)
|
||||||
|
|
||||||
print('Worker starting processing for', client_ip)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not request_json_body:
|
if not request_json_body:
|
||||||
# This was a dummy request from the streaming handlers.
|
# This was a dummy request from the streaming handlers.
|
||||||
|
@ -48,7 +52,6 @@ def worker(backend_url):
|
||||||
finally:
|
finally:
|
||||||
decrement_ip_count(client_ip, 'processing_ips')
|
decrement_ip_count(client_ip, 'processing_ips')
|
||||||
decr_active_workers(selected_model, backend_url)
|
decr_active_workers(selected_model, backend_url)
|
||||||
print('Worker finished processing for', client_ip)
|
|
||||||
|
|
||||||
|
|
||||||
def start_workers(cluster: dict):
|
def start_workers(cluster: dict):
|
||||||
|
|
|
@ -40,7 +40,7 @@ from llm_server.sock import init_socketio
|
||||||
# TODO: if a backend is at its limit of concurrent requests, choose a different one
|
# TODO: if a backend is at its limit of concurrent requests, choose a different one
|
||||||
|
|
||||||
# Lower priority
|
# Lower priority
|
||||||
# TODO: fix moderation freezing after a while
|
# TODO: make error messages consitient
|
||||||
# TODO: support logit_bias on OpenAI and Ooba endpoints.
|
# TODO: support logit_bias on OpenAI and Ooba endpoints.
|
||||||
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
|
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
|
||||||
# TODO: validate openai_silent_trim works as expected and only when enabled
|
# TODO: validate openai_silent_trim works as expected and only when enabled
|
||||||
|
|
Reference in New Issue