add background thread to gradio

This commit is contained in:
Cyberes 2023-10-11 12:56:04 -06:00
parent 69b8c1e35c
commit 169e216a38
2 changed files with 45 additions and 22 deletions

View File

@ -1,22 +1,48 @@
import os
import sys
import time
import traceback
import warnings
from threading import Thread
import gradio as gr
import openai
import requests
warnings.filterwarnings("ignore")
API_BASE = os.getenv('API_BASE')
if not API_BASE:
print('Must set the secret variable API_BASE to your https://your-site/api/openai/v1')
print('Must set the secret variable API_BASE to your https://your-site/api')
sys.exit(1)
BACKUP_API_BASE = os.getenv('BACKUP_API_BASE')
if BACKUP_API_BASE:
print('Using BACKUP_API_BASE:', BACKUP_API_BASE)
API_BASE = API_BASE.strip('/')
APP_TITLE = os.getenv('APP_TITLE')
PRIMARY_MODEL_CHOICE = os.getenv('PRIMARY_MODEL_CHOICE')
def background():
while True:
previous = openai.api_base
try:
r = requests.get(API_BASE + '/stats').json()
if PRIMARY_MODEL_CHOICE in r['models']['choices'].keys():
openai.api_base = API_BASE + '/openai/' + PRIMARY_MODEL_CHOICE + '/v1'
else:
openai.api_base = API_BASE + '/openai/v1'
except:
traceback.print_exc()
openai.api_base = API_BASE + '/openai/v1'
if openai.api_base != previous:
print('Set primary model to', openai.api_base)
time.sleep(10)
if PRIMARY_MODEL_CHOICE:
t = Thread(target=background)
t.daemon = True
t.start()
print('Started the background thread.')
# A system prompt can be injected into the very first spot in the context.
# If the user sends a message that contains the CONTEXT_TRIGGER_PHRASE,
@ -26,7 +52,7 @@ CONTEXT_TRIGGER_PHRASE = os.getenv('CONTEXT_TRIGGER_PHRASE')
CONTEXT_TRIGGER_INJECTION = os.getenv('CONTEXT_TRIGGER_INJECTION')
openai.api_key = 'null'
openai.api_base = API_BASE
openai.api_base = API_BASE + '/openai/v1'
def stream_response(prompt, history):
@ -43,7 +69,6 @@ def stream_response(prompt, history):
if do_injection or (CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in prompt):
messages.insert(0, {'role': 'system', 'content': CONTEXT_TRIGGER_INJECTION})
for _ in range(2):
try:
response = openai.ChatCompletion.create(
model='0',
@ -52,14 +77,9 @@ def stream_response(prompt, history):
max_tokens=300,
stream=True
)
break
except Exception:
openai.api_base = BACKUP_API_BASE
raise gr.Error("Failed to reach inference endpoint.")
# Go back to the default endpoint
openai.api_base = API_BASE
message = ''
for chunk in response:
if len(chunk['choices'][0]['delta']) != 0:

View File

@ -0,0 +1,3 @@
gradio
openai
requests