Commit Graph

345 Commits

Author SHA1 Message Date
Daniël de Kok 52e48739a5
Remove vLLM dependency for CUDA (#2751)
* Remove vLLM dependency for CUDA

This change adds `attention-kernels` as a dependency for paged
attention and cache reshaping. With that, we don't use vLLM
anywhere for CUDA.

Tested run (since we don't have paged attention in CI):

```
❯ ATTENTION=paged python -m pytest integration-tests -k "llama and awq" --release
[...]
5 snapshots passed.
```

* Fix clippy warning
2024-11-17 17:34:50 +01:00
Billel Mokeddem 4f4857a4ac
Fix: Change embeddings to embedding (#2738)
fix: change embeddings to embedding

Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-135.us-west-2.compute.internal>
2024-11-15 13:16:15 +01:00
Billel Mokeddem f9ee46f740
Fix: Change model_type from ssm to mamba (#2740)
Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-135.us-west-2.compute.internal>
2024-11-15 13:15:36 +01:00
Daniël de Kok a785000842
Add initial support for compressed-tensors checkpoints (#2732)
compressed-tensors is a safetensors extension for sparse, quantized
tensors. The format is more powerful than earlier AWQ/GPTQ/FP8
quantization, because

- Different quantizer configurations can be used for different targets.
- The format can specify input/output quantizers in addition to weight
  quantizers.
- Configurable exclusions for quantization.

This change adds a dependency on the `compressed-tensors` package for
its configuration parsing and layer matching functionality.

The following types of quantization are supported in this PR:

- W8A16 and W4A16 INT using GPTQ-Marlin kernels.
- W8A8 and W8A16 FP using FP8-Marlin and cutlass kernels.

Support for other quantization types will be added in subsequent PRs.
2024-11-10 13:54:07 +01:00
Nicolas Patry 9fde566602
Fixing linting on main. (#2719) 2024-11-04 15:21:41 +01:00
Travis Addair aadc9cb485
Fix prefix caching + speculative decoding (#2711) 2024-11-04 15:08:43 +01:00
Nicolas Patry a5593ba83e
Hotfixing auto length (warmup max_s was wrong). (#2716) 2024-11-04 09:55:54 +01:00
drbh 6e3220529d
fix: create position ids for text only input (#2714)
* fix: create position ids for text only input

* fix: prefer repeat over expand to avoid clone
2024-11-02 08:40:05 +08:00
drbh 01dacf8e8f
fix cuda graphs for qwen2-vl (#2708)
* feat: support multidimensional position ids on batch to enable cuda graphs on qwen2-vl

* fix: only check model type if config exists

* fix: adjust sharding and lm head logic

* fix qwen2 failure in intel cpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix: return correct shape logits and add streaming test

* fix: remove unused import and refactor test

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-11-01 03:05:34 +01:00
drbh befd9f6735
Support qwen2 vl (#2689)
* feat: add support for qwen2 vl model

* feat: fix token padding, enable warmup and process basic request

* fix: improve get_position_ids, add lift embed_tokens

* fix: remove get_cos_sin_hack dev function

* feat: add simple test chat with meesage and text

* fix: lint test

* fix: adjust positional embeddings for multi dimensional position ids

* fix: update docs and lint unused vars

* fix: include linted file

* fix: add norm after text output

* fix: format model file

* fix: adjust for ruff lints

* fix: remove unused rotate_half

* feat: refactors and calc num features

* fix: prefer position_ids passed from vlm causal lm and reset ids on batch

* fix: adjust get_position_ids if not available and add required args to signatures

* fix: adjust resize case for qwen2_vl warmup

* fix: avoid qwen2 vl specific paths with qwen2
2024-10-30 12:40:51 -04:00
Nicolas Patry 3a9cdc3241
Fixing auto bloom test. (#2699) 2024-10-28 06:14:11 +01:00
Nicolas Patry 90b226db29
We can have a tokenizer anywhere. (#2527)
* We can have a tokenizer anywhere.

* Handling potential lack of offsets (python tokenizer)

* Remove redundancy.

* Fixing the tests.

* Flake.lock update ?

* Fixing the  GIL locking.

* Fixing mamba by using the transformers version.

* Adding the legacy handle.

* Ellide lifetime.

* Lint.

* Deprecation message.

* Fixing bad rebase.
2024-10-28 05:00:24 +01:00
Nicolas Patry 0c9b6cdd76
Choosing input/total tokens automatically based on available VRAM? (#2673)
* Choosing input/total tokens automatically based on available VRAM?

* Update doc.

* Remove generated files.

* Trying to fix non chunking targets.

* Attempt #2

* fix.

* QuantLinear is rocm compatible.

* Much simpler logic after the overhead.

* Updating logic + non flash.

* Revert doc text.

* Simple updates.

* Fix integration mt0 (transformers update).
2024-10-28 04:59:49 +01:00
OlivierDehaene 6f88bd9390
feat: add triton kernels to decrease latency of large batches (#2687)
* feat: add triton kernels to decrease latency of large batches

* cast to int32

* fix kernel

* fix kernel

* disable triton on rocm

* fix speculation

* add slots filtering kernel
2024-10-25 21:10:00 +00:00
Daniël de Kok 0f346a3296
Switch from fbgemm-gpu w8a8 scaled matmul to vLLM/marlin-kernels (#2688)
* Switch from fbgemm-gpu w8a8 scaled matmul to vLLM/marlin-kernels

Performance and accuracy of these kernels are on par (tested with Llama
70B and 405B). Removes a dependency and resolves some stability issues
we have been seeing.

* Update test snapshots
2024-10-25 16:40:47 +02:00
Daniël de Kok eab07f746c
Add support for FP8 KV cache scales (#2628)
* Add support for FP8 KV cache scales

Since FP8 only has limited dynamic range, we can scale keys/values
before storing them into the cache (and unscale them in attention). To
avoid rescaling the cache as the absmax values change, good scales are
usually determined per layer using calibration calibration data and stored
in the checkpoint.

This change adds support for for using key-value scales and loading them
from checkpoints in the two most common formats:

- Separate per-layer `k_scale` and `v_scale` scalars.
- Per-layer `kv_scale` scalar (older format).

Currently, scales are only used with an `float8_e4m3fn` cache.

Besides adding support for key/value scales, the `fp8_quantize` function
is also extended to support quantization with a kernel vendored from
vLLM. This is slightly faster than the PyTorch implementation, but also
scales in FP32, potentially improving accuracy.

* Update FP8 KV cache test to use checkpoint with scales

* `can_scale`: check that the attention is flashinfer
2024-10-24 16:36:18 +02:00
OlivierDehaene 27ff1871b5
hotfix: fix flashllama 2024-10-23 13:22:31 +02:00
OlivierDehaene 03c9388bf7
feat: natively support Granite models (#2682)
* feat: natively support Granite models

* Update doc
2024-10-23 10:04:05 +00:00
Nicolas Patry 153ff3740b
CI job. Gpt awq 4 (#2665)
* add gptq and awq int4 support in intel platform

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix ci failure

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* set kv cache dtype

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* refine the code according to the review command

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Simplifying conditionals + reverting integration tests values.

* Unused import

* Fix redundant import.

* Revert change after rebase.

* Upgrading the tests (TP>1 fix changes to use different kernels.)

* Update server/text_generation_server/layers/gptq/__init__.py

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
2024-10-18 17:55:53 +02:00
drbh 5f32dea1e2
fix: prefer inplace softmax to avoid copy (#2661)
* fix: prefer inplace softmax to avoid copy

* Update server/text_generation_server/models/flash_causal_lm.py

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-10-17 08:49:02 -04:00
Daniël de Kok 59ea38cbca
Simplify the `attention` function (#2609)
* Simplify the `attention` function

- Use one definition rather than multiple.
- Add `key`/`value` arguments, so that we don't need the
  `PREFILL_IN_KVCACHE` constant.
- Make it kwargs-only (to avoid mixing up the various `Tensor` args).

* Fixup flashinfer support
2024-10-17 10:42:52 +02:00
Daniël de Kok 5bbe1ce028
Support `e4m3fn` KV cache (#2655)
* Support `e4m3fn` KV cache

* Make check more obvious
2024-10-17 10:42:16 +02:00
OlivierDehaene a6a0c97ed9
feat: prefill chunking (#2600)
* wip

* rollback

* refactor to use prefix/postfix namming + fix all_input_ids_tensor

* maybe patching vlms?

* fix filter and concat

* wip, no filter, no concat

* current

* add prepare_for_prefill

* working

* load tested

* re-create slots

* re-create slots

* fix slot_filtering_indices

* feedback loop

* remove log

* fix benchmarker

* fix vlm and seq2seq

* rename to cache and input lengths

* fix prefill logprobs

* fix launcher

* fix logprobs?

* idk at this point

* max input length

* omfg

* remove debugging lines

* fix tests

* fix mllama

* fix cargo tests

* remove support chunking for paged

* Fixing non blocked attentions

* Fixing dtype + AMD, Ipex targets.

* lint fix.

* rename

* Fix prefix_caching variable, remove defaults in server (confusing a lot
of the times).

* Add simple resolution when user specifies ATTENTION=paged.

* Put back non default simple tests.

* Fix env name

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-10-16 12:49:33 +02:00
Mohit Sharma 704a58c807
Fp8 e4m3_fnuz support for rocm (#2588)
* (feat) fp8 fnuz support for rocm

* (review comments) Fix compression_config load, type hints

* (bug) update all has_tensor

* (review_comments) fix typo and added comments

* (nit) improved comment
2024-10-16 09:54:50 +02:00
Nicolas Patry cf04a43fb1
Fixing linters. (#2650) 2024-10-15 12:43:49 +02:00
Dmitry Rogozhkin 58848cb471
feat: enable pytorch xpu support for non-attention models (#2561)
XPU backend is available natively (without IPEX) in pytorch starting
from pytorch 2.4. This commit extends TGI to cover the case when user
has XPU support thru pytorch 2.4, but does not have IPEX installed.
Models which don't require attention can work. For attention required
models more work is needed to provide attention implementation.

Tested with the following models:
* teknium/OpenHermes-2.5-Mistral-7B
* bigscience/bloom-560m
* google/gemma-7b
* google/flan-t5-xxl

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
2024-10-14 18:28:49 +02:00
Wang, Yi 57f9685dc3
enable mllama in intel platform (#2610)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-10-07 21:15:09 +02:00
Daniël de Kok 2358c2bb54
Add basic FP8 KV cache support (#2603)
* Add basic FP8 KV cache support

This change adds rudimentary FP8 KV cache support. The support is
enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so
uses this type for the KV cache. However support is still limited:

* Only the `fp8_e5m2` type is supported.
* The KV cache layout is the same as `float16`/`bfloat16` (HND).
* The FP8 KV cache is only supported for FlashInfer.
* Loading of scales is not yet supported.

* Fix Cargo.toml
2024-10-04 17:51:48 +02:00
Nicolas Patry d18ed5cfc5
Mllama flash version (#2585)
* Working loading state.

* Preprocessing.

* Working state ? (Broke idefics1 temporarily).

* Cleaner condition.

* Fix idefics.

* Updating config, removing TODO

* Mllama

* Ugrade transformers 4.45

* Flashing mllama.

* Starting to get there.

* Working state.

* Integrations tests for mllama (cutting to 10 tokens because there seems'
to be instability after (meaning size of the batch matters.

* Updating model link.

* Earlier assert.

* Fix vlm ?

* remove log.

* Force ignore all images but last.

* Default dtype bfloat16.

* Update integration test after switch to bf16.

* Remove dead code.

* Removed dead code.

* Upgrade the flake to latest transformers/tokenizers

* Move to hf tgi-nix

* Upgrade to 0.5.0
2024-10-02 11:22:13 +02:00
drbh 93a7042d7e
feat: support phi3.5 moe (#2479)
* feat: support phi3.5 moe model loading

* fix: prefer llama base model and improve rotary logic

* feat: return reasonable generation and add integration test

* fix: run lint and update docs

* fix: rerun lint for openapi docs

* fix: prefer do_sample false unless temp is set by user, and update chat tests

* fix: small typo adjustments

* fix: consolidate long rope paths

* fix: revert greedy by default and test changes

* Vendor configuration so that we don't have to `trust_remote_code`

* Use SparseMoELayer

* Add support for dense MoE

* Some type annotations

* Add the usual model tests

* Ruff.

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-09-30 11:15:09 +02:00
Mohit Sharma f9e561eced
Update ROCM libs and improvements (#2579)
* style

* update torch

* ix issues

* fix clone

* revert mkl

* added custom PA

* style

* fix style

* style

* hide env vart

* fix mixtral model

* add skinny kernel and merge fixes

* fixed style

* fix issue for sliding window models

* addressed review comments

* fix import

* improved error messag

* updated default value

* remove import

* fix imports after rebase

* float16 dep

* improve dockerfile

* cleaned dockerfile
2024-09-30 10:54:32 +02:00
Daniël de Kok 1028996fb3
flashinfer: pass window size and dtype (#2574) 2024-09-28 18:41:41 +02:00
Daniël de Kok 5b6b74e21d
Improve support for GPUs with capability < 8 (#2575)
* Improve support for GPUs with capability < 8

- For models that cannot use flashinfer, use flash-attn v1 + paged
  attention for models with a compute capability older than 8.
- Disable prefix caching when using paged attention.
- When using flash-attn v1, pass the key/value, rather than the
  cache, since v1 cannot use block tables.

* nix: add flash-attn-v1 to the server environment

* Move disabling prefix caching into the block of exceptions

* Capability as `usize`s
2024-09-27 16:19:42 +02:00
Alvaro Bartolome 0b7df77178
Add LoRA adapters support for Gemma2 (#2567)
* Add LoRA adapters support for Gemma2

* Make `black` formatting happy
2024-09-26 10:54:08 +02:00
Daniël de Kok 3f14cd1420
Add `DenseMoELayer` and wire it up in Mixtral/Deepseek V2 (#2537)
This replaces the custom layers in both models.
2024-09-24 14:27:06 +02:00
Daniël de Kok c29dc89c18
Add support for scalar FP8 weight scales (#2550)
* Add support for scalar FP8 weight scales

* Support LLM compressor FP8 checkpoints on H100

On H100, we use fbgemm-gpu, which requires bfloat16 as the input dtype.
However, we wouldn't pick up fp8 quantization for models quantized with
LLM compressor. This change adds enough parsing to detect if models have
FP8-quantized weights.

* Remove stray debug print
2024-09-24 13:57:40 +02:00
Wang, Yi f478aa77ad
hotfix: ipex fails since cuda moe kernel is not supported (#2532)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-09-20 10:02:55 +02:00
Daniël de Kok ce85efa968
Move to moe-kernels package and switch to common MoE layer (#2511)
* Move to moe-kernels package and switch to common MoE layer

This change introduces the new `moe-kernels` package:

- Add `moe-kernels` as a dependency.
- Introduce a `SparseMoELayer` module that can be used by MoE
  models.
- Port over Mixtral and Deepseek.

* Make `cargo check` pass

* Update runner
2024-09-17 18:08:58 +02:00
Wang, Yi 3ac7df2b6d
hotfix : enable intel ipex cpu and xpu in python3.11 (#2517)
enable intel ipex cpu and xpu in python3.11

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-09-12 17:23:49 +02:00
Nicolas Patry dae3bf1d87
Fix tokenization yi (#2507)
* Fixing odd tokenization self modifications on the Rust side (load and
resave in Python).

* Fixing the builds ?

* Fix the gh action?

* Fixing the location ?

* Validation is odd.

* Try a faster runner

* Upgrade python version.

* Remove sccache

* No sccache.

* Getting libpython maybe ?

* List stuff.

* Monkey it up.

* have no idea at this point

* Tmp.

* Shot in the dark.

* Tmate the hell out of this.

* Desperation.

* WTF.

* -y.

* Apparently 3.10 is not available anymore.

* Updating the dockerfile to make libpython discoverable at runtime too.

* Put back rust tests.

* Why do we want mkl on AMD ?

* Forcing 3.11 ?
2024-09-11 22:41:56 +02:00
Nicolas Patry a4e3e8c608
Prefix test - Different kind of load test to trigger prefix test bugs. (#2490)
* Adding prefix test.

* [WIP] tmp dump of integration load tests.

* Remove other tensor creation.

* Fixed the radix tree.

Used a slice everywhere in radix.rs to keep the cheap Arc cloning
instead of recomputing the input_ids.

* Fix parsing

* Is it really flashinfer version ?

* Remove some comments.

* Revert the max prefix hit.

* Adding numpy to diff.

* Upgraded flashinfer.

* Upgrading some stuff.

* Are we done yet ?

* Minor fixup

* Remove 1 log and put back the other.

* Add comment for why slot 0 is OK.

* Mounting on the job.

* Get me a debug branch

* Debugging CIs is fun.

* Attempt #28

* wip

* Tmate.

* Praying.

* Updating VLM causal model with updated context.

* Important line got squashed.

* Tmate again.

* Fingers crossed.

* We want only 1 run of integration tests.....

---------

Co-authored-by: Guillaume LEGENDRE <glegendre01@gmail.com>
2024-09-11 18:10:40 +02:00
Wang, Yi 5cd8025f18
hotfix: fix regression of attention api change in intel platform (#2439)
fix regression caused by attention api change. ipex.varlen_attention does not support paged-cache
format kv input now.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-09-05 17:41:39 +02:00
drbh 6cb42f49ae
feat: support lora revisions and qkv_proj weights (#2482)
* feat: support lora revisions and qkv_proj weights

* fix: add qkv_proj weights to weight test
2024-09-02 13:09:06 -04:00
Nicolas Patry d9fbbaafb0
Tied embeddings in MLP speculator. (#2473)
* Tied embeddings in MLP speculator.

* Fixing the scale_weight when users decide to not use the speculation as
much as defined in the config.

* Adding scaling support + optimize some ops.
2024-08-29 17:44:54 +02:00
Nicolas Patry e415b690a6
Lots of improvements (Still 2 allocators) (#2449)
* Making prefix/flashinfer the default and testing the full release tests.

* Include flashinfer in the docker.

* Using prebuilt.

* Allowing window_left_size (dummy version).

* Disabling flashinfer/prefix caching on odd head_dim

* Disable prefix caching for lora.

* More specific codes.

* Update lock

* Updating integration tests with new values with FI/FD.

Remove paged as a default too, and using FD everywhere.

* Update cargo lock ?

* Upgrade to 1.80 because of bitstream...

* Everywhere 1.80

* Forgot last default place.

* Apply suggestions from code review

Co-authored-by: drbh <david.richard.holtz@gmail.com>

* Updated flake lock

* Tmp

* Upgrade resolution system for less errors in resolution.

* Remove lambda for cleaner function.

* Handling debugger.

* OVerride the env in server tests.

* Is this enough to make it work ?

* This seems to be working.

* Downgrade some logs.

* Fixing the default for vlm.

* Don't enable prefix caching on VLM just yet.

* Change `add_special_tokens` in order to have the correct tokens for chat
input and not (since it's super important with the prefixing now)

* Fixing prefix caching for flashdecoding.

* Update all models.

* Fixed flashinfer version.

* add_special_tokens is internal only

* Fixing seqlen with the new vlms.

* Fixing the issue with `add_special_tokens` not being passed around.

* Fixing the test.

* Removing encoder_decoder (seq2seq).

* Update the chat test.

* Fixing the batching tokenization in flash causal lm.

* Truncating left for radix purposes.

* Oops this doesn't belong here.

* Put back default pure shell.

* Update server tests

- Default to throughput test in k6
- Use TGI_WIGGLE_ROOM to adjust wiggle room

* Only n_heads / process_group.size() are necessary.

* Revert the integrationt tests change (seem linked to head_size
modification).

* Adding error message when assert is violated.

* Fixing the free algorithm to handle times where the common prefix is
smaller.

* Apply suggestions from code review

Co-authored-by: OlivierDehaene <olivier@huggingface.co>

* Update server/text_generation_server/layers/attention/common.py

Co-authored-by: OlivierDehaene <olivier@huggingface.co>

* Fix disabling prefix caching - Fix windowing checks.

* Revert the Cohere tokenizer change (for now using a revision instead).

* Fmt.

---------

Co-authored-by: drbh <david.richard.holtz@gmail.com>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2024-08-29 16:29:01 +02:00
drbh 30be188400
Fix: don't apply post layernorm in SiglipVisionTransformer (#2459)
* Fix: don't apply post layernorm in SiglipVisionTransformer

This fixes a bug with LLaVA Next when using Siglip as the vision model. LLaVA Next expects the output of the vision model to be the encoder outputs before layernorm (see original transformers implementation here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L813).

This also makes Siglip consistent with the existing Clip implementation:

https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/custom_modeling/clip.py#L613

* fix: adjust pali gemma for post layer norm and small refactors

---------

Co-authored-by: Travis Addair <tgaddair@gmail.com>
2024-08-26 17:04:46 -04:00
Nicolas Patry b70ae0969f
Prefix caching (#2402)
* Prefix caching WIP

* Fixing prefix attention.

* Fixing flashinfer import.

* Fixing black.

* Fixing medusa (still wrong outputs, but functional).

* Just medusa values now.

* Fixing medusa without prefix caching.

* Fixing prefix caching.

* Medusa requires reshaping.

* Removing the logs.

* Remove router.nix

* Fixup:

- Remove logs
- Disable VLMs (they do not work)
- Disable prefix caching when user wants prefill logprobs.

* Update flake.lock

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2024-08-20 11:15:30 +02:00
Nicolas Patry 57b3495823
Fixing exl2 and other quanize tests again. (#2419)
* Fixing exl2 and other quanize tests again.

* Mark exl2 as non release (so CI tests them, needs to be removed latet).

* Fixing exl2 (by disabling cuda graphs)

* Fix quantization defaults without cuda graphs on exl2 (linked to new
issues with it).

* Removing serde override.

* Go back to released exl2 and remove log.

* Adding warnings for deprecated bitsandbytes + upgrade info to warn.
2024-08-15 11:12:51 +02:00
Nicolas Patry f3b5c69441
Upgrading exl2. (#2415)
* Upgrading exl2.

* Fixing the other pathways.

* Fix idefics.
2024-08-14 11:58:08 +02:00
Wang, Yi 59922f9bc1
add numa to improve cpu inference perf (#2330)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-08-13 15:33:55 +02:00