From 347a82b7e1f255999acb708b4e90eff89a432d3e Mon Sep 17 00:00:00 2001 From: Cyberes Date: Thu, 28 Sep 2023 03:54:20 -0600 Subject: [PATCH] avoid sending to backend to tokenize if it's greater than our specified context size --- llm_server/llm/vllm/tokenize.py | 24 ++++++++++++++++-------- llm_server/routes/v1/generate_stats.py | 2 +- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/llm_server/llm/vllm/tokenize.py b/llm_server/llm/vllm/tokenize.py index 6bb9343..a698fd6 100644 --- a/llm_server/llm/vllm/tokenize.py +++ b/llm_server/llm/vllm/tokenize.py @@ -5,14 +5,22 @@ from llm_server import opts def tokenize(prompt: str) -> int: - tokenizer = tiktoken.get_encoding("cl100k_base") if not prompt: # The tokenizers have issues when the prompt is None. return 0 - try: - r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) - j = r.json() - return j['length'] - except Exception as e: - print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}') - return len(tokenizer.encode(prompt)) + 10 + tokenizer = tiktoken.get_encoding("cl100k_base") + + # First we tokenize it locally to determine if it's worth sending it to the backend. + initial_estimate = len(tokenizer.encode(prompt)) + if initial_estimate <= opts.context_size + 200: + try: + r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + j = r.json() + return j['length'] + except Exception as e: + print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}') + return len(tokenizer.encode(prompt)) + 10 + else: + # If the result was greater than our context size, return the estimate. + # We won't be sending it through the backend so it does't need to be accurage. + return initial_estimate diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 5bf026b..e144099 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -89,7 +89,6 @@ def generate_stats(regen: bool = False): 'average_generation_elapsed_sec': int(average_generation_time), # 'estimated_avg_tps': estimated_avg_tps, 'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None, - 'nvidia': netdata_stats }, 'online': online, 'endpoints': { @@ -115,6 +114,7 @@ def generate_stats(regen: bool = False): 'anthropicKeys': '∞', }, 'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None, + 'nvidia': netdata_stats } result = deep_sort(output)