diff --git a/llm_server/config/config.py b/llm_server/config/config.py index b33a9f2..11092c0 100644 --- a/llm_server/config/config.py +++ b/llm_server/config/config.py @@ -1,7 +1,6 @@ import yaml config_default_vars = { - 'frontend_api_mode': 'ooba', 'log_prompts': False, 'database_path': './proxy-server.db', 'auth_required': False, @@ -38,7 +37,7 @@ config_default_vars = { 'background_homepage_cacher': True, 'openai_moderation_timeout': 5 } -config_required_vars = ['cluster', 'llm_middleware_name'] +config_required_vars = ['cluster', 'frontend_api_mode', 'llm_middleware_name'] mode_ui_names = { 'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), diff --git a/llm_server/database/database.py b/llm_server/database/database.py index 27a059c..1cd5389 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -50,7 +50,8 @@ def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen if token: increment_token_uses(token) - running_model = cluster_config.get_backend(backend_url).get('model') + backend_info = cluster_config.get_backend(backend_url) + running_model = backend_info.get('model') timestamp = int(time.time()) cursor = database.cursor() try: @@ -59,7 +60,7 @@ def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, - (ip, token, running_model, opts.mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) + (ip, token, running_model, cluster_config['mode'], backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) finally: cursor.close() diff --git a/llm_server/llm/__init__.py b/llm_server/llm/__init__.py index ba46635..09f1ad7 100644 --- a/llm_server/llm/__init__.py +++ b/llm_server/llm/__init__.py @@ -1,3 +1,4 @@ +from llm_server.cluster.cluster_config import cluster_config from llm_server.llm import oobabooga, vllm from llm_server.custom_redis import redis @@ -6,7 +7,7 @@ def get_token_count(prompt: str, backend_url: str): assert isinstance(prompt, str) assert isinstance(backend_url, str) - backend_mode = redis.get('backend_mode', dtype=str) + backend_mode = cluster_config.get_backend(backend_url)['mode'] if backend_mode == 'vllm': return vllm.tokenize(prompt, backend_url) elif backend_mode == 'ooba': diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index ac148dd..55fb6e4 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -73,15 +73,16 @@ def do_stream(ws, model_name): if not request_valid_json or not request_json_body.get('prompt'): return 'Invalid JSON', 400 else: - if opts.mode != 'vllm': - # TODO: implement other backends - raise NotImplementedError - auth_failure = require_api_key(request_json_body) if auth_failure: return auth_failure handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body) + + if handler.cluster_backend_info['mode'] != 'vllm': + # TODO: implement other backends + raise NotImplementedError + generated_text = '' input_prompt = request_json_body['prompt'] response_status_code = 0