local-llm-server/llm_server/llm/openai/transform.py

109 lines
4.5 KiB
Python

import concurrent.futures
import re
import secrets
import string
import traceback
from typing import Dict, List
import tiktoken
from llm_server import opts
from llm_server.llm import get_token_count
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
def generate_oai_string(length=24):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for i in range(length))
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
def get_token_count_thread(msg):
return get_token_count(msg["content"], backend_url)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts)
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
# If total tokens exceed the limit, start trimming
if total_tokens + formatting_tokens > context_token_limit:
while True:
while total_tokens + formatting_tokens > context_token_limit:
# Calculate the index to start removing messages from
remove_index = len(prompt) // 3
while remove_index < len(prompt):
total_tokens -= token_counts[remove_index]
prompt.pop(remove_index)
token_counts.pop(remove_index)
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
break
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts)
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
if total_tokens + formatting_tokens > context_token_limit:
# Start over, but this time calculate the token count using the backend
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
else:
break
return prompt
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
tokenizer = tiktoken.get_encoding("cl100k_base")
token_count = get_token_count(prompt, backend_url)
# If total tokens exceed the limit, start trimming
if token_count > context_token_limit:
while True:
while token_count > context_token_limit:
# Calculate the index to start removing characters from
remove_index = len(prompt) // 3
while remove_index < len(prompt):
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
token_count = len(tokenizer.encode(prompt))
if token_count <= context_token_limit or remove_index == len(prompt):
break
token_count = get_token_count(prompt, backend_url)
if token_count > context_token_limit:
# Start over, but this time calculate the token count using the backend
token_count = get_token_count(prompt, backend_url)
else:
break
return prompt
def transform_messages_to_prompt(oai_messages):
try:
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
for msg in oai_messages:
if 'content' not in msg.keys() or 'role' not in msg.keys():
return False
msg['content'] = str(msg['content']) # Prevent any weird issues.
if msg['role'] == 'system':
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
elif msg['role'] == 'user':
prompt += f'### USER: {msg["content"]}\n\n'
elif msg['role'] == 'assistant':
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
else:
raise Exception(f'Unknown role: {msg["role"]}')
except Exception as e:
# TODO: use logging
traceback.print_exc()
return ''
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
prompt += '\n\n### RESPONSE: '
return prompt