Commit Graph

13 Commits

Author SHA1 Message Date
Nicolas Patry 1da07e85aa
feat(server): Add Non flash MPT. (#514)
# What does this PR do?


This adds a non flash version of MPT.
Flash is harder because we need to create a bias ready cuda kernel of
flash attention.

Fixes
https://github.com/huggingface/text-generation-inference/issues/361
Fixes
https://github.com/huggingface/text-generation-inference/issues/491
Fixes
https://github.com/huggingface/text-generation-inference/issues/290
2023-07-03 13:01:46 +02:00
OlivierDehaene e74bd41e0f
feat(server): add paged attention to flash models (#516)
Closes #478
2023-06-30 19:09:59 +02:00
Nicolas Patry abd58ff82c
feat(server): Rework model loading (#344)
# What does this PR do?

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f)

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.taildb5d.ts.net>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
2023-06-08 14:51:52 +02:00
OlivierDehaene 895c5f1562
feat(server): only compute prefill logprobs when asked (#406)
Close #288
2023-06-02 17:12:30 +02:00
OlivierDehaene 87dc034b59
feat(server): add retry on download (#384) 2023-05-31 10:57:53 +02:00
OlivierDehaene 444400b457 increase health checks 2023-05-31 10:55:59 +02:00
OlivierDehaene b8b950b37c
feat(server): support RefinedWeb models (#379) 2023-05-30 18:25:19 +02:00
OlivierDehaene 62f91f78ac
feat(server): support vectorized warpers in flash causal lm (#317)
Co-authored-by: Joel Lamy-Poirier <joel.lamy-poirier@servicenow.com>
2023-05-26 12:30:27 +02:00
OlivierDehaene cfaa858070
feat(server): support fp16 for t5 (#360)
Fixes #349
2023-05-23 18:16:48 +02:00
OlivierDehaene 91d9beec90
fix(server): fix init for flash causal lm (#352)
Fixes #347
2023-05-22 15:05:32 +02:00
OlivierDehaene 5a58226130
fix(server): fix decode token (#334)
Fixes #333

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2023-05-16 23:23:27 +02:00
OlivierDehaene dbdc587ddd
feat(integration-tests): improve comparison and health checks (#336) 2023-05-16 20:22:11 +02:00
OlivierDehaene e71471bec9
feat: add snapshot testing (#282) 2023-05-15 23:36:30 +02:00