implement streaming for hf-textgen
This commit is contained in:
parent
26b04f364c
commit
bf648f605f
|
@ -52,3 +52,4 @@ should probably clear the `generation_time` time column in the `prompts` table.
|
||||||
- Convince Oobabooga to implement concurrent generation
|
- Convince Oobabooga to implement concurrent generation
|
||||||
- Make sure stats work when starting from an empty database
|
- Make sure stats work when starting from an empty database
|
||||||
- Make sure we're correctly canceling requests when the client cancels
|
- Make sure we're correctly canceling requests when the client cancels
|
||||||
|
- Implement auth and tokens on the websocket endpoint. Maybe add something to the instruct prompt and the remove it before proxying??
|
|
@ -19,8 +19,8 @@ config_default_vars = {
|
||||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||||
|
|
||||||
mode_ui_names = {
|
mode_ui_names = {
|
||||||
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url'),
|
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||||
'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url'),
|
'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
||||||
|
|
||||||
if not is_error:
|
if not is_error:
|
||||||
if not response_tokens:
|
if not response_tokens:
|
||||||
response_tokens = len(tokenizer.encode(response))
|
response_tokens = len(tokenizer.encode(response, disallowed_special=()))
|
||||||
else:
|
else:
|
||||||
response_tokens = None
|
response_tokens = None
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ database_path = './proxy-server.db'
|
||||||
auth_required = False
|
auth_required = False
|
||||||
log_prompts = False
|
log_prompts = False
|
||||||
frontend_api_client = ''
|
frontend_api_client = ''
|
||||||
full_client_api = None
|
base_client_api = None
|
||||||
http_host = None
|
http_host = None
|
||||||
verify_ssl = True
|
verify_ssl = True
|
||||||
show_num_prompts = True
|
show_num_prompts = True
|
||||||
|
|
|
@ -6,18 +6,16 @@ from ... import opts
|
||||||
bp = Blueprint('v1', __name__)
|
bp = Blueprint('v1', __name__)
|
||||||
|
|
||||||
|
|
||||||
# openai_bp = Blueprint('/v1', __name__)
|
|
||||||
|
|
||||||
@bp.before_request
|
@bp.before_request
|
||||||
def before_request():
|
def before_request():
|
||||||
if not opts.http_host:
|
if not opts.http_host:
|
||||||
opts.http_host = request.headers.get("Host")
|
opts.http_host = request.headers.get("Host")
|
||||||
if not opts.full_client_api:
|
if not opts.base_client_api:
|
||||||
opts.full_client_api = f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
||||||
if request.endpoint != 'v1.get_stats':
|
if request.endpoint != 'v1.get_stats':
|
||||||
response = require_api_key()
|
response = require_api_key()
|
||||||
if response is not None:
|
if response is not None:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
from . import generate, info, proxy
|
from . import generate, info, proxy, generate_stream
|
||||||
|
|
|
@ -81,7 +81,8 @@ def generate_stats():
|
||||||
},
|
},
|
||||||
'online': online,
|
'online': online,
|
||||||
'endpoints': {
|
'endpoints': {
|
||||||
'blocking': opts.full_client_api,
|
'blocking': f'https://{opts.base_client_api}',
|
||||||
|
'streaming': f'wss://{opts.base_client_api}/v1/stream',
|
||||||
},
|
},
|
||||||
'queue': {
|
'queue': {
|
||||||
'processing': active_gen_workers,
|
'processing': active_gen_workers,
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from flask import request
|
||||||
|
|
||||||
|
from ..helpers.client import format_sillytavern_err
|
||||||
|
from ... import opts
|
||||||
|
from ...database import log_prompt
|
||||||
|
from ...helpers import indefinite_article
|
||||||
|
from ...llm.hf_textgen.generate import prepare_json
|
||||||
|
from ...stream import sock
|
||||||
|
|
||||||
|
|
||||||
|
@sock.route('/api/v1/stream') # TODO: use blueprint route???
|
||||||
|
def stream(ws):
|
||||||
|
start_time = time.time()
|
||||||
|
if request.headers.get('cf-connecting-ip'):
|
||||||
|
client_ip = request.headers.get('cf-connecting-ip')
|
||||||
|
elif request.headers.get('x-forwarded-for'):
|
||||||
|
client_ip = request.headers.get('x-forwarded-for').split(',')[0]
|
||||||
|
else:
|
||||||
|
client_ip = request.remote_addr
|
||||||
|
token = request.headers.get('X-Api-Key')
|
||||||
|
|
||||||
|
message_num = 0
|
||||||
|
while ws.connected:
|
||||||
|
message = ws.receive()
|
||||||
|
data = json.loads(message)
|
||||||
|
|
||||||
|
if opts.mode == 'hf-textgen':
|
||||||
|
response = requests.post(f'{opts.backend_url}/generate_stream', json=prepare_json(data), stream=True, verify=False)
|
||||||
|
|
||||||
|
# Be extra careful when getting attributes from the response object
|
||||||
|
try:
|
||||||
|
response_status_code = response.status_code
|
||||||
|
except:
|
||||||
|
response_status_code = 0
|
||||||
|
|
||||||
|
details = {}
|
||||||
|
generated_text = ''
|
||||||
|
|
||||||
|
# Iterate over each line in the response
|
||||||
|
for line in response.iter_lines():
|
||||||
|
# Decode the line to a string
|
||||||
|
line = line.decode('utf-8')
|
||||||
|
# If the line starts with 'data:', remove the prefix and parse the remaining string as JSON
|
||||||
|
if line.startswith('data:'):
|
||||||
|
line = line[5:]
|
||||||
|
json_data = json.loads(line)
|
||||||
|
details = json_data.get('details', {})
|
||||||
|
generated_text = json_data.get('generated_text', '')
|
||||||
|
|
||||||
|
if json_data.get('error'):
|
||||||
|
error_type = json_data.get('error_type')
|
||||||
|
error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error'
|
||||||
|
generated_text = format_sillytavern_err(
|
||||||
|
f'Backend ({opts.mode}) {error_type_string}: {json_data.get("error")}',
|
||||||
|
f'HTTP CODE {response_status_code}')
|
||||||
|
ws.send(json.dumps({
|
||||||
|
'event': 'text_stream',
|
||||||
|
'message_num': message_num,
|
||||||
|
'text': generated_text
|
||||||
|
}))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
ws.send(json.dumps({
|
||||||
|
'event': 'text_stream',
|
||||||
|
'message_num': message_num,
|
||||||
|
'text': json_data['token']['text']
|
||||||
|
}))
|
||||||
|
message_num += 1
|
||||||
|
|
||||||
|
ws.send(json.dumps({
|
||||||
|
'event': 'stream_end',
|
||||||
|
'message_num': message_num
|
||||||
|
}))
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
parameters = data.copy()
|
||||||
|
del parameters['prompt']
|
||||||
|
|
||||||
|
log_prompt(client_ip, token, data['prompt'], generated_text, elapsed_time, parameters, dict(request.headers), response_status_code, response_tokens=details['generated_tokens'])
|
|
@ -0,0 +1,8 @@
|
||||||
|
from flask_sock import Sock
|
||||||
|
|
||||||
|
sock = Sock()
|
||||||
|
|
||||||
|
|
||||||
|
def init_socketio(app):
|
||||||
|
global sock
|
||||||
|
sock.init_app(app)
|
|
@ -8,3 +8,4 @@ gunicorn
|
||||||
redis
|
redis
|
||||||
gevent
|
gevent
|
||||||
async-timeout
|
async-timeout
|
||||||
|
flask-socketio
|
||||||
|
|
10
server.py
10
server.py
|
@ -17,6 +17,7 @@ from llm_server.routes.queue import start_workers
|
||||||
from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
|
from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
|
||||||
from llm_server.routes.v1 import bp
|
from llm_server.routes.v1 import bp
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
|
from llm_server.stream import init_socketio
|
||||||
from llm_server.threads import MainBackgroundThread
|
from llm_server.threads import MainBackgroundThread
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
@ -88,6 +89,7 @@ SemaphoreCheckerThread().start()
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
cache.init_app(app)
|
cache.init_app(app)
|
||||||
cache.clear() # clear redis cache
|
cache.clear() # clear redis cache
|
||||||
|
init_socketio(app)
|
||||||
# with app.app_context():
|
# with app.app_context():
|
||||||
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
|
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||||
|
@ -100,8 +102,8 @@ app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||||
@app.route('/api')
|
@app.route('/api')
|
||||||
@cache.cached(timeout=10, query_string=True)
|
@cache.cached(timeout=10, query_string=True)
|
||||||
def home():
|
def home():
|
||||||
if not opts.full_client_api:
|
if not opts.base_client_api:
|
||||||
opts.full_client_api = f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
||||||
stats = generate_stats()
|
stats = generate_stats()
|
||||||
|
|
||||||
if not bool(redis.get('backend_online')) or not stats['online']:
|
if not bool(redis.get('backend_online')) or not stats['online']:
|
||||||
|
@ -131,10 +133,12 @@ def home():
|
||||||
analytics_tracking_code=analytics_tracking_code,
|
analytics_tracking_code=analytics_tracking_code,
|
||||||
info_html=info_html,
|
info_html=info_html,
|
||||||
current_model=running_model,
|
current_model=running_model,
|
||||||
client_api=opts.full_client_api,
|
client_api=f'https://{opts.base_client_api}',
|
||||||
|
ws_client_api=f'wss://{opts.base_client_api}/v1/stream',
|
||||||
estimated_wait=estimated_wait_sec,
|
estimated_wait=estimated_wait_sec,
|
||||||
mode_name=mode_ui_names[opts.mode][0],
|
mode_name=mode_ui_names[opts.mode][0],
|
||||||
api_input_textbox=mode_ui_names[opts.mode][1],
|
api_input_textbox=mode_ui_names[opts.mode][1],
|
||||||
|
streaming_input_textbox=mode_ui_names[opts.mode][2],
|
||||||
context_size=opts.context_size,
|
context_size=opts.context_size,
|
||||||
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
||||||
extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '',
|
extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '',
|
||||||
|
|
|
@ -71,6 +71,7 @@
|
||||||
<div class="info-box">
|
<div class="info-box">
|
||||||
<p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p>
|
<p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p>
|
||||||
<p><strong>Client API URL:</strong> {{ client_api }}</p>
|
<p><strong>Client API URL:</strong> {{ client_api }}</p>
|
||||||
|
<p><strong>Streaming API URL:</strong> {{ ws_client_api }}</p>
|
||||||
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p>
|
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p>
|
||||||
{{ info_html|safe }}
|
{{ info_html|safe }}
|
||||||
</div>
|
</div>
|
||||||
|
@ -83,6 +84,10 @@
|
||||||
<ol>
|
<ol>
|
||||||
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
|
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
|
||||||
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
|
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
|
||||||
|
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>
|
||||||
|
<li>If using a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
|
||||||
|
API key</kbd> textbox.
|
||||||
|
</li>
|
||||||
<li>Click <kbd>Connect</kbd> to test the connection.</li>
|
<li>Click <kbd>Connect</kbd> to test the connection.</li>
|
||||||
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li>
|
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li>
|
||||||
<li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a>
|
<li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a>
|
||||||
|
|
Reference in New Issue