import concurrent.futures import re import secrets import string import traceback from typing import Dict, List import tiktoken from llm_server import opts from llm_server.llm import get_token_count ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line. ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line. def generate_oai_string(length=24): alphabet = string.ascii_letters + string.digits return ''.join(secrets.choice(alphabet) for i in range(length)) def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]: def get_token_count_thread(msg): return get_token_count(msg["content"], backend_url) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: token_counts = list(executor.map(get_token_count_thread, prompt)) total_tokens = sum(token_counts) formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens # If total tokens exceed the limit, start trimming if total_tokens + formatting_tokens > context_token_limit: while True: while total_tokens + formatting_tokens > context_token_limit: # Calculate the index to start removing messages from remove_index = len(prompt) // 3 while remove_index < len(prompt): total_tokens -= token_counts[remove_index] prompt.pop(remove_index) token_counts.pop(remove_index) if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt): break with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: token_counts = list(executor.map(get_token_count_thread, prompt)) total_tokens = sum(token_counts) formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens if total_tokens + formatting_tokens > context_token_limit: # Start over, but this time calculate the token count using the backend with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: token_counts = list(executor.map(get_token_count_thread, prompt)) else: break return prompt def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str: tokenizer = tiktoken.get_encoding("cl100k_base") token_count = get_token_count(prompt, backend_url) # If total tokens exceed the limit, start trimming if token_count > context_token_limit: while True: while token_count > context_token_limit: # Calculate the index to start removing characters from remove_index = len(prompt) // 3 while remove_index < len(prompt): prompt = prompt[:remove_index] + prompt[remove_index + 100:] token_count = len(tokenizer.encode(prompt)) if token_count <= context_token_limit or remove_index == len(prompt): break token_count = get_token_count(prompt, backend_url) if token_count > context_token_limit: # Start over, but this time calculate the token count using the backend token_count = get_token_count(prompt, backend_url) else: break return prompt def transform_messages_to_prompt(oai_messages): try: prompt = f'### INSTRUCTION: {opts.openai_system_prompt}' for msg in oai_messages: if not msg.get('content') or not msg.get('role'): return False if msg['role'] == 'system': prompt += f'### INSTRUCTION: {msg["content"]}\n\n' elif msg['role'] == 'user': prompt += f'### USER: {msg["content"]}\n\n' elif msg['role'] == 'assistant': prompt += f'### ASSISTANT: {msg["content"]}\n\n' else: return False except Exception as e: # TODO: use logging traceback.print_exc() return '' prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy prompt += '\n\n### RESPONSE: ' return prompt