Fixing exl2 and other quanize tests again. (#2419)
* Fixing exl2 and other quanize tests again. * Mark exl2 as non release (so CI tests them, needs to be removed latet). * Fixing exl2 (by disabling cuda graphs) * Fix quantization defaults without cuda graphs on exl2 (linked to new issues with it). * Removing serde override. * Go back to released exl2 and remove log. * Adding warnings for deprecated bitsandbytes + upgrade info to warn.
This commit is contained in:
parent
9aaa12e7ac
commit
57b3495823
|
@ -250,6 +250,9 @@ RUN cd server && \
|
|||
pip install nvidia-nccl-cu12==2.22.3
|
||||
|
||||
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||
# This is needed because exl2 tries to load flash-attn
|
||||
# And fails with our builds.
|
||||
ENV EXLLAMA_NO_FLASH_ATTN=1
|
||||
|
||||
# Deps before the binaries
|
||||
# The binaries change on every build given we burn the SHA into them
|
||||
|
|
|
@ -21,7 +21,6 @@ async def flash_llama_exl2(flash_llama_exl2_handle):
|
|||
return flash_llama_exl2_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
||||
|
@ -33,7 +32,6 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
|
|||
assert response == ignore_logprob_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2_all_params(
|
||||
|
@ -60,7 +58,6 @@ async def test_flash_llama_exl2_all_params(
|
|||
assert response == ignore_logprob_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2_load(
|
||||
|
|
|
@ -30,11 +30,18 @@ struct RawConfig {
|
|||
n_positions: Option<usize>,
|
||||
model_type: Option<String>,
|
||||
max_seq_len: Option<usize>,
|
||||
quantization_config: Option<QuantizationConfig>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct QuantizationConfig {
|
||||
quant_method: Option<Quantization>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Config {
|
||||
max_position_embeddings: Option<usize>,
|
||||
quantize: Option<Quantization>,
|
||||
}
|
||||
|
||||
impl From<RawConfig> for Config {
|
||||
|
@ -43,13 +50,16 @@ impl From<RawConfig> for Config {
|
|||
.max_position_embeddings
|
||||
.or(other.max_seq_len)
|
||||
.or(other.n_positions);
|
||||
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
||||
Config {
|
||||
max_position_embeddings,
|
||||
quantize,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
enum Quantization {
|
||||
/// 4 bit quantization. Requires a specific AWQ quantized model:
|
||||
/// <https://hf.co/models?search=awq>.
|
||||
|
@ -72,17 +82,17 @@ enum Quantization {
|
|||
Marlin,
|
||||
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
|
||||
/// but it is known that the model will be much slower to run than the native f16.
|
||||
#[deprecated(
|
||||
since = "1.1.0",
|
||||
note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
|
||||
)]
|
||||
// #[deprecated(
|
||||
// since = "1.1.0",
|
||||
// note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
|
||||
// )]
|
||||
Bitsandbytes,
|
||||
/// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
|
||||
/// but it is known that the model will be much slower to run than the native f16.
|
||||
BitsandbytesNF4,
|
||||
BitsandbytesNf4,
|
||||
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
|
||||
/// perplexity performance for you model
|
||||
BitsandbytesFP4,
|
||||
BitsandbytesFp4,
|
||||
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
|
||||
/// This dtype has native ops should be the fastest if available.
|
||||
/// This is currently not the fastest because of local unpacking + padding to satisfy matrix
|
||||
|
@ -99,10 +109,10 @@ impl std::fmt::Display for Quantization {
|
|||
Quantization::Bitsandbytes => {
|
||||
write!(f, "bitsandbytes")
|
||||
}
|
||||
Quantization::BitsandbytesNF4 => {
|
||||
Quantization::BitsandbytesNf4 => {
|
||||
write!(f, "bitsandbytes-nf4")
|
||||
}
|
||||
Quantization::BitsandbytesFP4 => {
|
||||
Quantization::BitsandbytesFp4 => {
|
||||
write!(f, "bitsandbytes-fp4")
|
||||
}
|
||||
Quantization::Exl2 => {
|
||||
|
@ -1085,6 +1095,7 @@ fn spawn_shards(
|
|||
cuda_graphs: Vec<usize>,
|
||||
max_total_tokens: usize,
|
||||
max_input_tokens: usize,
|
||||
quantize: Option<Quantization>,
|
||||
max_log_level: LevelFilter,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
|
@ -1106,7 +1117,6 @@ fn spawn_shards(
|
|||
let shutdown_sender = shutdown_sender.clone();
|
||||
let otlp_endpoint = args.otlp_endpoint.clone();
|
||||
let otlp_service_name = args.otlp_service_name.clone();
|
||||
let quantize = args.quantize;
|
||||
let speculate = args.speculate;
|
||||
let dtype = args.dtype;
|
||||
let trust_remote_code = args.trust_remote_code;
|
||||
|
@ -1429,65 +1439,68 @@ fn main() -> Result<(), LauncherError> {
|
|||
|
||||
tracing::info!("{:#?}", args);
|
||||
|
||||
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
|
||||
let model_id = args.model_id.clone();
|
||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
||||
let filename = if !path.exists() {
|
||||
// Assume it's a hub id
|
||||
let get_max_positions_quantize =
|
||||
|| -> Result<(usize, Option<Quantization>), Box<dyn std::error::Error>> {
|
||||
let model_id = args.model_id.clone();
|
||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
||||
let filename = if !path.exists() {
|
||||
// Assume it's a hub id
|
||||
|
||||
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||
// env variable has precedence over on file token.
|
||||
ApiBuilder::new().with_token(Some(token)).build()?
|
||||
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||
// env variable has precedence over on file token.
|
||||
ApiBuilder::new().with_token(Some(token)).build()?
|
||||
} else {
|
||||
Api::new()?
|
||||
};
|
||||
let repo = if let Some(ref revision) = args.revision {
|
||||
api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
} else {
|
||||
api.model(model_id)
|
||||
};
|
||||
repo.get("config.json")?
|
||||
} else {
|
||||
Api::new()?
|
||||
path.push("config.json");
|
||||
path
|
||||
};
|
||||
let repo = if let Some(ref revision) = args.revision {
|
||||
api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
} else {
|
||||
api.model(model_id)
|
||||
};
|
||||
repo.get("config.json")?
|
||||
} else {
|
||||
path.push("config.json");
|
||||
path
|
||||
};
|
||||
|
||||
let content = std::fs::read_to_string(filename)?;
|
||||
let config: RawConfig = serde_json::from_str(&content)?;
|
||||
let content = std::fs::read_to_string(filename)?;
|
||||
let config: RawConfig = serde_json::from_str(&content)?;
|
||||
|
||||
if config.model_type == Some("gemma2".to_string()) {
|
||||
tracing::info!("Forcing flash decoding because of softcap usage");
|
||||
std::env::set_var("ATTENTION", "flashdecoding");
|
||||
}
|
||||
let config: Config = config.into();
|
||||
|
||||
// Quantization usually means you're even more RAM constrained.
|
||||
let max_default = 4096;
|
||||
|
||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||
if max_position_embeddings > max_default {
|
||||
let max = max_position_embeddings;
|
||||
if args.max_input_tokens.is_none()
|
||||
&& args.max_total_tokens.is_none()
|
||||
&& args.max_batch_prefill_tokens.is_none()
|
||||
{
|
||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||
}
|
||||
Ok(max_default)
|
||||
} else {
|
||||
Ok(max_position_embeddings)
|
||||
if config.model_type == Some("gemma2".to_string()) {
|
||||
tracing::info!("Forcing flash decoding because of softcap usage");
|
||||
std::env::set_var("ATTENTION", "flashdecoding");
|
||||
}
|
||||
} else {
|
||||
Err(Box::new(LauncherError::ArgumentValidation(
|
||||
"no max defined".to_string(),
|
||||
)))
|
||||
}
|
||||
};
|
||||
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
||||
let config: Config = config.into();
|
||||
let quantize = config.quantize;
|
||||
|
||||
// Quantization usually means you're even more RAM constrained.
|
||||
let max_default = 4096;
|
||||
|
||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||
if max_position_embeddings > max_default {
|
||||
let max = max_position_embeddings;
|
||||
if args.max_input_tokens.is_none()
|
||||
&& args.max_total_tokens.is_none()
|
||||
&& args.max_batch_prefill_tokens.is_none()
|
||||
{
|
||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||
}
|
||||
Ok((max_default, quantize))
|
||||
} else {
|
||||
Ok((max_position_embeddings, quantize))
|
||||
}
|
||||
} else {
|
||||
Err(Box::new(LauncherError::ArgumentValidation(
|
||||
"no max defined".to_string(),
|
||||
)))
|
||||
}
|
||||
};
|
||||
let (max_position_embeddings, quantize): (usize, Option<Quantization>) =
|
||||
get_max_positions_quantize().unwrap_or((4096, None));
|
||||
|
||||
let max_input_tokens = {
|
||||
match (args.max_input_tokens, args.max_input_length) {
|
||||
|
@ -1544,18 +1557,26 @@ fn main() -> Result<(), LauncherError> {
|
|||
)));
|
||||
}
|
||||
|
||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
||||
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
||||
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
||||
}
|
||||
let quantize = args.quantize.or(quantize);
|
||||
let cuda_graphs = match (&args.cuda_graphs, &quantize) {
|
||||
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
||||
#[allow(deprecated)]
|
||||
(
|
||||
None,
|
||||
Some(
|
||||
Quantization::Bitsandbytes
|
||||
| Quantization::BitsandbytesNF4
|
||||
| Quantization::BitsandbytesFP4,
|
||||
| Quantization::BitsandbytesNf4
|
||||
| Quantization::BitsandbytesFp4,
|
||||
),
|
||||
) => {
|
||||
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
||||
tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
||||
vec![]
|
||||
}
|
||||
(None, Some(Quantization::Exl2)) => {
|
||||
tracing::warn!("Exl2 doesn't work with cuda graphs, deactivating them");
|
||||
vec![]
|
||||
}
|
||||
_ => {
|
||||
|
@ -1672,6 +1693,7 @@ fn main() -> Result<(), LauncherError> {
|
|||
cuda_graphs,
|
||||
max_total_tokens,
|
||||
max_input_tokens,
|
||||
quantize,
|
||||
max_log_level,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
|
|
|
@ -1070,6 +1070,30 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
|
|||
[package.extras]
|
||||
dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "3.0.0"
|
||||
description = "Python port of markdown-it. Markdown parsing, done right!"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
|
||||
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
mdurl = ">=0.1,<1.0"
|
||||
|
||||
[package.extras]
|
||||
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
|
||||
code-style = ["pre-commit (>=3.0,<4.0)"]
|
||||
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
|
||||
linkify = ["linkify-it-py (>=1,<3)"]
|
||||
plugins = ["mdit-py-plugins"]
|
||||
profiling = ["gprof2dot"]
|
||||
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
|
||||
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
||||
|
||||
[[package]]
|
||||
name = "markupsafe"
|
||||
version = "2.1.5"
|
||||
|
@ -1207,6 +1231,17 @@ torch = "*"
|
|||
type = "url"
|
||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
description = "Markdown URL utilities"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
|
||||
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mpmath"
|
||||
version = "1.3.0"
|
||||
|
@ -2277,6 +2312,20 @@ files = [
|
|||
[package.dependencies]
|
||||
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.18.0"
|
||||
description = "Pygments is a syntax highlighting package written in Python."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"},
|
||||
{file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
windows-terminal = ["colorama (>=0.4.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "7.4.4"
|
||||
|
@ -2508,6 +2557,24 @@ urllib3 = ">=1.21.1,<3"
|
|||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "13.7.1"
|
||||
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
{file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"},
|
||||
{file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
markdown-it-py = ">=2.2.0"
|
||||
pygments = ">=2.13.0,<3.0.0"
|
||||
|
||||
[package.extras]
|
||||
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
||||
|
||||
[[package]]
|
||||
name = "rpds-py"
|
||||
version = "0.19.0"
|
||||
|
@ -3584,4 +3651,4 @@ torch = ["torch"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.13"
|
||||
content-hash = "a89867b23017d2efa8a7aa14d4764bcbd3b4dea9bfbf06a7a68464cb184ac6a1"
|
||||
content-hash = "0ff7a244a409b616490cb238995bbe28dedf67ccb8855edafa2b71ee2e777dbd"
|
||||
|
|
|
@ -46,6 +46,7 @@ marlin-kernels = [
|
|||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
]
|
||||
rich = "^13.7.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
torch = ["torch"]
|
||||
|
|
|
@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
|||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
|
|
@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
|||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
|
|
@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
|||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
|
|
@ -652,6 +652,7 @@ class CausalLM(Model):
|
|||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.quantize = quantize
|
||||
return self
|
||||
|
||||
@property
|
||||
|
|
|
@ -412,6 +412,7 @@ class Mamba(Model):
|
|||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.quantize = quantize
|
||||
self.process_group, _rank, world_size = initialize_torch_distributed()
|
||||
if world_size > 1:
|
||||
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")
|
||||
|
|
|
@ -676,6 +676,7 @@ class Seq2SeqLM(Model):
|
|||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.quantize = quantize
|
||||
return self
|
||||
|
||||
@property
|
||||
|
|
Loading…
Reference in New Issue