From 1352f70847efe365337b188b7d91ee235e19069e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 2 Dec 2024 07:51:00 +0100 Subject: [PATCH] Fix prefix caching for chat completion since we removed logprobs. --- load_tests/Makefile | 9 ++ load_tests/common.js | 94 +++++++++++++++++++ load_tests/filter.py | 26 +++++ load_tests/orca.py | 27 ++++++ router/src/lib.rs | 3 +- .../models/flash_causal_lm.py | 3 - 6 files changed, 157 insertions(+), 5 deletions(-) create mode 100644 load_tests/Makefile create mode 100644 load_tests/common.js create mode 100644 load_tests/filter.py create mode 100644 load_tests/orca.py diff --git a/load_tests/Makefile b/load_tests/Makefile new file mode 100644 index 00000000..9199aa3b --- /dev/null +++ b/load_tests/Makefile @@ -0,0 +1,9 @@ + +ShareGPT_V3_unfiltered_cleaned_split.json: + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +prepare_share: ShareGPT_V3_unfiltered_cleaned_split.json + python filter.py + +prepare_orca: + python orca.py diff --git a/load_tests/common.js b/load_tests/common.js new file mode 100644 index 00000000..d890bf67 --- /dev/null +++ b/load_tests/common.js @@ -0,0 +1,94 @@ +import { check } from 'k6'; +import { scenario } from 'k6/execution'; +import http from 'k6/http'; +import { Trend, Counter } from 'k6/metrics'; + +const host = __ENV.HOST; +const model_id = __ENV.MODEL_ID; +const timePerToken = new Trend('time_per_token', true); +const tokens = new Counter('tokens'); +const new_tokens = new Counter('new_tokens'); +const input_tokens = new Counter('input_tokens'); +const max_new_tokens = 50; + +// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json")) +const shareGPT = JSON.parse(open("small.json")) + + +export function get_options() { + return { + thresholds: { + http_req_failed: ['rate==0'], + // time_per_token: [{ + // threshold: `p(50)<${5 * reference_latency_ms}`, + // abortOnFail: true, + // delayAbortEval: '10s' + // }], + }, + scenarios: { + // single_user: { + // executor: 'constant-arrival-rate', + // duration: '60s', + // preAllocatedVUs: 1, + // rate: 20, + // timeUnit: '1s', + // }, + // load_test: { + // executor: 'constant-arrival-rate', + // duration: '60s', + // preAllocatedVUs: 100, + // rate: 1, + // timeUnit: '1s', + // }, + // breakpoint: { + // executor: 'ramping-arrival-rate', //Assure load increase if the system slows + // preAllocatedVUs: 300, + // stages: [ + // { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load + // ], + // }, + throughput: { + executor: 'shared-iterations', + vus: 100, + iterations: 200, + maxDuration: '40s', + }, + }, + }; +} + +function generate_payload(gpt, max_new_tokens) { + const input = gpt["conversations"][0]["value"]; + return { "messages": [{ "role": "user", "content": input }], "temperature": 0, "model": `${model_id}`, "max_tokens": max_new_tokens } +} + +export const options = get_options(); + +export default function run() { + const headers = { 'Content-Type': 'application/json' }; + const query = shareGPT[scenario.iterationInTest % shareGPT.length]; + const payload = JSON.stringify(generate_payload(query, max_new_tokens)); + const res = http.post(`http://${host}/v1/chat/completions`, payload, { + headers, + }); + if (res.status >= 400 && res.status < 500) { + return; + } + + + check(res, { + 'Post status is 200': (res) => res.status === 200, + }); + const duration = res.timings.duration; + + if (res.status === 200) { + const body = res.json(); + const completion_tokens = body.usage.completion_tokens; + const latency_ms_per_token = duration / completion_tokens; + timePerToken.add(latency_ms_per_token); + const prompt_tokens = body.usage.prompt_tokens; + input_tokens.add(prompt_tokens); + new_tokens.add(completion_tokens); + tokens.add(completion_tokens + prompt_tokens); + } +} diff --git a/load_tests/filter.py b/load_tests/filter.py new file mode 100644 index 00000000..a00226ed --- /dev/null +++ b/load_tests/filter.py @@ -0,0 +1,26 @@ +import json + + +def main(): + with open("./ShareGPT_V3_unfiltered_cleaned_split.json", "r") as f: + data = json.load(f) + + # Select only the first 2k conversations that start with a human. + max = 2000 + conversations = [] + for conversation in data: + conv = conversation.get("conversations") + if conv and conv[0]["from"] == "human": + # Trim the rest of the output + conversation["conversations"] = conversation["conversations"][:1] + conversations.append(conversation) + + if len(conversation) >= max: + break + + with open("./small.json", "w") as f: + data = json.dump(conversations, f, indent=4) + + +if __name__ == "__main__": + main() diff --git a/load_tests/orca.py b/load_tests/orca.py new file mode 100644 index 00000000..e445afd5 --- /dev/null +++ b/load_tests/orca.py @@ -0,0 +1,27 @@ +import json +import datasets +import tqdm + + +def main(): + dataset = datasets.load_dataset("Open-Orca/OpenOrca", split="train") + # Select only the first 2k conversations that start with a human. + max = min(2000, len(dataset)) + conversations = [] + for item in tqdm.tqdm(dataset, total=max): + conversation = { + "conversations": [ + {"from": "human", "value": item["question"]}, + ], + "id": item["id"], + } + conversations.append(conversation) + if len(conversations) >= max: + break + + with open("./small.json", "w") as f: + json.dump(conversations, f, indent=4) + + +if __name__ == "__main__": + main() diff --git a/router/src/lib.rs b/router/src/lib.rs index ea697c3a..00c493c4 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -923,7 +923,6 @@ impl ChatRequest { messages, seed, stop, - stream, tools, tool_choice, tool_prompt, @@ -1003,7 +1002,7 @@ impl ChatRequest { truncate: None, watermark: false, details: true, - decoder_input_details: !stream, + decoder_input_details: false, seed, top_n_tokens: top_logprobs, grammar, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 73225f10..98c6f419 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1560,9 +1560,6 @@ class FlashCausalLM(Model): batch_num_blocks = batch.num_blocks num_tokens = batch.to_pb().current_tokens - logger.info(f"BLOCKS {batch.num_blocks}") - free_memory = get_free_memory(self.device, MEMORY_FRACTION) - logger.info(f"Free memory {free_memory}") if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): torch.cuda.tunable.tuning_enable(False) _, _batch, _ = self.generate_token(batch)