Commit Graph

38 Commits

Author SHA1 Message Date
Daniël de Kok 52e48739a5
Remove vLLM dependency for CUDA (#2751)
* Remove vLLM dependency for CUDA

This change adds `attention-kernels` as a dependency for paged
attention and cache reshaping. With that, we don't use vLLM
anywhere for CUDA.

Tested run (since we don't have paged attention in CI):

```
❯ ATTENTION=paged python -m pytest integration-tests -k "llama and awq" --release
[...]
5 snapshots passed.
```

* Fix clippy warning
2024-11-17 17:34:50 +01:00
Daniël de Kok eab07f746c
Add support for FP8 KV cache scales (#2628)
* Add support for FP8 KV cache scales

Since FP8 only has limited dynamic range, we can scale keys/values
before storing them into the cache (and unscale them in attention). To
avoid rescaling the cache as the absmax values change, good scales are
usually determined per layer using calibration calibration data and stored
in the checkpoint.

This change adds support for for using key-value scales and loading them
from checkpoints in the two most common formats:

- Separate per-layer `k_scale` and `v_scale` scalars.
- Per-layer `kv_scale` scalar (older format).

Currently, scales are only used with an `float8_e4m3fn` cache.

Besides adding support for key/value scales, the `fp8_quantize` function
is also extended to support quantization with a kernel vendored from
vLLM. This is slightly faster than the PyTorch implementation, but also
scales in FP32, potentially improving accuracy.

* Update FP8 KV cache test to use checkpoint with scales

* `can_scale`: check that the attention is flashinfer
2024-10-24 16:36:18 +02:00
Daniël de Kok 1b914f37e7
flashinfer: reminder to remove contiguous call in the future (#2685) 2024-10-24 14:59:56 +02:00
Daniël de Kok 8ec57558cd
Break cycle between the attention implementations and KV cache (#2627) 2024-10-17 14:54:22 +02:00
Daniël de Kok 59ea38cbca
Simplify the `attention` function (#2609)
* Simplify the `attention` function

- Use one definition rather than multiple.
- Add `key`/`value` arguments, so that we don't need the
  `PREFILL_IN_KVCACHE` constant.
- Make it kwargs-only (to avoid mixing up the various `Tensor` args).

* Fixup flashinfer support
2024-10-17 10:42:52 +02:00
Daniël de Kok 5bbe1ce028
Support `e4m3fn` KV cache (#2655)
* Support `e4m3fn` KV cache

* Make check more obvious
2024-10-17 10:42:16 +02:00
OlivierDehaene a6a0c97ed9
feat: prefill chunking (#2600)
* wip

* rollback

* refactor to use prefix/postfix namming + fix all_input_ids_tensor

* maybe patching vlms?

* fix filter and concat

* wip, no filter, no concat

* current

* add prepare_for_prefill

* working

* load tested

* re-create slots

* re-create slots

* fix slot_filtering_indices

* feedback loop

* remove log

* fix benchmarker

* fix vlm and seq2seq

* rename to cache and input lengths

* fix prefill logprobs

* fix launcher

* fix logprobs?

* idk at this point

* max input length

* omfg

* remove debugging lines

* fix tests

* fix mllama

* fix cargo tests

* remove support chunking for paged

* Fixing non blocked attentions

* Fixing dtype + AMD, Ipex targets.

* lint fix.

* rename

* Fix prefix_caching variable, remove defaults in server (confusing a lot
of the times).

* Add simple resolution when user specifies ATTENTION=paged.

* Put back non default simple tests.

* Fix env name

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-10-16 12:49:33 +02:00
Nicolas Patry 0c478846c5
Fixing intel Supports windowing. (#2637) 2024-10-11 21:47:03 +02:00
Nicolas Patry 8b295aa498
Upgrade minor rust version (Fixes rust build compilation cache) (#2617)
* Upgrade minor rust version (Fixes rust build compilation cache)

* Black
2024-10-08 09:42:50 +02:00
Florian Zimmermeister 0da4df4b96
Fix FP8 KV-cache condition (#2611)
Update kv_cache.py
2024-10-07 09:34:19 +02:00
Daniël de Kok 2358c2bb54
Add basic FP8 KV cache support (#2603)
* Add basic FP8 KV cache support

This change adds rudimentary FP8 KV cache support. The support is
enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so
uses this type for the KV cache. However support is still limited:

* Only the `fp8_e5m2` type is supported.
* The KV cache layout is the same as `float16`/`bfloat16` (HND).
* The FP8 KV cache is only supported for FlashInfer.
* Loading of scales is not yet supported.

* Fix Cargo.toml
2024-10-04 17:51:48 +02:00
Mohit Sharma f9e561eced
Update ROCM libs and improvements (#2579)
* style

* update torch

* ix issues

* fix clone

* revert mkl

* added custom PA

* style

* fix style

* style

* hide env vart

* fix mixtral model

* add skinny kernel and merge fixes

* fixed style

* fix issue for sliding window models

* addressed review comments

* fix import

* improved error messag

* updated default value

* remove import

* fix imports after rebase

* float16 dep

* improve dockerfile

* cleaned dockerfile
2024-09-30 10:54:32 +02:00
Daniël de Kok 1028996fb3
flashinfer: pass window size and dtype (#2574) 2024-09-28 18:41:41 +02:00
Daniël de Kok 5b6b74e21d
Improve support for GPUs with capability < 8 (#2575)
* Improve support for GPUs with capability < 8

- For models that cannot use flashinfer, use flash-attn v1 + paged
  attention for models with a compute capability older than 8.
- Disable prefix caching when using paged attention.
- When using flash-attn v1, pass the key/value, rather than the
  cache, since v1 cannot use block tables.

* nix: add flash-attn-v1 to the server environment

* Move disabling prefix caching into the block of exceptions

* Capability as `usize`s
2024-09-27 16:19:42 +02:00
Nicolas Patry dd8691b7c5
More tensor cores. (#2558)
* More tensor cores.

* Fixing the logic.

* Gemma is modified by this.
2024-09-24 23:57:26 +02:00
Wang, Yi 3ac7df2b6d
hotfix : enable intel ipex cpu and xpu in python3.11 (#2517)
enable intel ipex cpu and xpu in python3.11

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-09-12 17:23:49 +02:00
Wang, Yi 5cd8025f18
hotfix: fix regression of attention api change in intel platform (#2439)
fix regression caused by attention api change. ipex.varlen_attention does not support paged-cache
format kv input now.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-09-05 17:41:39 +02:00
Nicolas Patry e415b690a6
Lots of improvements (Still 2 allocators) (#2449)
* Making prefix/flashinfer the default and testing the full release tests.

* Include flashinfer in the docker.

* Using prebuilt.

* Allowing window_left_size (dummy version).

* Disabling flashinfer/prefix caching on odd head_dim

* Disable prefix caching for lora.

* More specific codes.

* Update lock

* Updating integration tests with new values with FI/FD.

Remove paged as a default too, and using FD everywhere.

* Update cargo lock ?

* Upgrade to 1.80 because of bitstream...

* Everywhere 1.80

* Forgot last default place.

* Apply suggestions from code review

Co-authored-by: drbh <david.richard.holtz@gmail.com>

* Updated flake lock

* Tmp

* Upgrade resolution system for less errors in resolution.

* Remove lambda for cleaner function.

* Handling debugger.

* OVerride the env in server tests.

* Is this enough to make it work ?

* This seems to be working.

* Downgrade some logs.

* Fixing the default for vlm.

* Don't enable prefix caching on VLM just yet.

* Change `add_special_tokens` in order to have the correct tokens for chat
input and not (since it's super important with the prefixing now)

* Fixing prefix caching for flashdecoding.

* Update all models.

* Fixed flashinfer version.

* add_special_tokens is internal only

* Fixing seqlen with the new vlms.

* Fixing the issue with `add_special_tokens` not being passed around.

* Fixing the test.

* Removing encoder_decoder (seq2seq).

* Update the chat test.

* Fixing the batching tokenization in flash causal lm.

* Truncating left for radix purposes.

* Oops this doesn't belong here.

* Put back default pure shell.

* Update server tests

- Default to throughput test in k6
- Use TGI_WIGGLE_ROOM to adjust wiggle room

* Only n_heads / process_group.size() are necessary.

* Revert the integrationt tests change (seem linked to head_size
modification).

* Adding error message when assert is violated.

* Fixing the free algorithm to handle times where the common prefix is
smaller.

* Apply suggestions from code review

Co-authored-by: OlivierDehaene <olivier@huggingface.co>

* Update server/text_generation_server/layers/attention/common.py

Co-authored-by: OlivierDehaene <olivier@huggingface.co>

* Fix disabling prefix caching - Fix windowing checks.

* Revert the Cohere tokenizer change (for now using a revision instead).

* Fmt.

---------

Co-authored-by: drbh <david.richard.holtz@gmail.com>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2024-08-29 16:29:01 +02:00
Nicolas Patry b70ae0969f
Prefix caching (#2402)
* Prefix caching WIP

* Fixing prefix attention.

* Fixing flashinfer import.

* Fixing black.

* Fixing medusa (still wrong outputs, but functional).

* Just medusa values now.

* Fixing medusa without prefix caching.

* Fixing prefix caching.

* Medusa requires reshaping.

* Removing the logs.

* Remove router.nix

* Fixup:

- Remove logs
- Disable VLMs (they do not work)
- Disable prefix caching when user wants prefill logprobs.

* Update flake.lock

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2024-08-20 11:15:30 +02:00
drbh 1cebccc72b
fix: adds causal to attention params (#2408)
fix: adds causal to attention params to check when using flash attn v1
2024-08-13 16:19:46 +02:00
Nicolas Patry 7a48a84784
Using an enum for flash backens (paged/flashdecoding/flashinfer) (#2385)
* Using an enum for flash backens (paged/flashdecoding/flashinfer)

* Early exit on server too.

* Clippy.

* Fix clippy and fmt.
2024-08-09 16:41:17 +02:00
Daniël de Kok 7830de1566
Add FlashInfer support (#2354)
This change adds support for FlashInfer. FlashInfer can be enabled using
`FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`.
Since this functionality is currently only for testing, FlashInfer is
not installed anywhere yet.

The FlashInfer API is quite different from FlashAttention/vLLM in that
it requires more global bookkeeping:

* A wrapper class needs to be contstructed (which we just call *state*).
  Since this is fairly expensive (due to pinned host memory allocation),
  we only do this once in a FlashCausalLM instance or for each CUDA
  Graph size.
* Each model forward call needs to be wrapped in `begin_forward` and
  `end_forward`. This sets up data structures that can be reused for all
  calls to attention for that forward call.

When calling attention, we need access to the state object. To avoid
passing an argument down the call chain (which would require changes to
all models), we use a context variable.

Each model forward call is wrapped using a context manager that does all
the bookkeeping for such a call:

* Set the context variable to the forward call's state.
* Call `begin_forward` on the state.
* Yield.
* Call `end_forward` on the state.
* Reset the context variable.

We cannot use a single shared global variable for this, since e.g. CUDA
Graphs of different sizes each have their own state.
2024-08-09 11:42:00 +02:00
drbh 2ca5980634
Pr 2337 ci branch (#2379)
* hotfix: fix xpu crash brought by code refine. torch.xpu rely on import ipex

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* reable gemma2 in xpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix in regression in ipex flashattention

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
2024-08-08 12:30:29 -04:00
drbh 29b8d19cdf
fix: return the out tensor rather then the functions return value (#2361) 2024-08-06 13:49:53 +02:00
drbh 215ed3ad52
fix: attempt forward on flash attn2 to check hardware support (#2335)
* fix: attempt forward on flash attn2 to check hardware support

* fix: warn window_size_left when using flash attn 1

* fix: prefer version check over test op and avoid window_size_left if not flash attn2

* fix: improve condtional and error message

* fix: update sliding window conditional

* fix: simplify changes and revert model changes

* fix: avoid changing conditional

* fix: typo tweak
2024-08-05 09:11:40 -04:00
Daniël de Kok 47447ef017
Unify attention output handling (#2343)
- Always return the hidden states.
- Create the output tensor inside the `attention` and `paged_attention`
  functions.

This removes the difference between how the output is handled between
attention (output parameter) and paged attention (return value). This
also removes the assumption that the attention implementation can
write to an output tensor (in preparation of FlashInfer).
2024-08-01 17:03:28 +02:00
drbh bab02ff2bc
feat: add ruff and resolve issue (#2262)
* feat: add ruff and resolve issue

* fix: update client exports and adjust after rebase

* fix: adjust syntax to avoid circular import

* fix: adjust client ruff settings

* fix: lint and refactor import check and avoid model enum as global names

* fix: improve fbgemm_gpu check and lints

* fix: update lints

* fix: prefer comparing model enum over str

* fix: adjust lints and ignore specific rules

* fix: avoid unneeded quantize check
2024-07-26 10:29:09 -04:00
Nicolas Patry 6aeb669072
Softcapping for gemma2. (#2273)
* Softcapping for gemma2.

* Less clutter.

* No access to transformers config, only config_dict here.

* 0.0 is the null value in the C++ API.
2024-07-22 18:27:10 +02:00
OlivierDehaene 53ec0b790b
feat(fp8): use fbgemm kernels and load fp8 weights directly (#2248)
* feat(fp8): add support for fbgemm

* allow loading fp8 weights directly

* update outlines

* fix makefile

* build fbgemm

* avoid circular import and fix dockerfile

* add default dtype

* refactored weights loader

* fix auto conversion

* fix quantization config parsing

* force new nccl on install

* missing get_weights implementation

* increase timeout
2024-07-20 19:02:04 +02:00
Nicolas Patry dea9c0dc74
Fixing rocm. (#2164) 2024-07-02 12:01:08 +02:00
Wang, Yi 5d97e0c4a3
fix FlashDecoding change's regression in intel platform (#2161)
install triton because GPTQParams needs it.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-07-02 11:56:07 +02:00
Nicolas Patry 4327210e6b
[Major Change][Undecided yet] Move to FlashDecoding instead of PagedAttention kernel. (#1940)
* Using flash decoding

Conditional flashdecoding.

Fix max_q.

Working kvcache

Working version with flash decoding.

Make it work for mistral.

Fix after rebase..

Less intrusive.

REvert changes in modeling.

Speedup flashdecoding.

HHachweew
Hack to make other models work.

Fixing non flash decoding llama path.

Router logic knows about page size.

Missing 2 models.

Missing cohere.

Fixing cohere flash decoding.

Revamped all this architecture.

Fix cohere.

Fixing falcon.

Enabling custom block size schedule.

Update router/src/infer.rs

Not sending preallocated output.

* Making it work on non flash decoding.

* Fix Cohere.

* Fix non decoding paths.

* Rebased.

* No need for cache_manager anymore.

* Update?

* "ipex" -> "cpu"

* These do not belong.

* Factoring cu_seqlen_qk for better abstracting over every model.

* Fixing non flash tests/imports.

* Changing return everywhere.

* Update mistral past.

* Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).

* Fixup mistral clamping (had issues with cuda graphs).

* No need to recreate anything actually.
2024-07-01 23:28:00 +02:00
Wang, Yi 5da4cfab1c
refine get xpu free memory/enable Qwen2/gemma2/gemma/phi in intel platform (#2132)
* refine get xpu free memory

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable qwen2 in xpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable gemma/gemma2/phi in intel platform

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-07-01 14:32:54 +02:00
Nicolas Patry 9e2fdf57c0
Removing IPEX_AVAIL. (#2115)
* Removing IPEX_AVAIL.

Chose to unify CPU and XPU under `ipex`. Most code is exactly similar
except for a very few spots.

The biggest number of spots is the kv-cache layout and the flash_xxx.py
files.
Since those files should be removed soon and factored away, we should
not need them.

* Forgot a few places.

* Unrelated change.

* Fixing HF_TOKEN.

* HF_TOKEN
2024-06-25 13:20:57 +02:00
Wang, Yi b64c70c9e7
Cpu tgi (#1936)
* add CPU tgi support

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* ipex distributed ops support

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Funtowicz Morgan <mfuntowicz@users.noreply.github.com>
2024-06-25 12:21:29 +02:00
fxmarty 9b3674d903
ROCm and sliding windows fixes (#2033)
* update vllm commit & fix models using sliding window

* update

* update commit

* fix bug where tunableop is bound to cuda graph even when cuda graph are disabled

* enable tunableop by default

* fix sliding window

* address review

* dead code

* precise comment

* is it flaky?
2024-06-10 15:09:50 +08:00
Nicolas Patry 0a94fad79f
Fixing rocm. (#2021)
# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
2024-06-05 14:41:34 +02:00
Nicolas Patry 06edde9491
Purely refactors paged/attention into `layers/attention` and make hardware differences more obvious with 1 file per hardware. (#1986)
# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
2024-05-31 17:57:01 +02:00