chore: add pre-commit (#1569)
This commit is contained in:
parent
142cdabed3
commit
9946165ee0
|
@ -71,12 +71,11 @@ jobs:
|
||||||
pip install pytest
|
pip install pytest
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||||
pytest -s -vv server/tests
|
pytest -s -vv server/tests
|
||||||
- name: Run Rust fmt
|
- name: Pre-commit checks
|
||||||
run: |
|
run: |
|
||||||
cargo fmt --check
|
pip install pre-commit
|
||||||
- name: Run Rust clippy
|
pre-commit install
|
||||||
run: |
|
pre-commit run --all-files
|
||||||
cargo clippy
|
|
||||||
- name: Run Rust tests
|
- name: Run Rust tests
|
||||||
run: |
|
run: |
|
||||||
cargo test
|
cargo test
|
||||||
|
|
|
@ -11,4 +11,3 @@ server/exllama_kernels/exllama_kernels/hip_func/
|
||||||
*_hip.cuh
|
*_hip.cuh
|
||||||
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
|
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
|
||||||
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.5.0
|
||||||
|
hooks:
|
||||||
|
- id: check-yaml
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
exclude: docs/source/basic_tutorials/launcher.md
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 24.2.0
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
- repo: https://github.com/doublify/pre-commit-rust
|
||||||
|
rev: v1.0
|
||||||
|
hooks:
|
||||||
|
- id: fmt
|
||||||
|
- id: cargo-check
|
||||||
|
- id: clippy
|
|
@ -29,4 +29,3 @@ tui = {package = "ratatui", version = "0.23", default-features = false, features
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
hf-hub = "0.3.1"
|
hf-hub = "0.3.1"
|
||||||
|
|
||||||
|
|
|
@ -134,6 +134,7 @@ class Parameters(BaseModel):
|
||||||
raise ValidationError("`value` cannot be empty for `json` grammar")
|
raise ValidationError("`value` cannot be empty for `json` grammar")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
class Request(BaseModel):
|
class Request(BaseModel):
|
||||||
# Prompt
|
# Prompt
|
||||||
inputs: str
|
inputs: str
|
||||||
|
|
|
@ -9,4 +9,3 @@ Standard attention mechanism uses High Bandwidth Memory (HBM) to store, read and
|
||||||
It is implemented for supported models. You can check out the complete list of models that support Flash Attention [here](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models), for models with flash prefix.
|
It is implemented for supported models. You can check out the complete list of models that support Flash Attention [here](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models), for models with flash prefix.
|
||||||
|
|
||||||
You can learn more about Flash Attention by reading the paper in this [link](https://arxiv.org/abs/2205.14135).
|
You can learn more about Flash Attention by reading the paper in this [link](https://arxiv.org/abs/2205.14135).
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,9 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_mamba_load(fused_kernel_mamba, generate_load, generous_response_snapshot):
|
async def test_mamba_load(
|
||||||
|
fused_kernel_mamba, generate_load, generous_response_snapshot
|
||||||
|
):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4
|
fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||||
)
|
)
|
||||||
|
|
|
@ -27,6 +27,7 @@ pub struct Validation {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Validation {
|
impl Validation {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
workers: usize,
|
workers: usize,
|
||||||
tokenizer: Option<Tokenizer>,
|
tokenizer: Option<Tokenizer>,
|
||||||
|
|
|
@ -40,5 +40,3 @@ __forceinline__ __device__ void dequant_6bit_16
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -251,9 +251,9 @@ class LlamaMLP(nn.Module):
|
||||||
if "gelu" not in act
|
if "gelu" not in act
|
||||||
else lambda x: torch.nn.functional.gelu(
|
else lambda x: torch.nn.functional.gelu(
|
||||||
x,
|
x,
|
||||||
approximate="tanh"
|
approximate=(
|
||||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
else "none",
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
|
|
|
@ -255,9 +255,9 @@ class MistralMLP(nn.Module):
|
||||||
if "gelu" not in act
|
if "gelu" not in act
|
||||||
else lambda x: torch.nn.functional.gelu(
|
else lambda x: torch.nn.functional.gelu(
|
||||||
x,
|
x,
|
||||||
approximate="tanh"
|
approximate=(
|
||||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
else "none",
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
|
|
|
@ -344,9 +344,9 @@ class BlockSparseMoE(nn.Module):
|
||||||
if "gelu" in act:
|
if "gelu" in act:
|
||||||
self.act = lambda x: torch.nn.functional.gelu(
|
self.act = lambda x: torch.nn.functional.gelu(
|
||||||
x,
|
x,
|
||||||
approximate="tanh"
|
approximate=(
|
||||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
else "none",
|
),
|
||||||
)
|
)
|
||||||
elif "silu" in act:
|
elif "silu" in act:
|
||||||
self.act = torch.nn.functional.silu
|
self.act = torch.nn.functional.silu
|
||||||
|
@ -600,9 +600,9 @@ class DenseMoE(nn.Module):
|
||||||
if "gelu" in act:
|
if "gelu" in act:
|
||||||
self.act = lambda x: torch.nn.functional.gelu(
|
self.act = lambda x: torch.nn.functional.gelu(
|
||||||
x,
|
x,
|
||||||
approximate="tanh"
|
approximate=(
|
||||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
else "none",
|
),
|
||||||
)
|
)
|
||||||
elif "silu" in act:
|
elif "silu" in act:
|
||||||
self.act = torch.nn.functional.silu
|
self.act = torch.nn.functional.silu
|
||||||
|
|
|
@ -187,9 +187,9 @@ class FlashMLP(nn.Module):
|
||||||
if "gelu" not in act
|
if "gelu" not in act
|
||||||
else lambda x: torch.nn.functional.gelu(
|
else lambda x: torch.nn.functional.gelu(
|
||||||
x,
|
x,
|
||||||
approximate="tanh"
|
approximate=(
|
||||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
else "none",
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -225,9 +225,9 @@ class PhiMLP(nn.Module):
|
||||||
if "gelu" not in act
|
if "gelu" not in act
|
||||||
else lambda x: torch.nn.functional.gelu(
|
else lambda x: torch.nn.functional.gelu(
|
||||||
x,
|
x,
|
||||||
approximate="tanh"
|
approximate=(
|
||||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
else "none",
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,12 @@ def _load_multi_mqa_gptq(
|
||||||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
qzeros = qzeros.to(device=weights.device)
|
qzeros = qzeros.to(device=weights.device)
|
||||||
|
|
||||||
bits, groupsize, _, quant_method, = weights._get_gptq_params()
|
(
|
||||||
|
bits,
|
||||||
|
groupsize,
|
||||||
|
_,
|
||||||
|
quant_method,
|
||||||
|
) = weights._get_gptq_params()
|
||||||
if quant_method == "gptq":
|
if quant_method == "gptq":
|
||||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||||
g_idx = g_idx.to(device=weights.device)
|
g_idx = g_idx.to(device=weights.device)
|
||||||
|
@ -306,9 +311,9 @@ class MLP(nn.Module):
|
||||||
if "gelu" not in act
|
if "gelu" not in act
|
||||||
else lambda x: torch.nn.functional.gelu(
|
else lambda x: torch.nn.functional.gelu(
|
||||||
x,
|
x,
|
||||||
approximate="tanh"
|
approximate=(
|
||||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
else "none",
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -66,6 +66,7 @@ class IdeficsVisionConfig(PretrainedConfig):
|
||||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type = "idefics"
|
model_type = "idefics"
|
||||||
attribute_map = {
|
attribute_map = {
|
||||||
"hidden_size": "embed_dim",
|
"hidden_size": "embed_dim",
|
||||||
|
@ -125,6 +126,7 @@ class IdeficsPerceiverConfig(PretrainedConfig):
|
||||||
qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):
|
qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to use qk layer norms in perceiver
|
Whether or not to use qk layer norms in perceiver
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type = "idefics"
|
model_type = "idefics"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -219,6 +221,7 @@ class IdeficsConfig(PretrainedConfig):
|
||||||
>>> # Accessing the model configuration
|
>>> # Accessing the model configuration
|
||||||
>>> configuration = model.config
|
>>> configuration = model.config
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
model_type = "idefics"
|
model_type = "idefics"
|
||||||
is_composition = True
|
is_composition = True
|
||||||
|
|
||||||
|
|
|
@ -123,10 +123,10 @@ def expand_inputs_for_generation(
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
|
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
|
||||||
)
|
)
|
||||||
encoder_outputs[
|
encoder_outputs["last_hidden_state"] = (
|
||||||
"last_hidden_state"
|
encoder_outputs.last_hidden_state.index_select(
|
||||||
] = encoder_outputs.last_hidden_state.index_select(
|
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
|
||||||
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
|
)
|
||||||
)
|
)
|
||||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||||
return input_ids, model_kwargs
|
return input_ids, model_kwargs
|
||||||
|
|
|
@ -133,6 +133,7 @@ class IdeficsProcessor(ProcessorMixin):
|
||||||
An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
|
An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
|
||||||
image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)
|
image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
attributes = ["image_processor", "tokenizer"]
|
attributes = ["image_processor", "tokenizer"]
|
||||||
image_processor_class = "IdeficsImageProcessor"
|
image_processor_class = "IdeficsImageProcessor"
|
||||||
tokenizer_class = "LlamaTokenizerFast"
|
tokenizer_class = "LlamaTokenizerFast"
|
||||||
|
|
|
@ -19,10 +19,12 @@ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InferenceParams:
|
class InferenceParams:
|
||||||
"""Inference parameters that are passed to the main model in order
|
"""Inference parameters that are passed to the main model in order
|
||||||
to efficienly calculate and store the context during inference."""
|
to efficienly calculate and store the context during inference."""
|
||||||
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
max_batch_size: int
|
max_batch_size: int
|
||||||
conv_states: torch.Tensor
|
conv_states: torch.Tensor
|
||||||
|
@ -137,13 +139,28 @@ class MambaBlock(nn.Module):
|
||||||
def step(self, hidden_states, conv_state, ssm_state):
|
def step(self, hidden_states, conv_state, ssm_state):
|
||||||
xz = self.in_proj(hidden_states.squeeze(1))
|
xz = self.in_proj(hidden_states.squeeze(1))
|
||||||
x, z = xz.chunk(2, dim=-1) # (B D)
|
x, z = xz.chunk(2, dim=-1) # (B D)
|
||||||
x = causal_conv1d_update(x, conv_state, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation)
|
x = causal_conv1d_update(
|
||||||
|
x,
|
||||||
|
conv_state,
|
||||||
|
self.conv1d.weight.squeeze(1),
|
||||||
|
self.conv1d.bias,
|
||||||
|
self.activation,
|
||||||
|
)
|
||||||
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
||||||
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
dt = F.linear(dt, self.dt_proj.weight)
|
dt = F.linear(dt, self.dt_proj.weight)
|
||||||
A = self.negA
|
A = self.negA
|
||||||
y = selective_state_update(
|
y = selective_state_update(
|
||||||
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
ssm_state,
|
||||||
|
x,
|
||||||
|
dt,
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
self.D,
|
||||||
|
z=z,
|
||||||
|
dt_bias=self.dt_proj.bias,
|
||||||
|
dt_softplus=True,
|
||||||
)
|
)
|
||||||
out = self.out_proj(y)
|
out = self.out_proj(y)
|
||||||
return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()
|
return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
|
|
@ -35,7 +35,6 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashCausalLMBatch(Batch):
|
class FlashCausalLMBatch(Batch):
|
||||||
batch_id: int
|
batch_id: int
|
||||||
|
@ -1213,8 +1212,9 @@ class FlashCausalLM(Model):
|
||||||
# accept each new token for this specific request since we may
|
# accept each new token for this specific request since we may
|
||||||
# have more than one new token per request with speculative decoding
|
# have more than one new token per request with speculative decoding
|
||||||
for next_token_id in _next_token_ids:
|
for next_token_id in _next_token_ids:
|
||||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id)
|
batch.next_token_chooser = (
|
||||||
|
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
|
||||||
|
)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_lengths[i] = input_length + n_accepted_ids
|
batch.input_lengths[i] = input_length + n_accepted_ids
|
||||||
|
|
|
@ -92,7 +92,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
next_token_choosers.append(
|
||||||
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
|
)
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
|
|
@ -114,7 +114,9 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
next_token_choosers.append(
|
||||||
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
|
)
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -401,9 +403,9 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
pixel_values = batch.pixel_values.new_zeros(
|
pixel_values = batch.pixel_values.new_zeros(
|
||||||
(total_batch_size, max_num_images, 3, 224, 224)
|
(total_batch_size, max_num_images, 3, 224, 224)
|
||||||
)
|
)
|
||||||
pixel_values[
|
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
|
||||||
start_index:end_index, :curr_batch_max_num_images
|
batch.pixel_values
|
||||||
] = batch.pixel_values
|
)
|
||||||
|
|
||||||
if image_attention_mask is None:
|
if image_attention_mask is None:
|
||||||
image_attention_mask = batch.image_attention_mask.new_zeros(
|
image_attention_mask = batch.image_attention_mask.new_zeros(
|
||||||
|
@ -500,14 +502,14 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
# We slice the keys to remove the padding from previous batches
|
# We slice the keys to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_input_length - 1
|
past_seq_len = batch.max_input_length - 1
|
||||||
if batch.keys_head_dim_last:
|
if batch.keys_head_dim_last:
|
||||||
padded_past_keys[
|
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
|
||||||
start_index:end_index, :, -past_seq_len:, :
|
past_keys[:, :, -past_seq_len:, :]
|
||||||
] = past_keys[:, :, -past_seq_len:, :]
|
)
|
||||||
else:
|
else:
|
||||||
# BLOOM case
|
# BLOOM case
|
||||||
padded_past_keys[
|
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
|
||||||
start_index:end_index, :, :, -past_seq_len:
|
past_keys[:, :, :, -past_seq_len:]
|
||||||
] = past_keys[:, :, :, -past_seq_len:]
|
)
|
||||||
del past_keys
|
del past_keys
|
||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
|
@ -525,9 +527,9 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
end_index = start_index + len(batch)
|
end_index = start_index + len(batch)
|
||||||
# We slice the past values to remove the padding from previous batches
|
# We slice the past values to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_input_length - 1
|
past_seq_len = batch.max_input_length - 1
|
||||||
padded_past_values[
|
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
||||||
start_index:end_index, :, -past_seq_len:, :
|
past_values[:, :, -past_seq_len:, :]
|
||||||
] = past_values[:, :, -past_seq_len:, :]
|
)
|
||||||
del past_values
|
del past_values
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
@ -603,9 +605,11 @@ class IdeficsCausalLM(Model):
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto"
|
device_map=(
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
"auto"
|
||||||
else None,
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
else None
|
||||||
|
),
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -836,9 +840,9 @@ class IdeficsCausalLM(Model):
|
||||||
|
|
||||||
# Update attention_mask as we added a new token to input_ids
|
# Update attention_mask as we added a new token to input_ids
|
||||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||||
batch.image_attention_mask[
|
batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
|
||||||
:, -batch.padding_right_offset, :
|
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
|
||||||
] = batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
|
)
|
||||||
# Decrease right offset
|
# Decrease right offset
|
||||||
batch.padding_right_offset -= 1
|
batch.padding_right_offset -= 1
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,10 @@ from text_generation_server.utils import (
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
|
from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
|
||||||
import time
|
import time
|
||||||
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams
|
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||||
|
MambaModel,
|
||||||
|
InferenceParams,
|
||||||
|
)
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from typing import Any, List, Optional, Tuple, Type, Dict
|
from typing import Any, List, Optional, Tuple, Type, Dict
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
|
@ -28,21 +31,35 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: int, d_state: int, seqlen_offset: int, dtype: torch.dtype, device: torch.device):
|
|
||||||
|
def new_inference_params(
|
||||||
|
n_blocks: int,
|
||||||
|
batch_size: int,
|
||||||
|
d_inner: int,
|
||||||
|
d_conv: int,
|
||||||
|
d_state: int,
|
||||||
|
seqlen_offset: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
conv_states = torch.zeros(
|
conv_states = torch.zeros(
|
||||||
(n_blocks,
|
(
|
||||||
batch_size,
|
n_blocks,
|
||||||
d_inner,
|
batch_size,
|
||||||
d_conv,),
|
d_inner,
|
||||||
|
d_conv,
|
||||||
|
),
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
ssm_states = torch.zeros(
|
ssm_states = torch.zeros(
|
||||||
(n_blocks,
|
(
|
||||||
batch_size,
|
n_blocks,
|
||||||
d_inner,
|
batch_size,
|
||||||
d_state,),
|
d_inner,
|
||||||
|
d_state,
|
||||||
|
),
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -52,7 +69,6 @@ def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: i
|
||||||
seqlen_offset=seqlen_offset,
|
seqlen_offset=seqlen_offset,
|
||||||
conv_states=conv_states,
|
conv_states=conv_states,
|
||||||
ssm_states=ssm_states,
|
ssm_states=ssm_states,
|
||||||
|
|
||||||
)
|
)
|
||||||
return inference_params
|
return inference_params
|
||||||
|
|
||||||
|
@ -124,7 +140,9 @@ class MambaBatch(Batch):
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
next_token_choosers.append(
|
||||||
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
|
)
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -251,7 +269,9 @@ class MambaBatch(Batch):
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
|
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
|
||||||
self.inference_params.conv_states = self.inference_params.conv_states[:, indices]
|
self.inference_params.conv_states = self.inference_params.conv_states[
|
||||||
|
:, indices
|
||||||
|
]
|
||||||
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
|
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -280,13 +300,20 @@ class MambaBatch(Batch):
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
seqlen_offset = 0
|
seqlen_offset = 0
|
||||||
|
|
||||||
(n_blocks, _, d_inner, d_conv) = (
|
(n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
|
||||||
batches[0].inference_params.conv_states.shape
|
|
||||||
)
|
|
||||||
(_, _, _, d_state) = batches[0].inference_params.ssm_states.shape
|
(_, _, _, d_state) = batches[0].inference_params.ssm_states.shape
|
||||||
dtype = batches[0].inference_params.conv_states.dtype
|
dtype = batches[0].inference_params.conv_states.dtype
|
||||||
device = batches[0].inference_params.conv_states.device
|
device = batches[0].inference_params.conv_states.device
|
||||||
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=total_batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=device, dtype=dtype)
|
inference_params = new_inference_params(
|
||||||
|
n_blocks=n_blocks,
|
||||||
|
batch_size=total_batch_size,
|
||||||
|
d_state=d_state,
|
||||||
|
d_conv=d_conv,
|
||||||
|
d_inner=d_inner,
|
||||||
|
seqlen_offset=seqlen_offset,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
# Batch tensors
|
# Batch tensors
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
@ -334,13 +361,20 @@ class MambaBatch(Batch):
|
||||||
max_input_length - batch.max_input_length
|
max_input_length - batch.max_input_length
|
||||||
) * len(batch)
|
) * len(batch)
|
||||||
|
|
||||||
inference_params.max_seqlen = max(inference_params.max_seqlen, batch.inference_params.max_seqlen)
|
inference_params.max_seqlen = max(
|
||||||
|
inference_params.max_seqlen, batch.inference_params.max_seqlen
|
||||||
|
)
|
||||||
assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset"
|
assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset"
|
||||||
inference_params.seqlen_offset = max(inference_params.seqlen_offset, batch.inference_params.seqlen_offset)
|
inference_params.seqlen_offset = max(
|
||||||
|
inference_params.seqlen_offset, batch.inference_params.seqlen_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_params.conv_states[:, start_index:end_index] = (
|
||||||
inference_params.conv_states[:, start_index:end_index] = batch.inference_params.conv_states
|
batch.inference_params.conv_states
|
||||||
inference_params.ssm_states[:, start_index:end_index] = batch.inference_params.ssm_states
|
)
|
||||||
|
inference_params.ssm_states[:, start_index:end_index] = (
|
||||||
|
batch.inference_params.ssm_states
|
||||||
|
)
|
||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
|
|
||||||
|
@ -452,36 +486,39 @@ class Mamba(Model):
|
||||||
|
|
||||||
# Important seqlen_offset to go through the update mecanism with the state
|
# Important seqlen_offset to go through the update mecanism with the state
|
||||||
seqlen_offset = 1
|
seqlen_offset = 1
|
||||||
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype)
|
inference_params = new_inference_params(
|
||||||
|
n_blocks=n_blocks,
|
||||||
|
batch_size=batch_size,
|
||||||
|
d_state=d_state,
|
||||||
|
d_conv=d_conv,
|
||||||
|
d_inner=d_inner,
|
||||||
|
seqlen_offset=seqlen_offset,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
# Run once outside to warmup
|
# Run once outside to warmup
|
||||||
self.model.forward(
|
self.model.forward(input_ids=input_ids, inference_params=inference_params)
|
||||||
input_ids=input_ids,
|
|
||||||
inference_params=inference_params
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids, inference_params=inference_params
|
||||||
inference_params=inference_params
|
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
graph_dict = {
|
graph_dict = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"inference_params": inference_params,
|
"inference_params": inference_params,
|
||||||
"graph": graph,
|
"graph": graph,
|
||||||
"logits": logits
|
"logits": logits,
|
||||||
}
|
}
|
||||||
self.cuda_graphs[batch_size] = graph_dict
|
self.cuda_graphs[batch_size] = graph_dict
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, input_ids: torch.Tensor, inference_params: Any
|
||||||
input_ids: torch.Tensor,
|
|
||||||
inference_params: Any
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
padded_bs = bs
|
padded_bs = bs
|
||||||
|
@ -504,15 +541,21 @@ class Mamba(Model):
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
cuda_graph["input_ids"][: bs] = input_ids
|
cuda_graph["input_ids"][:bs] = input_ids
|
||||||
cuda_graph["inference_params"].conv_states[:, : bs] = inference_params.conv_states
|
cuda_graph["inference_params"].conv_states[
|
||||||
cuda_graph["inference_params"].ssm_states[:, : bs] = inference_params.ssm_states
|
:, :bs
|
||||||
|
] = inference_params.conv_states
|
||||||
|
cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states
|
||||||
|
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
inference_params.conv_states.copy_(cuda_graph["inference_params"].conv_states[:, :bs])
|
inference_params.conv_states.copy_(
|
||||||
inference_params.ssm_states.copy_(cuda_graph["inference_params"].ssm_states[:, :bs])
|
cuda_graph["inference_params"].conv_states[:, :bs]
|
||||||
|
)
|
||||||
|
inference_params.ssm_states.copy_(
|
||||||
|
cuda_graph["inference_params"].ssm_states[:, :bs]
|
||||||
|
)
|
||||||
|
|
||||||
# Slice output to the correct shape
|
# Slice output to the correct shape
|
||||||
return cuda_graph["logits"][:bs]
|
return cuda_graph["logits"][:bs]
|
||||||
|
@ -533,14 +576,20 @@ class Mamba(Model):
|
||||||
d_state = self.model.config.d_state
|
d_state = self.model.config.d_state
|
||||||
d_conv = self.model.config.d_conv
|
d_conv = self.model.config.d_conv
|
||||||
d_inner = self.model.config.d_inner
|
d_inner = self.model.config.d_inner
|
||||||
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype)
|
inference_params = new_inference_params(
|
||||||
|
n_blocks=n_blocks,
|
||||||
|
batch_size=batch_size,
|
||||||
|
d_state=d_state,
|
||||||
|
d_conv=d_conv,
|
||||||
|
d_inner=d_inner,
|
||||||
|
seqlen_offset=seqlen_offset,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
batch.inference_params = inference_params
|
batch.inference_params = inference_params
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
logits = self.forward(
|
logits = self.forward(input_ids, inference_params=batch.inference_params)
|
||||||
input_ids, inference_params=batch.inference_params
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# batch.inference_params = new_inference_params
|
# batch.inference_params = new_inference_params
|
||||||
# Results
|
# Results
|
||||||
|
@ -694,9 +743,9 @@ class Mamba(Model):
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
batch.next_token_choosers[i] = batch.next_token_choosers[
|
||||||
next_token_id_squeezed.item()
|
i
|
||||||
)
|
].advance_grammar(next_token_id_squeezed.item())
|
||||||
batch.input_ids[i, 0] = next_token_id
|
batch.input_ids[i, 0] = next_token_id
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
|
|
|
@ -36,9 +36,11 @@ class RW(CausalLM):
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto"
|
device_map=(
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
"auto"
|
||||||
else None,
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
else None
|
||||||
|
),
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
|
@ -96,7 +96,9 @@ class Seq2SeqLMBatch(Batch):
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
next_token_choosers.append(
|
||||||
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
|
)
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -351,9 +353,9 @@ class Seq2SeqLMBatch(Batch):
|
||||||
(total_batch_size, max_input_length),
|
(total_batch_size, max_input_length),
|
||||||
)
|
)
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
attention_mask[
|
attention_mask[start_index:end_index, -batch.max_input_length :] = (
|
||||||
start_index:end_index, -batch.max_input_length :
|
batch.attention_mask[:, -batch.max_input_length :]
|
||||||
] = batch.attention_mask[:, -batch.max_input_length :]
|
)
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
|
@ -547,9 +549,11 @@ class Seq2SeqLM(Model):
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto"
|
device_map=(
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
"auto"
|
||||||
else None,
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
else None
|
||||||
|
),
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -750,7 +754,7 @@ class Seq2SeqLM(Model):
|
||||||
|
|
||||||
if top_n_tokens > 0:
|
if top_n_tokens > 0:
|
||||||
all_top_tokens = []
|
all_top_tokens = []
|
||||||
for (top_token_ids, top_token_logprobs) in zip(
|
for top_token_ids, top_token_logprobs in zip(
|
||||||
top_token_ids, top_token_logprobs
|
top_token_ids, top_token_logprobs
|
||||||
):
|
):
|
||||||
toptoken_texts = self.tokenizer.batch_decode(
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
|
|
|
@ -88,14 +88,16 @@ class Generation:
|
||||||
def to_pb(self) -> generate_pb2.Generation:
|
def to_pb(self) -> generate_pb2.Generation:
|
||||||
return generate_pb2.Generation(
|
return generate_pb2.Generation(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
prefill_tokens=self.prefill_tokens.to_pb()
|
prefill_tokens=(
|
||||||
if self.prefill_tokens is not None
|
self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None
|
||||||
else None,
|
),
|
||||||
tokens=self.tokens.to_pb(),
|
tokens=self.tokens.to_pb(),
|
||||||
generated_text=self.generated_text.to_pb()
|
generated_text=(
|
||||||
if self.generated_text is not None
|
self.generated_text.to_pb() if self.generated_text is not None else None
|
||||||
else None,
|
),
|
||||||
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens]
|
top_tokens=(
|
||||||
if self.top_tokens is not None
|
[top_tokens.to_pb() for top_tokens in self.top_tokens]
|
||||||
else None,
|
if self.top_tokens is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -182,7 +182,7 @@ try:
|
||||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||||
|
|
||||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||||
zeros = (zeros + 1) & maxq # eventually avoid overflow
|
zeros = (zeros + 1) & maxq # eventually avoid overflow
|
||||||
|
|
||||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||||
|
|
|
@ -355,7 +355,9 @@ def get_linear(weight, bias, quantize):
|
||||||
"to use Exllama/GPTQ kernels for AWQ inference."
|
"to use Exllama/GPTQ kernels for AWQ inference."
|
||||||
)
|
)
|
||||||
if not HAS_AWQ:
|
if not HAS_AWQ:
|
||||||
raise NotImplementedError("You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly")
|
raise NotImplementedError(
|
||||||
|
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
||||||
|
)
|
||||||
linear = WQLinear(
|
linear = WQLinear(
|
||||||
w_bit=bits,
|
w_bit=bits,
|
||||||
group_size=groupsize,
|
group_size=groupsize,
|
||||||
|
|
|
@ -516,7 +516,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||||
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
||||||
schema = build_regex_from_object(schema)
|
schema = build_regex_from_object(schema)
|
||||||
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
||||||
pass # schema is already a regex just here for clarity
|
pass # schema is already a regex just here for clarity
|
||||||
fsm = RegexFSM(schema, tokenizer)
|
fsm = RegexFSM(schema, tokenizer)
|
||||||
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||||
return fsm
|
return fsm
|
||||||
|
|
|
@ -409,8 +409,12 @@ class HeterogeneousNextTokenChooser:
|
||||||
|
|
||||||
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
|
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index(
|
self.fsm_grammar_states[grammar_state_index] = (
|
||||||
next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index
|
self.grammar_processor.advance_at_index(
|
||||||
|
next_id,
|
||||||
|
self.fsm_grammar_states[grammar_state_index],
|
||||||
|
grammar_state_index,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue