diff --git a/llm_server/llm/vllm/tokenize.py b/llm_server/llm/vllm/tokenize.py index dd9553b..8d7bffc 100644 --- a/llm_server/llm/vllm/tokenize.py +++ b/llm_server/llm/vllm/tokenize.py @@ -14,5 +14,5 @@ def tokenize(prompt: str) -> int: j = r.json() return j['length'] except: - print(traceback.format_exc()) + traceback.print_exc() return len(tokenizer.encode(prompt)) + 10 diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index 2c48bf0..e5b0fad 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -26,7 +26,9 @@ class VLLMBackend(LLMBackend): response_tokens=response_json_body.get('details', {}).get('generated_tokens')) # TODO: use async/await instead of threads - threading.Thread(target=background_task).start() + thread = threading.Thread(target=background_task) + thread.start() + thread.join() return jsonify({'results': [{'text': backend_response}]}), 200 diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 567d380..54ef871 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -95,7 +95,9 @@ def openai_chat_completions(): log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) # TODO: use async/await instead of threads - threading.Thread(target=background_task).start() + thread = threading.Thread(target=background_task) + thread.start() + thread.join() return Response(generate(), mimetype='text/event-stream') except: diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 58d24f8..488107c 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -28,6 +28,9 @@ def stream(ws): if auth_failure: return auth_failure + r_headers = dict(request.headers) + r_url = request.url + message_num = 0 while ws.connected: message = ws.receive() @@ -51,11 +54,20 @@ def stream(ws): start_time = time.time() request_valid, invalid_response = handler.validate_request() if not request_valid: + err_msg = invalid_response[0].json['results'][0]['text'] ws.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, - 'text': invalid_response[0].json['results'][0]['text'] + 'text': err_msg })) + + def background_task(): + log_prompt(handler.client_ip, handler.token, input_prompt, err_msg, None, handler.parameters, r_headers, response_status_code, r_url, is_error=True) + + # TODO: use async/await instead of threads + thread = threading.Thread(target=background_task) + thread.start() + thread.join() else: input_prompt = request_json_body['prompt'] msg_to_backend = { @@ -104,15 +116,15 @@ def stream(ws): end_time = time.time() elapsed_time = end_time - start_time - r_headers = dict(request.headers) - r_url = request.url - def background_task(): + def background_task_success(): generated_tokens = tokenize(generated_text) log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) # TODO: use async/await instead of threads - threading.Thread(target=background_task).start() + thread = threading.Thread(target=background_task_success) + thread.start() + thread.join() except: generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error') traceback.print_exc() @@ -122,13 +134,14 @@ def stream(ws): 'text': generated_text })) - def background_task(): + def background_task_exception(): generated_tokens = tokenize(generated_text) - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, dict(request.headers), response_status_code, request.url, response_tokens=generated_tokens) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) # TODO: use async/await instead of threads - threading.Thread(target=background_task).start() - + thread = threading.Thread(target=background_task_exception) + thread.start() + thread.join() ws.send(json.dumps({ 'event': 'stream_end', 'message_num': message_num diff --git a/server.py b/server.py index 2cff053..41fa879 100644 --- a/server.py +++ b/server.py @@ -20,6 +20,10 @@ from llm_server.routes.server_error import handle_server_error # TODO: add more excluding to SYSTEM__ tokens # TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded) # TODO: support turbo-instruct on openai endpoint +# TODO: option to trim context in openai mode so that we silently fit the model's context +# TODO: validate system tokens before excluding them +# TODO: unify logging thread in a function and use async/await instead +# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests try: import vllm