hf_text-generation-inference/server/text_generation_server/models/flash_llama.py

102 lines
3.4 KiB
Python
Raw Normal View History

import torch
import torch.distributed
from opentelemetry import trace
fix: LlamaTokenizerFast to AutoTokenizer at flash_llama.py (#619) # What does this PR do? A few tokenizer_config in huggingface use LlamaTokenizer, so I think I would have selected `LlamaTokenizer` before. For a few cases where you're using a llama structure but not a llama tokenizer, why not make it to call the AutoTokenizer in exception handling. In the case of `decapoda-research/llama-7b-hf`, LLamaTokenizer is still being used in config.json, so it should be called through` LlamaTokenizer`. Also, if an exception is thrown by LlamaTokenizer, it will cause `LlamaTokenzierFast` to be called from AutoTokenizer. Fixes # 560 ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [x] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil
2023-08-14 06:20:18 -06:00
from transformers import AutoConfig, AutoTokenizer
from transformers.models.llama import LlamaTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
2023-07-18 10:49:42 -06:00
LlamaConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashLlama(FlashCausalLM):
def __init__(
feat(server): GPTQ quantization (step1) (#277) Changes only the type from `bool` to `Option<Enum>` pretty much everywhere. - Use `Optional[str]` in Python (easier to manage than importing type everywhere). Except for the cli to get proper validation - Updated all models to handle gracefully new values. (Error out if unknown value, or gptq since not implemented). <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2023-05-12 06:46:41 -06:00
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
2023-12-11 04:46:30 -07:00
use_medusa: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438) Let's start discussing implementation. - Need to expose the quantization scripts (either included here or add doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa) - Make sure GPTQ works for multiple models (priority to Falcon). Currently it means that every place we use `get_{tensor|sharded}` to check for quantization. My idea is to reintegrate as much as possible into `utils/layer.py` by expanding `load_multi` to be a bit more generic. This might require some thinking, but ultimately the `qweight,qzeros,scales,g_idx` should be in a single place, and independant of bias presence. # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal> Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
try:
tokenizer = LlamaTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
except Exception:
fix: LlamaTokenizerFast to AutoTokenizer at flash_llama.py (#619) # What does this PR do? A few tokenizer_config in huggingface use LlamaTokenizer, so I think I would have selected `LlamaTokenizer` before. For a few cases where you're using a llama structure but not a llama tokenizer, why not make it to call the AutoTokenizer in exception handling. In the case of `decapoda-research/llama-7b-hf`, LLamaTokenizer is still being used in config.json, so it should be called through` LlamaTokenizer`. Also, if an exception is thrown by LlamaTokenizer, it will cause `LlamaTokenzierFast` to be called from AutoTokenizer. Fixes # 560 ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [x] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil
2023-08-14 06:20:18 -06:00
tokenizer = AutoTokenizer.from_pretrained(
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438) Let's start discussing implementation. - Need to expose the quantization scripts (either included here or add doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa) - Make sure GPTQ works for multiple models (priority to Falcon). Currently it means that every place we use `get_{tensor|sharded}` to check for quantization. My idea is to reintegrate as much as possible into `utils/layer.py` by expanding `load_multi` to be a bit more generic. This might require some thinking, but ultimately the `qweight,qzeros,scales,g_idx` should be in a single place, and independant of bias presence. # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal> Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
2023-07-18 10:49:42 -06:00
config = LlamaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
feat(server): Using `quantize_config.json` instead of GPTQ_BITS env variables. (#671) - Current PR is not great because we're side stepping the `Weights.__init__` but Weights shouldn't requires anything related to the config or the model_id as it aims to be a simple Wrapper over multi file loading. - Ideal solution would be to use something like Rust enum ``` enum Quantize{ Bitandbytes(Bitsandbytes), GPTQ(bits: usize, groupsize: usize) ``` And passing that around during load. Unfortunately we don't have access to this, so for now, side-stepping seems easier. - Re-enabling groupsize<0 with exllama (confirmed it works.) Helps #601 In next steps we should make sure our quantization script uses that format and make it standard. # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2023-07-25 05:00:27 -06:00
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
Add AWQ quantization inference support (#1019) (#1054) # Add AWQ quantization inference support Fixes https://github.com/huggingface/text-generation-inference/issues/781 This PR (partially) adds support for AWQ quantization for inference. More information on AWQ [here](https://arxiv.org/abs/2306.00978). In general, AWQ is faster and more accurate than GPTQ, which is currently supported by TGI. This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors (in `requirements.txt`, just one line change). Quick way to test this PR would be bring up TGI as follows: ``` text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq text-generation-launcher \ --huggingface-hub-cache ~/.cache/huggingface/hub/ \ --model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \ --trust-remote-code --port 8080 \ --max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \ --quantize awq ``` Please note: * This PR was tested with FlashAttention v2 and vLLM. * This PR adds support for AWQ inference, not quantizing the models. That needs to be done outside of TGI, instructions [here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa). * This PR only adds support for `FlashLlama` models for now. * Multi-GPU setup has not been tested. * No integration tests have been added so far, will add later if maintainers are interested in this change. * This PR can be tested on any of the models released [here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models). Please refer to the linked issue for benchmarks for [abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq) vs [TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ). Please note, AWQ has released faster (and in case of Llama, fused) kernels for 4-bit GEMM, currently at the top of the `main` branch at https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit that has been tested to work. We can switch to latest commit later on. ## Who can review? @OlivierDehaene OR @Narsil --------- # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Abhinav M Kulkarni <abhinavkulkarni@gmail.com> Co-authored-by: Abhinav Kulkarni <abhinav@concentric.ai>
2023-09-25 07:31:27 -06:00
if config.quantize in ["gptq", "awq"]:
feat(server): Using `quantize_config.json` instead of GPTQ_BITS env variables. (#671) - Current PR is not great because we're side stepping the `Weights.__init__` but Weights shouldn't requires anything related to the config or the model_id as it aims to be a simple Wrapper over multi file loading. - Ideal solution would be to use something like Rust enum ``` enum Quantize{ Bitandbytes(Bitsandbytes), GPTQ(bits: usize, groupsize: usize) ``` And passing that around during load. Unfortunately we don't have access to this, so for now, side-stepping seems easier. - Re-enabling groupsize<0 with exllama (confirmed it works.) Helps #601 In next steps we should make sure our quantization script uses that format and make it standard. # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
2023-07-25 05:00:27 -06:00
weights._set_gptq_params(model_id)
model = FlashLlamaForCausalLM(config, weights)
2023-12-11 04:46:30 -07:00
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
2023-12-11 06:49:52 -07:00
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
2023-12-11 04:46:30 -07:00
with open(medusa_config, "r") as f:
config = json.load(f)
2023-12-11 06:49:52 -07:00
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
2023-12-11 04:46:30 -07:00
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)