Prefix test - Different kind of load test to trigger prefix test bugs. (#2490)

* Adding prefix test.

* [WIP] tmp dump of integration load tests.

* Remove other tensor creation.

* Fixed the radix tree.

Used a slice everywhere in radix.rs to keep the cheap Arc cloning
instead of recomputing the input_ids.

* Fix parsing

* Is it really flashinfer version ?

* Remove some comments.

* Revert the max prefix hit.

* Adding numpy to diff.

* Upgraded flashinfer.

* Upgrading some stuff.

* Are we done yet ?

* Minor fixup

* Remove 1 log and put back the other.

* Add comment for why slot 0 is OK.

* Mounting on the job.

* Get me a debug branch

* Debugging CIs is fun.

* Attempt #28

* wip

* Tmate.

* Praying.

* Updating VLM causal model with updated context.

* Important line got squashed.

* Tmate again.

* Fingers crossed.

* We want only 1 run of integration tests.....

---------

Co-authored-by: Guillaume LEGENDRE <glegendre01@gmail.com>
This commit is contained in:
Nicolas Patry 2024-09-11 18:10:40 +02:00 committed by GitHub
parent eabbbbda23
commit a4e3e8c608
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 4113 additions and 1077 deletions

View File

@ -376,10 +376,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// Send generation responses back to the infer task // Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the // If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true) // request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| { let stopped = send_responses(generation, entry).inspect_err(|_err| {
tracing::error!("Entry response channel error."); tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err
}).unwrap_or(true); }).unwrap_or(true);
if stopped { if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug."); entries.remove(&id).expect("ID not found in entries. This is a bug.");

View File

@ -123,8 +123,6 @@ impl Allocator for RadixAllocator {
prefill_tokens: prefill_tokens.clone(), prefill_tokens: prefill_tokens.clone(),
}; };
tracing::debug!("Blocks {blocks:?}");
self.allocation_id += 1; self.allocation_id += 1;
self.allocations.insert(self.allocation_id, allocation); self.allocations.insert(self.allocation_id, allocation);

View File

