Fix prefix caching for chat completion since we removed logprobs.

This commit is contained in:
Nicolas Patry 2024-12-02 07:51:00 +01:00
parent db1114955a
commit 1352f70847
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
6 changed files with 157 additions and 5 deletions

9
load_tests/Makefile Normal file
View File

@ -0,0 +1,9 @@
ShareGPT_V3_unfiltered_cleaned_split.json:
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
prepare_share: ShareGPT_V3_unfiltered_cleaned_split.json
python filter.py
prepare_orca:
python orca.py

94
load_tests/common.js Normal file
View File

@ -0,0 +1,94 @@
import { check } from 'k6';
import { scenario } from 'k6/execution';
import http from 'k6/http';
import { Trend, Counter } from 'k6/metrics';
const host = __ENV.HOST;
const model_id = __ENV.MODEL_ID;
const timePerToken = new Trend('time_per_token', true);
const tokens = new Counter('tokens');
const new_tokens = new Counter('new_tokens');
const input_tokens = new Counter('input_tokens');
const max_new_tokens = 50;
// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json"))
const shareGPT = JSON.parse(open("small.json"))
export function get_options() {
return {
thresholds: {
http_req_failed: ['rate==0'],
// time_per_token: [{
// threshold: `p(50)<${5 * reference_latency_ms}`,
// abortOnFail: true,
// delayAbortEval: '10s'
// }],
},
scenarios: {
// single_user: {
// executor: 'constant-arrival-rate',
// duration: '60s',
// preAllocatedVUs: 1,
// rate: 20,
// timeUnit: '1s',
// },
// load_test: {
// executor: 'constant-arrival-rate',
// duration: '60s',
// preAllocatedVUs: 100,
// rate: 1,
// timeUnit: '1s',
// },
// breakpoint: {
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
// preAllocatedVUs: 300,
// stages: [
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
// ],
// },
throughput: {
executor: 'shared-iterations',
vus: 100,
iterations: 200,
maxDuration: '40s',
},
},
};
}
function generate_payload(gpt, max_new_tokens) {
const input = gpt["conversations"][0]["value"];
return { "messages": [{ "role": "user", "content": input }], "temperature": 0, "model": `${model_id}`, "max_tokens": max_new_tokens }
}
export const options = get_options();
export default function run() {
const headers = { 'Content-Type': 'application/json' };
const query = shareGPT[scenario.iterationInTest % shareGPT.length];
const payload = JSON.stringify(generate_payload(query, max_new_tokens));
const res = http.post(`http://${host}/v1/chat/completions`, payload, {
headers,
});
if (res.status >= 400 && res.status < 500) {
return;
}
check(res, {
'Post status is 200': (res) => res.status === 200,
});
const duration = res.timings.duration;
if (res.status === 200) {
const body = res.json();
const completion_tokens = body.usage.completion_tokens;
const latency_ms_per_token = duration / completion_tokens;
timePerToken.add(latency_ms_per_token);
const prompt_tokens = body.usage.prompt_tokens;
input_tokens.add(prompt_tokens);
new_tokens.add(completion_tokens);
tokens.add(completion_tokens + prompt_tokens);
}
}

26
load_tests/filter.py Normal file
View File

@ -0,0 +1,26 @@
import json
def main():
with open("./ShareGPT_V3_unfiltered_cleaned_split.json", "r") as f:
data = json.load(f)
# Select only the first 2k conversations that start with a human.
max = 2000
conversations = []
for conversation in data:
conv = conversation.get("conversations")
if conv and conv[0]["from"] == "human":
# Trim the rest of the output
conversation["conversations"] = conversation["conversations"][:1]
conversations.append(conversation)
if len(conversation) >= max:
break
with open("./small.json", "w") as f:
data = json.dump(conversations, f, indent=4)
if __name__ == "__main__":
main()

27
load_tests/orca.py Normal file
View File

@ -0,0 +1,27 @@
import json
import datasets
import tqdm
def main():
dataset = datasets.load_dataset("Open-Orca/OpenOrca", split="train")
# Select only the first 2k conversations that start with a human.
max = min(2000, len(dataset))
conversations = []
for item in tqdm.tqdm(dataset, total=max):
conversation = {
"conversations": [
{"from": "human", "value": item["question"]},
],
"id": item["id"],
}
conversations.append(conversation)
if len(conversations) >= max:
break
with open("./small.json", "w") as f:
json.dump(conversations, f, indent=4)
if __name__ == "__main__":
main()

View File

@ -923,7 +923,6 @@ impl ChatRequest {
messages, messages,
seed, seed,
stop, stop,
stream,
tools, tools,
tool_choice, tool_choice,
tool_prompt, tool_prompt,
@ -1003,7 +1002,7 @@ impl ChatRequest {
truncate: None, truncate: None,
watermark: false, watermark: false,
details: true, details: true,
decoder_input_details: !stream, decoder_input_details: false,
seed, seed,
top_n_tokens: top_logprobs, top_n_tokens: top_logprobs,
grammar, grammar,

View File

@ -1560,9 +1560,6 @@ class FlashCausalLM(Model):
batch_num_blocks = batch.num_blocks batch_num_blocks = batch.num_blocks
num_tokens = batch.to_pb().current_tokens num_tokens = batch.to_pb().current_tokens
logger.info(f"BLOCKS {batch.num_blocks}")
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
logger.info(f"Free memory {free_memory}")
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.tuning_enable(False)
_, _batch, _ = self.generate_token(batch) _, _batch, _ = self.generate_token(batch)