fix exception

This commit is contained in:
Cyberes 2023-10-03 13:47:18 -06:00
parent 32ad97e57c
commit 581a0fec99
4 changed files with 11 additions and 9 deletions

View File

@ -1,7 +1,6 @@
import yaml
config_default_vars = {
'frontend_api_mode': 'ooba',
'log_prompts': False,
'database_path': './proxy-server.db',
'auth_required': False,
@ -38,7 +37,7 @@ config_default_vars = {
'background_homepage_cacher': True,
'openai_moderation_timeout': 5
}
config_required_vars = ['cluster', 'llm_middleware_name']
config_required_vars = ['cluster', 'frontend_api_mode', 'llm_middleware_name']
mode_ui_names = {
'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),

View File

@ -50,7 +50,8 @@ def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen
if token:
increment_token_uses(token)
running_model = cluster_config.get_backend(backend_url).get('model')
backend_info = cluster_config.get_backend(backend_url)
running_model = backend_info.get('model')
timestamp = int(time.time())
cursor = database.cursor()
try:
@ -59,7 +60,7 @@ def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(ip, token, running_model, opts.mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
(ip, token, running_model, cluster_config['mode'], backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally:
cursor.close()

View File

@ -1,3 +1,4 @@
from llm_server.cluster.cluster_config import cluster_config
from llm_server.llm import oobabooga, vllm
from llm_server.custom_redis import redis
@ -6,7 +7,7 @@ def get_token_count(prompt: str, backend_url: str):
assert isinstance(prompt, str)
assert isinstance(backend_url, str)
backend_mode = redis.get('backend_mode', dtype=str)
backend_mode = cluster_config.get_backend(backend_url)['mode']
if backend_mode == 'vllm':
return vllm.tokenize(prompt, backend_url)
elif backend_mode == 'ooba':

View File

@ -73,15 +73,16 @@ def do_stream(ws, model_name):
if not request_valid_json or not request_json_body.get('prompt'):
return 'Invalid JSON', 400
else:
if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
auth_failure = require_api_key(request_json_body)
if auth_failure:
return auth_failure
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
generated_text = ''
input_prompt = request_json_body['prompt']
response_status_code = 0