From ba552e1a821d7a8f60ac0c576aab1857a523038a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 28 Nov 2023 17:54:26 +0100 Subject: [PATCH] Let each model resolve their own default dtype. (#1287) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- integration-tests/conftest.py | 8 ++++++++ integration-tests/models/test_idefics.py | 2 +- server/text_generation_server/cli.py | 2 +- server/text_generation_server/models/__init__.py | 4 +++- 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index c1cbe7f3..d2216241 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -210,6 +210,7 @@ def launcher(event_loop): quantize: Optional[str] = None, trust_remote_code: bool = False, use_flash_attention: bool = True, + dtype: Optional[str] = None ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -237,6 +238,9 @@ def launcher(event_loop): if quantize is not None: args.append("--quantize") args.append(quantize) + if dtype is not None: + args.append("--dtype") + args.append(dtype) if trust_remote_code: args.append("--trust-remote-code") @@ -269,6 +273,7 @@ def launcher(event_loop): quantize: Optional[str] = None, trust_remote_code: bool = False, use_flash_attention: bool = True, + dtype: Optional[str] = None ): port = random.randint(8000, 10_000) @@ -279,6 +284,9 @@ def launcher(event_loop): if quantize is not None: args.append("--quantize") args.append(quantize) + if dtype is not None: + args.append("--dtype") + args.append(dtype) if trust_remote_code: args.append("--trust-remote-code") diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index 5f4571b5..5a81a4f0 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -3,7 +3,7 @@ import pytest @pytest.fixture(scope="module") def idefics_handle(launcher): - with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2) as handle: + with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16") as handle: yield handle diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index b741a84c..3abe86af 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -76,7 +76,7 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = None if dtype is None else dtype.value - if dtype is not None and quantize is not None: + if dtype is not None and quantize not in {None, "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4"}: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5b1b5715..ab3b25b7 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -87,7 +87,9 @@ def get_model( trust_remote_code: bool, ) -> Model: if dtype is None: - dtype = torch.float16 + # Keep it as default for now and let + # every model resolve their own default dtype. + dtype = None elif dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16":