import asyncio import json import os import sys import time from pathlib import Path try: import websockets except ImportError: print("Websockets package not found. Make sure it's installed.") script_path = os.path.dirname(os.path.realpath(__file__)) def parse_bash_config(file_path): config = {} with open(file_path, 'r') as f: for line in f: if line.startswith('#') or '=' not in line: continue key, value = line.strip().split('=', 1) if value.startswith('"') and value.endswith('"'): value = value[1:-1] elif value.startswith('(') and value.endswith(')'): value = value[1:-1].split() value = [v.strip('"') for v in value] config[key] = value return config config = parse_bash_config(Path(script_path, 'config.sh')) async def run(context): request = { 'prompt': context, 'max_new_tokens': 250, 'auto_max_new_tokens': False, 'max_tokens_second': 0, 'preset': 'None', 'do_sample': True, 'temperature': 0.7, 'top_p': 0.1, 'typical_p': 1, 'epsilon_cutoff': 0, 'eta_cutoff': 0, 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, 'repetition_penalty_range': 0, 'top_k': 40, 'min_length': 0, 'no_repeat_ngram_size': 0, 'num_beams': 1, 'penalty_alpha': 0, 'length_penalty': 1, 'early_stopping': False, 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, 'guidance_scale': 1, 'negative_prompt': '', 'seed': -1, 'add_bos_token': True, 'truncation_length': 2048, 'ban_eos_token': False, 'custom_token_bans': '', 'skip_special_tokens': True, 'stopping_strings': [] } socket_type = 'ws://' if config['HOST'].startswith('https://'): socket_type = 'wss://' config['HOST'] = config['HOST'].strip('http://') config['HOST'] = config['HOST'].strip('https://') print('Connecting to', f'{socket_type}{config["HOST"]}/api/v1/stream') async with websockets.connect(f'{socket_type}{config["HOST"]}/api/v1/stream', ping_interval=None) as websocket: await websocket.send(json.dumps(request)) yield context # Remove this if you just want to see the reply while True: incoming_data = await websocket.recv() incoming_data = json.loads(incoming_data) print(incoming_data) match incoming_data['event']: # case 'text_stream': # yield incoming_data['text'] case 'stream_end': return async def print_response_stream(prompt): # try: async for response in run(prompt): print(response, end='') sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. # except Exception as e: # print(e) if __name__ == '__main__': prompt = "Write a 300 word story about an apple tree.\n\n" while True: print('--> START <--') asyncio.run(print_response_stream(prompt)) print('--> DONE <--') time.sleep(2)