diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index 6dd5874..a118845 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -4,7 +4,7 @@ import flask class LLMBackend: - default_params: dict + _default_params: dict def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers): raise NotImplementedError diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index f2fc82d..a00afe7 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -10,7 +10,7 @@ from llm_server.routes.helpers.http import validate_json class VLLMBackend(LLMBackend): - default_params = vars(SamplingParams()) + _default_params = vars(SamplingParams()) def handle_response(self, success, request, response_json_body, response_status_code, client_ip, token, prompt: str, elapsed_time, parameters, headers): if len(response_json_body.get('text', [])): @@ -25,14 +25,18 @@ class VLLMBackend(LLMBackend): def get_parameters(self, parameters) -> Tuple[dict | None, str | None]: try: + # top_k == -1 means disabled + top_k = parameters.get('top_k', self._default_params['top_k']) + if top_k <= 0: + top_k = -1 sampling_params = SamplingParams( - temperature=parameters.get('temperature', self.default_params['temperature']), - top_p=parameters.get('top_p', self.default_params['top_p']), - top_k=parameters.get('top_k', self.default_params['top_k']), + temperature=parameters.get('temperature', self._default_params['temperature']), + top_p=parameters.get('top_p', self._default_params['top_p']), + top_k=top_k, use_beam_search=True if parameters.get('num_beams', 0) > 1 else False, - stop=parameters.get('stopping_strings', self.default_params['stop']), + stop=parameters.get('stopping_strings', self._default_params['stop']), ignore_eos=parameters.get('ban_eos_token', False), - max_tokens=parameters.get('max_new_tokens', self.default_params['max_tokens']) + max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) ) except ValueError as e: return None, str(e).strip('.') diff --git a/llm_server/opts.py b/llm_server/opts.py index 8c484aa..82f27dd 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -30,4 +30,4 @@ expose_openai_system_prompt = True enable_streaming = True openai_api_key = None backend_request_timeout = 30 -backend_generate_request_timeout = 120 +backend_generate_request_timeout = 95 diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index 9577480..44b740c 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -1,12 +1,11 @@ -import time +from typing import Tuple +import flask from flask import jsonify from llm_server import opts from llm_server.database import log_prompt from llm_server.routes.helpers.client import format_sillytavern_err -from llm_server.routes.helpers.http import validate_json -from llm_server.routes.queue import priority_queue from llm_server.routes.request_handler import RequestHandler @@ -35,3 +34,8 @@ class OobaRequestHandler(RequestHandler): return jsonify({ 'results': [{'text': backend_response}] }), 200 + + def handle_error(self, msg: str) -> Tuple[flask.Response, int]: + return jsonify({ + 'results': [{'text': msg}] + }), 200 diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 3d611e7..cd34240 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -10,7 +10,6 @@ from ..openai_request_handler import OpenAIRequestHandler, build_openai_response @openai_bp.route('/chat/completions', methods=['POST']) def openai_chat_completions(): - # TODO: make this work with oobabooga request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('messages'): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 3c9b341..d3292ed 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -48,9 +48,11 @@ class OpenAIRequestHandler(RequestHandler): # Reconstruct the request JSON with the validated parameters and prompt. self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE']) llm_request = {**self.parameters, 'prompt': self.prompt} - - _, (backend_response, backend_response_status_code) = self.generate_response(llm_request) - return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code + (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) + if success: + return build_openai_response(self.prompt, backend_response.json['results'][0]['text']), backend_response_status_code + else: + return backend_response, backend_response_status_code def handle_ratelimited(self): backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') @@ -81,13 +83,16 @@ class OpenAIRequestHandler(RequestHandler): prompt += '\n\n### RESPONSE: ' return prompt + def handle_error(self, msg: str) -> Tuple[flask.Response, int]: + return build_openai_response('', msg), 200 + def check_moderation_endpoint(prompt: str): headers = { 'Content-Type': 'application/json', 'Authorization': f"Bearer {opts.openai_api_key}", } - response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}).json() + response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10).json() offending_categories = [] for k, v in response['results'][0]['categories'].items(): if v: diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index a6596cc..d5d16a1 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -1,6 +1,10 @@ -import heapq +import json +import pickle import threading import time +from uuid import uuid4 + +from redis import Redis from llm_server import opts from llm_server.llm.generator import generator @@ -27,58 +31,77 @@ def decrement_ip_count(client_ip: int, redis_key): return ip_count -class PriorityQueue: +class RedisPriorityQueue: def __init__(self): - self._queue = [] self._index = 0 - self._cv = threading.Condition() self._lock = threading.Lock() - redis.set_dict('queued_ip_count', {}) + self.redis = Redis(host='localhost', port=6379, db=15) + + # Clear the DB + for key in self.redis.scan_iter('*'): + self.redis.delete(key) + + self.pubsub = self.redis.pubsub() + self.pubsub.subscribe('events') def put(self, item, priority): event = DataEvent() - with self._cv: - # Check if the IP is already in the dictionary and if it has reached the limit - ip_count = redis.get_dict('queued_ip_count') - if item[1] in ip_count and ip_count[item[1]] >= opts.simultaneous_requests_per_ip and priority != 0: - return None # reject the request - heapq.heappush(self._queue, (-priority, self._index, item, event)) - self._index += 1 - # Increment the count for this IP - with self._lock: - increment_ip_count(item[1], 'queued_ip_count') - self._cv.notify() + # Check if the IP is already in the dictionary and if it has reached the limit + ip_count = self.redis.hget('queued_ip_count', item[1]) + if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0: + return None # reject the request + self.redis.zadd('queue', {json.dumps((self._index, item, event.event_id)): -priority}) + self._index += 1 + # Increment the count for this IP + with self._lock: + self.increment_ip_count(item[1], 'queued_ip_count') return event def get(self): - with self._cv: - while len(self._queue) == 0: - self._cv.wait() - _, _, item, event = heapq.heappop(self._queue) - # Decrement the count for this IP - with self._lock: - decrement_ip_count(item[1], 'queued_ip_count') - return item, event + while True: + data = self.redis.zpopmin('queue') + if data: + item = json.loads(data[0][0]) + client_ip = item[1][1] + # Decrement the count for this IP + with self._lock: + self.decrement_ip_count(client_ip, 'queued_ip_count') + return item + time.sleep(1) # wait for an item to be added to the queue + + def increment_ip_count(self, ip, key): + self.redis.hincrby(key, ip, 1) + + def decrement_ip_count(self, ip, key): + self.redis.hincrby(key, ip, -1) def __len__(self): - return len(self._queue) + return self.redis.zcard('queue') -priority_queue = PriorityQueue() +class DataEvent: + def __init__(self, event_id=None): + self.event_id = event_id if event_id else str(uuid4()) + self.redis = Redis(host='localhost', port=6379, db=14) + self.pubsub = self.redis.pubsub() + self.pubsub.subscribe(self.event_id) + + def set(self, data): + self.redis.publish(self.event_id, pickle.dumps(data)) + + def wait(self): + for item in self.pubsub.listen(): + if item['type'] == 'message': + return pickle.loads(item['data']) -class DataEvent(threading.Event): - def __init__(self): - super().__init__() - self.data = None +priority_queue = RedisPriorityQueue() def worker(): - global processing_ips_lock while True: - (request_json_body, client_ip, token, parameters), event = priority_queue.get() + index, (request_json_body, client_ip, token, parameters), event_id = priority_queue.get() - # redis.sadd('processing_ips', client_ip) increment_ip_count(client_ip, 'processing_ips') redis.incr('active_gen_workers') @@ -91,10 +114,9 @@ def worker(): with generation_elapsed_lock: generation_elapsed.append((end_time, elapsed_time)) - event.data = (success, response, error_msg) - event.set() + event = DataEvent(event_id) + event.set((success, response, error_msg)) - # redis.srem('processing_ips', client_ip) decrement_ip_count(client_ip, 'processing_ips') redis.decr('active_gen_workers') diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 0b57149..e075df6 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -3,7 +3,7 @@ import time from typing import Tuple, Union import flask -from flask import Response, jsonify +from flask import Response from llm_server import opts from llm_server.database import log_prompt @@ -27,7 +27,7 @@ class RequestHandler: self.token = self.request.headers.get('X-Api-Key') self.priority = self.get_priority() self.backend = get_backend() - self.parameters = self.parameters_invalid_msg = None + self.parameters = None self.used = False SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time() @@ -50,31 +50,26 @@ class RequestHandler: return result[0] return DEFAULT_PRIORITY - def load_parameters(self): - # Handle OpenAI + def get_parameters(self): if self.request_json_body.get('max_tokens'): self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens') - self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) + parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) + return parameters, parameters_invalid_msg def validate_request(self) -> Tuple[bool, Tuple[Response | None, int]]: - self.load_parameters() - params_valid = False + self.parameters, parameters_invalid_msg = self.get_parameters() request_valid = False + invalid_request_err_msg = None if self.parameters: - params_valid = True request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters) - if not request_valid or not params_valid: - error_messages = [msg for valid, msg in [request_valid, params_valid] if not valid and msg] + if not request_valid: + error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (not bool(parameters_invalid_msg), parameters_invalid_msg)] if not valid and msg] combined_error_message = ', '.join(error_messages) - err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) + backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) # TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types - return False, (jsonify({ - 'code': 400, - 'msg': 'parameter validation error', - 'results': [{'text': err}] - }), 200) + return False, self.handle_error(backend_response) return True, (None, 0) def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]: @@ -88,9 +83,7 @@ class RequestHandler: prompt = llm_request['prompt'] - event.wait() - success, response, error_msg = event.data - + success, response, error_msg = event.wait() end_time = time.time() elapsed_time = end_time - self.start_time @@ -113,11 +106,7 @@ class RequestHandler: error_msg = error_msg.strip('.') + '.' backend_response = format_sillytavern_err(error_msg, 'error') log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) - return (False, None, None, 0), (jsonify({ - 'code': 500, - 'msg': error_msg, - 'results': [{'text': backend_response}] - }), 200) + return (False, None, None, 0), self.handle_error(backend_response) # =============================================== @@ -137,11 +126,7 @@ class RequestHandler: error_msg = 'The backend did not return valid JSON.' backend_response = format_sillytavern_err(error_msg, 'error') log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) - return (False, None, None, 0), (jsonify({ - 'code': 500, - 'msg': error_msg, - 'results': [{'text': backend_response}] - }), 200) + return (False, None, None, 0), self.handle_error(backend_response) # =============================================== @@ -164,6 +149,9 @@ class RequestHandler: def handle_ratelimited(self) -> Tuple[flask.Response, int]: raise NotImplementedError + def handle_error(self, msg: str) -> Tuple[flask.Response, int]: + raise NotImplementedError + def get_backend(): if opts.mode == 'oobabooga': diff --git a/server.py b/server.py index d891cf9..f77c558 100644 --- a/server.py +++ b/server.py @@ -186,4 +186,4 @@ def server_error(e): if __name__ == "__main__": - app.run(host='0.0.0.0') + app.run(host='0.0.0.0', threaded=False, processes=15)