rewrite tokenizer, restructure validation

This commit is contained in:
Cyberes 2023-09-24 13:02:30 -06:00
parent 62412f4873
commit cb99c3490e
15 changed files with 98 additions and 56 deletions

View File

@ -14,7 +14,10 @@ class DatabaseConnection:
password=password,
database=database,
charset='utf8mb4',
autocommit=True,
)
# Test it.
conn = self.db_pool.connection()
del conn

View File

@ -9,11 +9,11 @@ from llm_server.llm.vllm import tokenize
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
prompt_tokens = llm_server.llm.tokenizer(prompt)
prompt_tokens = llm_server.llm.get_token_count(prompt)
if not is_error:
if not response_tokens:
response_tokens = llm_server.llm.tokenizer(response)
response_tokens = llm_server.llm.get_token_count(response)
else:
response_tokens = None

View File

@ -0,0 +1,12 @@
from llm_server.llm import oobabooga, vllm
from llm_server.routes.cache import redis
def get_token_count(prompt):
backend_mode = redis.get('backend_mode', str)
if backend_mode == 'vllm':
return vllm.tokenize(prompt)
elif backend_mode == 'ooba':
return oobabooga.tokenize(prompt)
else:
raise Exception(backend_mode)

View File

@ -1 +0,0 @@
tokenizer = None

View File

@ -3,7 +3,7 @@ from typing import Tuple, Union
import flask
from llm_server import opts
from llm_server.llm.backend import tokenizer
from llm_server.llm import get_token_count
class LLMBackend:
@ -27,11 +27,17 @@ class LLMBackend:
"""
raise NotImplementedError
def validate_request(self, parameters: dict) -> Tuple[bool, Union[str, None]]:
raise NotImplementedError
def validate_request(self, parameters: dict, prompt: str, request: flask.Request) -> Tuple[bool, Union[str, None]]:
"""
If a backend needs to do other checks not related to the prompt or parameters.
Default is no extra checks preformed.
:param parameters:
:return:
"""
return True, None
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
prompt_len = len(tokenizer(prompt))
prompt_len = get_token_count(prompt)
if prompt_len > opts.context_size - 10:
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size'
return True, None

View File

@ -0,0 +1 @@
from .tokenize import tokenize

View File

@ -1,6 +1,7 @@
"""
This file is used by the worker that processes requests.
"""
import traceback
import requests
@ -13,6 +14,7 @@ def generate(json_data: dict):
except requests.exceptions.ReadTimeout:
return False, None, 'Request to backend timed out'
except Exception as e:
traceback.print_exc()
return False, None, 'Request to backend encountered error'
if r.status_code != 200:
return False, r, f'Backend returned {r.status_code}'

View File

@ -0,0 +1,7 @@
import tiktoken
tokenizer = tiktoken.get_encoding("cl100k_base")
def tokenize(prompt: str) -> int:
return len(tokenizer.encode(prompt)) + 10

View File

@ -3,6 +3,7 @@ This file is used by the worker that processes requests.
"""
import json
import time
import traceback
from uuid import uuid4
import requests
@ -17,6 +18,9 @@ from llm_server import opts
def prepare_json(json_data: dict):
# logit_bias is not currently supported
# del json_data['logit_bias']
# Convert back to VLLM.
json_data['max_tokens'] = json_data.pop('max_new_tokens')
return json_data
@ -43,8 +47,8 @@ def transform_to_text(json_request, api_response):
if data['choices'][0]['finish_reason']:
finish_reason = data['choices'][0]['finish_reason']
prompt_tokens = len(llm_server.llm.tokenizer(prompt))
completion_tokens = len(llm_server.llm.tokenizer(text))
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
completion_tokens = len(llm_server.llm.get_token_count(text))
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
return {
@ -83,7 +87,8 @@ def handle_blocking_request(json_data: dict):
except requests.exceptions.ReadTimeout:
return False, None, 'Request to backend timed out'
except Exception as e:
return False, None, 'Request to backend encountered error' # f'{e.__class__.__name__}: {e}'
traceback.print_exc()
return False, None, 'Request to backend encountered error'
if r.status_code != 200:
return False, r, f'Backend returned {r.status_code}'
return True, r, None

View File

@ -1,11 +1,8 @@
import traceback
from typing import Tuple, Union
import requests
from flask import jsonify
from vllm import SamplingParams
import llm_server
from llm_server import opts
from llm_server.database.database import log_prompt
from llm_server.llm.llm_backend import LLMBackend
@ -42,25 +39,9 @@ class VLLMBackend(LLMBackend):
)
except ValueError as e:
return None, str(e).strip('.')
return vars(sampling_params), None
def validate_request(self, parameters) -> (bool, Union[str, None]):
if parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
return True, None
# We use max_new_tokens throughout the server.
result = vars(sampling_params)
result['max_new_tokens'] = result.pop('max_tokens')
# def tokenize(self, prompt):
# try:
# r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
# j = r.json()
# return j['length']
# except:
# # Fall back to whatever the superclass is doing.
# print(traceback.format_exc())
# return super().tokenize(prompt)
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
prompt_len = llm_server.llm.tokenizer(prompt)
if prompt_len > opts.context_size:
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size'
return True, None
return result, None

