109 lines
4.5 KiB
Python
109 lines
4.5 KiB
Python
import concurrent.futures
|
|
import re
|
|
import secrets
|
|
import string
|
|
import traceback
|
|
from typing import Dict, List
|
|
|
|
import tiktoken
|
|
|
|
from llm_server import opts
|
|
from llm_server.llm import get_token_count
|
|
|
|
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
|
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
|
|
|
|
|
|
def generate_oai_string(length=24):
|
|
alphabet = string.ascii_letters + string.digits
|
|
return ''.join(secrets.choice(alphabet) for i in range(length))
|
|
|
|
|
|
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
|
|
def get_token_count_thread(msg):
|
|
return get_token_count(msg["content"], backend_url)
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
token_counts = list(executor.map(get_token_count_thread, prompt))
|
|
|
|
total_tokens = sum(token_counts)
|
|
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
|
|
|
|
# If total tokens exceed the limit, start trimming
|
|
if total_tokens + formatting_tokens > context_token_limit:
|
|
while True:
|
|
while total_tokens + formatting_tokens > context_token_limit:
|
|
# Calculate the index to start removing messages from
|
|
remove_index = len(prompt) // 3
|
|
|
|
while remove_index < len(prompt):
|
|
total_tokens -= token_counts[remove_index]
|
|
prompt.pop(remove_index)
|
|
token_counts.pop(remove_index)
|
|
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
|
|
break
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
token_counts = list(executor.map(get_token_count_thread, prompt))
|
|
|
|
total_tokens = sum(token_counts)
|
|
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
|
|
if total_tokens + formatting_tokens > context_token_limit:
|
|
# Start over, but this time calculate the token count using the backend
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
token_counts = list(executor.map(get_token_count_thread, prompt))
|
|
else:
|
|
break
|
|
return prompt
|
|
|
|
|
|
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
|
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
token_count = get_token_count(prompt, backend_url)
|
|
|
|
# If total tokens exceed the limit, start trimming
|
|
if token_count > context_token_limit:
|
|
while True:
|
|
while token_count > context_token_limit:
|
|
# Calculate the index to start removing characters from
|
|
remove_index = len(prompt) // 3
|
|
|
|
while remove_index < len(prompt):
|
|
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
|
|
token_count = len(tokenizer.encode(prompt))
|
|
if token_count <= context_token_limit or remove_index == len(prompt):
|
|
break
|
|
|
|
token_count = get_token_count(prompt, backend_url)
|
|
if token_count > context_token_limit:
|
|
# Start over, but this time calculate the token count using the backend
|
|
token_count = get_token_count(prompt, backend_url)
|
|
else:
|
|
break
|
|
return prompt
|
|
|
|
|
|
def transform_messages_to_prompt(oai_messages):
|
|
try:
|
|
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
|
|
for msg in oai_messages:
|
|
if 'content' not in msg.keys() or 'role' not in msg.keys():
|
|
return False
|
|
msg['content'] = str(msg['content']) # Prevent any weird issues.
|
|
if msg['role'] == 'system':
|
|
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
|
|
elif msg['role'] == 'user':
|
|
prompt += f'### USER: {msg["content"]}\n\n'
|
|
elif msg['role'] == 'assistant':
|
|
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
|
|
else:
|
|
raise Exception(f'Unknown role: {msg["role"]}')
|
|
except Exception as e:
|
|
# TODO: use logging
|
|
traceback.print_exc()
|
|
return ''
|
|
|
|
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
|
|
prompt += '\n\n### RESPONSE: '
|
|
return prompt
|