fix background log not doing anything

This commit is contained in:
Cyberes 2023-09-25 18:18:29 -06:00
parent 8184e24bff
commit 8240a1ebbb
5 changed files with 33 additions and 12 deletions

View File

@ -14,5 +14,5 @@ def tokenize(prompt: str) -> int:
j = r.json()
return j['length']
except:
print(traceback.format_exc())
traceback.print_exc()
return len(tokenizer.encode(prompt)) + 10

View File

@ -26,7 +26,9 @@ class VLLMBackend(LLMBackend):
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
# TODO: use async/await instead of threads
threading.Thread(target=background_task).start()
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
return jsonify({'results': [{'text': backend_response}]}), 200

View File

@ -95,7 +95,9 @@ def openai_chat_completions():
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
threading.Thread(target=background_task).start()
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
return Response(generate(), mimetype='text/event-stream')
except:

View File

@ -28,6 +28,9 @@ def stream(ws):
if auth_failure:
return auth_failure
r_headers = dict(request.headers)
r_url = request.url
message_num = 0
while ws.connected:
message = ws.receive()
@ -51,11 +54,20 @@ def stream(ws):
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
err_msg = invalid_response[0].json['results'][0]['text']
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': invalid_response[0].json['results'][0]['text']
'text': err_msg
}))
def background_task():
log_prompt(handler.client_ip, handler.token, input_prompt, err_msg, None, handler.parameters, r_headers, response_status_code, r_url, is_error=True)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
else:
input_prompt = request_json_body['prompt']
msg_to_backend = {
@ -104,15 +116,15 @@ def stream(ws):
end_time = time.time()
elapsed_time = end_time - start_time
r_headers = dict(request.headers)
r_url = request.url
def background_task():
def background_task_success():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
threading.Thread(target=background_task).start()
thread = threading.Thread(target=background_task_success)
thread.start()
thread.join()
except:
generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error')
traceback.print_exc()
@ -122,13 +134,14 @@ def stream(ws):
'text': generated_text
}))
def background_task():
def background_task_exception():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
threading.Thread(target=background_task).start()
thread = threading.Thread(target=background_task_exception)
thread.start()
thread.join()
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num

View File

@ -20,6 +20,10 @@ from llm_server.routes.server_error import handle_server_error
# TODO: add more excluding to SYSTEM__ tokens
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
# TODO: support turbo-instruct on openai endpoint
# TODO: option to trim context in openai mode so that we silently fit the model's context
# TODO: validate system tokens before excluding them
# TODO: unify logging thread in a function and use async/await instead
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
try:
import vllm