fix background log not doing anything

This commit is contained in:
Cyberes 2023-09-25 18:18:29 -06:00
parent 8184e24bff
commit 8240a1ebbb
5 changed files with 33 additions and 12 deletions

View File

@ -14,5 +14,5 @@ def tokenize(prompt: str) -> int:
j = r.json() j = r.json()
return j['length'] return j['length']
except: except:
print(traceback.format_exc()) traceback.print_exc()
return len(tokenizer.encode(prompt)) + 10 return len(tokenizer.encode(prompt)) + 10

View File

@ -26,7 +26,9 @@ class VLLMBackend(LLMBackend):
response_tokens=response_json_body.get('details', {}).get('generated_tokens')) response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
# TODO: use async/await instead of threads # TODO: use async/await instead of threads
threading.Thread(target=background_task).start() thread = threading.Thread(target=background_task)
thread.start()
thread.join()
return jsonify({'results': [{'text': backend_response}]}), 200 return jsonify({'results': [{'text': backend_response}]}), 200

View File

@ -95,7 +95,9 @@ def openai_chat_completions():
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads # TODO: use async/await instead of threads
threading.Thread(target=background_task).start() thread = threading.Thread(target=background_task)
thread.start()
thread.join()
return Response(generate(), mimetype='text/event-stream') return Response(generate(), mimetype='text/event-stream')
except: except:

View File

@ -28,6 +28,9 @@ def stream(ws):
if auth_failure: if auth_failure:
return auth_failure return auth_failure
r_headers = dict(request.headers)
r_url = request.url
message_num = 0 message_num = 0
while ws.connected: while ws.connected:
message = ws.receive() message = ws.receive()
@ -51,11 +54,20 @@ def stream(ws):
start_time = time.time() start_time = time.time()
request_valid, invalid_response = handler.validate_request() request_valid, invalid_response = handler.validate_request()
if not request_valid: if not request_valid:
err_msg = invalid_response[0].json['results'][0]['text']
ws.send(json.dumps({ ws.send(json.dumps({
'event': 'text_stream', 'event': 'text_stream',
'message_num': message_num, 'message_num': message_num,
'text': invalid_response[0].json['results'][0]['text'] 'text': err_msg
})) }))
def background_task():
log_prompt(handler.client_ip, handler.token, input_prompt, err_msg, None, handler.parameters, r_headers, response_status_code, r_url, is_error=True)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
else: else:
input_prompt = request_json_body['prompt'] input_prompt = request_json_body['prompt']
msg_to_backend = { msg_to_backend = {
@ -104,15 +116,15 @@ def stream(ws):
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
r_headers = dict(request.headers)
r_url = request.url
def background_task(): def background_task_success():
generated_tokens = tokenize(generated_text) generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads # TODO: use async/await instead of threads
threading.Thread(target=background_task).start() thread = threading.Thread(target=background_task_success)
thread.start()
thread.join()
except: except:
generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error') generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error')
traceback.print_exc() traceback.print_exc()
@ -122,13 +134,14 @@ def stream(ws):
'text': generated_text 'text': generated_text
})) }))
def background_task(): def background_task_exception():
generated_tokens = tokenize(generated_text) generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens) log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads # TODO: use async/await instead of threads
threading.Thread(target=background_task).start() thread = threading.Thread(target=background_task_exception)
thread.start()
thread.join()
ws.send(json.dumps({ ws.send(json.dumps({
'event': 'stream_end', 'event': 'stream_end',
'message_num': message_num 'message_num': message_num

View File

@ -20,6 +20,10 @@ from llm_server.routes.server_error import handle_server_error
# TODO: add more excluding to SYSTEM__ tokens # TODO: add more excluding to SYSTEM__ tokens
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded) # TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
# TODO: support turbo-instruct on openai endpoint # TODO: support turbo-instruct on openai endpoint
# TODO: option to trim context in openai mode so that we silently fit the model's context
# TODO: validate system tokens before excluding them
# TODO: unify logging thread in a function and use async/await instead
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
try: try:
import vllm import vllm