update database, tokenizer handle null prompt, convert top_p to vllm on openai, actually validate prompt on streaming,

This commit is contained in:
Cyberes 2023-09-25 22:32:48 -06:00
parent 2d299dbae5
commit 11e84db59c
6 changed files with 20 additions and 12 deletions

View File

@ -13,15 +13,15 @@ def create_db():
backend_url TEXT,
request_url TEXT,
generation_time FLOAT,
prompt TEXT,
prompt LONGTEXT,
prompt_tokens INTEGER,
response TEXT,
response LONGTEXT,
response_tokens INTEGER,
response_status INTEGER,
parameters TEXT,
CHECK (parameters IS NULL OR JSON_VALID(parameters)),
# CHECK (parameters IS NULL OR JSON_VALID(parameters)),
headers TEXT,
CHECK (headers IS NULL OR JSON_VALID(headers)),
# CHECK (headers IS NULL OR JSON_VALID(headers)),
timestamp INTEGER
)
''')

View File

@ -9,6 +9,9 @@ tokenizer = tiktoken.get_encoding("cl100k_base")
def tokenize(prompt: str) -> int:
if not prompt:
# The tokenizers have issues when the prompt is None.
return 0
try:
r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
j = r.json()

View File

@ -14,8 +14,7 @@ class OobaRequestHandler(RequestHandler):
super().__init__(*args, **kwargs)
def handle_request(self):
if self.used:
raise Exception('Can only use a RequestHandler object once.')
assert not self.used
request_valid, invalid_response = self.validate_request()
if not request_valid:

View File

@ -70,6 +70,9 @@ class OpenAIRequestHandler(RequestHandler):
if opts.openai_force_no_hashes:
self.parameters['stop'].append('### ')
if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0:
self.request_json_body['top_p'] = 0.01
llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)

View File

@ -49,10 +49,10 @@ def stream(ws):
handler = OobaRequestHandler(request, request_json_body)
generated_text = ''
input_prompt = None
input_prompt = request_json_body['prompt']
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
if not request_valid:
err_msg = invalid_response[0].json['results'][0]['text']
ws.send(json.dumps({
@ -69,7 +69,6 @@ def stream(ws):
thread.start()
thread.join()
else:
input_prompt = request_json_body['prompt']
msg_to_backend = {
**handler.parameters,
'prompt': input_prompt,
@ -142,7 +141,9 @@ def stream(ws):
thread = threading.Thread(target=background_task_exception)
thread.start()
thread.join()
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
ws.close() # this is important if we encountered and error and exited early.

View File

@ -16,14 +16,16 @@ from llm_server.llm import get_token_count
from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error
# TODO: allow setting more custom ratelimits per-token
# TODO: add more excluding to SYSTEM__ tokens
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
# TODO: support turbo-instruct on openai endpoint
# TODO: option to trim context in openai mode so that we silently fit the model's context
# TODO: validate system tokens before excluding them
# TODO: unify logging thread in a function and use async/await instead
# TODO: make sure prompts are logged even when the user cancels generation
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
# TODO: unify logging thread in a function and use async/await instead
# TODO: add more excluding to SYSTEM__ tokens
try:
import vllm