View File

@ -18,6 +18,6 @@ def openai_chat_completions():
return OpenAIRequestHandler(request).handle_request()
except Exception as e:
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
print(traceback.format_exc())
traceback.print_exc()
print(request.data)
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 200

View File

@ -131,8 +131,8 @@ def build_openai_response(prompt, response):
if len(x) > 1:
response = re.sub(r'\n$', '', y[0].strip(' '))
prompt_tokens = llm_server.llm.tokenizer(prompt)
response_tokens = llm_server.llm.tokenizer(response)
prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response)
return jsonify({
"id": f"chatcmpl-{uuid4()}",
"object": "chat.completion",

View File

@ -69,30 +69,55 @@ class RequestHandler:
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
return parameters, parameters_invalid_msg
def validate_request(self) -> Tuple[bool, Tuple[Response | None, int]]:
self.parameters, parameters_invalid_msg = self.get_parameters()
request_valid = False
invalid_request_err_msg = None
if self.parameters:
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters)
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
"""
This needs to be called at the start of the subclass handle_request() method.
:param prompt:
:param do_log:
:return:
"""
invalid_request_err_msgs = []
if not request_valid:
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (not bool(parameters_invalid_msg), parameters_invalid_msg)] if not valid and msg]
combined_error_message = ', '.join(error_messages)
backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
self.parameters, parameters_invalid_msg = self.get_parameters() # Parameters will be None if invalid.
if self.parameters and not parameters_invalid_msg:
# Backends shouldn't check max_new_tokens, but rather things specific to their backend.
# Let the RequestHandler do the generic checks.
if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
invalid_request_err_msgs.append(f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}')
if prompt:
prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt)
if not prompt_valid:
invalid_request_err_msgs.append(invalid_prompt_err_msg)
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters, prompt, self.request)
if not request_valid:
invalid_request_err_msgs.append(invalid_request_err_msg)
else:
invalid_request_err_msgs.append(parameters_invalid_msg)
if len(invalid_request_err_msgs):
if len(invalid_request_err_msgs) > 1:
# Format multiple error messages each on a new line.
e = [f'\n{x}.' for x in invalid_request_err_msgs]
combined_error_message = '\n'.join(e)
else:
# Otherwise, just grab the first and only one.
combined_error_message = invalid_request_err_msgs[0] + '.'
backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}', 'error')
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
return False, self.handle_error(backend_response)
return True, (None, 0)
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
prompt = llm_request['prompt']
if not self.is_client_ratelimited():
# Validate the prompt right before submission since the backend handler may have changed something.
prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt)
if not prompt_valid:
backend_response = format_sillytavern_err(f'Validation Error: {invalid_prompt_err_msg}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
return (False, None, None, 0), self.handle_error(backend_response)
# Validate again before submission since the backend handler may have changed something.
# Also, this is the first time we validate the prompt.
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
if not request_valid:
return (False, None, None, 0), invalid_response
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
else:

View File

@ -10,8 +10,7 @@ import llm_server
from llm_server.database.conn import db_pool
from llm_server.database.create import create_db
from llm_server.database.database import get_number_of_rows
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.llm import get_token_count
from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error
@ -96,6 +95,8 @@ if not opts.verify_ssl:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
redis.set('backend_mode', opts.mode)
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
@ -108,7 +109,7 @@ if opts.mode == 'oobabooga':
raise NotImplementedError
# llm_server.llm.tokenizer = OobaboogaBackend()
elif opts.mode == 'vllm':
llm_server.llm.tokenizer = llm_server.llm.vllm.tokenize
llm_server.llm.get_token_count = llm_server.llm.vllm.tokenize
else:
raise Exception