This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.
local-llm-server/other/gradio/gradio_chat.py

96 lines
3.0 KiB
Python
Raw Normal View History

2023-10-09 18:12:12 -06:00
import os
import sys
2023-10-11 12:56:04 -06:00
import time
import traceback
2023-10-09 18:12:12 -06:00
import warnings
2023-10-11 12:56:04 -06:00
from threading import Thread
2023-10-09 18:12:12 -06:00
import gradio as gr
import openai
2023-10-11 12:56:04 -06:00
import requests
2023-10-09 18:12:12 -06:00
warnings.filterwarnings("ignore")
API_BASE = os.getenv('API_BASE')
if not API_BASE:
2023-10-11 12:56:04 -06:00
print('Must set the secret variable API_BASE to your https://your-site/api')
2023-10-09 18:12:12 -06:00
sys.exit(1)
2023-10-11 12:56:04 -06:00
API_BASE = API_BASE.strip('/')
2023-10-11 09:09:41 -06:00
APP_TITLE = os.getenv('APP_TITLE')
2023-10-11 12:56:04 -06:00
PRIMARY_MODEL_CHOICE = os.getenv('PRIMARY_MODEL_CHOICE')
def background():
while True:
previous = openai.api_base
try:
r = requests.get(API_BASE + '/stats').json()
if PRIMARY_MODEL_CHOICE in r['models']['choices'].keys():
openai.api_base = API_BASE + '/openai/' + PRIMARY_MODEL_CHOICE + '/v1'
else:
openai.api_base = API_BASE + '/openai/v1'
except:
traceback.print_exc()
openai.api_base = API_BASE + '/openai/v1'
if openai.api_base != previous:
print('Set primary model to', openai.api_base)
time.sleep(10)
if PRIMARY_MODEL_CHOICE:
t = Thread(target=background)
t.daemon = True
t.start()
print('Started the background thread.')
2023-10-11 09:09:41 -06:00
2023-10-09 18:12:12 -06:00
# A system prompt can be injected into the very first spot in the context.
# If the user sends a message that contains the CONTEXT_TRIGGER_PHRASE,
# the content in CONTEXT_TRIGGER_INJECTION will be injected.
# Setting CONTEXT_TRIGGER_PHRASE will also add it to the selectable examples.
CONTEXT_TRIGGER_PHRASE = os.getenv('CONTEXT_TRIGGER_PHRASE')
CONTEXT_TRIGGER_INJECTION = os.getenv('CONTEXT_TRIGGER_INJECTION')
openai.api_key = 'null'
2023-10-11 12:56:04 -06:00
openai.api_base = API_BASE + '/openai/v1'
2023-10-09 18:12:12 -06:00
def stream_response(prompt, history):
messages = []
do_injection = False
for human, assistant in history:
messages.append({'role': 'user', 'content': str(human)})
messages.append({'role': 'assistant', 'content': str(assistant)})
if CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in human:
do_injection = True
messages.append({'role': 'user', 'content': prompt})
if do_injection or (CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in prompt):
messages.insert(0, {'role': 'system', 'content': CONTEXT_TRIGGER_INJECTION})
2023-10-11 12:56:04 -06:00
try:
response = openai.ChatCompletion.create(
model='0',
messages=messages,
temperature=0,
max_tokens=300,
2023-10-11 18:04:15 -06:00
stream=True,
headers={'LLM-Source': 'huggingface-demo'}
2023-10-11 12:56:04 -06:00
)
except Exception:
raise gr.Error("Failed to reach inference endpoint.")
2023-10-09 18:12:12 -06:00
message = ''
for chunk in response:
if len(chunk['choices'][0]['delta']) != 0:
message += chunk['choices'][0]['delta']['content']
yield message
2023-10-11 09:09:41 -06:00
examples = ["hello"]
2023-10-09 18:12:12 -06:00
if CONTEXT_TRIGGER_PHRASE:
examples.insert(0, CONTEXT_TRIGGER_PHRASE)
2023-10-11 09:09:41 -06:00
gr.ChatInterface(stream_response, examples=examples, title=APP_TITLE, analytics_enabled=False, cache_examples=False, css='#component-0{height:100%!important}').queue(concurrency_count=1, api_open=False).launch(show_api=False)