add background thread to gradio

This commit is contained in:
Cyberes 2023-10-11 12:56:04 -06:00
parent 69b8c1e35c
commit 169e216a38
2 changed files with 45 additions and 22 deletions

View File

@ -1,22 +1,48 @@
import os import os
import sys import sys
import time
import traceback
import warnings import warnings
from threading import Thread
import gradio as gr import gradio as gr
import openai import openai
import requests
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
API_BASE = os.getenv('API_BASE') API_BASE = os.getenv('API_BASE')
if not API_BASE: if not API_BASE:
print('Must set the secret variable API_BASE to your https://your-site/api/openai/v1') print('Must set the secret variable API_BASE to your https://your-site/api')
sys.exit(1) sys.exit(1)
API_BASE = API_BASE.strip('/')
BACKUP_API_BASE = os.getenv('BACKUP_API_BASE')
if BACKUP_API_BASE:
print('Using BACKUP_API_BASE:', BACKUP_API_BASE)
APP_TITLE = os.getenv('APP_TITLE') APP_TITLE = os.getenv('APP_TITLE')
PRIMARY_MODEL_CHOICE = os.getenv('PRIMARY_MODEL_CHOICE')
def background():
while True:
previous = openai.api_base
try:
r = requests.get(API_BASE + '/stats').json()
if PRIMARY_MODEL_CHOICE in r['models']['choices'].keys():
openai.api_base = API_BASE + '/openai/' + PRIMARY_MODEL_CHOICE + '/v1'
else:
openai.api_base = API_BASE + '/openai/v1'
except:
traceback.print_exc()
openai.api_base = API_BASE + '/openai/v1'
if openai.api_base != previous:
print('Set primary model to', openai.api_base)
time.sleep(10)
if PRIMARY_MODEL_CHOICE:
t = Thread(target=background)
t.daemon = True
t.start()
print('Started the background thread.')
# A system prompt can be injected into the very first spot in the context. # A system prompt can be injected into the very first spot in the context.
# If the user sends a message that contains the CONTEXT_TRIGGER_PHRASE, # If the user sends a message that contains the CONTEXT_TRIGGER_PHRASE,
@ -26,7 +52,7 @@ CONTEXT_TRIGGER_PHRASE = os.getenv('CONTEXT_TRIGGER_PHRASE')
CONTEXT_TRIGGER_INJECTION = os.getenv('CONTEXT_TRIGGER_INJECTION') CONTEXT_TRIGGER_INJECTION = os.getenv('CONTEXT_TRIGGER_INJECTION')
openai.api_key = 'null' openai.api_key = 'null'
openai.api_base = API_BASE openai.api_base = API_BASE + '/openai/v1'
def stream_response(prompt, history): def stream_response(prompt, history):
@ -43,7 +69,6 @@ def stream_response(prompt, history):
if do_injection or (CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in prompt): if do_injection or (CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in prompt):
messages.insert(0, {'role': 'system', 'content': CONTEXT_TRIGGER_INJECTION}) messages.insert(0, {'role': 'system', 'content': CONTEXT_TRIGGER_INJECTION})
for _ in range(2):
try: try:
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model='0', model='0',
@ -52,14 +77,9 @@ def stream_response(prompt, history):
max_tokens=300, max_tokens=300,
stream=True stream=True
) )
break
except Exception: except Exception:
openai.api_base = BACKUP_API_BASE
raise gr.Error("Failed to reach inference endpoint.") raise gr.Error("Failed to reach inference endpoint.")
# Go back to the default endpoint
openai.api_base = API_BASE
message = '' message = ''
for chunk in response: for chunk in response:
if len(chunk['choices'][0]['delta']) != 0: if len(chunk['choices'][0]['delta']) != 0:

View File

@ -0,0 +1,3 @@
gradio
openai
requests