Fixing exl2 and other quanize tests again. (#2419)
* Fixing exl2 and other quanize tests again. * Mark exl2 as non release (so CI tests them, needs to be removed latet). * Fixing exl2 (by disabling cuda graphs) * Fix quantization defaults without cuda graphs on exl2 (linked to new issues with it). * Removing serde override. * Go back to released exl2 and remove log. * Adding warnings for deprecated bitsandbytes + upgrade info to warn.
This commit is contained in:
parent
9aaa12e7ac
commit
57b3495823
|
@ -250,6 +250,9 @@ RUN cd server && \
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
|
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||||
|
# This is needed because exl2 tries to load flash-attn
|
||||||
|
# And fails with our builds.
|
||||||
|
ENV EXLLAMA_NO_FLASH_ATTN=1
|
||||||
|
|
||||||
# Deps before the binaries
|
# Deps before the binaries
|
||||||
# The binaries change on every build given we burn the SHA into them
|
# The binaries change on every build given we burn the SHA into them
|
||||||
|
|
|
@ -21,7 +21,6 @@ async def flash_llama_exl2(flash_llama_exl2_handle):
|
||||||
return flash_llama_exl2_handle.client
|
return flash_llama_exl2_handle.client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
||||||
|
@ -33,7 +32,6 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
|
||||||
assert response == ignore_logprob_response_snapshot
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_exl2_all_params(
|
async def test_flash_llama_exl2_all_params(
|
||||||
|
@ -60,7 +58,6 @@ async def test_flash_llama_exl2_all_params(
|
||||||
assert response == ignore_logprob_response_snapshot
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_exl2_load(
|
async def test_flash_llama_exl2_load(
|
||||||
|
|
|
@ -30,11 +30,18 @@ struct RawConfig {
|
||||||
n_positions: Option<usize>,
|
n_positions: Option<usize>,
|
||||||
model_type: Option<String>,
|
model_type: Option<String>,
|
||||||
max_seq_len: Option<usize>,
|
max_seq_len: Option<usize>,
|
||||||
|
quantization_config: Option<QuantizationConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct QuantizationConfig {
|
||||||
|
quant_method: Option<Quantization>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
max_position_embeddings: Option<usize>,
|
max_position_embeddings: Option<usize>,
|
||||||
|
quantize: Option<Quantization>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<RawConfig> for Config {
|
impl From<RawConfig> for Config {
|
||||||
|
@ -43,13 +50,16 @@ impl From<RawConfig> for Config {
|
||||||
.max_position_embeddings
|
.max_position_embeddings
|
||||||
.or(other.max_seq_len)
|
.or(other.max_seq_len)
|
||||||
.or(other.n_positions);
|
.or(other.n_positions);
|
||||||
|
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
||||||
Config {
|
Config {
|
||||||
max_position_embeddings,
|
max_position_embeddings,
|
||||||
|
quantize,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)]
|
||||||
|
#[serde(rename_all = "kebab-case")]
|
||||||
enum Quantization {
|
enum Quantization {
|
||||||
/// 4 bit quantization. Requires a specific AWQ quantized model:
|
/// 4 bit quantization. Requires a specific AWQ quantized model:
|
||||||
/// <https://hf.co/models?search=awq>.
|
/// <https://hf.co/models?search=awq>.
|
||||||
|
@ -72,17 +82,17 @@ enum Quantization {
|
||||||
Marlin,
|
Marlin,
|
||||||
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
|
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
|
||||||
/// but it is known that the model will be much slower to run than the native f16.
|
/// but it is known that the model will be much slower to run than the native f16.
|
||||||
#[deprecated(
|
// #[deprecated(
|
||||||
since = "1.1.0",
|
// since = "1.1.0",
|
||||||
note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
|
// note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
|
||||||
)]
|
// )]
|
||||||
Bitsandbytes,
|
Bitsandbytes,
|
||||||
/// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
|
/// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
|
||||||
/// but it is known that the model will be much slower to run than the native f16.
|
/// but it is known that the model will be much slower to run than the native f16.
|
||||||
BitsandbytesNF4,
|
BitsandbytesNf4,
|
||||||
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
|
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
|
||||||
/// perplexity performance for you model
|
/// perplexity performance for you model
|
||||||
BitsandbytesFP4,
|
BitsandbytesFp4,
|
||||||
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
|
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
|
||||||
/// This dtype has native ops should be the fastest if available.
|
/// This dtype has native ops should be the fastest if available.
|
||||||
/// This is currently not the fastest because of local unpacking + padding to satisfy matrix
|
/// This is currently not the fastest because of local unpacking + padding to satisfy matrix
|
||||||
|
@ -99,10 +109,10 @@ impl std::fmt::Display for Quantization {
|
||||||
Quantization::Bitsandbytes => {
|
Quantization::Bitsandbytes => {
|
||||||
write!(f, "bitsandbytes")
|
write!(f, "bitsandbytes")
|
||||||
}
|
}
|
||||||
Quantization::BitsandbytesNF4 => {
|
Quantization::BitsandbytesNf4 => {
|
||||||
write!(f, "bitsandbytes-nf4")
|
write!(f, "bitsandbytes-nf4")
|
||||||
}
|
}
|
||||||
Quantization::BitsandbytesFP4 => {
|
Quantization::BitsandbytesFp4 => {
|
||||||
write!(f, "bitsandbytes-fp4")
|
write!(f, "bitsandbytes-fp4")
|
||||||
}
|
}
|
||||||
Quantization::Exl2 => {
|
Quantization::Exl2 => {
|
||||||
|
@ -1085,6 +1095,7 @@ fn spawn_shards(
|
||||||
cuda_graphs: Vec<usize>,
|
cuda_graphs: Vec<usize>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
max_input_tokens: usize,
|
max_input_tokens: usize,
|
||||||
|
quantize: Option<Quantization>,
|
||||||
max_log_level: LevelFilter,
|
max_log_level: LevelFilter,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
|
@ -1106,7 +1117,6 @@ fn spawn_shards(
|
||||||
let shutdown_sender = shutdown_sender.clone();
|
let shutdown_sender = shutdown_sender.clone();
|
||||||
let otlp_endpoint = args.otlp_endpoint.clone();
|
let otlp_endpoint = args.otlp_endpoint.clone();
|
||||||
let otlp_service_name = args.otlp_service_name.clone();
|
let otlp_service_name = args.otlp_service_name.clone();
|
||||||
let quantize = args.quantize;
|
|
||||||
let speculate = args.speculate;
|
let speculate = args.speculate;
|
||||||
let dtype = args.dtype;
|
let dtype = args.dtype;
|
||||||
let trust_remote_code = args.trust_remote_code;
|
let trust_remote_code = args.trust_remote_code;
|
||||||
|
@ -1429,65 +1439,68 @@ fn main() -> Result<(), LauncherError> {
|
||||||
|
|
||||||
tracing::info!("{:#?}", args);
|
tracing::info!("{:#?}", args);
|
||||||
|
|
||||||
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
|
let get_max_positions_quantize =
|
||||||
let model_id = args.model_id.clone();
|
|| -> Result<(usize, Option<Quantization>), Box<dyn std::error::Error>> {
|
||||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
let model_id = args.model_id.clone();
|
||||||
let filename = if !path.exists() {
|
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
||||||
// Assume it's a hub id
|
let filename = if !path.exists() {
|
||||||
|
// Assume it's a hub id
|
||||||
|
|
||||||
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||||
// env variable has precedence over on file token.
|
// env variable has precedence over on file token.
|
||||||
ApiBuilder::new().with_token(Some(token)).build()?
|
ApiBuilder::new().with_token(Some(token)).build()?
|
||||||
|
} else {
|
||||||
|
Api::new()?
|
||||||
|
};
|
||||||
|
let repo = if let Some(ref revision) = args.revision {
|
||||||
|
api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
revision.to_string(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
api.model(model_id)
|
||||||
|
};
|
||||||
|
repo.get("config.json")?
|
||||||
} else {
|
} else {
|
||||||
Api::new()?
|
path.push("config.json");
|
||||||
|
path
|
||||||
};
|
};
|
||||||
let repo = if let Some(ref revision) = args.revision {
|
|
||||||
api.repo(Repo::with_revision(
|
|
||||||
model_id,
|
|
||||||
RepoType::Model,
|
|
||||||
revision.to_string(),
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
api.model(model_id)
|
|
||||||
};
|
|
||||||
repo.get("config.json")?
|
|
||||||
} else {
|
|
||||||
path.push("config.json");
|
|
||||||
path
|
|
||||||
};
|
|
||||||
|
|
||||||
let content = std::fs::read_to_string(filename)?;
|
let content = std::fs::read_to_string(filename)?;
|
||||||
let config: RawConfig = serde_json::from_str(&content)?;
|
let config: RawConfig = serde_json::from_str(&content)?;
|
||||||
|
|
||||||
if config.model_type == Some("gemma2".to_string()) {
|
if config.model_type == Some("gemma2".to_string()) {
|
||||||
tracing::info!("Forcing flash decoding because of softcap usage");
|
tracing::info!("Forcing flash decoding because of softcap usage");
|
||||||
std::env::set_var("ATTENTION", "flashdecoding");
|
std::env::set_var("ATTENTION", "flashdecoding");
|
||||||
}
|
|
||||||
let config: Config = config.into();
|
|
||||||
|
|
||||||
// Quantization usually means you're even more RAM constrained.
|
|
||||||
let max_default = 4096;
|
|
||||||
|
|
||||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
|
||||||
if max_position_embeddings > max_default {
|
|
||||||
let max = max_position_embeddings;
|
|
||||||
if args.max_input_tokens.is_none()
|
|
||||||
&& args.max_total_tokens.is_none()
|
|
||||||
&& args.max_batch_prefill_tokens.is_none()
|
|
||||||
{
|
|
||||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
|
||||||
}
|
|
||||||
Ok(max_default)
|
|
||||||
} else {
|
|
||||||
Ok(max_position_embeddings)
|
|
||||||
}
|
}
|
||||||
} else {
|
let config: Config = config.into();
|
||||||
Err(Box::new(LauncherError::ArgumentValidation(
|
let quantize = config.quantize;
|
||||||
"no max defined".to_string(),
|
|
||||||
)))
|
// Quantization usually means you're even more RAM constrained.
|
||||||
}
|
let max_default = 4096;
|
||||||
};
|
|
||||||
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||||
|
if max_position_embeddings > max_default {
|
||||||
|
let max = max_position_embeddings;
|
||||||
|
if args.max_input_tokens.is_none()
|
||||||
|
&& args.max_total_tokens.is_none()
|
||||||
|
&& args.max_batch_prefill_tokens.is_none()
|
||||||
|
{
|
||||||
|
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||||
|
}
|
||||||
|
Ok((max_default, quantize))
|
||||||
|
} else {
|
||||||
|
Ok((max_position_embeddings, quantize))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(Box::new(LauncherError::ArgumentValidation(
|
||||||
|
"no max defined".to_string(),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let (max_position_embeddings, quantize): (usize, Option<Quantization>) =
|
||||||
|
get_max_positions_quantize().unwrap_or((4096, None));
|
||||||
|
|
||||||
let max_input_tokens = {
|
let max_input_tokens = {
|
||||||
match (args.max_input_tokens, args.max_input_length) {
|
match (args.max_input_tokens, args.max_input_length) {
|
||||||
|
@ -1544,18 +1557,26 @@ fn main() -> Result<(), LauncherError> {
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
||||||
|
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
||||||
|
}
|
||||||
|
let quantize = args.quantize.or(quantize);
|
||||||
|
let cuda_graphs = match (&args.cuda_graphs, &quantize) {
|
||||||
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
(
|
(
|
||||||
None,
|
None,
|
||||||
Some(
|
Some(
|
||||||
Quantization::Bitsandbytes
|
Quantization::Bitsandbytes
|
||||||
| Quantization::BitsandbytesNF4
|
| Quantization::BitsandbytesNf4
|
||||||
| Quantization::BitsandbytesFP4,
|
| Quantization::BitsandbytesFp4,
|
||||||
),
|
),
|
||||||
) => {
|
) => {
|
||||||
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
|
(None, Some(Quantization::Exl2)) => {
|
||||||
|
tracing::warn!("Exl2 doesn't work with cuda graphs, deactivating them");
|
||||||
vec![]
|
vec![]
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
@ -1672,6 +1693,7 @@ fn main() -> Result<(), LauncherError> {
|
||||||
cuda_graphs,
|
cuda_graphs,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
|
quantize,
|
||||||
max_log_level,
|
max_log_level,
|
||||||
shutdown.clone(),
|
shutdown.clone(),
|
||||||
&shutdown_receiver,
|
&shutdown_receiver,
|
||||||
|
|
|
@ -1070,6 +1070,30 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
|
||||||
[package.extras]
|
[package.extras]
|
||||||
dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"]
|
dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "markdown-it-py"
|
||||||
|
version = "3.0.0"
|
||||||
|
description = "Python port of markdown-it. Markdown parsing, done right!"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
|
||||||
|
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
mdurl = ">=0.1,<1.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
|
||||||
|
code-style = ["pre-commit (>=3.0,<4.0)"]
|
||||||
|
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
|
||||||
|
linkify = ["linkify-it-py (>=1,<3)"]
|
||||||
|
plugins = ["mdit-py-plugins"]
|
||||||
|
profiling = ["gprof2dot"]
|
||||||
|
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
|
||||||
|
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "markupsafe"
|
name = "markupsafe"
|
||||||
version = "2.1.5"
|
version = "2.1.5"
|
||||||
|
@ -1207,6 +1231,17 @@ torch = "*"
|
||||||
type = "url"
|
type = "url"
|
||||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mdurl"
|
||||||
|
version = "0.1.2"
|
||||||
|
description = "Markdown URL utilities"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
|
||||||
|
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mpmath"
|
name = "mpmath"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
|
@ -2277,6 +2312,20 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
|
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pygments"
|
||||||
|
version = "2.18.0"
|
||||||
|
description = "Pygments is a syntax highlighting package written in Python."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"},
|
||||||
|
{file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
windows-terminal = ["colorama (>=0.4.6)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest"
|
name = "pytest"
|
||||||
version = "7.4.4"
|
version = "7.4.4"
|
||||||
|
@ -2508,6 +2557,24 @@ urllib3 = ">=1.21.1,<3"
|
||||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rich"
|
||||||
|
version = "13.7.1"
|
||||||
|
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7.0"
|
||||||
|
files = [
|
||||||
|
{file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"},
|
||||||
|
{file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
markdown-it-py = ">=2.2.0"
|
||||||
|
pygments = ">=2.13.0,<3.0.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rpds-py"
|
name = "rpds-py"
|
||||||
version = "0.19.0"
|
version = "0.19.0"
|
||||||
|
@ -3584,4 +3651,4 @@ torch = ["torch"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<3.13"
|
python-versions = ">=3.9,<3.13"
|
||||||
content-hash = "a89867b23017d2efa8a7aa14d4764bcbd3b4dea9bfbf06a7a68464cb184ac6a1"
|
content-hash = "0ff7a244a409b616490cb238995bbe28dedf67ccb8855edafa2b71ee2e777dbd"
|
||||||
|
|
|
@ -46,6 +46,7 @@ marlin-kernels = [
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
]
|
]
|
||||||
|
rich = "^13.7.1"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
torch = ["torch"]
|
torch = ["torch"]
|
||||||
|
|
|
@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -652,6 +652,7 @@ class CausalLM(Model):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
self.quantize = quantize
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -412,6 +412,7 @@ class Mamba(Model):
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
self.quantize = quantize
|
||||||
self.process_group, _rank, world_size = initialize_torch_distributed()
|
self.process_group, _rank, world_size = initialize_torch_distributed()
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")
|
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")
|
||||||
|
|
|
@ -676,6 +676,7 @@ class Seq2SeqLM(Model):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
self.quantize = quantize
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
Loading…
Reference in New Issue