implement streaming for hf-textgen
This commit is contained in:
parent
26b04f364c
commit
bf648f605f
|
@ -52,3 +52,4 @@ should probably clear the `generation_time` time column in the `prompts` table.
|
|||
- Convince Oobabooga to implement concurrent generation
|
||||
- Make sure stats work when starting from an empty database
|
||||
- Make sure we're correctly canceling requests when the client cancels
|
||||
- Implement auth and tokens on the websocket endpoint. Maybe add something to the instruct prompt and the remove it before proxying??
|
|
@ -19,8 +19,8 @@ config_default_vars = {
|
|||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||
|
||||
mode_ui_names = {
|
||||
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url'),
|
||||
'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url'),
|
||||
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
|
||||
if not is_error:
|
||||
if not response_tokens:
|
||||
response_tokens = len(tokenizer.encode(response))
|
||||
response_tokens = len(tokenizer.encode(response, disallowed_special=()))
|
||||
else:
|
||||
response_tokens = None
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ database_path = './proxy-server.db'
|
|||
auth_required = False
|
||||
log_prompts = False
|
||||
frontend_api_client = ''
|
||||
full_client_api = None
|
||||
base_client_api = None
|
||||
http_host = None
|
||||
verify_ssl = True
|
||||
show_num_prompts = True
|
||||
|
|
|
@ -6,18 +6,16 @@ from ... import opts
|
|||
bp = Blueprint('v1', __name__)
|
||||
|
||||
|
||||
# openai_bp = Blueprint('/v1', __name__)
|
||||
|
||||
@bp.before_request
|
||||
def before_request():
|
||||
if not opts.http_host:
|
||||
opts.http_host = request.headers.get("Host")
|
||||
if not opts.full_client_api:
|
||||
opts.full_client_api = f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
||||
if not opts.base_client_api:
|
||||
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
||||
if request.endpoint != 'v1.get_stats':
|
||||
response = require_api_key()
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
|
||||
from . import generate, info, proxy
|
||||
from . import generate, info, proxy, generate_stream
|
||||
|
|
|
@ -81,7 +81,8 @@ def generate_stats():
|
|||
},
|
||||
'online': online,
|
||||
'endpoints': {
|
||||
'blocking': opts.full_client_api,
|
||||
'blocking': f'https://{opts.base_client_api}',
|
||||
'streaming': f'wss://{opts.base_client_api}/v1/stream',
|
||||
},
|
||||
'queue': {
|
||||
'processing': active_gen_workers,
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
import json
|
||||
import time
|
||||
|
||||
import requests
|
||||
from flask import request
|
||||
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ... import opts
|
||||
from ...database import log_prompt
|
||||
from ...helpers import indefinite_article
|
||||
from ...llm.hf_textgen.generate import prepare_json
|
||||
from ...stream import sock
|
||||
|
||||
|
||||
@sock.route('/api/v1/stream') # TODO: use blueprint route???
|
||||
def stream(ws):
|
||||
start_time = time.time()
|
||||
if request.headers.get('cf-connecting-ip'):
|
||||
client_ip = request.headers.get('cf-connecting-ip')
|
||||
elif request.headers.get('x-forwarded-for'):
|
||||
client_ip = request.headers.get('x-forwarded-for').split(',')[0]
|
||||
else:
|
||||
client_ip = request.remote_addr
|
||||
token = request.headers.get('X-Api-Key')
|
||||
|
||||
message_num = 0
|
||||
while ws.connected:
|
||||
message = ws.receive()
|
||||
data = json.loads(message)
|
||||
|
||||
if opts.mode == 'hf-textgen':
|
||||
response = requests.post(f'{opts.backend_url}/generate_stream', json=prepare_json(data), stream=True, verify=False)
|
||||
|
||||
# Be extra careful when getting attributes from the response object
|
||||
try:
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
|
||||
details = {}
|
||||
generated_text = ''
|
||||
|
||||
# Iterate over each line in the response
|
||||
for line in response.iter_lines():
|
||||
# Decode the line to a string
|
||||
line = line.decode('utf-8')
|
||||
# If the line starts with 'data:', remove the prefix and parse the remaining string as JSON
|
||||
if line.startswith('data:'):
|
||||
line = line[5:]
|
||||
json_data = json.loads(line)
|
||||
details = json_data.get('details', {})
|
||||
generated_text = json_data.get('generated_text', '')
|
||||
|
||||
if json_data.get('error'):
|
||||
error_type = json_data.get('error_type')
|
||||
error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error'
|
||||
generated_text = format_sillytavern_err(
|
||||
f'Backend ({opts.mode}) {error_type_string}: {json_data.get("error")}',
|
||||
f'HTTP CODE {response_status_code}')
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': generated_text
|
||||
}))
|
||||
break
|
||||
else:
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': json_data['token']['text']
|
||||
}))
|
||||
message_num += 1
|
||||
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
parameters = data.copy()
|
||||
del parameters['prompt']
|
||||
|
||||
log_prompt(client_ip, token, data['prompt'], generated_text, elapsed_time, parameters, dict(request.headers), response_status_code, response_tokens=details['generated_tokens'])
|
|
@ -0,0 +1,8 @@
|
|||
from flask_sock import Sock
|
||||
|
||||
sock = Sock()
|
||||
|
||||
|
||||
def init_socketio(app):
|
||||
global sock
|
||||
sock.init_app(app)
|
|
@ -7,4 +7,5 @@ tiktoken
|
|||
gunicorn
|
||||
redis
|
||||
gevent
|
||||
async-timeout
|
||||
async-timeout
|
||||
flask-socketio
|
||||
|
|
10
server.py
10
server.py
|
@ -17,6 +17,7 @@ from llm_server.routes.queue import start_workers
|
|||
from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
|
||||
from llm_server.routes.v1 import bp
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.stream import init_socketio
|
||||
from llm_server.threads import MainBackgroundThread
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
@ -88,6 +89,7 @@ SemaphoreCheckerThread().start()
|
|||
app = Flask(__name__)
|
||||
cache.init_app(app)
|
||||
cache.clear() # clear redis cache
|
||||
init_socketio(app)
|
||||
# with app.app_context():
|
||||
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||
|
@ -100,8 +102,8 @@ app.register_blueprint(bp, url_prefix='/api/v1/')
|
|||
@app.route('/api')
|
||||
@cache.cached(timeout=10, query_string=True)
|
||||
def home():
|
||||
if not opts.full_client_api:
|
||||
opts.full_client_api = f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
||||
if not opts.base_client_api:
|
||||
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
||||
stats = generate_stats()
|
||||
|
||||
if not bool(redis.get('backend_online')) or not stats['online']:
|
||||
|
@ -131,10 +133,12 @@ def home():
|
|||
analytics_tracking_code=analytics_tracking_code,
|
||||
info_html=info_html,
|
||||
current_model=running_model,
|
||||
client_api=opts.full_client_api,
|
||||
client_api=f'https://{opts.base_client_api}',
|
||||
ws_client_api=f'wss://{opts.base_client_api}/v1/stream',
|
||||
estimated_wait=estimated_wait_sec,
|
||||
mode_name=mode_ui_names[opts.mode][0],
|
||||
api_input_textbox=mode_ui_names[opts.mode][1],
|
||||
streaming_input_textbox=mode_ui_names[opts.mode][2],
|
||||
context_size=opts.context_size,
|
||||
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
||||
extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '',
|
||||
|
|
|
@ -71,6 +71,7 @@
|
|||
<div class="info-box">
|
||||
<p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p>
|
||||
<p><strong>Client API URL:</strong> {{ client_api }}</p>
|
||||
<p><strong>Streaming API URL:</strong> {{ ws_client_api }}</p>
|
||||
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p>
|
||||
{{ info_html|safe }}
|
||||
</div>
|
||||
|
@ -83,6 +84,10 @@
|
|||
<ol>
|
||||
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
|
||||
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
|
||||
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>
|
||||
<li>If using a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
|
||||
API key</kbd> textbox.
|
||||
</li>
|
||||
<li>Click <kbd>Connect</kbd> to test the connection.</li>
|
||||
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li>
|
||||
<li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a>
|
||||
|
|
Reference in New Issue