224 lines
10 KiB
Python
224 lines
10 KiB
Python
import time
|
|
import traceback
|
|
|
|
import simplejson as json
|
|
import ujson
|
|
from flask import Response, jsonify, request
|
|
from redis import Redis
|
|
|
|
from llm_server.custom_redis import redis
|
|
from . import openai_bp, openai_model_bp
|
|
from ..helpers.http import validate_json
|
|
from ..ooba_request_handler import OobaRequestHandler
|
|
from ..queue import priority_queue
|
|
from ...config.global_config import GlobalConfig
|
|
from ...database.log_to_db import log_to_db
|
|
from ...llm import get_token_count
|
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
|
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
|
from ...logging import create_logger
|
|
|
|
# TODO: add rate-limit headers?
|
|
|
|
_logger = create_logger('OpenAICompletions')
|
|
|
|
|
|
@openai_bp.route('/completions', methods=['POST'])
|
|
@openai_model_bp.route('/<model_name>/v1/completions', methods=['POST'])
|
|
def openai_completions(model_name=None):
|
|
request_valid_json, request_json_body = validate_json(request)
|
|
if not request_valid_json or not request_json_body.get('prompt'):
|
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
|
else:
|
|
handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
|
if handler.offline:
|
|
return return_invalid_model_err(model_name)
|
|
|
|
if handler.cluster_backend_info['mode'] != 'vllm':
|
|
# TODO: implement other backends
|
|
raise NotImplementedError
|
|
|
|
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
|
if invalid_oai_err_msg:
|
|
return invalid_oai_err_msg
|
|
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
|
|
|
|
if GlobalConfig.get().openai_silent_trim:
|
|
handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
|
else:
|
|
# The handle_request() call below will load the prompt so we don't have
|
|
# to do anything else here.
|
|
pass
|
|
|
|
handler.request_json_body['prompt'] = handler.prompt
|
|
|
|
if not request_json_body.get('stream'):
|
|
invalid_oai_err_msg = validate_oai(request_json_body)
|
|
if invalid_oai_err_msg:
|
|
return invalid_oai_err_msg
|
|
response, status_code = handler.handle_request(return_ok=False)
|
|
if status_code == 429:
|
|
return handler.handle_ratelimited()
|
|
output = response.json['results'][0]['text']
|
|
|
|
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
|
|
response_tokens = get_token_count(output, handler.backend_url)
|
|
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
|
|
|
response = jsonify({
|
|
"id": f"cmpl-{generate_oai_string(30)}",
|
|
"object": "text_completion",
|
|
"created": int(time.time()),
|
|
"model": running_model if GlobalConfig.get().openai_expose_our_model else request_json_body.get('model'),
|
|
"choices": [
|
|
{
|
|
"text": output,
|
|
"index": 0,
|
|
"logprobs": None,
|
|
"finish_reason": "stop"
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": response_tokens,
|
|
"total_tokens": prompt_tokens + response_tokens
|
|
}
|
|
})
|
|
|
|
# TODO:
|
|
# stats = redis.get('proxy_stats', dtype=dict)
|
|
# if stats:
|
|
# response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
|
return response, 200
|
|
else:
|
|
if not GlobalConfig.get().enable_streaming:
|
|
return 'Streaming disabled', 403
|
|
|
|
request_valid, invalid_response = handler.validate_request()
|
|
if not request_valid:
|
|
return invalid_response
|
|
|
|
handler.parameters, _ = handler.get_parameters()
|
|
handler.request_json_body = {
|
|
'prompt': handler.request_json_body['prompt'],
|
|
'model': handler.request_json_body['model'],
|
|
**handler.parameters
|
|
}
|
|
|
|
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
|
if invalid_oai_err_msg:
|
|
return invalid_oai_err_msg
|
|
|
|
if GlobalConfig.get().openai_silent_trim:
|
|
handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']]
|
|
if not handler.prompt:
|
|
# Prevent issues on the backend.
|
|
return 'Invalid prompt', 400
|
|
|
|
start_time = time.time()
|
|
|
|
request_valid, invalid_response = handler.validate_request()
|
|
if not request_valid:
|
|
return invalid_response
|
|
|
|
event = None
|
|
if not handler.is_client_ratelimited():
|
|
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
|
if not event:
|
|
log_to_db(
|
|
handler.client_ip,
|
|
handler.token,
|
|
handler.prompt,
|
|
None,
|
|
None,
|
|
handler.parameters,
|
|
request.headers,
|
|
429,
|
|
request.url,
|
|
handler.backend_url,
|
|
)
|
|
return handler.handle_ratelimited()
|
|
|
|
try:
|
|
r_headers = dict(request.headers)
|
|
r_url = request.url
|
|
model = redis.get('running_model', 'ERROR', dtype=str) if GlobalConfig.get().openai_expose_our_model else request_json_body.get('model')
|
|
oai_string = generate_oai_string(30)
|
|
|
|
_, stream_name, error_msg = event.wait()
|
|
if error_msg:
|
|
_logger.error(f'OAI failed to start streaming: {error_msg}')
|
|
stream_name = None
|
|
return 'Request Timeout', 408
|
|
|
|
def generate():
|
|
stream_redis = Redis(db=8)
|
|
generated_text = ''
|
|
try:
|
|
last_id = '0-0'
|
|
while True:
|
|
stream_data = stream_redis.xread({stream_name: last_id}, block=GlobalConfig.get().redis_stream_timeout)
|
|
if not stream_data:
|
|
_logger.debug(f"No message received in {GlobalConfig.get().redis_stream_timeout / 1000} seconds, closing stream.")
|
|
yield 'data: [DONE]\n\n'
|
|
else:
|
|
for stream_index, item in stream_data[0][1]:
|
|
last_id = stream_index
|
|
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
|
data = ujson.loads(item[b'data'])
|
|
if data['error']:
|
|
_logger.error(f'OAI streaming encountered error: {data["error"]}')
|
|
yield 'data: [DONE]\n\n'
|
|
return
|
|
elif data['new']:
|
|
response = {
|
|
"id": f"cmpl-{oai_string}",
|
|
"object": "text_completion",
|
|
"created": timestamp,
|
|
"model": model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {
|
|
"content": data['new']
|
|
},
|
|
"finish_reason": None
|
|
}
|
|
]
|
|
}
|
|
generated_text = generated_text + data['new']
|
|
yield f'data: {json.dumps(response)}\n\n'
|
|
elif data['completed']:
|
|
yield 'data: [DONE]\n\n'
|
|
end_time = time.time()
|
|
elapsed_time = end_time - start_time
|
|
log_to_db(
|
|
handler.client_ip,
|
|
handler.token,
|
|
handler.prompt,
|
|
generated_text,
|
|
elapsed_time,
|
|
handler.parameters,
|
|
r_headers,
|
|
200,
|
|
r_url,
|
|
handler.backend_url,
|
|
)
|
|
return
|
|
except GeneratorExit:
|
|
# This should be triggered if a client disconnects early.
|
|
return
|
|
except Exception:
|
|
traceback.print_exc()
|
|
yield 'data: [DONE]\n\n'
|
|
finally:
|
|
if event:
|
|
redis.publish(f'notifications:{event.event_id}', 'canceled')
|
|
if stream_name:
|
|
stream_redis.delete(stream_name)
|
|
|
|
return Response(generate(), mimetype='text/event-stream')
|
|
except Exception:
|
|
traceback.print_exc()
|
|
return 'INTERNAL SERVER', 500
|