add background thread to gradio
This commit is contained in:
parent
69b8c1e35c
commit
169e216a38
|
@ -1,22 +1,48 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import openai
|
import openai
|
||||||
|
import requests
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
API_BASE = os.getenv('API_BASE')
|
API_BASE = os.getenv('API_BASE')
|
||||||
if not API_BASE:
|
if not API_BASE:
|
||||||
print('Must set the secret variable API_BASE to your https://your-site/api/openai/v1')
|
print('Must set the secret variable API_BASE to your https://your-site/api')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
API_BASE = API_BASE.strip('/')
|
||||||
BACKUP_API_BASE = os.getenv('BACKUP_API_BASE')
|
|
||||||
if BACKUP_API_BASE:
|
|
||||||
print('Using BACKUP_API_BASE:', BACKUP_API_BASE)
|
|
||||||
|
|
||||||
APP_TITLE = os.getenv('APP_TITLE')
|
APP_TITLE = os.getenv('APP_TITLE')
|
||||||
|
PRIMARY_MODEL_CHOICE = os.getenv('PRIMARY_MODEL_CHOICE')
|
||||||
|
|
||||||
|
|
||||||
|
def background():
|
||||||
|
while True:
|
||||||
|
previous = openai.api_base
|
||||||
|
try:
|
||||||
|
r = requests.get(API_BASE + '/stats').json()
|
||||||
|
if PRIMARY_MODEL_CHOICE in r['models']['choices'].keys():
|
||||||
|
openai.api_base = API_BASE + '/openai/' + PRIMARY_MODEL_CHOICE + '/v1'
|
||||||
|
else:
|
||||||
|
openai.api_base = API_BASE + '/openai/v1'
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
|
openai.api_base = API_BASE + '/openai/v1'
|
||||||
|
if openai.api_base != previous:
|
||||||
|
print('Set primary model to', openai.api_base)
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
|
if PRIMARY_MODEL_CHOICE:
|
||||||
|
t = Thread(target=background)
|
||||||
|
t.daemon = True
|
||||||
|
t.start()
|
||||||
|
print('Started the background thread.')
|
||||||
|
|
||||||
# A system prompt can be injected into the very first spot in the context.
|
# A system prompt can be injected into the very first spot in the context.
|
||||||
# If the user sends a message that contains the CONTEXT_TRIGGER_PHRASE,
|
# If the user sends a message that contains the CONTEXT_TRIGGER_PHRASE,
|
||||||
|
@ -26,7 +52,7 @@ CONTEXT_TRIGGER_PHRASE = os.getenv('CONTEXT_TRIGGER_PHRASE')
|
||||||
CONTEXT_TRIGGER_INJECTION = os.getenv('CONTEXT_TRIGGER_INJECTION')
|
CONTEXT_TRIGGER_INJECTION = os.getenv('CONTEXT_TRIGGER_INJECTION')
|
||||||
|
|
||||||
openai.api_key = 'null'
|
openai.api_key = 'null'
|
||||||
openai.api_base = API_BASE
|
openai.api_base = API_BASE + '/openai/v1'
|
||||||
|
|
||||||
|
|
||||||
def stream_response(prompt, history):
|
def stream_response(prompt, history):
|
||||||
|
@ -43,7 +69,6 @@ def stream_response(prompt, history):
|
||||||
if do_injection or (CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in prompt):
|
if do_injection or (CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in prompt):
|
||||||
messages.insert(0, {'role': 'system', 'content': CONTEXT_TRIGGER_INJECTION})
|
messages.insert(0, {'role': 'system', 'content': CONTEXT_TRIGGER_INJECTION})
|
||||||
|
|
||||||
for _ in range(2):
|
|
||||||
try:
|
try:
|
||||||
response = openai.ChatCompletion.create(
|
response = openai.ChatCompletion.create(
|
||||||
model='0',
|
model='0',
|
||||||
|
@ -52,14 +77,9 @@ def stream_response(prompt, history):
|
||||||
max_tokens=300,
|
max_tokens=300,
|
||||||
stream=True
|
stream=True
|
||||||
)
|
)
|
||||||
break
|
|
||||||
except Exception:
|
except Exception:
|
||||||
openai.api_base = BACKUP_API_BASE
|
|
||||||
raise gr.Error("Failed to reach inference endpoint.")
|
raise gr.Error("Failed to reach inference endpoint.")
|
||||||
|
|
||||||
# Go back to the default endpoint
|
|
||||||
openai.api_base = API_BASE
|
|
||||||
|
|
||||||
message = ''
|
message = ''
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if len(chunk['choices'][0]['delta']) != 0:
|
if len(chunk['choices'][0]['delta']) != 0:
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
gradio
|
||||||
|
openai
|
||||||
|
requests
|
Reference in New Issue