import tiktoken from llm_server.cluster.cluster_config import cluster_config from llm_server.llm import oobabooga, vllm from llm_server.logging import create_logger def fallback_tokenizer(prompt: str): tokenizer = tiktoken.get_encoding("cl100k_base") return len(tokenizer.encode(prompt)) + 10 def get_token_count(prompt: str, backend_url: str): backend_url = cluster_config.validate_backend(backend_url) if not backend_url: logger = create_logger('tokenizer') logger.warning('using fallback tokenizer as there is no valid backend') return fallback_tokenizer(prompt) backend_mode = cluster_config.get_backend(backend_url).get('mode') if not backend_mode: logger = create_logger('tokenizer') logger.warning("using fallback tokenizer as the backend isn't initalized") return fallback_tokenizer(prompt) if backend_mode == 'vllm': return vllm.tokenize(prompt, backend_url) elif backend_mode == 'ooba': return oobabooga.tokenize(prompt) else: raise Exception(backend_mode)