feat(server): GPTQ quantization (step1) (#277)

Changes only the type from `bool` to `Option<Enum>` pretty much
everywhere.
- Use `Optional[str]` in Python (easier to manage than importing type
everywhere). Except for the cli to get proper validation
- Updated all models to handle gracefully new values. (Error out if
unknown value, or gptq since not implemented).

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Nicolas Patry 2023-05-12 14:46:41 +02:00 committed by GitHub
parent 4f6d038c0b
commit 76a48cd365
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 136 additions and 49 deletions

View File

@ -1,4 +1,4 @@
use clap::Parser;
use clap::{Parser, ValueEnum};
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
@ -16,6 +16,26 @@ use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
mod env_runtime;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
Bitsandbytes,
Gptq,
}
impl std::fmt::Display for Quantization {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
Quantization::Bitsandbytes => {
write!(f, "bitsandbytes")
}
Quantization::Gptq => {
write!(f, "gptq")
}
}
}
}
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
@ -46,10 +66,10 @@ struct Args {
#[clap(long, env)]
num_shard: Option<usize>,
/// Wether you want the model to be quantized or not. This will use bitsandbytes for
/// quantization on the fly.
#[clap(long, env)]
quantize: bool,
/// Wether you want the model to be quantized or not. This will use `bitsandbytes` for
/// quantization on the fly, or `gptq`.
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,
/// The maximum amount of concurrent requests for this particular deployment.
/// Having a low limit will refuse clients requests instead of having them
@ -218,7 +238,7 @@ enum ShardStatus {
fn shard_manager(
model_id: String,
revision: Option<String>,
quantize: bool,
quantize: Option<Quantization>,
uds_path: String,
rank: usize,
world_size: usize,
@ -257,8 +277,9 @@ fn shard_manager(
shard_argv.push("--sharded".to_string());
}
if quantize {
shard_argv.push("--quantize".to_string())
if let Some(quantize) = quantize {
shard_argv.push("--quantize".to_string());
shard_argv.push(quantize.to_string())
}
// Model optional revision

View File

@ -5,17 +5,23 @@ import typer
from pathlib import Path
from loguru import logger
from typing import Optional
from enum import Enum
app = typer.Typer()
class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
gptq = "gptq"
@app.command()
def serve(
model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
quantize: bool = False,
quantize: Optional[Quantization] = None,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,
@ -55,6 +61,8 @@ def serve(
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
server.serve(model_id, revision, sharded, quantize, uds_path)

View File

@ -91,7 +91,7 @@ torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
) -> Model:
if "facebook/galactica" in model_id:
if sharded:

View File

@ -49,7 +49,12 @@ class BloomCausalLMBatch(CausalLMBatch):
class BLOOM(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
super(BLOOM, self).__init__(
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
)
@ -61,7 +66,10 @@ class BLOOM(CausalLM):
class BLOOMSharded(BLOOM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
self.master = rank == 0
@ -113,7 +121,7 @@ class BLOOMSharded(BLOOM):
def load_weights(
model,
filenames: List[str],
quantize: bool,
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
@ -167,7 +175,7 @@ class BLOOMSharded(BLOOM):
tensor = tensor.contiguous().to(dtype)
if quantize:
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -217,9 +225,14 @@ class BLOOMSharded(BLOOM):
return linear
module.linear = replace_linear(state)
else:
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "word_embeddings.weight":

View File

@ -447,7 +447,7 @@ class CausalLM(Model):
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
quantize: Optional[str] = None,
decode_buffer: int = 3,
):
if torch.cuda.is_available():
@ -468,7 +468,7 @@ class CausalLM(Model):
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
load_in_8bit=quantize == "bitsandbytes",
).eval()
tokenizer.pad_token_id = (
self.model.config.pad_token_id

View File

@ -105,7 +105,7 @@ class FastLinear(nn.Linear):
self.bnb_linear = None
def prepare_weights(self, quantize: bool = False):
if quantize:
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -129,8 +129,12 @@ class FastLinear(nn.Linear):
# Delete reference to data
self.weight = None
self.bias = None
else:
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:

View File

@ -92,8 +92,8 @@ class FastLinear(nn.Linear):
self.quantized = False
self.bnb_linear = None
def prepare_weights(self, quantize: bool = False):
if quantize:
def prepare_weights(self, quantize: Optional[str] = None):
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -117,8 +117,12 @@ class FastLinear(nn.Linear):
# Delete reference to data
self.weight = None
self.bias = None
else:
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:

View File

@ -67,8 +67,8 @@ class FastLinear(nn.Linear):
self.quantized = False
self.bnb_linear = None
def prepare_weights(self, quantize: bool = False):
if quantize:
def prepare_weights(self, quantize: Optional[str] = None):
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -92,8 +92,12 @@ class FastLinear(nn.Linear):
# Delete reference to data
self.weight = None
self.bias = None
else:
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:

View File

@ -393,7 +393,7 @@ class FlashCausalLM(Model):
model_cls: Type[PreTrainedModel],
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
quantize: Optional[str] = None,
decode_buffer: int = 3,
):
if torch.cuda.is_available():
@ -410,7 +410,7 @@ class FlashCausalLM(Model):
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize,
load_in_8bit=quantize == "bitsandbytes",
)
.eval()
.to(device)

View File

@ -154,7 +154,10 @@ class FlashLlama(FlashCausalLM):
class FlashLlamaSharded(FlashLlama):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
self.past_pad = None
self.process_group, rank, world_size = initialize_torch_distributed()

View File

@ -193,7 +193,10 @@ class Galactica(OPT):
class GalacticaSharded(Galactica):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
self.master = rank == 0
@ -244,7 +247,7 @@ class GalacticaSharded(Galactica):
def load_weights(
model,
filenames: List[str],
quantize: bool,
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
@ -299,7 +302,7 @@ class GalacticaSharded(Galactica):
tensor = tensor.contiguous().to(dtype)
if quantize:
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -349,9 +352,14 @@ class GalacticaSharded(Galactica):
return linear
module.linear = replace_linear(state)
else:
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "model.decoder.embed_tokens.weight":

View File

@ -32,7 +32,10 @@ except Exception as e:
class GPTNeoxSharded(CausalLM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
self.master = rank == 0
@ -83,7 +86,7 @@ class GPTNeoxSharded(CausalLM):
def load_weights(
model,
filenames: List[str],
quantize: bool,
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
@ -148,7 +151,7 @@ class GPTNeoxSharded(CausalLM):
tensor = tensor.contiguous().to(dtype)
if quantize:
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -198,9 +201,14 @@ class GPTNeoxSharded(CausalLM):
return linear
module.linear = replace_linear(state)
else:
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor

View File

@ -14,7 +14,12 @@ EOD = "<|endoftext|>"
class SantaCoder(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
@ -46,7 +51,7 @@ class SantaCoder(CausalLM):
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=True, # required
)
.to(device)

View File

@ -501,7 +501,7 @@ class Seq2SeqLM(Model):
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
quantize: Optional[str] = None,
decode_buffer: int = 3,
):
if torch.cuda.is_available():
@ -519,7 +519,7 @@ class Seq2SeqLM(Model):
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
load_in_8bit=quantize == "bitsandbytes",
).eval()
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"

View File

@ -32,7 +32,10 @@ except Exception as e:
class T5Sharded(Seq2SeqLM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
self.master = rank == 0
@ -83,7 +86,7 @@ class T5Sharded(Seq2SeqLM):
def load_weights(
model,
filenames: List[str],
quantize: bool,
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
@ -154,7 +157,7 @@ class T5Sharded(Seq2SeqLM):
tensor = tensor.contiguous().to(dtype)
if quantize:
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -205,8 +208,14 @@ class T5Sharded(Seq2SeqLM):
module.linear = replace_linear(state)
else:
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor

View File

@ -100,14 +100,14 @@ def serve(
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: bool,
quantize: Optional[str],
uds_path: Path,
):
async def serve_inner(
model_id: str,
revision: Optional[str],
sharded: bool = False,
quantize: bool = False,
quantize: Optional[str] = None,
):
unix_socket_template = "unix://{}-{}"
if sharded: