preemo_text-generation-infe.../server/text_generation_server/models/__init__.py

313 lines
9.6 KiB
Python
Raw Normal View History

import os
2023-01-20 04:24:39 -07:00
import torch
2023-03-24 07:02:14 -06:00
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
2023-01-31 10:53:56 -07:00
from typing import Optional
2023-03-07 10:52:22 -07:00
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM, AutoCausalLM
2023-04-03 11:06:42 -06:00
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
2023-03-07 10:52:22 -07:00
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
2023-03-07 10:52:22 -07:00
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.ct2_causal_lm import CT2CausalLM
2023-01-20 04:24:39 -07:00
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
# Disable gradients
torch.set_grad_enabled(False)
__all__ = [
"Model",
"BLOOMSharded",
"CausalLM",
"AutoCausalLM",
"FlashCausalLM",
"GalacticaSharded",
"Seq2SeqLM",
"SantaCoder",
"OPTSharded",
"T5Sharded",
"get_model",
]
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = True
2023-03-24 07:02:14 -06:00
try:
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
)
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
)
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
2023-04-03 11:06:42 -06:00
FLASH_ATTENTION = False
2023-03-24 07:02:14 -06:00
2023-04-03 11:06:42 -06:00
if FLASH_ATTENTION:
2023-03-24 07:02:14 -06:00
__all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
2023-01-31 10:53:56 -07:00
def get_model(
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool,
2023-01-31 10:53:56 -07:00
) -> Model:
dtype_ct2 = dtype
if dtype is None:
dtype = torch.float16
elif dtype == "float16":
dtype = torch.float16
elif dtype == "bfloat16":
dtype = torch.bfloat16
else:
raise RuntimeError(f"Unknown dtype {dtype}")
if quantize is not None and "ct2" in quantize:
return CT2CausalLM(
model_id,
revision,
quantize=quantize,
dtype=dtype_ct2,
trust_remote_code=trust_remote_code,
)
if "facebook/galactica" in model_id:
return GalacticaSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
dtypetrust_remote_code=trust_remote_code,
)
2023-02-14 05:02:16 -07:00
Hotfixes for santacoder/bigcode. (#294) # What does this PR do? Hotfixes: - Uses `model_type`=`gpt_bigcode` for more general usage. - Hotfixes linked lm_head vs wte_embedding (safetensors file do not contain the key, correctly when the file is sharded, where as pytorch copies the tensor) <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal> Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-05-15 02:35:20 -06:00
if model_id.startswith("bigcode/"):
if FLASH_ATTENTION:
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
2023-04-03 11:06:42 -06:00
else:
return SantaCoder(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
2023-02-14 05:02:16 -07:00
2023-06-01 11:49:13 -06:00
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config_dict["model_type"]
2023-01-31 10:53:56 -07:00
Hotfixes for santacoder/bigcode. (#294) # What does this PR do? Hotfixes: - Uses `model_type`=`gpt_bigcode` for more general usage. - Hotfixes linked lm_head vs wte_embedding (safetensors file do not contain the key, correctly when the file is sharded, where as pytorch copies the tensor) <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal> Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-05-15 02:35:20 -06:00
if model_type == "gpt_bigcode":
if FLASH_ATTENTION:
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
Hotfixes for santacoder/bigcode. (#294) # What does this PR do? Hotfixes: - Uses `model_type`=`gpt_bigcode` for more general usage. - Hotfixes linked lm_head vs wte_embedding (safetensors file do not contain the key, correctly when the file is sharded, where as pytorch copies the tensor) <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal> Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-05-15 02:35:20 -06:00
else:
return SantaCoder(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
Hotfixes for santacoder/bigcode. (#294) # What does this PR do? Hotfixes: - Uses `model_type`=`gpt_bigcode` for more general usage. - Hotfixes linked lm_head vs wte_embedding (safetensors file do not contain the key, correctly when the file is sharded, where as pytorch copies the tensor) <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal> Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-05-15 02:35:20 -06:00
if model_type == "bloom":
return BLOOMSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "mpt":
return MPTSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
elif model_type == "gpt_neox":
if FLASH_ATTENTION:
return FlashNeoXSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
return GPTNeoxSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
2023-01-31 10:53:56 -07:00
else:
return AutoCausalLM(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
2023-02-14 05:02:16 -07:00
elif model_type == "llama":
if FLASH_ATTENTION:
return FlashLlama(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
else:
return AutoCausalLM(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
2023-02-14 05:02:16 -07:00
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
if sharded:
if FLASH_ATTENTION:
if config_dict.get("alibi", False):
raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
else:
if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRWSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
return RW(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "opt":
return OPTSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "t5":
return T5Sharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
2023-02-14 05:02:16 -07:00
if sharded:
raise ValueError("sharded is not supported for AutoModel")
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438) Let's start discussing implementation. - Need to expose the quantization scripts (either included here or add doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa) - Make sure GPTQ works for multiple models (priority to Falcon). Currently it means that every place we use `get_{tensor|sharded}` to check for quantization. My idea is to reintegrate as much as possible into `utils/layer.py` by expanding `load_multi` to be a bit more generic. This might require some thinking, but ultimately the `qweight,qzeros,scales,g_idx` should be in a single place, and independant of bias presence. # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal> Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
if quantize == "gptq":
raise ValueError(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError(
"4bit quantization is not supported for AutoModel"
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return AutoCausalLM(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return AutoCausalLM(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if "AutoModelForSeq2SeqLM" in auto_map.keys():
return Seq2SeqLM(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise ValueError(f"Unsupported model type {model_type}")