finish openai endpoints
This commit is contained in:
parent
2a3ff7e21e
commit
f7e9687527
|
@ -43,12 +43,10 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
|
||||||
|
|
||||||
### Use
|
### Use
|
||||||
|
|
||||||
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis.
|
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. You may need to wait a few minutes for the daemon to populate the database.
|
||||||
|
|
||||||
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`
|
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### To Do
|
### To Do
|
||||||
|
|
||||||
- [x] Implement streaming
|
- [x] Implement streaming
|
||||||
|
|
|
@ -14,8 +14,11 @@ def test_backend(backend_url: str, test_prompt: bool = False):
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": 3,
|
"max_new_tokens": 3,
|
||||||
}
|
}
|
||||||
success, response, err = generator(data, backend_url, timeout=10)
|
try:
|
||||||
if not success or not response or err:
|
success, response, err = generator(data, backend_url, timeout=10)
|
||||||
|
if not success or not response or err:
|
||||||
|
return False, {}
|
||||||
|
except:
|
||||||
return False, {}
|
return False, {}
|
||||||
i = get_info(backend_url, backend_info['mode'])
|
i = get_info(backend_url, backend_info['mode'])
|
||||||
if not i.get('model'):
|
if not i.get('model'):
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
|
|
||||||
|
|
||||||
def generator(request_json_body, cluster_backend, timeout: int = None):
|
def generator(request_json_body, cluster_backend, timeout: int = None):
|
||||||
if opts.mode == 'oobabooga':
|
mode = cluster_config.get_backend(cluster_backend)['mode']
|
||||||
|
if mode == 'ooba':
|
||||||
# from .oobabooga.generate import generate
|
# from .oobabooga.generate import generate
|
||||||
# return generate(request_json_body)
|
# return generate(request_json_body)
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
elif opts.mode == 'vllm':
|
elif mode == 'vllm':
|
||||||
from .vllm.generate import generate
|
from .vllm.generate import generate
|
||||||
return generate(request_json_body, cluster_backend, timeout=timeout)
|
return generate(request_json_body, cluster_backend, timeout=timeout)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -12,6 +12,7 @@ class LLMBackend:
|
||||||
|
|
||||||
def __init__(self, backend_url: str):
|
def __init__(self, backend_url: str):
|
||||||
self.backend_url = backend_url
|
self.backend_url = backend_url
|
||||||
|
self.backend_info = cluster_config.get_backend(self.backend_url)
|
||||||
|
|
||||||
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
|
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -44,8 +45,7 @@ class LLMBackend:
|
||||||
|
|
||||||
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
|
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
|
||||||
prompt_len = get_token_count(prompt, self.backend_url)
|
prompt_len = get_token_count(prompt, self.backend_url)
|
||||||
token_limit = cluster_config.get_backend(self.backend_url)['model_config']['max_position_embeddings']
|
token_limit = self.backend_info['model_config']['max_position_embeddings']
|
||||||
if prompt_len > token_limit - 10:
|
if prompt_len > token_limit - 10:
|
||||||
model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str)
|
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {self.backend_info["model"]}). Please lower your context size'
|
||||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {model_name}). Please lower your context size'
|
|
||||||
return True, None
|
return True, None
|
||||||
|
|
|
@ -20,19 +20,17 @@ def generate_oai_string(length=24):
|
||||||
|
|
||||||
|
|
||||||
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
|
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
|
||||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
def get_token_count_thread(msg):
|
||||||
|
return get_token_count(msg["content"], backend_url)
|
||||||
def get_token_count_tiktoken_thread(msg):
|
|
||||||
return len(tokenizer.encode(msg["content"]))
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||||
token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt))
|
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||||
|
|
||||||
total_tokens = sum(token_counts)
|
total_tokens = sum(token_counts)
|
||||||
formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens
|
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
|
||||||
|
|
||||||
# If total tokens exceed the limit, start trimming
|
# If total tokens exceed the limit, start trimming
|
||||||
if total_tokens > context_token_limit:
|
if total_tokens + formatting_tokens > context_token_limit:
|
||||||
while True:
|
while True:
|
||||||
while total_tokens + formatting_tokens > context_token_limit:
|
while total_tokens + formatting_tokens > context_token_limit:
|
||||||
# Calculate the index to start removing messages from
|
# Calculate the index to start removing messages from
|
||||||
|
@ -45,15 +43,11 @@ def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int,
|
||||||
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
|
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
|
||||||
break
|
break
|
||||||
|
|
||||||
def get_token_count_thread(msg):
|
|
||||||
return get_token_count(msg["content"], backend_url)
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||||
token_counts = list(executor.map(get_token_count_thread, prompt))
|
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||||
|
|
||||||
total_tokens = sum(token_counts)
|
total_tokens = sum(token_counts)
|
||||||
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
|
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
|
||||||
|
|
||||||
if total_tokens + formatting_tokens > context_token_limit:
|
if total_tokens + formatting_tokens > context_token_limit:
|
||||||
# Start over, but this time calculate the token count using the backend
|
# Start over, but this time calculate the token count using the backend
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||||
|
@ -65,11 +59,7 @@ def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int,
|
||||||
|
|
||||||
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
|
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
|
||||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
token_count = get_token_count(prompt, backend_url)
|
||||||
def get_token_count_tiktoken_thread(msg):
|
|
||||||
return len(tokenizer.encode(msg))
|
|
||||||
|
|
||||||
token_count = get_token_count_tiktoken_thread(prompt)
|
|
||||||
|
|
||||||
# If total tokens exceed the limit, start trimming
|
# If total tokens exceed the limit, start trimming
|
||||||
if token_count > context_token_limit:
|
if token_count > context_token_limit:
|
||||||
|
@ -80,21 +70,17 @@ def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str)
|
||||||
|
|
||||||
while remove_index < len(prompt):
|
while remove_index < len(prompt):
|
||||||
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
|
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
|
||||||
token_count = get_token_count_tiktoken_thread(prompt)
|
token_count = len(tokenizer.encode(prompt))
|
||||||
if token_count <= context_token_limit or remove_index == len(prompt):
|
if token_count <= context_token_limit or remove_index == len(prompt):
|
||||||
break
|
break
|
||||||
|
|
||||||
def get_token_count_thread(msg):
|
token_count = get_token_count(prompt, backend_url)
|
||||||
return get_token_count(msg, backend_url)
|
|
||||||
|
|
||||||
token_count = get_token_count_thread(prompt)
|
|
||||||
|
|
||||||
if token_count > context_token_limit:
|
if token_count > context_token_limit:
|
||||||
# Start over, but this time calculate the token count using the backend
|
# Start over, but this time calculate the token count using the backend
|
||||||
token_count = get_token_count_thread(prompt)
|
token_count = get_token_count(prompt, backend_url)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
print(token_count)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,29 +1,35 @@
|
||||||
import requests
|
import asyncio
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize(prompt: str, backend_url: str) -> int:
|
def tokenize(prompt: str, backend_url: str) -> int:
|
||||||
assert backend_url
|
assert backend_url
|
||||||
if not prompt:
|
if not prompt:
|
||||||
# The tokenizers have issues when the prompt is None.
|
|
||||||
return 0
|
return 0
|
||||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
||||||
token_limit = cluster_config.get_backend(backend_url)['model_config']['max_position_embeddings']
|
|
||||||
|
|
||||||
# First we tokenize it locally to determine if it's worth sending it to the backend.
|
async def run():
|
||||||
initial_estimate = len(tokenizer.encode(prompt))
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
if initial_estimate <= token_limit + 200:
|
|
||||||
try:
|
async def send_chunk(chunk):
|
||||||
r = requests.post(f'{backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
try:
|
||||||
j = r.json()
|
async with session.post(f'{backend_url}/tokenize', json={'input': chunk}, verify_ssl=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) as response:
|
||||||
return j['length']
|
j = await response.json()
|
||||||
except Exception as e:
|
return j['length']
|
||||||
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
|
except Exception as e:
|
||||||
return len(tokenizer.encode(prompt)) + 10
|
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
|
||||||
else:
|
return len(tokenizer.encode(chunk)) + 10
|
||||||
# If the result was greater than our context size, return the estimate.
|
|
||||||
# We won't be sending it through the backend so it does't need to be accurage.
|
chunk_size = 300
|
||||||
return initial_estimate
|
chunks = [prompt[i:i + chunk_size] for i in range(0, len(prompt), chunk_size)]
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
tasks = [send_chunk(chunk) for chunk in chunks]
|
||||||
|
lengths = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
return sum(lengths)
|
||||||
|
|
||||||
|
return asyncio.run(run())
|
||||||
|
|
|
@ -13,7 +13,7 @@ class OobaRequestHandler(RequestHandler):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def handle_request(self):
|
def handle_request(self, return_ok: bool = True):
|
||||||
assert not self.used
|
assert not self.used
|
||||||
|
|
||||||
request_valid, invalid_response = self.validate_request()
|
request_valid, invalid_response = self.validate_request()
|
||||||
|
@ -25,14 +25,19 @@ class OobaRequestHandler(RequestHandler):
|
||||||
llm_request = {**self.parameters, 'prompt': prompt}
|
llm_request = {**self.parameters, 'prompt': prompt}
|
||||||
|
|
||||||
_, backend_response = self.generate_response(llm_request)
|
_, backend_response = self.generate_response(llm_request)
|
||||||
return backend_response
|
if return_ok:
|
||||||
|
# Always return 200 so ST displays our error messages
|
||||||
|
return backend_response[0], 200
|
||||||
|
else:
|
||||||
|
# The OpenAI route needs to detect 429 errors.
|
||||||
|
return backend_response
|
||||||
|
|
||||||
def handle_ratelimited(self, do_log: bool = True):
|
def handle_ratelimited(self, do_log: bool = True):
|
||||||
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
||||||
backend_response = self.handle_error(msg)
|
backend_response = self.handle_error(msg)
|
||||||
if do_log:
|
if do_log:
|
||||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
|
||||||
return backend_response[0], 200 # We only return the response from handle_error(), not the error code
|
return backend_response[0], 429
|
||||||
|
|
||||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||||
|
|
|
@ -8,10 +8,11 @@ from llm_server.custom_redis import redis
|
||||||
from . import openai_bp
|
from . import openai_bp
|
||||||
from ..helpers.http import validate_json
|
from ..helpers.http import validate_json
|
||||||
from ..openai_request_handler import OpenAIRequestHandler
|
from ..openai_request_handler import OpenAIRequestHandler
|
||||||
|
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...database.database import log_prompt
|
from ...database.database import log_prompt
|
||||||
from ...llm.generator import generator
|
from ...llm.generator import generator
|
||||||
from ...llm.openai.oai_to_vllm import oai_to_vllm
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
||||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,11 +25,6 @@ def openai_chat_completions():
|
||||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||||
else:
|
else:
|
||||||
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body)
|
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body)
|
||||||
|
|
||||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
|
||||||
# TODO: implement other backends
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
if not request_json_body.get('stream'):
|
if not request_json_body.get('stream'):
|
||||||
try:
|
try:
|
||||||
return handler.handle_request()
|
return handler.handle_request()
|
||||||
|
@ -37,30 +33,51 @@ def openai_chat_completions():
|
||||||
return 'Internal server error', 500
|
return 'Internal server error', 500
|
||||||
else:
|
else:
|
||||||
if not opts.enable_streaming:
|
if not opts.enable_streaming:
|
||||||
# TODO: return a proper OAI error message
|
return 'DISABLED', 401
|
||||||
return 'disabled', 401
|
|
||||||
|
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||||
|
if invalid_oai_err_msg:
|
||||||
|
return invalid_oai_err_msg
|
||||||
|
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||||
|
|
||||||
if opts.openai_silent_trim:
|
if opts.openai_silent_trim:
|
||||||
handler.request_json_body['messages'] = trim_messages_to_fit(request_json_body['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
|
||||||
|
else:
|
||||||
|
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
|
||||||
|
|
||||||
response_status_code = 0
|
response_status_code = 0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
request_valid, invalid_response = handler.validate_request()
|
request_valid, invalid_response = handler.validate_request()
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
return invalid_response
|
return invalid_response
|
||||||
else:
|
else:
|
||||||
if opts.openai_silent_trim:
|
|
||||||
oai_messages = trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
|
||||||
else:
|
|
||||||
oai_messages = handler.request.json['messages']
|
|
||||||
|
|
||||||
handler.prompt = transform_messages_to_prompt(oai_messages)
|
|
||||||
handler.parameters = oai_to_vllm(handler.parameters, hashes=True, mode=handler.cluster_backend_info['mode'])
|
|
||||||
msg_to_backend = {
|
msg_to_backend = {
|
||||||
**handler.parameters,
|
**handler.parameters,
|
||||||
'prompt': handler.prompt,
|
'prompt': handler.prompt,
|
||||||
'stream': True,
|
'stream': True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Add a dummy event to the queue and wait for it to reach a worker
|
||||||
|
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
|
||||||
|
if not event:
|
||||||
|
log_prompt(
|
||||||
|
handler.client_ip,
|
||||||
|
handler.token,
|
||||||
|
handler.prompt,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
handler.parameters,
|
||||||
|
request.headers,
|
||||||
|
response_status_code,
|
||||||
|
request.url,
|
||||||
|
handler.backend_url,
|
||||||
|
)
|
||||||
|
return handler.handle_ratelimited()
|
||||||
|
|
||||||
|
# Wait for a worker to get our request and discard it.
|
||||||
|
_, _, _ = event.wait()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = generator(msg_to_backend, handler.backend_url)
|
response = generator(msg_to_backend, handler.backend_url)
|
||||||
r_headers = dict(request.headers)
|
r_headers = dict(request.headers)
|
||||||
|
@ -69,57 +86,61 @@ def openai_chat_completions():
|
||||||
oai_string = generate_oai_string(30)
|
oai_string = generate_oai_string(30)
|
||||||
|
|
||||||
def generate():
|
def generate():
|
||||||
generated_text = ''
|
try:
|
||||||
partial_response = b''
|
generated_text = ''
|
||||||
for chunk in response.iter_content(chunk_size=1):
|
partial_response = b''
|
||||||
partial_response += chunk
|
for chunk in response.iter_content(chunk_size=1):
|
||||||
if partial_response.endswith(b'\x00'):
|
partial_response += chunk
|
||||||
json_strs = partial_response.split(b'\x00')
|
if partial_response.endswith(b'\x00'):
|
||||||
for json_str in json_strs:
|
json_strs = partial_response.split(b'\x00')
|
||||||
if json_str:
|
for json_str in json_strs:
|
||||||
try:
|
if json_str:
|
||||||
json_obj = json.loads(json_str.decode())
|
try:
|
||||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
json_obj = json.loads(json_str.decode())
|
||||||
generated_text = generated_text + new
|
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||||
except IndexError:
|
generated_text = generated_text + new
|
||||||
# ????
|
except IndexError:
|
||||||
continue
|
# ????
|
||||||
|
continue
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"id": f"chatcmpl-{oai_string}",
|
"id": f"chatcmpl-{oai_string}",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"created": int(time.time()),
|
"created": int(time.time()),
|
||||||
"model": model,
|
"model": model,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"delta": {
|
"delta": {
|
||||||
"content": new
|
"content": new
|
||||||
},
|
},
|
||||||
"finish_reason": None
|
"finish_reason": None
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
yield f'data: {json.dumps(data)}\n\n'
|
yield f'data: {json.dumps(data)}\n\n'
|
||||||
|
yield 'data: [DONE]\n\n'
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
|
||||||
yield 'data: [DONE]\n\n'
|
log_prompt(
|
||||||
end_time = time.time()
|
handler.client_ip,
|
||||||
elapsed_time = end_time - start_time
|
handler.token,
|
||||||
|
handler.prompt,
|
||||||
log_prompt(
|
generated_text,
|
||||||
handler.client_ip,
|
elapsed_time,
|
||||||
handler.token,
|
handler.parameters,
|
||||||
handler.prompt,
|
r_headers,
|
||||||
generated_text,
|
response_status_code,
|
||||||
elapsed_time,
|
r_url,
|
||||||
handler.parameters,
|
handler.backend_url,
|
||||||
r_headers,
|
)
|
||||||
response_status_code,
|
finally:
|
||||||
r_url,
|
# The worker incremented it, we'll decrement it.
|
||||||
handler.backend_url,
|
decrement_ip_count(handler.client_ip, 'processing_ips')
|
||||||
)
|
decr_active_workers(handler.selected_model, handler.backend_url)
|
||||||
|
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream')
|
||||||
except:
|
except Exception:
|
||||||
# TODO: simulate OAI here
|
traceback.print_exc()
|
||||||
raise Exception
|
return 'INTERNAL SERVER', 500
|
||||||
|
|
|
@ -8,6 +8,7 @@ from llm_server.custom_redis import redis
|
||||||
from . import openai_bp
|
from . import openai_bp
|
||||||
from ..helpers.http import validate_json
|
from ..helpers.http import validate_json
|
||||||
from ..ooba_request_handler import OobaRequestHandler
|
from ..ooba_request_handler import OobaRequestHandler
|
||||||
|
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...database.database import log_prompt
|
from ...database.database import log_prompt
|
||||||
from ...llm import get_token_count
|
from ...llm import get_token_count
|
||||||
|
@ -24,80 +25,98 @@ def openai_completions():
|
||||||
if not request_valid_json or not request_json_body.get('prompt'):
|
if not request_valid_json or not request_json_body.get('prompt'):
|
||||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||||
else:
|
else:
|
||||||
try:
|
handler = OobaRequestHandler(incoming_request=request)
|
||||||
handler = OobaRequestHandler(incoming_request=request)
|
|
||||||
|
|
||||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||||
# TODO: implement other backends
|
# TODO: implement other backends
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||||
if invalid_oai_err_msg:
|
if invalid_oai_err_msg:
|
||||||
return invalid_oai_err_msg
|
return invalid_oai_err_msg
|
||||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
|
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||||
|
|
||||||
# Convert parameters to the selected backend type
|
if opts.openai_silent_trim:
|
||||||
if opts.openai_silent_trim:
|
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||||
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
else:
|
||||||
else:
|
# The handle_request() call below will load the prompt so we don't have
|
||||||
# The handle_request() call below will load the prompt so we don't have
|
# to do anything else here.
|
||||||
# to do anything else here.
|
pass
|
||||||
pass
|
|
||||||
|
|
||||||
if not request_json_body.get('stream'):
|
if not request_json_body.get('stream'):
|
||||||
response, status_code = handler.handle_request()
|
response, status_code = handler.handle_request(return_ok=False)
|
||||||
if status_code != 200:
|
if status_code == 429:
|
||||||
return status_code
|
return handler.handle_ratelimited()
|
||||||
output = response.json['results'][0]['text']
|
output = response.json['results'][0]['text']
|
||||||
|
|
||||||
# TODO: async/await
|
# TODO: async/await
|
||||||
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
|
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
|
||||||
response_tokens = get_token_count(output, handler.backend_url)
|
response_tokens = get_token_count(output, handler.backend_url)
|
||||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||||
|
|
||||||
response = jsonify({
|
response = jsonify({
|
||||||
"id": f"cmpl-{generate_oai_string(30)}",
|
"id": f"cmpl-{generate_oai_string(30)}",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"created": int(time.time()),
|
"created": int(time.time()),
|
||||||
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
|
"model": running_model if opts.openai_expose_our_model else request_json_body.get('model'),
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": output,
|
"text": output,
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": None,
|
||||||
"finish_reason": "stop"
|
"finish_reason": "stop"
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": prompt_tokens,
|
|
||||||
"completion_tokens": response_tokens,
|
|
||||||
"total_tokens": prompt_tokens + response_tokens
|
|
||||||
}
|
}
|
||||||
})
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": response_tokens,
|
||||||
|
"total_tokens": prompt_tokens + response_tokens
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
stats = redis.get('proxy_stats', dtype=dict)
|
stats = redis.get('proxy_stats', dtype=dict)
|
||||||
if stats:
|
if stats:
|
||||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||||
return response, 200
|
return response, 200
|
||||||
|
else:
|
||||||
|
if not opts.enable_streaming:
|
||||||
|
return 'DISABLED', 401
|
||||||
|
|
||||||
|
response_status_code = 0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
request_valid, invalid_response = handler.validate_request()
|
||||||
|
if not request_valid:
|
||||||
|
return invalid_response
|
||||||
else:
|
else:
|
||||||
if not opts.enable_streaming:
|
handler.prompt = handler.request_json_body['prompt']
|
||||||
# TODO: return a proper OAI error message
|
msg_to_backend = {
|
||||||
return 'disabled', 401
|
**handler.parameters,
|
||||||
|
'prompt': handler.prompt,
|
||||||
|
'stream': True,
|
||||||
|
}
|
||||||
|
|
||||||
response_status_code = 0
|
# Add a dummy event to the queue and wait for it to reach a worker
|
||||||
start_time = time.time()
|
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority, handler.backend_url)
|
||||||
|
if not event:
|
||||||
|
log_prompt(
|
||||||
|
handler.client_ip,
|
||||||
|
handler.token,
|
||||||
|
handler.prompt,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
handler.parameters,
|
||||||
|
request.headers,
|
||||||
|
response_status_code,
|
||||||
|
request.url,
|
||||||
|
handler.backend_url,
|
||||||
|
)
|
||||||
|
return handler.handle_ratelimited()
|
||||||
|
|
||||||
request_valid, invalid_response = handler.validate_request()
|
# Wait for a worker to get our request and discard it.
|
||||||
if not request_valid:
|
_, _, _ = event.wait()
|
||||||
# TODO: simulate OAI here
|
|
||||||
raise Exception('TODO: simulate OAI here')
|
try:
|
||||||
else:
|
|
||||||
handler.prompt = handler.request_json_body['prompt']
|
|
||||||
msg_to_backend = {
|
|
||||||
**handler.parameters,
|
|
||||||
'prompt': handler.prompt,
|
|
||||||
'stream': True,
|
|
||||||
}
|
|
||||||
response = generator(msg_to_backend, handler.backend_url)
|
response = generator(msg_to_backend, handler.backend_url)
|
||||||
r_headers = dict(request.headers)
|
r_headers = dict(request.headers)
|
||||||
r_url = request.url
|
r_url = request.url
|
||||||
|
@ -105,57 +124,61 @@ def openai_completions():
|
||||||
oai_string = generate_oai_string(30)
|
oai_string = generate_oai_string(30)
|
||||||
|
|
||||||
def generate():
|
def generate():
|
||||||
generated_text = ''
|
try:
|
||||||
partial_response = b''
|
generated_text = ''
|
||||||
for chunk in response.iter_content(chunk_size=1):
|
partial_response = b''
|
||||||
partial_response += chunk
|
for chunk in response.iter_content(chunk_size=1):
|
||||||
if partial_response.endswith(b'\x00'):
|
partial_response += chunk
|
||||||
json_strs = partial_response.split(b'\x00')
|
if partial_response.endswith(b'\x00'):
|
||||||
for json_str in json_strs:
|
json_strs = partial_response.split(b'\x00')
|
||||||
if json_str:
|
for json_str in json_strs:
|
||||||
try:
|
if json_str:
|
||||||
json_obj = json.loads(json_str.decode())
|
try:
|
||||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
json_obj = json.loads(json_str.decode())
|
||||||
generated_text = generated_text + new
|
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||||
except IndexError:
|
generated_text = generated_text + new
|
||||||
# ????
|
except IndexError:
|
||||||
continue
|
# ????
|
||||||
|
continue
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"id": f"chatcmpl-{oai_string}",
|
"id": f"cmpl-{oai_string}",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"created": int(time.time()),
|
"created": int(time.time()),
|
||||||
"model": model,
|
"model": model,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"delta": {
|
"delta": {
|
||||||
"content": new
|
"content": new
|
||||||
},
|
},
|
||||||
"finish_reason": None
|
"finish_reason": None
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
yield f'data: {json.dumps(data)}\n\n'
|
yield f'data: {json.dumps(data)}\n\n'
|
||||||
|
yield 'data: [DONE]\n\n'
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
|
||||||
yield 'data: [DONE]\n\n'
|
log_prompt(
|
||||||
end_time = time.time()
|
handler.client_ip,
|
||||||
elapsed_time = end_time - start_time
|
handler.token,
|
||||||
|
handler.prompt,
|
||||||
log_prompt(
|
generated_text,
|
||||||
handler.client_ip,
|
elapsed_time,
|
||||||
handler.token,
|
handler.parameters,
|
||||||
handler.prompt,
|
r_headers,
|
||||||
generated_text,
|
response_status_code,
|
||||||
elapsed_time,
|
r_url,
|
||||||
handler.parameters,
|
handler.backend_url,
|
||||||
r_headers,
|
)
|
||||||
response_status_code,
|
finally:
|
||||||
r_url,
|
# The worker incremented it, we'll decrement it.
|
||||||
handler.backend_url,
|
decrement_ip_count(handler.client_ip, 'processing_ips')
|
||||||
)
|
decr_active_workers(handler.selected_model, handler.backend_url)
|
||||||
|
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream')
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return 'Internal Server Error', 500
|
return 'INTERNAL SERVER', 500
|
||||||
|
|
|
@ -10,8 +10,9 @@ from flask import Response, jsonify, make_response
|
||||||
|
|
||||||
import llm_server
|
import llm_server
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
|
from llm_server.cluster.model_choices import get_model_choices
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
from llm_server.database.database import is_api_key_moderated
|
from llm_server.database.database import is_api_key_moderated, log_prompt
|
||||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
||||||
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
from llm_server.routes.request_handler import RequestHandler
|
from llm_server.routes.request_handler import RequestHandler
|
||||||
|
@ -70,9 +71,24 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
return backend_response, backend_response_status_code
|
return backend_response, backend_response_status_code
|
||||||
|
|
||||||
def handle_ratelimited(self, do_log: bool = True):
|
def handle_ratelimited(self, do_log: bool = True):
|
||||||
# TODO: return a simulated OpenAI error message
|
_, default_backend_info = get_model_choices()
|
||||||
# Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.
|
w = int(default_backend_info['estimated_wait']) if default_backend_info['estimated_wait'] > 0 else 2
|
||||||
return 'Ratelimited', 429
|
response = jsonify({
|
||||||
|
"error": {
|
||||||
|
"message": "Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues.",
|
||||||
|
"type": "rate_limit_exceeded",
|
||||||
|
"param": None,
|
||||||
|
"code": None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
response.headers['x-ratelimit-limit-requests'] = '2'
|
||||||
|
response.headers['x-ratelimit-remaining-requests'] = '0'
|
||||||
|
response.headers['x-ratelimit-reset-requests'] = f"{w}s"
|
||||||
|
|
||||||
|
if do_log:
|
||||||
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
|
||||||
|
|
||||||
|
return response, 429
|
||||||
|
|
||||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
|
@ -209,7 +209,7 @@ class RequestHandler:
|
||||||
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
|
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.')
|
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} already queued/processing.')
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||||
|
|
|
@ -115,6 +115,10 @@ def do_stream(ws, model_name):
|
||||||
err_msg = r.json['results'][0]['text']
|
err_msg = r.json['results'][0]['text']
|
||||||
send_err_and_quit(err_msg)
|
send_err_and_quit(err_msg)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Wait for a worker to get our request and discard it.
|
||||||
|
_, _, _ = event.wait()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = generator(llm_request, handler.backend_url)
|
response = generator(llm_request, handler.backend_url)
|
||||||
if not response:
|
if not response:
|
||||||
|
|
|
@ -6,6 +6,7 @@ from llm_server.custom_redis import flask_cache
|
||||||
from . import bp
|
from . import bp
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...cluster.backend import get_a_cluster_backend, get_backends_from_model, is_valid_model
|
from ...cluster.backend import get_a_cluster_backend, get_backends_from_model, is_valid_model
|
||||||
|
from ...cluster.cluster_config import cluster_config
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/v1/model', methods=['GET'])
|
@bp.route('/v1/model', methods=['GET'])
|
||||||
|
|
|
@ -21,8 +21,10 @@ def worker():
|
||||||
incr_active_workers(selected_model, backend_url)
|
incr_active_workers(selected_model, backend_url)
|
||||||
|
|
||||||
if not request_json_body:
|
if not request_json_body:
|
||||||
# This was a dummy request from the websocket handler.
|
# This was a dummy request from the websocket handlers.
|
||||||
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
|
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
|
||||||
|
event = DataEvent(event_id)
|
||||||
|
event.set((True, None, None))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -13,4 +13,6 @@ openai~=0.28.0
|
||||||
urllib3~=2.0.4
|
urllib3~=2.0.4
|
||||||
flask-sock==0.6.0
|
flask-sock==0.6.0
|
||||||
gunicorn==21.2.0
|
gunicorn==21.2.0
|
||||||
redis==5.0.1
|
redis==5.0.1
|
||||||
|
aiohttp==3.8.5
|
||||||
|
asyncio==3.4.3
|
Reference in New Issue