From 2f243a1a150da40fc71cbdd08cd07e314cf7098e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 22 May 2024 16:22:57 +0200 Subject: [PATCH] Creating doc automatically for supported models. (#1929) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .github/workflows/autodocs.yml | 6 +- docs/source/supported_models.md | 48 +++-- .../text_generation_server/models/__init__.py | 181 +++++++++++++++--- update_doc.py | 88 ++++++++- 4 files changed, 267 insertions(+), 56 deletions(-) diff --git a/.github/workflows/autodocs.yml b/.github/workflows/autodocs.yml index c378e177..48ed01e2 100644 --- a/.github/workflows/autodocs.yml +++ b/.github/workflows/autodocs.yml @@ -13,11 +13,7 @@ jobs: - name: Install Launcher id: install-launcher - env: - REF: ${{ github.head_ref }} - REPO: ${{ github.repository }} - run: cargo install --git "https://github.com/$REPO" --branch "$REF" text-generation-launcher - + run: cargo install --path launcher/ - name: Check launcher Docs are up-to-date run: | echo text-generation-launcher --help diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index d478085e..4b6cf731 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -1,30 +1,36 @@ + # Supported Models and Hardware Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models are hardware are supported. ## Supported Models -The following models are optimized and can be served with TGI, which uses custom CUDA kernels for better inference. You can add the flag `--disable-custom-kernels` at the end of the `docker run` command if you wish to disable them. - -- [BLOOM](https://huggingface.co/bigscience/bloom) -- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl) -- [Galactica](https://huggingface.co/facebook/galactica-120b) -- [GPT-2](https://huggingface.co/openai-community/gpt2) -- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) -- [Llama](https://github.com/facebookresearch/llama) -- [OPT](https://huggingface.co/facebook/opt-66b) -- [SantaCoder](https://huggingface.co/bigcode/santacoder) -- [Starcoder](https://huggingface.co/bigcode/starcoder) -- [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b) -- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b) -- [MPT](https://huggingface.co/mosaicml/mpt-30b) -- [Llama V2](https://huggingface.co/meta-llama) -- [Code Llama](https://huggingface.co/codellama) +- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) +- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) +- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) +- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) +- [Gemma](https://huggingface.co/google/gemma-7b) +- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) +- [Dbrx](https://huggingface.co/databricks/dbrx-instruct) +- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj) - [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) -- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) -- [Phi](https://huggingface.co/microsoft/phi-2) -- [Idefics](HuggingFaceM4/idefics-9b-instruct) (Multimodal) -- [Llava-next](llava-hf/llava-v1.6-mistral-7b-hf) (Multimodal) +- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1) +- [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder) +- [Phi](https://huggingface.co/microsoft/phi-1_5) +- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) +- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct) +- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) +- [Qwen 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) +- [Opt](https://huggingface.co/facebook/opt-6.7b) +- [T5](https://huggingface.co/google/flan-t5-xxl) +- [Galactica](https://huggingface.co/facebook/galactica-120b) +- [SantaCoder](https://huggingface.co/bigcode/santacoder) +- [Bloom](https://huggingface.co/bigscience/bloom-560m) +- [Mpt](https://huggingface.co/mosaicml/mpt-7b-instruct) +- [Gpt2](https://huggingface.co/openai-community/gpt2) +- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) +- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal) + If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models: @@ -39,4 +45,4 @@ If you wish to serve a supported model that already exists on a local folder, ju ```bash text-generation-launcher --model-id -`````` +``` diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 9e5676f5..b319ab5d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,4 +1,5 @@ import torch +import enum import os from loguru import logger @@ -116,6 +117,142 @@ if MAMBA_AVAILABLE: __all__.append(Mamba) +class ModelType(enum.Enum): + IDEFICS2 = { + "type": "idefics2", + "name": "Idefics 2", + "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b", + "multimodal": True, + } + LLAVA_NEXT = { + "type": "llava_next", + "name": "Llava Next (1.6)", + "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf", + "multimodal": True, + } + LLAMA = { + "type": "llama", + "name": "Llama", + "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct", + } + PHI3 = { + "type": "phi3", + "name": "Phi 3", + "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", + } + GEMMA = { + "type": "gemma", + "name": "Gemma", + "url": "https://huggingface.co/google/gemma-7b", + } + COHERE = { + "type": "cohere", + "name": "Cohere", + "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus", + } + DBRX = { + "type": "dbrx", + "name": "Dbrx", + "url": "https://huggingface.co/databricks/dbrx-instruct", + } + MAMBA = { + "type": "ssm", + "name": "Mamba", + "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj", + } + MISTRAL = { + "type": "mistral", + "name": "Mistral", + "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2", + } + MIXTRAL = { + "type": "mixtral", + "name": "Mixtral", + "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1", + } + GPT_BIGCODE = { + "type": "gpt_bigcode", + "name": "Gpt Bigcode", + "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder", + } + PHI = { + "type": "phi", + "name": "Phi", + "url": "https://huggingface.co/microsoft/phi-1_5", + } + BAICHUAN = { + "type": "baichuan", + "name": "Baichuan", + "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat", + } + FALCON = { + "type": "falcon", + "name": "Falcon", + "url": "https://huggingface.co/tiiuae/falcon-7b-instruct", + } + STARCODER2 = { + "type": "starcoder2", + "name": "StarCoder 2", + "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1", + } + QWEN2 = { + "type": "qwen2", + "name": "Qwen 2", + "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1", + } + OPT = { + "type": "opt", + "name": "Opt", + "url": "https://huggingface.co/facebook/opt-6.7b", + } + T5 = { + "type": "t5", + "name": "T5", + "url": "https://huggingface.co/google/flan-t5-xxl", + } + GALACTICA = { + "type": "galactica", + "name": "Galactica", + "url": "https://huggingface.co/facebook/galactica-120b", + } + SANTACODER = { + "type": "santacoder", + "name": "SantaCoder", + "url": "https://huggingface.co/bigcode/santacoder", + } + BLOOM = { + "type": "bloom", + "name": "Bloom", + "url": "https://huggingface.co/bigscience/bloom-560m", + } + MPT = { + "type": "mpt", + "name": "Mpt", + "url": "https://huggingface.co/mosaicml/mpt-7b-instruct", + } + GPT2 = { + "type": "gpt2", + "name": "Gpt2", + "url": "https://huggingface.co/openai-community/gpt2", + } + GPT_NEOX = { + "type": "gpt_neox", + "name": "Gpt Neox", + "url": "https://huggingface.co/EleutherAI/gpt-neox-20b", + } + IDEFICS = { + "type": "idefics", + "name": "Idefics", + "url": "https://huggingface.co/HuggingFaceM4/idefics-9b", + "multimodal": True, + } + + +__GLOBALS = locals() +for data in ModelType: + __GLOBALS[data.name] = data.value["type"] + + def get_model( model_id: str, revision: Optional[str], @@ -267,7 +404,7 @@ def get_model( else: logger.info(f"Unknown quantization method {method}") - if model_type == "ssm": + if model_type == MAMBA: return Mamba( model_id, revision, @@ -288,8 +425,8 @@ def get_model( ) if ( - model_type == "gpt_bigcode" - or model_type == "gpt2" + model_type == GPT_BIGCODE + or model_type == GPT2 and model_id.startswith("bigcode/") ): if FLASH_ATTENTION: @@ -315,7 +452,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "bloom": + if model_type == BLOOM: return BLOOMSharded( model_id, revision, @@ -324,7 +461,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - elif model_type == "mpt": + elif model_type == MPT: return MPTSharded( model_id, revision, @@ -333,7 +470,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - elif model_type == "gpt2": + elif model_type == GPT2: if FLASH_ATTENTION: return FlashGPT2( model_id, @@ -354,7 +491,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - elif model_type == "gpt_neox": + elif model_type == GPT_NEOX: if FLASH_ATTENTION: return FlashNeoXSharded( model_id, @@ -383,7 +520,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == "phi": + elif model_type == PHI: if FLASH_ATTENTION: return FlashPhi( model_id, @@ -418,7 +555,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3": + elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: if FLASH_ATTENTION: return FlashLlama( model_id, @@ -439,7 +576,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - if model_type == "gemma": + if model_type == GEMMA: if FLASH_ATTENTION: return FlashGemma( model_id, @@ -461,7 +598,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "cohere": + if model_type == COHERE: if FLASH_ATTENTION: return FlashCohere( model_id, @@ -483,7 +620,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "dbrx": + if model_type == DBRX: if FLASH_ATTENTION: return FlashDbrx( model_id, @@ -505,7 +642,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]: + if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]: if sharded: if FLASH_ATTENTION: if config_dict.get("alibi", False): @@ -539,7 +676,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "mistral": + if model_type == MISTRAL: sliding_window = config_dict.get("sliding_window", -1) if ( ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) @@ -566,7 +703,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "mixtral": + if model_type == MIXTRAL: sliding_window = config_dict.get("sliding_window", -1) if ( ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) @@ -593,7 +730,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "starcoder2": + if model_type == STARCODER2: sliding_window = config_dict.get("sliding_window", -1) if ( ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) @@ -621,7 +758,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "qwen2": + if model_type == QWEN2: sliding_window = config_dict.get("sliding_window", -1) if ( ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) @@ -647,7 +784,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "opt": + if model_type == OPT: return OPTSharded( model_id, revision, @@ -657,7 +794,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "t5": + if model_type == T5: return T5Sharded( model_id, revision, @@ -666,7 +803,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - if model_type == "idefics": + if model_type == IDEFICS: if FLASH_ATTENTION: return IDEFICSSharded( model_id, @@ -678,7 +815,7 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == "idefics2": + if model_type == IDEFICS2: if FLASH_ATTENTION: return Idefics2( model_id, @@ -703,7 +840,7 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == "llava_next": + if model_type == LLAVA_NEXT: if FLASH_ATTENTION: return LlavaNext( model_id, diff --git a/update_doc.py b/update_doc.py index 6127418c..5da81c72 100644 --- a/update_doc.py +++ b/update_doc.py @@ -1,13 +1,34 @@ import subprocess import argparse +import ast + +TEMPLATE = """ +# Supported Models and Hardware + +Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models are hardware are supported. + +## Supported Models + +SUPPORTED_MODELS + +If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models: + +```python +# for causal LMs/text-generation models +AutoModelForCausalLM.from_pretrained(, device_map="auto")` +# or, for text-to-text generation models +AutoModelForSeq2SeqLM.from_pretrained(, device_map="auto") +``` + +If you wish to serve a supported model that already exists on a local folder, just point to the local folder. + +```bash +text-generation-launcher --model-id +``` +""" -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--check", action="store_true") - - args = parser.parse_args() - +def check_cli(check: bool): output = subprocess.check_output(["text-generation-launcher", "--help"]).decode( "utf-8" ) @@ -41,7 +62,7 @@ def main(): block = [] filename = "docs/source/basic_tutorials/launcher.md" - if args.check: + if check: with open(filename, "r") as f: doc = f.read() if doc != final_doc: @@ -53,12 +74,63 @@ def main(): ).stdout.decode("utf-8") print(diff) raise Exception( - "Doc is not up-to-date, run `python update_doc.py` in order to update it" + "Cli arguments Doc is not up-to-date, run `python update_doc.py` in order to update it" ) else: with open(filename, "w") as f: f.write(final_doc) +def check_supported_models(check: bool): + filename = "server/text_generation_server/models/__init__.py" + with open(filename, "r") as f: + tree = ast.parse(f.read()) + + enum_def = [ + x for x in tree.body if isinstance(x, ast.ClassDef) and x.name == "ModelType" + ][0] + _locals = {} + _globals = {} + exec(f"import enum\n{ast.unparse(enum_def)}", _globals, _locals) + ModelType = _locals["ModelType"] + list_string = "" + for data in ModelType: + list_string += f"- [{data.value['name']}]({data.value['url']})" + if data.value.get("multimodal", None): + list_string += " (Multimodal)" + list_string += "\n" + + final_doc = TEMPLATE.replace("SUPPORTED_MODELS", list_string) + + filename = "docs/source/supported_models.md" + if check: + with open(filename, "r") as f: + doc = f.read() + if doc != final_doc: + tmp = "supported.md" + with open(tmp, "w") as g: + g.write(final_doc) + diff = subprocess.run( + ["diff", tmp, filename], capture_output=True + ).stdout.decode("utf-8") + print(diff) + raise Exception( + "Supported models is not up-to-date, run `python update_doc.py` in order to update it" + ) + else: + with open(filename, "w") as f: + f.write(final_doc) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--check", action="store_true") + + args = parser.parse_args() + + check_cli(args.check) + check_supported_models(args.check) + + if __name__ == "__main__": main()