2023-09-26 22:09:11 -06:00
|
|
|
import concurrent.futures
|
|
|
|
import re
|
|
|
|
import secrets
|
|
|
|
import string
|
|
|
|
import time
|
|
|
|
import traceback
|
|
|
|
from typing import Dict, List
|
|
|
|
|
|
|
|
import tiktoken
|
2023-09-26 23:59:22 -06:00
|
|
|
from flask import jsonify, make_response
|
2023-09-26 22:09:11 -06:00
|
|
|
|
|
|
|
import llm_server
|
|
|
|
from llm_server import opts
|
|
|
|
from llm_server.llm import get_token_count
|
|
|
|
from llm_server.routes.cache import redis
|
|
|
|
|
|
|
|
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
|
|
|
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
|
|
|
|
|
|
|
|
|
|
|
|
def build_openai_response(prompt, response, model=None):
|
|
|
|
# Seperate the user's prompt from the context
|
|
|
|
x = prompt.split('### USER:')
|
|
|
|
if len(x) > 1:
|
|
|
|
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
|
|
|
|
|
|
|
# Make sure the bot doesn't put any other instructions in its response
|
|
|
|
# y = response.split('\n### ')
|
|
|
|
# if len(y) > 1:
|
|
|
|
# response = re.sub(r'\n$', '', y[0].strip(' '))
|
|
|
|
response = re.sub(ANTI_RESPONSE_RE, '', response)
|
|
|
|
response = re.sub(ANTI_CONTINUATION_RE, '', response)
|
|
|
|
|
|
|
|
# TODO: async/await
|
|
|
|
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
|
|
|
response_tokens = llm_server.llm.get_token_count(response)
|
|
|
|
running_model = redis.get('running_model', str, 'ERROR')
|
|
|
|
|
2023-09-26 23:59:22 -06:00
|
|
|
response = make_response(jsonify({
|
2023-09-26 22:09:11 -06:00
|
|
|
"id": f"chatcmpl-{generate_oai_string(30)}",
|
|
|
|
"object": "chat.completion",
|
|
|
|
"created": int(time.time()),
|
|
|
|
"model": running_model if opts.openai_expose_our_model else model,
|
|
|
|
"choices": [{
|
|
|
|
"index": 0,
|
|
|
|
"message": {
|
|
|
|
"role": "assistant",
|
|
|
|
"content": response,
|
|
|
|
},
|
|
|
|
"logprobs": None,
|
|
|
|
"finish_reason": "stop"
|
|
|
|
}],
|
|
|
|
"usage": {
|
|
|
|
"prompt_tokens": prompt_tokens,
|
|
|
|
"completion_tokens": response_tokens,
|
|
|
|
"total_tokens": prompt_tokens + response_tokens
|
|
|
|
}
|
2023-09-26 23:59:22 -06:00
|
|
|
}), 200)
|
|
|
|
|
|
|
|
stats = redis.get('proxy_stats', dict)
|
|
|
|
if stats:
|
|
|
|
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
|
|
|
return response
|
2023-09-26 22:09:11 -06:00
|
|
|
|
|
|
|
|
|
|
|
def generate_oai_string(length=24):
|
|
|
|
alphabet = string.ascii_letters + string.digits
|
|
|
|
return ''.join(secrets.choice(alphabet) for i in range(length))
|
|
|
|
|
|
|
|
|
|
|
|
def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]:
|
|
|
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
|
|
|
|
|
|
def get_token_count_tiktoken_thread(msg):
|
|
|
|
return len(tokenizer.encode(msg["content"]))
|
|
|
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
|
|
token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt))
|
|
|
|
|
|
|
|
total_tokens = sum(token_counts)
|
|
|
|
formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens
|
|
|
|
|
|
|
|
# If total tokens exceed the limit, start trimming
|
|
|
|
if total_tokens > context_token_limit:
|
|
|
|
while True:
|
|
|
|
while total_tokens + formatting_tokens > context_token_limit:
|
|
|
|
# Calculate the index to start removing messages from
|
|
|
|
remove_index = len(prompt) // 3
|
|
|
|
|
|
|
|
while remove_index < len(prompt):
|
|
|
|
total_tokens -= token_counts[remove_index]
|
|
|
|
prompt.pop(remove_index)
|
|
|
|
token_counts.pop(remove_index)
|
|
|
|
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
|
|
|
|
break
|
|
|
|
|
|
|
|
def get_token_count_thread(msg):
|
|
|
|
return get_token_count(msg["content"])
|
|
|
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
|
|
token_counts = list(executor.map(get_token_count_thread, prompt))
|
|
|
|
|
|
|
|
total_tokens = sum(token_counts)
|
|
|
|
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens
|
|
|
|
|
|
|
|
if total_tokens + formatting_tokens > context_token_limit:
|
|
|
|
# Start over, but this time calculate the token count using the backend
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
|
|
token_counts = list(executor.map(get_token_count_thread, prompt))
|
|
|
|
else:
|
|
|
|
break
|
|
|
|
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
def transform_messages_to_prompt(oai_messages):
|
|
|
|
try:
|
|
|
|
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
|
|
|
|
for msg in oai_messages:
|
|
|
|
if not msg.get('content') or not msg.get('role'):
|
|
|
|
return False
|
|
|
|
if msg['role'] == 'system':
|
|
|
|
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
|
|
|
|
elif msg['role'] == 'user':
|
|
|
|
prompt += f'### USER: {msg["content"]}\n\n'
|
|
|
|
elif msg['role'] == 'assistant':
|
|
|
|
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
|
|
# TODO: use logging
|
|
|
|
traceback.print_exc()
|
|
|
|
return ''
|
|
|
|
|
|
|
|
prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy
|
|
|
|
prompt += '\n\n### RESPONSE: '
|
|
|
|
return prompt
|