finish openai endpoints
This commit is contained in:
parent
2a3ff7e21e
commit
f7e9687527
|
@ -43,12 +43,10 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
|
|||
|
||||
### Use
|
||||
|
||||
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis.
|
||||
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. You may need to wait a few minutes for the daemon to populate the database.
|
||||
|
||||
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`
|
||||
|
||||
|
||||
|
||||
### To Do
|
||||
|
||||
- [x] Implement streaming
|
||||
|
|
|
@ -14,8 +14,11 @@ def test_backend(backend_url: str, test_prompt: bool = False):
|
|||
"temperature": 0,
|
||||
"max_new_tokens": 3,
|
||||
}
|
||||
success, response, err = generator(data, backend_url, timeout=10)
|
||||
if not success or not response or err:
|
||||
try:
|
||||
success, response, err = generator(data, backend_url, timeout=10)
|
||||
if not success or not response or err:
|
||||
return False, {}
|
||||
except:
|
||||
return False, {}
|
||||
i = get_info(backend_url, backend_info['mode'])
|
||||
if not i.get('model'):
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
|
||||
|
||||
def generator(request_json_body, cluster_backend, timeout: int = None):
|
||||
if opts.mode == 'oobabooga':
|
||||
mode = cluster_config.get_backend(cluster_backend)['mode']
|
||||
if mode == 'ooba':
|
||||
# from .oobabooga.generate import generate
|
||||
# return generate(request_json_body)
|
||||
raise NotImplementedError
|
||||
elif opts.mode == 'vllm':
|
||||
elif mode == 'vllm':
|
||||
from .vllm.generate import generate
|
||||
return generate(request_json_body, cluster_backend, timeout=timeout)
|
||||
else:
|
||||
|
|
|
@ -12,6 +12,7 @@ class LLMBackend:
|
|||
|
||||
def __init__(self, backend_url: str):
|
||||
self.backend_url = backend_url
|
||||
self.backend_info = cluster_config.get_backend(self.backend_url)
|
||||
|
||||
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
raise NotImplementedError
|
||||
|
@ -44,8 +45,7 @@ class LLMBackend:
|
|||
|
||||
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
|
||||
prompt_len = get_token_count(prompt, self.backend_url)
|
||||
token_limit = cluster_config.get_backend(self.backend_url)['model_config']['max_position_embeddings']
|
||||
token_limit = self.backend_info['model_config']['max_position_embeddings']
|
||||
if prompt_len > token_limit - 10:
|
||||
model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str)
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {model_name}). Please lower your context size'
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {self.backend_info["model"]}). Please lower your context size'
|
||||
return True, None
|
||||
|
|
|
@ -20,19 +20,17 @@ def generate_oai_string(length=24):
|
|||
|
||||
|
||||
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def get_token_count_tiktoken_thread(msg):
|
||||
return len(tokenizer.encode(msg["content"]))
|
||||
def get_token_count_thread(msg):
|
||||
return get_token_count(msg["content"], backend_url)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt))
|
||||
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||
|
||||
total_tokens = sum(token_counts)
|
||||
formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens
|
||||
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
|
||||
|
||||
# If total tokens exceed the limit, start trimming
|
||||
if total_tokens > context_token_limit:
|
||||
if total_tokens + formatting_tokens > context_token_limit:
|
||||
while True:
|
||||
while total_tokens + formatting_tokens > context_token_limit:
|
||||
# Calculate the index to start removing messages from
|
||||
|
@ -45,15 +43,11 @@ def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int,
|
|||
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
|
||||
break
|
||||
|
||||
def get_token_count_thread(msg):
|
||||
return get_token_count(msg["content"], backend_url)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||
|
||||
total_tokens = sum(token_counts)
|
||||
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
|
||||
|
||||
if total_tokens + formatting_tokens > context_token_limit:
|
||||
# Start over, but this time calculate the token count using the backend
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
|
@ -65,11 +59,7 @@ def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int,
|
|||
|
||||
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def get_token_count_tiktoken_thread(msg):
|
||||
return len(tokenizer.encode(msg))
|
||||
|
||||
token_count = get_token_count_tiktoken_thread(prompt)
|
||||
token_count = get_token_count(prompt, backend_url)
|
||||
|
||||
# If total tokens exceed the limit, start trimming
|
||||
if token_count > context_token_limit:
|
||||
|
@ -80,21 +70,17 @@ def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str)
|
|||
|
||||
while remove_index < len(prompt):
|
||||
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
|
||||
token_count = get_token_count_tiktoken_thread(prompt)
|
||||
token_count = len(tokenizer.encode(prompt))
|
||||
if token_count <= context_token_limit or remove_index == len(prompt):
|
||||
break
|
||||
|
||||
def get_token_count_thread(msg):
|
||||
return get_token_count(msg, backend_url)
|
||||
|
||||
token_count = get_token_count_thread(prompt)
|
||||
|
||||
token_count = get_token_count(prompt, backend_url)
|
||||
if token_count > context_token_limit:
|
||||
# Start over, but this time calculate the token count using the backend
|
||||
token_count = get_token_count_thread(prompt)
|
||||
token_count = get_token_count(prompt, backend_url)
|
||||
else:
|
||||
break
|
||||
|
||||
print(token_count)
|
||||
return prompt
|
||||
|
||||
|
||||
|
|
|
@ -1,29 +1,35 @@
|
|||
import requests
|
||||
import asyncio
|
||||
|
||||
import aiohttp
|
||||
import tiktoken
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
|
||||
|
||||
def tokenize(prompt: str, backend_url: str) -> int:
|
||||
assert backend_url
|
||||
if not prompt:
|
||||
# The tokenizers have issues when the prompt is None.
|
||||
return 0
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
token_limit = cluster_config.get_backend(backend_url)['model_config']['max_position_embeddings']
|
||||
|
||||
# First we tokenize it locally to determine if it's worth sending it to the backend.
|
||||
initial_estimate = len(tokenizer.encode(prompt))
|
||||
if initial_estimate <= token_limit + 200:
|
||||
try:
|
||||
r = requests.post(f'{backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
j = r.json()
|
||||
return j['length']
|
||||
except Exception as e:
|
||||
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
|
||||
return len(tokenizer.encode(prompt)) + 10
|
||||
else:
|
||||
# If the result was greater than our context size, return the estimate.
|
||||
# We won't be sending it through the backend so it does't need to be accurage.
|
||||
return initial_estimate
|
||||
async def run():
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
async def send_chunk(chunk):
|
||||
try:
|
||||
async with session.post(f'{backend_url}/tokenize', json={'input': chunk}, verify_ssl=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) as response:
|
||||
j = await response.json()
|
||||
return j['length']
|
||||
except Exception as e:
|
||||
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
|
||||
return len(tokenizer.encode(chunk)) + 10
|
||||
|
||||
chunk_size = 300
|
||||
chunks = [prompt[i:i + chunk_size] for i in range(0, len(prompt), chunk_size)]
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = [send_chunk(chunk) for chunk in chunks]
|
||||
lengths = await asyncio.gather(*tasks)
|
||||
|
||||
return sum(lengths)
|
||||
|
||||
return asyncio.run(run())
|
||||
|
|
|
@ -13,7 +13,7 @@ class OobaRequestHandler(RequestHandler):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def handle_request(self):
|
||||
def handle_request(self, return_ok: bool = True):
|
||||
assert not self.used
|
||||
|
||||
request_valid, invalid_response = self.validate_request()
|
||||
|
@ -25,14 +25,19 @@ class OobaRequestHandler(RequestHandler):
|
|||
llm_request = {**self.parameters, 'prompt': prompt}
|
||||
|
||||
_, backend_response = self.generate_response(llm_request)
|
||||
return backend_response
|
||||
if return_ok:
|
||||
# Always return 200 so ST displays our error messages
|
||||
return backend_response[0], 200
|
||||
else:
|
||||
# The OpenAI route needs to detect 429 errors.
|
||||
return backend_response
|
||||
|
||||
def handle_ratelimited(self, do_log: bool = True):
|
||||
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
||||
backend_response = self.handle_error(msg)
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
|
||||
return backend_response[0], 200 # We only return the response from handle_error(), not the error code
|
||||
return backend_response[0], 429
|
||||
|
||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
|
|
|
@ -8,10 +8,11 @@ from llm_server.custom_redis import redis
|
|||
from . import openai_bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..openai_request_handler import OpenAIRequestHandler
|
||||
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...llm.generator import generator
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
|
||||
|
||||
|
@ -24,11 +25,6 @@ def openai_chat_completions():
|
|||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||
else:
|
||||
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body)
|
||||
|
||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
if not request_json_body.get('stream'):
|
||||
try:
|
||||
return handler.handle_request()
|
||||
|
@ -37,30 +33,51 @@ def openai_chat_completions():
|
|||
return 'Internal server error', 500
|
||||
else:
|
||||
if not opts.enable_streaming:
|
||||
# TODO: return a proper OAI error message
|
||||
return 'disabled', 401
|
||||
return 'DISABLED', 401
|
||||
|
||||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
handler.request_json_body['messages'] = trim_messages_to_fit(request_json_body['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
|
||||
else:
|
||||
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
|
||||
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
else:
|
||||
if opts.openai_silent_trim:
|
||||
oai_messages = trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||
else:
|
||||
oai_messages = handler.request.json['messages']
|
||||
|
||||
handler.prompt = transform_messages_to_prompt(oai_messages)
|
||||
handler.parameters = oai_to_vllm(handler.parameters, hashes=True, mode=handler.cluster_backend_info['mode'])
|
||||
msg_to_backend = {
|
||||
**handler.parameters,
|
||||
'prompt': handler.prompt,
|
||||
'stream': True,
|
||||
}
|
||||
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
|
||||
if not event:
|
||||
log_prompt(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
None,
|
||||
None,
|
||||
handler.parameters,
|
||||
request.headers,
|
||||
response_status_code,
|
||||
request.url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return handler.handle_ratelimited()
|
||||
|
||||
# Wait for a worker to get our request and discard it.
|
||||
_, _, _ = event.wait()
|
||||
|
||||
try:
|
||||
response = generator(msg_to_backend, handler.backend_url)
|
||||
r_headers = dict(request.headers)
|
||||
|
@ -69,57 +86,61 @@ def openai_chat_completions():
|
|||
oai_string = generate_oai_string(30)
|
||||
|
||||
def generate():
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
try:
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
|
||||
data = {
|
||||
"id": f"chatcmpl-{oai_string}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
data = {
|
||||
"id": f"chatcmpl-{oai_string}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
log_prompt(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
response_status_code,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
log_prompt(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
response_status_code,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
finally:
|
||||
# The worker incremented it, we'll decrement it.
|
||||
decrement_ip_count(handler.client_ip, 'processing_ips')
|
||||
decr_active_workers(handler.selected_model, handler.backend_url)
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except:
|
||||
# TODO: simulate OAI here
|
||||
raise Exception
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'INTERNAL SERVER', 500
|
||||
|
|
|
@ -8,6 +8,7 @@ from llm_server.custom_redis import redis
|
|||
from . import openai_bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...llm import get_token_count
|
||||
|
@ -24,80 +25,98 @@ def openai_completions():
|
|||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
else:
|
||||
try:
|
||||
handler = OobaRequestHandler(incoming_request=request)
|
||||
handler = OobaRequestHandler(incoming_request=request)
|
||||
|
||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||
|
||||
# Convert parameters to the selected backend type
|
||||
if opts.openai_silent_trim:
|
||||
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||
else:
|
||||
# The handle_request() call below will load the prompt so we don't have
|
||||
# to do anything else here.
|
||||
pass
|
||||
if opts.openai_silent_trim:
|
||||
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||
else:
|
||||
# The handle_request() call below will load the prompt so we don't have
|
||||
# to do anything else here.
|
||||
pass
|
||||
|
||||
if not request_json_body.get('stream'):
|
||||
response, status_code = handler.handle_request()
|
||||
if status_code != 200:
|
||||
return status_code
|
||||
output = response.json['results'][0]['text']
|
||||
if not request_json_body.get('stream'):
|
||||
response, status_code = handler.handle_request(return_ok=False)
|
||||
if status_code == 429:
|
||||
return handler.handle_ratelimited()
|
||||
output = response.json['results'][0]['text']
|
||||
|
||||
# TODO: async/await
|
||||
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
|
||||
response_tokens = get_token_count(output, handler.backend_url)
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
# TODO: async/await
|
||||
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
|
||||
response_tokens = get_token_count(output, handler.backend_url)
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
|
||||
response = jsonify({
|
||||
"id": f"cmpl-{generate_oai_string(30)}",
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
|
||||
"choices": [
|
||||
{
|
||||
"text": output,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": response_tokens,
|
||||
"total_tokens": prompt_tokens + response_tokens
|
||||
response = jsonify({
|
||||
"id": f"cmpl-{generate_oai_string(30)}",
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
|
||||
"choices": [
|
||||
{
|
||||
"text": output,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
})
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": response_tokens,
|
||||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
})
|
||||
|
||||
stats = redis.get('proxy_stats', dtype=dict)
|
||||
if stats:
|
||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response, 200
|
||||
stats = redis.get('proxy_stats', dtype=dict)
|
||||
if stats:
|
||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response, 200
|
||||
else:
|
||||
if not opts.enable_streaming:
|
||||
return 'DISABLED', 401
|
||||
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
else:
|
||||
if not opts.enable_streaming:
|
||||
# TODO: return a proper OAI error message
|
||||
return 'disabled', 401
|
||||
handler.prompt = handler.request_json_body['prompt']
|
||||
msg_to_backend = {
|
||||
**handler.parameters,
|
||||
'prompt': handler.prompt,
|
||||
'stream': True,
|
||||
}
|
||||
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
|
||||
if not event:
|
||||
log_prompt(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
None,
|
||||
None,
|
||||
handler.parameters,
|
||||
request.headers,
|
||||
response_status_code,
|
||||
request.url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return handler.handle_ratelimited()
|
||||
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
# TODO: simulate OAI here
|
||||
raise Exception('TODO: simulate OAI here')
|
||||
else:
|
||||
handler.prompt = handler.request_json_body['prompt']
|
||||
msg_to_backend = {
|
||||
**handler.parameters,
|
||||
'prompt': handler.prompt,
|
||||
'stream': True,
|
||||
}
|
||||
# Wait for a worker to get our request and discard it.
|
||||
_, _, _ = event.wait()
|
||||
|
||||
try:
|
||||
response = generator(msg_to_backend, handler.backend_url)
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
|
@ -105,57 +124,61 @@ def openai_completions():
|
|||
oai_string = generate_oai_string(30)
|
||||
|
||||
def generate():
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
try:
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
|
||||
data = {
|
||||
"id": f"chatcmpl-{oai_string}",
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
data = {
|
||||
"id": f"cmpl-{oai_string}",
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
log_prompt(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
response_status_code,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
log_prompt(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
response_status_code,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
finally:
|
||||
# The worker incremented it, we'll decrement it.
|
||||
decrement_ip_count(handler.client_ip, 'processing_ips')
|
||||
decr_active_workers(handler.selected_model, handler.backend_url)
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'Internal Server Error', 500
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'INTERNAL SERVER', 500
|
||||
|
|
|
@ -10,8 +10,9 @@ from flask import Response, jsonify, make_response
|
|||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.model_choices import get_model_choices
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import is_api_key_moderated
|
||||
from llm_server.database.database import is_api_key_moderated, log_prompt
|
||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
||||
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
@ -70,9 +71,24 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
return backend_response, backend_response_status_code
|
||||
|
||||
def handle_ratelimited(self, do_log: bool = True):
|
||||
# TODO: return a simulated OpenAI error message
|
||||
# Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.
|
||||
return 'Ratelimited', 429
|
||||
_, default_backend_info = get_model_choices()
|
||||
w = int(default_backend_info['estimated_wait']) if default_backend_info['estimated_wait'] > 0 else 2
|
||||
response = jsonify({
|
||||
"error": {
|
||||
"message": "Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues.",
|
||||
"type": "rate_limit_exceeded",
|
||||
"param": None,
|
||||
"code": None
|
||||
}
|
||||
})
|
||||
response.headers['x-ratelimit-limit-requests'] = '2'
|
||||
response.headers['x-ratelimit-remaining-requests'] = '0'
|
||||
response.headers['x-ratelimit-reset-requests'] = f"{w}s"
|
||||
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
|
||||
|
||||
return response, 429
|
||||
|
||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||
return jsonify({
|
||||
|
|
|
@ -209,7 +209,7 @@ class RequestHandler:
|
|||
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
|
||||
return False
|
||||
else:
|
||||
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.')
|
||||
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} already queued/processing.')
|
||||
return True
|
||||
|
||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||
|
|
|
@ -115,6 +115,10 @@ def do_stream(ws, model_name):
|
|||
err_msg = r.json['results'][0]['text']
|
||||
send_err_and_quit(err_msg)
|
||||
return
|
||||
|
||||
# Wait for a worker to get our request and discard it.
|
||||
_, _, _ = event.wait()
|
||||
|
||||
try:
|
||||
response = generator(llm_request, handler.backend_url)
|
||||
if not response:
|
||||
|
|
|
@ -6,6 +6,7 @@ from llm_server.custom_redis import flask_cache
|
|||
from . import bp
|
||||
from ... import opts
|
||||
from ...cluster.backend import get_a_cluster_backend, get_backends_from_model, is_valid_model
|
||||
from ...cluster.cluster_config import cluster_config
|
||||
|
||||
|
||||
@bp.route('/v1/model', methods=['GET'])
|
||||
|
|
|
@ -21,8 +21,10 @@ def worker():
|
|||
incr_active_workers(selected_model, backend_url)
|
||||
|
||||
if not request_json_body:
|
||||
# This was a dummy request from the websocket handler.
|
||||
# This was a dummy request from the websocket handlers.
|
||||
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
|
||||
event = DataEvent(event_id)
|
||||
event.set((True, None, None))
|
||||
continue
|
||||
|
||||
try:
|
||||
|
|
|
@ -13,4 +13,6 @@ openai~=0.28.0
|
|||
urllib3~=2.0.4
|
||||
flask-sock==0.6.0
|
||||
gunicorn==21.2.0
|
||||
redis==5.0.1
|
||||
redis==5.0.1
|
||||
aiohttp==3.8.5
|
||||
asyncio==3.4.3
|
Reference in New Issue