@ -492,6 +492,24 @@
"type": "github" "type": "github"
} }
}, },
"flake-utils_7": {
"inputs": {
"systems": "systems_7"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"gitignore": { "gitignore": {
"inputs": { "inputs": {
"nixpkgs": [ "nixpkgs": [
@ -700,16 +718,16 @@
}, },
"nixpkgs_6": { "nixpkgs_6": {
"locked": { "locked": {
"lastModified": 1723912943, "lastModified": 1724915739,
"narHash": "sha256-39F9GzyhxYcY3wTeKuEFWRJWcrGBosO4nf4xzMTWZX8=", "narHash": "sha256-7PgRge4mn5akFvhPwefuaLQGbF5BnmxlwZJEf7CgbrE=",
"owner": "danieldk", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "b82cdca86dbb30013b76c4b55d48806476820a5c", "rev": "85be051bb60943d3328d91aaf2598798f87e19af",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "danieldk", "owner": "nixos",
"ref": "cuda-12.4", "ref": "nixos-unstable-small",
"repo": "nixpkgs", "repo": "nixpkgs",
"type": "github" "type": "github"
} }
@ -835,11 +853,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1724638882, "lastModified": 1725848835,
"narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=", "narHash": "sha256-u4lCr+tOEWhsFiww5G04U5jUNzaQJi0/ZMIDGiLeT14=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "19b70f147b9c67a759e35824b241f1ed92e46694", "rev": "2ef910a6276a2f34513d18f2f826a8dea72c3b3f",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -938,17 +956,33 @@
"type": "github" "type": "github"
} }
}, },
"systems_7": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"tgi-nix": { "tgi-nix": {
"inputs": { "inputs": {
"flake-compat": "flake-compat_4", "flake-compat": "flake-compat_4",
"flake-utils": "flake-utils_7",
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1725011596, "lastModified": 1725868835,
"narHash": "sha256-zfq8lOXFgJnKxxsqSelHuKUvhxgH3cEmLoAgsOO62Cg=", "narHash": "sha256-6OFEaFFRCG/JKkU6kHV08EPEGM1MCuKZ70NlGJcL/JY=",
"owner": "danieldk", "owner": "danieldk",
"repo": "tgi-nix", "repo": "tgi-nix",
"rev": "717c2b07e38538abf05237cca65b2d1363c2c9af", "rev": "87afbe21e2d2cc17e177c9965a64ba68ad7c22da",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -19,6 +19,7 @@ from syrupy.extensions.json import JSONSnapshotExtension
from text_generation import AsyncClient from text_generation import AsyncClient
from text_generation.types import ( from text_generation.types import (
BestOfSequence, BestOfSequence,
Message,
ChatComplete, ChatComplete,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionComplete, ChatCompletionComplete,
@ -97,25 +98,25 @@ class ResponseComparator(JSONSnapshotExtension):
) -> bool: ) -> bool:
def convert_data(data): def convert_data(data):
data = json.loads(data) data = json.loads(data)
if isinstance(data, Dict) and "choices" in data: return _convert_data(data)
choices = data["choices"]
if isinstance(choices, List) and len(choices) >= 1:
if "delta" in choices[0]:
return ChatCompletionChunk(**data)
if "text" in choices[0]:
return Completion(**data)
return ChatComplete(**data)
def _convert_data(data):
if isinstance(data, Dict): if isinstance(data, Dict):
return Response(**data) if "choices" in data:
data["choices"] = list(
sorted(data["choices"], key=lambda x: x["index"])
)
choices = data["choices"]
if isinstance(choices, List) and len(choices) >= 1:
if "delta" in choices[0]:
return ChatCompletionChunk(**data)
if "text" in choices[0]:
return Completion(**data)
return ChatComplete(**data)
else:
return Response(**data)
if isinstance(data, List): if isinstance(data, List):
if ( return [_convert_data(d) for d in data]
len(data) > 0
and "object" in data[0]
and data[0]["object"] == "text_completion"
):
return [Completion(**d) for d in data]
return [Response(**d) for d in data]
raise NotImplementedError raise NotImplementedError
def eq_token(token: Token, other: Token) -> bool: def eq_token(token: Token, other: Token) -> bool:
@ -571,3 +572,38 @@ def generate_load():
return await asyncio.gather(*futures) return await asyncio.gather(*futures)
return generate_load_inner return generate_load_inner
@pytest.fixture(scope="module")
def generate_multi():
async def generate_load_inner(
client: AsyncClient,
prompts: List[str],
max_new_tokens: int,
seed: Optional[int] = None,
) -> List[Response]:
import numpy as np
arange = np.arange(len(prompts))
perm = np.random.permutation(arange)
rperm = [-1] * len(perm)
for i, p in enumerate(perm):
rperm[p] = i
shuffled_prompts = [prompts[p] for p in perm]
futures = [
client.chat(
messages=[Message(role="user", content=prompt)],
max_tokens=max_new_tokens,
temperature=0,
seed=seed,
)
for prompt in shuffled_prompts
]
shuffled_responses = await asyncio.gather(*futures)
responses = [shuffled_responses[p] for p in rperm]
return responses
return generate_load_inner

View File

@ -1,38 +1,38 @@
{ {
"choices": [ "choices": [
{ {
"finish_reason": "stop", "finish_reason": "length",
"index": 0,
"logprobs": null,
"text": " A Beginners Guide\nDeep learning is a subset"
},
{
"finish_reason": "length",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": " PR for more information?" "text": " This is a question that has puzzled many people for"
}, },
{ {
"finish_reason": "length", "finish_reason": "length",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": "hd20220811-" "text": "usculas_minusculas(s):\n \"\"\"\n"
},
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "le Business Incubator is providing a workspace"
}, },
{ {
"finish_reason": "length", "finish_reason": "length",
"index": 2, "index": 2,
"logprobs": null, "logprobs": null,
"text": " severely flawed and often has a substandard" "text": " Paris\nWhat is the capital of France?\nThe"
} }
], ],
"created": 1722014725, "created": 1725877154,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "2.2.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 36, "completion_tokens": 40,
"prompt_tokens": 8, "prompt_tokens": 22,
"total_tokens": 44 "total_tokens": 62
} }
} }

View File

@ -5,12 +5,12 @@
"finish_reason": "", "finish_reason": "",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": "\n" "text": " A"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -20,12 +20,72 @@
"finish_reason": "", "finish_reason": "",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": "\n" "text": " This"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " Paris"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "us"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Beginner"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " is"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -38,9 +98,9 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -50,12 +110,12 @@
"finish_reason": "", "finish_reason": "",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": "hd" "text": "cul"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -65,12 +125,12 @@
"finish_reason": "", "finish_reason": "",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": "\n" "text": "s"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -80,12 +140,12 @@
"finish_reason": "", "finish_reason": "",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": "\n" "text": " a"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -95,12 +155,12 @@
"finish_reason": "", "finish_reason": "",
"index": 2, "index": 2,
"logprobs": null, "logprobs": null,
"text": "\n" "text": "What"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -110,12 +170,12 @@
"finish_reason": "", "finish_reason": "",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": "aho" "text": "as"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -125,12 +185,12 @@
"finish_reason": "", "finish_reason": "",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": "2" "text": " Guide"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -140,252 +200,12 @@
"finish_reason": "", "finish_reason": "",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": "2" "text": " question"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "2"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "ima"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "."
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "."
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "."
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "\n"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Sarah"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Yes"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " And"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "i"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "'"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": ","
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " what"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "'"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "s"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Moh"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -398,9 +218,9 @@
"text": " is" "text": " is"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -410,12 +230,12 @@
"finish_reason": "", "finish_reason": "",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": "m" "text": "_minus"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -425,12 +245,12 @@
"finish_reason": "", "finish_reason": "",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": " Room" "text": "\n"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -440,12 +260,12 @@
"finish_reason": "", "finish_reason": "",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": "s" "text": " that"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -458,9 +278,9 @@
"text": " the" "text": " the"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -470,12 +290,12 @@
"finish_reason": "", "finish_reason": "",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": " tired" "text": "cul"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -485,12 +305,12 @@
"finish_reason": "", "finish_reason": "",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": ":" "text": "Deep"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -500,12 +320,12 @@
"finish_reason": "", "finish_reason": "",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": "'" "text": " has"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -518,9 +338,9 @@
"text": " capital" "text": " capital"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -530,12 +350,192 @@
"finish_reason": "", "finish_reason": "",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": "," "text": "as"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " learning"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " puzzled"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " of"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "(s"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " is"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " many"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " France"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "):\n"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " a"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " people"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "?\n"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " "
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -545,12 +545,12 @@
"finish_reason": "length", "finish_reason": "length",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": " She" "text": " subset"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -560,12 +560,12 @@
"finish_reason": "length", "finish_reason": "length",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": " scale" "text": " for"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -575,12 +575,12 @@
"finish_reason": "length", "finish_reason": "length",
"index": 2, "index": 2,
"logprobs": null, "logprobs": null,
"text": " of" "text": "The"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -590,12 +590,12 @@
"finish_reason": "length", "finish_reason": "length",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": " its" "text": " \"\"\"\n"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
} }

View File

@ -4,17 +4,17 @@
"finish_reason": "length", "finish_reason": "length",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": " PR for flake8" "text": " A Beginners Guide\nDeep learning is a subset"
} }
], ],
"created": 1713284454, "created": 1725876621,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native", "system_fingerprint": "2.2.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 5, "completion_tokens": 10,
"prompt_tokens": 6, "prompt_tokens": 6,
"total_tokens": 11 "total_tokens": 16
} }
} }

