local-llm-server/llm_server/llm/openai/transform.py

138 lines
5.0 KiB
Python
Raw Normal View History

2023-09-26 22:09:11 -06:00
import concurrent.futures
import re
import secrets
import string
import time
import traceback
from typing import Dict, List
import tiktoken
from flask import jsonify, make_response
2023-09-26 22:09:11 -06:00
import llm_server
from llm_server import opts
from llm_server.llm import get_token_count
from llm_server.routes.cache import redis
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
def build_openai_response(prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
# y = response.split('\n### ')
# if len(y) > 1:
# response = re.sub(r'\n$', '', y[0].strip(' '))
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response)
running_model = redis.get('running_model', str, 'ERROR')
response = make_response(jsonify({
2023-09-26 22:09:11 -06:00
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
stats = redis.get('proxy_stats', dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
2023-09-26 22:09:11 -06:00
def generate_oai_string(length=24):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for i in range(length))
def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]:
tokenizer = tiktoken.get_encoding("cl100k_base")
def get_token_count_tiktoken_thread(msg):
return len(tokenizer.encode(msg["content"]))
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt))
total_tokens = sum(token_counts)
formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens
# If total tokens exceed the limit, start trimming
if total_tokens > context_token_limit:
while True:
while total_tokens + formatting_tokens > context_token_limit:
# Calculate the index to start removing messages from
remove_index = len(prompt) // 3
while remove_index < len(prompt):
total_tokens -= token_counts[remove_index]
prompt.pop(remove_index)
token_counts.pop(remove_index)
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
break
def get_token_count_thread(msg):
return get_token_count(msg["content"])
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts)
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens
if total_tokens + formatting_tokens > context_token_limit:
# Start over, but this time calculate the token count using the backend
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
else:
break
return prompt
def transform_messages_to_prompt(oai_messages):
try:
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
for msg in oai_messages:
if not msg.get('content') or not msg.get('role'):
return False
if msg['role'] == 'system':
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
elif msg['role'] == 'user':
prompt += f'### USER: {msg["content"]}\n\n'
elif msg['role'] == 'assistant':
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
else:
return False
except Exception as e:
# TODO: use logging
traceback.print_exc()
return ''
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
prompt += '\n\n### RESPONSE: '
return prompt