update database, tokenizer handle null prompt, convert top_p to vllm on openai, actually validate prompt on streaming,
This commit is contained in:
parent
2d299dbae5
commit
11e84db59c
|
@ -13,15 +13,15 @@ def create_db():
|
||||||
backend_url TEXT,
|
backend_url TEXT,
|
||||||
request_url TEXT,
|
request_url TEXT,
|
||||||
generation_time FLOAT,
|
generation_time FLOAT,
|
||||||
prompt TEXT,
|
prompt LONGTEXT,
|
||||||
prompt_tokens INTEGER,
|
prompt_tokens INTEGER,
|
||||||
response TEXT,
|
response LONGTEXT,
|
||||||
response_tokens INTEGER,
|
response_tokens INTEGER,
|
||||||
response_status INTEGER,
|
response_status INTEGER,
|
||||||
parameters TEXT,
|
parameters TEXT,
|
||||||
CHECK (parameters IS NULL OR JSON_VALID(parameters)),
|
# CHECK (parameters IS NULL OR JSON_VALID(parameters)),
|
||||||
headers TEXT,
|
headers TEXT,
|
||||||
CHECK (headers IS NULL OR JSON_VALID(headers)),
|
# CHECK (headers IS NULL OR JSON_VALID(headers)),
|
||||||
timestamp INTEGER
|
timestamp INTEGER
|
||||||
)
|
)
|
||||||
''')
|
''')
|
||||||
|
|
|
@ -9,6 +9,9 @@ tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
|
||||||
def tokenize(prompt: str) -> int:
|
def tokenize(prompt: str) -> int:
|
||||||
|
if not prompt:
|
||||||
|
# The tokenizers have issues when the prompt is None.
|
||||||
|
return 0
|
||||||
try:
|
try:
|
||||||
r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||||
j = r.json()
|
j = r.json()
|
||||||
|
|
|
@ -14,8 +14,7 @@ class OobaRequestHandler(RequestHandler):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def handle_request(self):
|
def handle_request(self):
|
||||||
if self.used:
|
assert not self.used
|
||||||
raise Exception('Can only use a RequestHandler object once.')
|
|
||||||
|
|
||||||
request_valid, invalid_response = self.validate_request()
|
request_valid, invalid_response = self.validate_request()
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
|
|
|
@ -70,6 +70,9 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
if opts.openai_force_no_hashes:
|
if opts.openai_force_no_hashes:
|
||||||
self.parameters['stop'].append('### ')
|
self.parameters['stop'].append('### ')
|
||||||
|
|
||||||
|
if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0:
|
||||||
|
self.request_json_body['top_p'] = 0.01
|
||||||
|
|
||||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||||
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||||
|
|
||||||
|
|
|
@ -49,10 +49,10 @@ def stream(ws):
|
||||||
|
|
||||||
handler = OobaRequestHandler(request, request_json_body)
|
handler = OobaRequestHandler(request, request_json_body)
|
||||||
generated_text = ''
|
generated_text = ''
|
||||||
input_prompt = None
|
input_prompt = request_json_body['prompt']
|
||||||
response_status_code = 0
|
response_status_code = 0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
request_valid, invalid_response = handler.validate_request()
|
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
err_msg = invalid_response[0].json['results'][0]['text']
|
err_msg = invalid_response[0].json['results'][0]['text']
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
|
@ -69,7 +69,6 @@ def stream(ws):
|
||||||
thread.start()
|
thread.start()
|
||||||
thread.join()
|
thread.join()
|
||||||
else:
|
else:
|
||||||
input_prompt = request_json_body['prompt']
|
|
||||||
msg_to_backend = {
|
msg_to_backend = {
|
||||||
**handler.parameters,
|
**handler.parameters,
|
||||||
'prompt': input_prompt,
|
'prompt': input_prompt,
|
||||||
|
@ -142,7 +141,9 @@ def stream(ws):
|
||||||
thread = threading.Thread(target=background_task_exception)
|
thread = threading.Thread(target=background_task_exception)
|
||||||
thread.start()
|
thread.start()
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'stream_end',
|
'event': 'stream_end',
|
||||||
'message_num': message_num
|
'message_num': message_num
|
||||||
}))
|
}))
|
||||||
|
ws.close() # this is important if we encountered and error and exited early.
|
||||||
|
|
|
@ -16,14 +16,16 @@ from llm_server.llm import get_token_count
|
||||||
from llm_server.routes.openai import openai_bp
|
from llm_server.routes.openai import openai_bp
|
||||||
from llm_server.routes.server_error import handle_server_error
|
from llm_server.routes.server_error import handle_server_error
|
||||||
|
|
||||||
# TODO: allow setting more custom ratelimits per-token
|
|
||||||
# TODO: add more excluding to SYSTEM__ tokens
|
|
||||||
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
|
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
|
||||||
# TODO: support turbo-instruct on openai endpoint
|
# TODO: support turbo-instruct on openai endpoint
|
||||||
# TODO: option to trim context in openai mode so that we silently fit the model's context
|
# TODO: option to trim context in openai mode so that we silently fit the model's context
|
||||||
# TODO: validate system tokens before excluding them
|
# TODO: validate system tokens before excluding them
|
||||||
# TODO: unify logging thread in a function and use async/await instead
|
# TODO: make sure prompts are logged even when the user cancels generation
|
||||||
|
|
||||||
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
|
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
|
||||||
|
# TODO: unify logging thread in a function and use async/await instead
|
||||||
|
# TODO: add more excluding to SYSTEM__ tokens
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import vllm
|
import vllm
|
||||||
|
|
Reference in New Issue