fix exception when not valid model
This commit is contained in:
parent
acf409abfc
commit
08df52a4fd
|
@ -24,7 +24,6 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
|
|||
pass
|
||||
|
||||
prompt_tokens = get_token_count(prompt, backend_url)
|
||||
print('starting')
|
||||
|
||||
if not is_error:
|
||||
if not response_tokens:
|
||||
|
|
|
@ -34,11 +34,12 @@ class RequestHandler:
|
|||
self.client_ip = self.get_client_ip()
|
||||
self.token = self.get_auth_token()
|
||||
self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
|
||||
self.backend_url = get_a_cluster_backend(selected_model)
|
||||
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||
self.parameters = None
|
||||
self.used = False
|
||||
|
||||
self.backend_url = get_a_cluster_backend(selected_model)
|
||||
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||
|
||||
if not self.cluster_backend_info.get('mode'):
|
||||
print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
|
||||
if not self.cluster_backend_info.get('model'):
|
||||
|
|
|
@ -9,7 +9,6 @@ from ..helpers.http import require_api_key, validate_json
|
|||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||
from ... import opts
|
||||
from ...database.database import do_db_log
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm.generator import generator
|
||||
from ...sock import sock
|
||||
|
@ -47,18 +46,18 @@ def do_stream(ws, model_name):
|
|||
'message_num': 1
|
||||
}))
|
||||
log_to_db(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=quitting_err_msg,
|
||||
gen_time=None,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.cluster_backend_info,
|
||||
response_tokens=None,
|
||||
is_error=True
|
||||
)
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=quitting_err_msg,
|
||||
gen_time=None,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.cluster_backend_info,
|
||||
response_tokens=None,
|
||||
is_error=True
|
||||
)
|
||||
|
||||
if not opts.enable_streaming:
|
||||
return 'Streaming is disabled', 500
|
||||
|
@ -79,6 +78,17 @@ def do_stream(ws, model_name):
|
|||
return auth_failure
|
||||
|
||||
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
|
||||
if handler.offline:
|
||||
msg = f'{handler.selected_model} is not a valid model choice.'
|
||||
print(msg)
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': 0,
|
||||
'text': msg
|
||||
}))
|
||||
return
|
||||
|
||||
assert not handler.offline
|
||||
|
||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||
# TODO: implement other backends
|
||||
|
@ -199,16 +209,16 @@ def do_stream(ws, model_name):
|
|||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_to_db(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=elapsed_time,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url
|
||||
)
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=elapsed_time,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
# Must close the connection or greenlets will complain.
|
||||
|
|
|
@ -25,4 +25,3 @@ def db_logger():
|
|||
|
||||
if function_name == 'log_prompt':
|
||||
do_db_log(*args, **kwargs)
|
||||
print('finished log')
|
||||
|
|
Reference in New Issue