This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.
local-llm-server/llm_server/routes/openai/completions.py

162 lines
7.3 KiB
Python
Raw Normal View History

2023-09-26 22:09:11 -06:00
import time
import traceback
2023-10-01 14:15:01 -06:00
import simplejson as json
from flask import Response, jsonify, request
2023-09-26 22:09:11 -06:00
from llm_server.custom_redis import redis
2023-10-01 14:15:01 -06:00
from . import openai_bp
2023-09-26 22:09:11 -06:00
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
from ... import opts
2023-10-01 14:15:01 -06:00
from ...database.database import log_prompt
2023-09-26 22:09:11 -06:00
from ...llm import get_token_count
2023-10-01 14:15:01 -06:00
from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
2023-09-26 22:09:11 -06:00
# TODO: add rate-limit headers?
@openai_bp.route('/completions', methods=['POST'])
def openai_completions():
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
try:
2023-10-01 14:15:01 -06:00
handler = OobaRequestHandler(incoming_request=request)
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
# Convert parameters to the selected backend type
if opts.openai_silent_trim:
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else:
# The handle_request() call below will load the prompt so we don't have
# to do anything else here.
pass
if not request_json_body.get('stream'):
response, status_code = handler.handle_request()
if status_code != 200:
return status_code
output = response.json['results'][0]['text']
# TODO: async/await
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
response_tokens = get_token_count(output, handler.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = jsonify({
"id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
"choices": [
{
"text": output,
"index": 0,
"logprobs": None,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
2023-09-26 22:09:11 -06:00
}
2023-10-01 14:15:01 -06:00
})
stats = redis.get('proxy_stats', dtype=dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response, 200
else:
if not opts.enable_streaming:
# TODO: return a proper OAI error message
return 'disabled', 401
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
# TODO: simulate OAI here
raise Exception('TODO: simulate OAI here')
else:
handler.prompt = handler.request_json_body['prompt']
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
response = generator(msg_to_backend, handler.backend_url)
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
def generate():
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
data = {
"id": f"chatcmpl-{oai_string}",
"object": "text_completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
response_status_code,
r_url,
handler.backend_url,
)
return Response(generate(), mimetype='text/event-stream')
2023-09-27 14:48:47 -06:00
except Exception:
traceback.print_exc()
return 'Internal Server Error', 500