From 169e216a38693ecafbd41ccb379c843183c78201 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Wed, 11 Oct 2023 12:56:04 -0600 Subject: [PATCH] add background thread to gradio --- other/gradio/gradio_chat.py | 64 +++++++++++++++++++++++------------ other/gradio/requirements.txt | 3 ++ 2 files changed, 45 insertions(+), 22 deletions(-) create mode 100644 other/gradio/requirements.txt diff --git a/other/gradio/gradio_chat.py b/other/gradio/gradio_chat.py index fa1b892..af4aeeb 100644 --- a/other/gradio/gradio_chat.py +++ b/other/gradio/gradio_chat.py @@ -1,22 +1,48 @@ import os import sys +import time +import traceback import warnings +from threading import Thread import gradio as gr import openai +import requests warnings.filterwarnings("ignore") API_BASE = os.getenv('API_BASE') if not API_BASE: - print('Must set the secret variable API_BASE to your https://your-site/api/openai/v1') + print('Must set the secret variable API_BASE to your https://your-site/api') sys.exit(1) - -BACKUP_API_BASE = os.getenv('BACKUP_API_BASE') -if BACKUP_API_BASE: - print('Using BACKUP_API_BASE:', BACKUP_API_BASE) +API_BASE = API_BASE.strip('/') APP_TITLE = os.getenv('APP_TITLE') +PRIMARY_MODEL_CHOICE = os.getenv('PRIMARY_MODEL_CHOICE') + + +def background(): + while True: + previous = openai.api_base + try: + r = requests.get(API_BASE + '/stats').json() + if PRIMARY_MODEL_CHOICE in r['models']['choices'].keys(): + openai.api_base = API_BASE + '/openai/' + PRIMARY_MODEL_CHOICE + '/v1' + else: + openai.api_base = API_BASE + '/openai/v1' + except: + traceback.print_exc() + openai.api_base = API_BASE + '/openai/v1' + if openai.api_base != previous: + print('Set primary model to', openai.api_base) + time.sleep(10) + + +if PRIMARY_MODEL_CHOICE: + t = Thread(target=background) + t.daemon = True + t.start() + print('Started the background thread.') # A system prompt can be injected into the very first spot in the context. # If the user sends a message that contains the CONTEXT_TRIGGER_PHRASE, @@ -26,7 +52,7 @@ CONTEXT_TRIGGER_PHRASE = os.getenv('CONTEXT_TRIGGER_PHRASE') CONTEXT_TRIGGER_INJECTION = os.getenv('CONTEXT_TRIGGER_INJECTION') openai.api_key = 'null' -openai.api_base = API_BASE +openai.api_base = API_BASE + '/openai/v1' def stream_response(prompt, history): @@ -43,22 +69,16 @@ def stream_response(prompt, history): if do_injection or (CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in prompt): messages.insert(0, {'role': 'system', 'content': CONTEXT_TRIGGER_INJECTION}) - for _ in range(2): - try: - response = openai.ChatCompletion.create( - model='0', - messages=messages, - temperature=0, - max_tokens=300, - stream=True - ) - break - except Exception: - openai.api_base = BACKUP_API_BASE - raise gr.Error("Failed to reach inference endpoint.") - - # Go back to the default endpoint - openai.api_base = API_BASE + try: + response = openai.ChatCompletion.create( + model='0', + messages=messages, + temperature=0, + max_tokens=300, + stream=True + ) + except Exception: + raise gr.Error("Failed to reach inference endpoint.") message = '' for chunk in response: diff --git a/other/gradio/requirements.txt b/other/gradio/requirements.txt new file mode 100644 index 0000000..eb4baac --- /dev/null +++ b/other/gradio/requirements.txt @@ -0,0 +1,3 @@ +gradio +openai +requests \ No newline at end of file