hf_text-generation-inference/integration-tests/models
Nicolas Patry abd58ff82c
feat(server): Rework model loading (#344)
# What does this PR do?

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f)

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.taildb5d.ts.net>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
2023-06-08 14:51:52 +02:00
..
__snapshots__ feat(server): Rework model loading (#344) 2023-06-08 14:51:52 +02:00
test_bloom_560m.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_bloom_560m_sharded.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_flash_falcon.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_flash_llama.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_flash_neox.py feat(server): Rework model loading (#344) 2023-06-08 14:51:52 +02:00
test_flash_neox_sharded.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_flash_santacoder.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_flash_starcoder.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_mt0_base.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_neox.py feat(server): Rework model loading (#344) 2023-06-08 14:51:52 +02:00
test_neox_sharded.py feat(server): Rework model loading (#344) 2023-06-08 14:51:52 +02:00
test_t5_sharded.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00