feat: add simple ttft load_test
This commit is contained in:
parent
0759ec495e
commit
fe3991e857
|
@ -0,0 +1,108 @@
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
HOST = os.getenv("HOST", "localhost:3000")
|
||||||
|
MODEL_ID = os.getenv("MODEL_ID", "default-model")
|
||||||
|
NUM_REQUESTS = 10
|
||||||
|
MAX_NEW_TOKENS = 100
|
||||||
|
TIMEOUT = 30
|
||||||
|
|
||||||
|
|
||||||
|
def load_inputs(filename):
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
inputs = []
|
||||||
|
for item in data:
|
||||||
|
if "conversations" in item:
|
||||||
|
if len(item["conversations"]) > 0:
|
||||||
|
inputs.append(item["conversations"][0]["value"])
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
def generate_payload(input_text):
|
||||||
|
return {
|
||||||
|
"messages": [{"role": "user", "content": input_text}],
|
||||||
|
"temperature": 0,
|
||||||
|
"model": MODEL_ID,
|
||||||
|
"max_tokens": MAX_NEW_TOKENS,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def benchmark_sse(session, input_text):
|
||||||
|
payload = generate_payload(input_text)
|
||||||
|
start_time = time()
|
||||||
|
first_token_time = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
f"http://{HOST}/v1/chat/completions", json=payload, timeout=TIMEOUT
|
||||||
|
) as response:
|
||||||
|
async for line in response.content:
|
||||||
|
if line.startswith(b"data:"):
|
||||||
|
if first_token_time is None:
|
||||||
|
first_token_time = time()
|
||||||
|
return (first_token_time - start_time) * 1000
|
||||||
|
|
||||||
|
if first_token_time is None:
|
||||||
|
raise Exception("No SSE data received within the timeout period")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise Exception(f"Request timed out after {TIMEOUT} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
async def run_benchmark(inputs, same_input=False):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
tasks = []
|
||||||
|
longest_input = 0
|
||||||
|
for i in range(NUM_REQUESTS):
|
||||||
|
input_text = inputs[0] if same_input else inputs[i % len(inputs)]
|
||||||
|
if len(input_text) > longest_input:
|
||||||
|
longest_input = len(input_text)
|
||||||
|
task = asyncio.create_task(benchmark_sse(session, input_text))
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i, task in enumerate(asyncio.as_completed(tasks), 1):
|
||||||
|
try:
|
||||||
|
time_to_first_event = await task
|
||||||
|
results.append(time_to_first_event)
|
||||||
|
print(
|
||||||
|
f"Request {i}: Time to first event - {time_to_first_event:.2f}ms longest input: {longest_input}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Request {i} failed: {str(e)}")
|
||||||
|
|
||||||
|
if results:
|
||||||
|
avg_time = sum(results) / len(results)
|
||||||
|
print(f"\nAverage time to first event: {avg_time:.2f}ms")
|
||||||
|
else:
|
||||||
|
print("\nNo successful requests")
|
||||||
|
|
||||||
|
return avg_time if results else None
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
inputs = load_inputs("small.json")
|
||||||
|
|
||||||
|
print("Running benchmark with same input:")
|
||||||
|
same_input_avg = await run_benchmark(inputs, same_input=True)
|
||||||
|
|
||||||
|
# sleep for a second to avoid the next inputs in the same batch
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
print("\nRunning benchmark with different inputs:")
|
||||||
|
different_inputs_avg = await run_benchmark(inputs, same_input=False)
|
||||||
|
|
||||||
|
if same_input_avg and different_inputs_avg:
|
||||||
|
print(f"\nSame input average: {same_input_avg:.2f}ms")
|
||||||
|
print(f"Different inputs average: {different_inputs_avg:.2f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
Loading…
Reference in New Issue