rewrite tokenizer, restructure validation
This commit is contained in:
parent
62412f4873
commit
cb99c3490e
|
@ -14,7 +14,10 @@ class DatabaseConnection:
|
|||
password=password,
|
||||
database=database,
|
||||
charset='utf8mb4',
|
||||
autocommit=True,
|
||||
)
|
||||
|
||||
# Test it.
|
||||
conn = self.db_pool.connection()
|
||||
del conn
|
||||
|
||||
|
|
|
@ -9,11 +9,11 @@ from llm_server.llm.vllm import tokenize
|
|||
|
||||
|
||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
|
||||
prompt_tokens = llm_server.llm.tokenizer(prompt)
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||
|
||||
if not is_error:
|
||||
if not response_tokens:
|
||||
response_tokens = llm_server.llm.tokenizer(response)
|
||||
response_tokens = llm_server.llm.get_token_count(response)
|
||||
else:
|
||||
response_tokens = None
|
||||
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
from llm_server.llm import oobabooga, vllm
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
def get_token_count(prompt):
|
||||
backend_mode = redis.get('backend_mode', str)
|
||||
if backend_mode == 'vllm':
|
||||
return vllm.tokenize(prompt)
|
||||
elif backend_mode == 'ooba':
|
||||
return oobabooga.tokenize(prompt)
|
||||
else:
|
||||
raise Exception(backend_mode)
|
|
@ -1 +0,0 @@
|
|||
tokenizer = None
|
|
@ -3,7 +3,7 @@ from typing import Tuple, Union
|
|||
import flask
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.llm.backend import tokenizer
|
||||
from llm_server.llm import get_token_count
|
||||
|
||||
|
||||
class LLMBackend:
|
||||
|
@ -27,11 +27,17 @@ class LLMBackend:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_request(self, parameters: dict) -> Tuple[bool, Union[str, None]]:
|
||||
raise NotImplementedError
|
||||
def validate_request(self, parameters: dict, prompt: str, request: flask.Request) -> Tuple[bool, Union[str, None]]:
|
||||
"""
|
||||
If a backend needs to do other checks not related to the prompt or parameters.
|
||||
Default is no extra checks preformed.
|
||||
:param parameters:
|
||||
:return:
|
||||
"""
|
||||
return True, None
|
||||
|
||||
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
|
||||
prompt_len = len(tokenizer(prompt))
|
||||
prompt_len = get_token_count(prompt)
|
||||
if prompt_len > opts.context_size - 10:
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size'
|
||||
return True, None
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .tokenize import tokenize
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
This file is used by the worker that processes requests.
|
||||
"""
|
||||
import traceback
|
||||
|
||||
import requests
|
||||
|
||||
|
@ -13,6 +14,7 @@ def generate(json_data: dict):
|
|||
except requests.exceptions.ReadTimeout:
|
||||
return False, None, 'Request to backend timed out'
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, None, 'Request to backend encountered error'
|
||||
if r.status_code != 200:
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
import tiktoken
|
||||
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def tokenize(prompt: str) -> int:
|
||||
return len(tokenizer.encode(prompt)) + 10
|
|
@ -3,6 +3,7 @@ This file is used by the worker that processes requests.
|
|||
"""
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
@ -17,6 +18,9 @@ from llm_server import opts
|
|||
def prepare_json(json_data: dict):
|
||||
# logit_bias is not currently supported
|
||||
# del json_data['logit_bias']
|
||||
|
||||
# Convert back to VLLM.
|
||||
json_data['max_tokens'] = json_data.pop('max_new_tokens')
|
||||
return json_data
|
||||
|
||||
|
||||
|
@ -43,8 +47,8 @@ def transform_to_text(json_request, api_response):
|
|||
if data['choices'][0]['finish_reason']:
|
||||
finish_reason = data['choices'][0]['finish_reason']
|
||||
|
||||
prompt_tokens = len(llm_server.llm.tokenizer(prompt))
|
||||
completion_tokens = len(llm_server.llm.tokenizer(text))
|
||||
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
|
||||
completion_tokens = len(llm_server.llm.get_token_count(text))
|
||||
|
||||
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
|
||||
return {
|
||||
|
@ -83,7 +87,8 @@ def handle_blocking_request(json_data: dict):
|
|||
except requests.exceptions.ReadTimeout:
|
||||
return False, None, 'Request to backend timed out'
|
||||
except Exception as e:
|
||||
return False, None, 'Request to backend encountered error' # f'{e.__class__.__name__}: {e}'
|
||||
traceback.print_exc()
|
||||
return False, None, 'Request to backend encountered error'
|
||||
if r.status_code != 200:
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
return True, r, None
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
import traceback
|
||||
from typing import Tuple, Union
|
||||
|
||||
import requests
|
||||
from flask import jsonify
|
||||
from vllm import SamplingParams
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.llm.llm_backend import LLMBackend
|
||||
|
@ -42,25 +39,9 @@ class VLLMBackend(LLMBackend):
|
|||
)
|
||||
except ValueError as e:
|
||||
return None, str(e).strip('.')
|
||||
return vars(sampling_params), None
|
||||
|
||||
def validate_request(self, parameters) -> (bool, Union[str, None]):
|
||||
if parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
|
||||
return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
|
||||
return True, None
|
||||
# We use max_new_tokens throughout the server.
|
||||
result = vars(sampling_params)
|
||||
result['max_new_tokens'] = result.pop('max_tokens')
|
||||
|
||||
# def tokenize(self, prompt):
|
||||
# try:
|
||||
# r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
# j = r.json()
|
||||
# return j['length']
|
||||
# except:
|
||||
# # Fall back to whatever the superclass is doing.
|
||||
# print(traceback.format_exc())
|
||||
# return super().tokenize(prompt)
|
||||
|
||||
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
|
||||
prompt_len = llm_server.llm.tokenizer(prompt)
|
||||
if prompt_len > opts.context_size:
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size'
|
||||
return True, None
|
||||
return result, None
|
||||
|
|
|
@ -18,6 +18,6 @@ def openai_chat_completions():
|
|||
return OpenAIRequestHandler(request).handle_request()
|
||||
except Exception as e:
|
||||
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
|
||||
print(traceback.format_exc())
|
||||
traceback.print_exc()
|
||||
print(request.data)
|
||||
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 200
|
||||
|
|
|
@ -131,8 +131,8 @@ def build_openai_response(prompt, response):
|
|||
if len(x) > 1:
|
||||
response = re.sub(r'\n$', '', y[0].strip(' '))
|
||||
|
||||
prompt_tokens = llm_server.llm.tokenizer(prompt)
|
||||
response_tokens = llm_server.llm.tokenizer(response)
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||
response_tokens = llm_server.llm.get_token_count(response)
|
||||
return jsonify({
|
||||
"id": f"chatcmpl-{uuid4()}",
|
||||
"object": "chat.completion",
|
||||
|
|
|
@ -69,30 +69,55 @@ class RequestHandler:
|
|||
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
||||
return parameters, parameters_invalid_msg
|
||||
|
||||
def validate_request(self) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||
self.parameters, parameters_invalid_msg = self.get_parameters()
|
||||
request_valid = False
|
||||
invalid_request_err_msg = None
|
||||
if self.parameters:
|
||||
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters)
|
||||
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||
"""
|
||||
This needs to be called at the start of the subclass handle_request() method.
|
||||
:param prompt:
|
||||
:param do_log:
|
||||
:return:
|
||||
"""
|
||||
invalid_request_err_msgs = []
|
||||
|
||||
if not request_valid:
|
||||
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (not bool(parameters_invalid_msg), parameters_invalid_msg)] if not valid and msg]
|
||||
combined_error_message = ', '.join(error_messages)
|
||||
backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
self.parameters, parameters_invalid_msg = self.get_parameters() # Parameters will be None if invalid.
|
||||
if self.parameters and not parameters_invalid_msg:
|
||||
# Backends shouldn't check max_new_tokens, but rather things specific to their backend.
|
||||
# Let the RequestHandler do the generic checks.
|
||||
if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
|
||||
invalid_request_err_msgs.append(f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}')
|
||||
|
||||
if prompt:
|
||||
prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt)
|
||||
if not prompt_valid:
|
||||
invalid_request_err_msgs.append(invalid_prompt_err_msg)
|
||||
|
||||
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters, prompt, self.request)
|
||||
if not request_valid:
|
||||
invalid_request_err_msgs.append(invalid_request_err_msg)
|
||||
else:
|
||||
invalid_request_err_msgs.append(parameters_invalid_msg)
|
||||
|
||||
if len(invalid_request_err_msgs):
|
||||
if len(invalid_request_err_msgs) > 1:
|
||||
# Format multiple error messages each on a new line.
|
||||
e = [f'\n{x}.' for x in invalid_request_err_msgs]
|
||||
combined_error_message = '\n'.join(e)
|
||||
else:
|
||||
# Otherwise, just grab the first and only one.
|
||||
combined_error_message = invalid_request_err_msgs[0] + '.'
|
||||
backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}', 'error')
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
return False, self.handle_error(backend_response)
|
||||
return True, (None, 0)
|
||||
|
||||
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
|
||||
prompt = llm_request['prompt']
|
||||
if not self.is_client_ratelimited():
|
||||
# Validate the prompt right before submission since the backend handler may have changed something.
|
||||
prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt)
|
||||
if not prompt_valid:
|
||||
backend_response = format_sillytavern_err(f'Validation Error: {invalid_prompt_err_msg}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
return (False, None, None, 0), self.handle_error(backend_response)
|
||||
# Validate again before submission since the backend handler may have changed something.
|
||||
# Also, this is the first time we validate the prompt.
|
||||
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
|
||||
if not request_valid:
|
||||
return (False, None, None, 0), invalid_response
|
||||
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
||||
else:
|
||||
|
|
|
@ -10,8 +10,7 @@ import llm_server
|
|||
from llm_server.database.conn import db_pool
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.database.database import get_number_of_rows
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.routes.openai import openai_bp
|
||||
from llm_server.routes.server_error import handle_server_error
|
||||
|
||||
|
@ -96,6 +95,8 @@ if not opts.verify_ssl:
|
|||
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
redis.set('backend_mode', opts.mode)
|
||||
|
||||
if config['load_num_prompts']:
|
||||
redis.set('proompts', get_number_of_rows('prompts'))
|
||||
|
||||
|
@ -108,7 +109,7 @@ if opts.mode == 'oobabooga':
|
|||
raise NotImplementedError
|
||||
# llm_server.llm.tokenizer = OobaboogaBackend()
|
||||
elif opts.mode == 'vllm':
|
||||
llm_server.llm.tokenizer = llm_server.llm.vllm.tokenize
|
||||
llm_server.llm.get_token_count = llm_server.llm.vllm.tokenize
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
|
|
Reference in New Issue