View File

@ -11,7 +11,7 @@ from text_generation.types import (
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_completion_handle(launcher): def flash_llama_completion_handle(launcher):
with launcher( with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "meta-llama/Meta-Llama-3.1-8B-Instruct",
) as handle: ) as handle:
yield handle yield handle
@ -34,16 +34,19 @@ def test_flash_llama_completion_single_prompt(
f"{flash_llama_completion.base_url}/v1/completions", f"{flash_llama_completion.base_url}/v1/completions",
json={ json={
"model": "tgi", "model": "tgi",
"prompt": "Say this is a test", "prompt": "What is Deep Learning?",
"max_tokens": 5, "max_tokens": 10,
"seed": 0, "temperature": 0.0,
}, },
headers=flash_llama_completion.headers, headers=flash_llama_completion.headers,
stream=False, stream=False,
) )
response = response.json() response = response.json()
assert len(response["choices"]) == 1 assert len(response["choices"]) == 1
assert (
response["choices"][0]["text"]
== " A Beginners Guide\nDeep learning is a subset"
)
assert response == response_snapshot assert response == response_snapshot
@ -53,9 +56,15 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
f"{flash_llama_completion.base_url}/v1/completions", f"{flash_llama_completion.base_url}/v1/completions",
json={ json={
"model": "tgi", "model": "tgi",
"prompt": ["Say", "this", "is", "a"], "prompt": [
"What is Deep Learning?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10, "max_tokens": 10,
"seed": 0, "seed": 0,
"temperature": 0.0,
}, },
headers=flash_llama_completion.headers, headers=flash_llama_completion.headers,
stream=False, stream=False,
@ -63,9 +72,16 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
response = response.json() response = response.json()
assert len(response["choices"]) == 4 assert len(response["choices"]) == 4
all_indexes = [choice["index"] for choice in response["choices"]] all_indexes = [(choice["index"], choice["text"]) for choice in response["choices"]]
all_indexes.sort() all_indexes.sort()
assert all_indexes == [0, 1, 2, 3] all_indices, all_strings = zip(*all_indexes)
assert list(all_indices) == [0, 1, 2, 3]
assert list(all_strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert response == response_snapshot assert response == response_snapshot
@ -77,19 +93,21 @@ async def test_flash_llama_completion_many_prompts_stream(
request = { request = {
"model": "tgi", "model": "tgi",
"prompt": [ "prompt": [
"What color is the sky?", "What is Deep Learning?",
"Is water wet?", "Is water wet?",
"What is the capital of France?", "What is the capital of France?",
"def mai", "def mai",
], ],
"max_tokens": 10, "max_tokens": 10,
"seed": 0, "seed": 0,
"temperature": 0.0,
"stream": True, "stream": True,
} }
url = f"{flash_llama_completion.base_url}/v1/completions" url = f"{flash_llama_completion.base_url}/v1/completions"
chunks = [] chunks = []
strings = [""] * 4
async with ClientSession(headers=flash_llama_completion.headers) as session: async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response: async with session.post(url, json=request) as response:
# iterate over the stream # iterate over the stream
@ -108,7 +126,15 @@ async def test_flash_llama_completion_many_prompts_stream(
for c in chunk: for c in chunk:
chunks.append(Completion(**c)) chunks.append(Completion(**c))
assert "choices" in c assert "choices" in c
assert 0 <= c["choices"][0]["index"] <= 4 index = c["choices"][0]["index"]
assert 0 <= index <= 4
strings[index] += c["choices"][0]["text"]
assert response.status == 200 assert response.status == 200
assert list(strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert chunks == response_snapshot assert chunks == response_snapshot

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@ -6,9 +6,10 @@ authors = ["Nicolas Patry <nicolas@huggingface.co>"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
pydantic = "> 2, < 3" pydantic = "> 2, < 3"
python = ">=3.9,<3.13" python = ">=3.10,<3.13"
syrupy = "^4.7.1" syrupy = "^4.7.1"
text-generation = "^0.6.0" text-generation = "^0.6.0"
pytest = "^7.4.0" pytest = "^7.4.0"
pytest-asyncio = "^0.21.1" pytest-asyncio = "^0.21.1"
docker = "^6.1.3" docker = "^7"
numpy = "^1.20"

View File

@ -1,34 +1,35 @@
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "3.13" aiohappyeyeballs==2.4.0 ; python_version >= "3.10" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13" aiohttp==3.10.5 ; python_version >= "3.10" and python_version < "3.13"
annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "3.13" aiosignal==1.3.1 ; python_version >= "3.10" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13" annotated-types==0.7.0 ; python_version >= "3.10" and python_version < "3.13"
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13" async-timeout==4.0.3 ; python_version >= "3.10" and python_version < "3.11"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" attrs==24.2.0 ; python_version >= "3.10" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.8.30 ; python_version >= "3.10" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "3.13"
docker==6.1.3 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11" docker==7.1.0 ; python_version >= "3.10" and python_version < "3.13"
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" exceptiongroup==1.2.2 ; python_version >= "3.10" and python_version < "3.11"
frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "3.13" filelock==3.16.0 ; python_version >= "3.10" and python_version < "3.13"
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13" frozenlist==1.4.1 ; python_version >= "3.10" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.9.0 ; python_version >= "3.10" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.24.6 ; python_version >= "3.10" and python_version < "3.13"
iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13" idna==3.8 ; python_version >= "3.10" and python_version < "3.13"
multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13" iniconfig==2.0.0 ; python_version >= "3.10" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" multidict==6.1.0 ; python_version >= "3.10" and python_version < "3.13"
pluggy==1.3.0 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.10" and python_version < "3.13"
pydantic-core==2.16.3 ; python_version >= "3.9" and python_version < "3.13" packaging==24.1 ; python_version >= "3.10" and python_version < "3.13"
pydantic==2.6.4 ; python_version >= "3.9" and python_version < "3.13" pluggy==1.5.0 ; python_version >= "3.10" and python_version < "3.13"
pytest-asyncio==0.21.1 ; python_version >= "3.9" and python_version < "3.13" pydantic-core==2.23.3 ; python_version >= "3.10" and python_version < "3.13"
pytest==7.4.0 ; python_version >= "3.9" and python_version < "3.13" pydantic==2.9.1 ; python_version >= "3.10" and python_version < "3.13"
pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" pytest-asyncio==0.21.2 ; python_version >= "3.10" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pytest==7.4.4 ; python_version >= "3.10" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" pywin32==306 ; python_version >= "3.10" and python_version < "3.13" and sys_platform == "win32"
syrupy==4.7.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.10" and python_version < "3.13"
text-generation==0.6.1 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.10" and python_version < "3.13"
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" syrupy==4.7.1 ; python_version >= "3.10" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" text-generation==0.6.1 ; python_version >= "3.10" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" tomli==2.0.1 ; python_version >= "3.10" and python_version < "3.11"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.5 ; python_version >= "3.10" and python_version < "3.13"
websocket-client==1.6.2 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.10" and python_version < "3.13"
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.2 ; python_version >= "3.10" and python_version < "3.13"
yarl==1.11.1 ; python_version >= "3.10" and python_version < "3.13"

View File

@ -1843,9 +1843,8 @@ fn main() -> Result<(), LauncherError> {
shutdown.clone(), shutdown.clone(),
&shutdown_receiver, &shutdown_receiver,
) )
.map_err(|err| { .inspect_err(|_| {
shutdown_shards(shutdown.clone(), &shutdown_receiver); shutdown_shards(shutdown.clone(), &shutdown_receiver);
err
})?; })?;
// Default exit code // Default exit code

View File

@ -336,6 +336,8 @@ pub enum InferError {
ValidationError(#[from] ValidationError), ValidationError(#[from] ValidationError),
#[error("Incomplete generation")] #[error("Incomplete generation")]
IncompleteGeneration, IncompleteGeneration,
#[error("Incomplete generation stream")]
IncompleteGenerationStream,
#[error("Template error: {0}")] #[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error), TemplateError(#[from] minijinja::Error),
#[error("Missing template vatiable: {0}")] #[error("Missing template vatiable: {0}")]
@ -351,6 +353,7 @@ impl InferError {
InferError::Overloaded(_) => "overloaded", InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation", InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation", InferError::IncompleteGeneration => "incomplete_generation",
InferError::IncompleteGenerationStream => "incomplete_generation_stream",
InferError::TemplateError(_) => "template_error", InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable", InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error", InferError::ToolError(_) => "tool_error",

View File

@ -318,7 +318,10 @@ pub(crate) async fn generate_internal(
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
// Do not long ultra long inputs, like image payloads. // Do not long ultra long inputs, like image payloads.
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]); tracing::debug!(
"Input: {}",
&req.inputs.chars().take(1000).collect::<String>()
);
let compute_characters = req.inputs.chars().count(); let compute_characters = req.inputs.chars().count();
let mut add_prompt = None; let mut add_prompt = None;
@ -674,7 +677,7 @@ async fn generate_stream_internal(
// Check if generation reached the end // Check if generation reached the end
// Skip if we already sent an error // Skip if we already sent an error
if !end_reached && !error { if !end_reached && !error {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGenerationStream;
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
@ -2555,6 +2558,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::IncompleteGenerationStream => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,

View File

@ -1,2 +1,2 @@
install-flashinfer: install-flashinfer:
pip install flashinfer==0.1.5 -i https://flashinfer.ai/whl/cu124/torch2.4 pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4

View File

@ -515,6 +515,7 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
assert len(pb.requests) > 0
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@ -640,6 +641,7 @@ class FlashCausalLMBatch(Batch):
adapter_segments = torch.tensor( adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device adapter_segments, dtype=torch.int32, device=device
) )
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
return type(self)( return type(self)(
batch_id=self.batch_id, batch_id=self.batch_id,
@ -834,6 +836,8 @@ class FlashCausalLMBatch(Batch):
start_slots = torch.concat(start_slots) start_slots = torch.concat(start_slots)
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype, dtype=batches[0].next_token_chooser.dtype,
@ -1150,27 +1154,6 @@ class FlashCausalLM(Model):
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lens=prefix_lengths, prefix_lens=prefix_lengths,
) )
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": self.kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
}
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import ( from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs, create_decode_state_cuda_graphs,
) )
@ -1187,21 +1170,38 @@ class FlashCausalLM(Model):
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
) )
self.cuda_graphs[bs]["state"] = state
else: else:
state = None state = None
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": self.kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
"state": state,
"graph": graph,
}
torch.cuda.synchronize() torch.cuda.synchronize()
# Run once outside to warmup # Run once outside to warmup
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor, input_lengths_tensor=input_lengths_tensor,
state=state, state=state,
prefix_lens=prefix_lengths,
prefix_lens_tensor=prefix_lengths_tensor, prefix_lens_tensor=prefix_lengths_tensor,
): ):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1214,6 +1214,7 @@ class FlashCausalLM(Model):
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
) )
del seqlen
torch.cuda.synchronize() torch.cuda.synchronize()
@ -1479,9 +1480,7 @@ class FlashCausalLM(Model):
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths, input_lengths_tensor=input_lengths,
input_lengths_tensor=input_lengths + prefix_lens_tensor,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor, prefix_lens_tensor=prefix_lens_tensor,
): ):
max_k = (input_lengths + prefix_lens_tensor).max().item() max_k = (input_lengths + prefix_lens_tensor).max().item()
@ -1519,26 +1518,28 @@ class FlashCausalLM(Model):
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
) )
# assert block_tables.shape[0] >= slots.shape[0]
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else: else:
cuda_graph["block_tables"][ cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1] : block_tables.shape[0], : block_tables.shape[1]
] = block_tables ] = block_tables
cuda_graph["slots"].fill_(-1)
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
input_lengths + prefix_lens_tensor cuda_graph["prefix_lengths"].zero_()
) cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor
with self._forward_context( with self._forward_context(
block_tables=cuda_graph["block_tables"], block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=batch.input_lengths,
input_lengths_tensor=cuda_graph["input_lengths"], input_lengths_tensor=cuda_graph["input_lengths"],
prefix_lens=batch.prefix_lens, prefix_lens_tensor=cuda_graph["prefix_lengths"],
prefix_lens_tensor=prefix_lens_tensor, state=cuda_graph["state"],
state=cuda_graph.get("state"),
): ):
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
@ -1767,7 +1768,7 @@ class FlashCausalLM(Model):
left = 0 left = 0
if n_accepted_ids > 1: if n_accepted_ids > 1:
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}") log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
current_stopped = False current_stopped = False
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):
@ -1922,9 +1923,7 @@ class FlashCausalLM(Model):
*, *,
block_tables: torch.Tensor, block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
input_lengths: List[int],
input_lengths_tensor: torch.Tensor, input_lengths_tensor: torch.Tensor,
prefix_lens: List[int],
prefix_lens_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> ContextManager: ) -> ContextManager:
@ -1950,7 +1949,7 @@ class FlashCausalLM(Model):
# ), # ),
block_tables=block_tables, block_tables=block_tables,
cu_seqlens=cu_seqlen_prefill, cu_seqlens=cu_seqlen_prefill,
input_lengths=input_lengths_tensor, input_lengths=input_lengths_tensor + prefix_lens_tensor,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
@ -1960,7 +1959,7 @@ class FlashCausalLM(Model):
assert input_lengths_tensor is not None assert input_lengths_tensor is not None
return use_decode_state( return use_decode_state(
state=state if state is not None else self.decode_state, state=state if state is not None else self.decode_state,
input_lengths=input_lengths_tensor, input_lengths=input_lengths_tensor + prefix_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,

View File

@ -367,9 +367,7 @@ class VlmCausalLM(FlashCausalLM):
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths, input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor, prefix_lens_tensor=prefix_lens_tensor,
): ):
max_k = (input_lengths + prefix_lens_tensor).max().item() max_k = (input_lengths + prefix_lens_tensor).max().item()