import concurrent.futures import re import secrets import string import time import traceback from typing import Dict, List import tiktoken from flask import jsonify, make_response import llm_server from llm_server import opts from llm_server.llm import get_token_count from llm_server.routes.cache import redis ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line. ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line. def build_openai_response(prompt, response, model=None): # Seperate the user's prompt from the context x = prompt.split('### USER:') if len(x) > 1: prompt = re.sub(r'\n$', '', x[-1].strip(' ')) # Make sure the bot doesn't put any other instructions in its response # y = response.split('\n### ') # if len(y) > 1: # response = re.sub(r'\n$', '', y[0].strip(' ')) response = re.sub(ANTI_RESPONSE_RE, '', response) response = re.sub(ANTI_CONTINUATION_RE, '', response) # TODO: async/await prompt_tokens = llm_server.llm.get_token_count(prompt) response_tokens = llm_server.llm.get_token_count(response) running_model = redis.get('running_model', str, 'ERROR') response = make_response(jsonify({ "id": f"chatcmpl-{generate_oai_string(30)}", "object": "chat.completion", "created": int(time.time()), "model": running_model if opts.openai_expose_our_model else model, "choices": [{ "index": 0, "message": { "role": "assistant", "content": response, }, "logprobs": None, "finish_reason": "stop" }], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": response_tokens, "total_tokens": prompt_tokens + response_tokens } }), 200) stats = redis.get('proxy_stats', dict) if stats: response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] return response def generate_oai_string(length=24): alphabet = string.ascii_letters + string.digits return ''.join(secrets.choice(alphabet) for i in range(length)) def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]: tokenizer = tiktoken.get_encoding("cl100k_base") def get_token_count_tiktoken_thread(msg): return len(tokenizer.encode(msg["content"])) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt)) total_tokens = sum(token_counts) formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens # If total tokens exceed the limit, start trimming if total_tokens > context_token_limit: while True: while total_tokens + formatting_tokens > context_token_limit: # Calculate the index to start removing messages from remove_index = len(prompt) // 3 while remove_index < len(prompt): total_tokens -= token_counts[remove_index] prompt.pop(remove_index) token_counts.pop(remove_index) if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt): break def get_token_count_thread(msg): return get_token_count(msg["content"]) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: token_counts = list(executor.map(get_token_count_thread, prompt)) total_tokens = sum(token_counts) formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens if total_tokens + formatting_tokens > context_token_limit: # Start over, but this time calculate the token count using the backend with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: token_counts = list(executor.map(get_token_count_thread, prompt)) else: break return prompt def transform_messages_to_prompt(oai_messages): try: prompt = f'### INSTRUCTION: {opts.openai_system_prompt}' for msg in oai_messages: if not msg.get('content') or not msg.get('role'): return False if msg['role'] == 'system': prompt += f'### INSTRUCTION: {msg["content"]}\n\n' elif msg['role'] == 'user': prompt += f'### USER: {msg["content"]}\n\n' elif msg['role'] == 'assistant': prompt += f'### ASSISTANT: {msg["content"]}\n\n' else: return False except Exception as e: # TODO: use logging traceback.print_exc() return '' prompt = prompt.strip(' ').strip('\n').strip('\n\n') # TODO: this is really lazy prompt += '\n\n### RESPONSE: ' return prompt