This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.

108 lines
4.4 KiB

import concurrent.futures
import re
import secrets
import string
import traceback
from typing import Dict, List
import tiktoken
from llm_server import opts
from llm_server.llm import get_token_count
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
def generate_oai_string(length=24):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for i in range(length))
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
def get_token_count_thread(msg):
return get_token_count(msg["content"], backend_url)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(, prompt))
total_tokens = sum(token_counts)
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
# If total tokens exceed the limit, start trimming
if total_tokens + formatting_tokens > context_token_limit:
while True:
while total_tokens + formatting_tokens > context_token_limit:
# Calculate the index to start removing messages from
remove_index = len(prompt) // 3
while remove_index < len(prompt):
total_tokens -= token_counts[remove_index]
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(, prompt))
total_tokens = sum(token_counts)
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
if total_tokens + formatting_tokens > context_token_limit:
# Start over, but this time calculate the token count using the backend
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(, prompt))
return prompt
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
tokenizer = tiktoken.get_encoding("cl100k_base")
token_count = get_token_count(prompt, backend_url)
# If total tokens exceed the limit, start trimming
if token_count > context_token_limit:
while True:
while token_count > context_token_limit:
# Calculate the index to start removing characters from
remove_index = len(prompt) // 3
while remove_index < len(prompt):
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
token_count = len(tokenizer.encode(prompt))
if token_count <= context_token_limit or remove_index == len(prompt):
token_count = get_token_count(prompt, backend_url)
if token_count > context_token_limit:
# Start over, but this time calculate the token count using the backend
token_count = get_token_count(prompt, backend_url)
return prompt
def transform_messages_to_prompt(oai_messages):
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
for msg in oai_messages:
if not msg.get('content') or not msg.get('role'):
return False
if msg['role'] == 'system':
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
elif msg['role'] == 'user':
prompt += f'### USER: {msg["content"]}\n\n'
elif msg['role'] == 'assistant':
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
raise Exception(f'Unknown role: {msg["role"]}')
except Exception as e:
# TODO: use logging
return ''
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
prompt += '\n\n### RESPONSE: '
return prompt