local-llm-server/llm_server/routes/v1/generate_stream.py

87 lines
3.4 KiB
Python
Raw Normal View History

2023-08-29 17:56:12 -06:00
import json
import time
import requests
from flask import request
from ..helpers.client import format_sillytavern_err
from ... import opts
from ...database import log_prompt
from ...helpers import indefinite_article
from ...stream import sock
2023-08-30 18:53:26 -06:00
# TODO: have workers process streaming requests
2023-08-29 17:56:12 -06:00
@sock.route('/api/v1/stream') # TODO: use blueprint route???
def stream(ws):
2023-08-30 19:58:59 -06:00
return 'disabled', 401
# start_time = time.time()
# if request.headers.get('cf-connecting-ip'):
# client_ip = request.headers.get('cf-connecting-ip')
# elif request.headers.get('x-forwarded-for'):
# client_ip = request.headers.get('x-forwarded-for').split(',')[0]
# else:
# client_ip = request.remote_addr
# token = request.headers.get('X-Api-Key')
#
# message_num = 0
# while ws.connected:
# message = ws.receive()
# data = json.loads(message)
#
# if opts.mode == 'hf-textgen':
# response = requests.post(f'{opts.backend_url}/generate_stream', json=prepare_json(data), stream=True, verify=False)
#
# # Be extra careful when getting attributes from the response object
# try:
# response_status_code = response.status_code
# except:
# response_status_code = 0
#
# details = {}
# generated_text = ''
#
# # Iterate over each line in the response
# for line in response.iter_lines():
# # Decode the line to a string
# line = line.decode('utf-8')
# # If the line starts with 'data:', remove the prefix and parse the remaining string as JSON
# if line.startswith('data:'):
# line = line[5:]
# json_data = json.loads(line)
# details = json_data.get('details', {})
# generated_text = json_data.get('generated_text', '')
#
# if json_data.get('error'):
# error_type = json_data.get('error_type')
# error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error'
# generated_text = format_sillytavern_err(
# f'Backend ({opts.mode}) {error_type_string}: {json_data.get("error")}',
# f'HTTP CODE {response_status_code}')
# ws.send(json.dumps({
# 'event': 'text_stream',
# 'message_num': message_num,
# 'text': generated_text
# }))
# break
# else:
# ws.send(json.dumps({
# 'event': 'text_stream',
# 'message_num': message_num,
# 'text': json_data['token']['text']
# }))
# message_num += 1
#
# ws.send(json.dumps({
# 'event': 'stream_end',
# 'message_num': message_num
# }))
#
# end_time = time.time()
# elapsed_time = end_time - start_time
# parameters = data.copy()
# del parameters['prompt']
#
# log_prompt(client_ip, token, data['prompt'], generated_text, elapsed_time, parameters, dict(request.headers), response_status_code, response_tokens=details['generated_tokens'])