diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 2c13131..f45756a 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -3,7 +3,6 @@ import time import traceback from flask import Response, jsonify, request -from redis import Redis from llm_server.custom_redis import redis from . import openai_bp @@ -98,62 +97,60 @@ def openai_chat_completions(): oai_string = generate_oai_string(30) def generate(): - try: - response = generator(msg_to_backend, handler.backend_url) - generated_text = '' - partial_response = b'' - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_strs = partial_response.split(b'\x00') - for json_str in json_strs: - if json_str: - try: - json_obj = json.loads(json_str.decode()) - new = json_obj['text'][0].split(handler.prompt + generated_text)[1] - generated_text = generated_text + new - except IndexError: - # ???? - continue + response = generator(msg_to_backend, handler.backend_url) + generated_text = '' + partial_response = b'' + for chunk in response.iter_content(chunk_size=1): + partial_response += chunk + if partial_response.endswith(b'\x00'): + json_strs = partial_response.split(b'\x00') + for json_str in json_strs: + if json_str: + try: + json_obj = json.loads(json_str.decode()) + new = json_obj['text'][0].split(handler.prompt + generated_text)[1] + generated_text = generated_text + new + except IndexError: + # ???? + continue - data = { - "id": f"chatcmpl-{oai_string}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": new - }, - "finish_reason": None - } - ] - } - yield f'data: {json.dumps(data)}\n\n' - yield 'data: [DONE]\n\n' - end_time = time.time() - elapsed_time = end_time - start_time - log_to_db( - handler.client_ip, - handler.token, - handler.prompt, - generated_text, - elapsed_time, - handler.parameters, - r_headers, - response_status_code, - r_url, - handler.backend_url, - ) - finally: - # The worker incremented it, we'll decrement it. - decrement_ip_count(handler.client_ip, 'processing_ips') - decr_active_workers(handler.selected_model, handler.backend_url) - print('cleaned up') + data = { + "id": f"chatcmpl-{oai_string}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": new + }, + "finish_reason": None + } + ] + } + yield f'data: {json.dumps(data)}\n\n' + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + response_status_code, + r_url, + handler.backend_url, + ) return Response(generate(), mimetype='text/event-stream') except Exception: traceback.print_exc() return 'INTERNAL SERVER', 500 + finally: + # The worker incremented it, we'll decrement it. + decrement_ip_count(handler.client_ip, 'processing_ips') + decr_active_workers(handler.selected_model, handler.backend_url)