chore: add pre-commit (#1569)
This commit is contained in:
parent
142cdabed3
commit
9946165ee0
|
@ -71,12 +71,11 @@ jobs:
|
|||
pip install pytest
|
||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
pytest -s -vv server/tests
|
||||
- name: Run Rust fmt
|
||||
- name: Pre-commit checks
|
||||
run: |
|
||||
cargo fmt --check
|
||||
- name: Run Rust clippy
|
||||
run: |
|
||||
cargo clippy
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
- name: Run Rust tests
|
||||
run: |
|
||||
cargo test
|
||||
|
|
|
@ -11,4 +11,3 @@ server/exllama_kernels/exllama_kernels/hip_func/
|
|||
*_hip.cuh
|
||||
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
|
||||
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
||||
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
exclude: docs/source/basic_tutorials/launcher.md
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.2.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/doublify/pre-commit-rust
|
||||
rev: v1.0
|
||||
hooks:
|
||||
- id: fmt
|
||||
- id: cargo-check
|
||||
- id: clippy
|
|
@ -29,4 +29,3 @@ tui = {package = "ratatui", version = "0.23", default-features = false, features
|
|||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
hf-hub = "0.3.1"
|
||||
|
||||
|
|
|
@ -134,6 +134,7 @@ class Parameters(BaseModel):
|
|||
raise ValidationError("`value` cannot be empty for `json` grammar")
|
||||
return v
|
||||
|
||||
|
||||
class Request(BaseModel):
|
||||
# Prompt
|
||||
inputs: str
|
||||
|
|
|
@ -9,4 +9,3 @@ Standard attention mechanism uses High Bandwidth Memory (HBM) to store, read and
|
|||
It is implemented for supported models. You can check out the complete list of models that support Flash Attention [here](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models), for models with flash prefix.
|
||||
|
||||
You can learn more about Flash Attention by reading the paper in this [link](https://arxiv.org/abs/2205.14135).
|
||||
|
||||
|
|
|
@ -54,7 +54,9 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_mamba_load(fused_kernel_mamba, generate_load, generous_response_snapshot):
|
||||
async def test_mamba_load(
|
||||
fused_kernel_mamba, generate_load, generous_response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||
)
|
||||
|
|
|
@ -27,6 +27,7 @@ pub struct Validation {
|
|||
}
|
||||
|
||||
impl Validation {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
workers: usize,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
|
|
|
@ -40,5 +40,3 @@ __forceinline__ __device__ void dequant_6bit_16
|
|||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
|
|
|
@ -251,9 +251,9 @@ class LlamaMLP(nn.Module):
|
|||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate="tanh"
|
||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none",
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
|
|
|
@ -255,9 +255,9 @@ class MistralMLP(nn.Module):
|
|||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate="tanh"
|
||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none",
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
|
|
|
@ -344,9 +344,9 @@ class BlockSparseMoE(nn.Module):
|
|||
if "gelu" in act:
|
||||
self.act = lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate="tanh"
|
||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none",
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
elif "silu" in act:
|
||||
self.act = torch.nn.functional.silu
|
||||
|
@ -600,9 +600,9 @@ class DenseMoE(nn.Module):
|
|||
if "gelu" in act:
|
||||
self.act = lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate="tanh"
|
||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none",
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
elif "silu" in act:
|
||||
self.act = torch.nn.functional.silu
|
||||
|
|
|
@ -187,9 +187,9 @@ class FlashMLP(nn.Module):
|
|||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate="tanh"
|
||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none",
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -225,9 +225,9 @@ class PhiMLP(nn.Module):
|
|||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate="tanh"
|
||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none",
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -69,7 +69,12 @@ def _load_multi_mqa_gptq(
|
|||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
qzeros = qzeros.to(device=weights.device)
|
||||
|
||||
bits, groupsize, _, quant_method, = weights._get_gptq_params()
|
||||
(
|
||||
bits,
|
||||
groupsize,
|
||||
_,
|
||||
quant_method,
|
||||
) = weights._get_gptq_params()
|
||||
if quant_method == "gptq":
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
g_idx = g_idx.to(device=weights.device)
|
||||
|
@ -306,9 +311,9 @@ class MLP(nn.Module):
|
|||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate="tanh"
|
||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none",
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -66,6 +66,7 @@ class IdeficsVisionConfig(PretrainedConfig):
|
|||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
"""
|
||||
|
||||
model_type = "idefics"
|
||||
attribute_map = {
|
||||
"hidden_size": "embed_dim",
|
||||
|
@ -125,6 +126,7 @@ class IdeficsPerceiverConfig(PretrainedConfig):
|
|||
qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use qk layer norms in perceiver
|
||||
"""
|
||||
|
||||
model_type = "idefics"
|
||||
|
||||
def __init__(
|
||||
|
@ -219,6 +221,7 @@ class IdeficsConfig(PretrainedConfig):
|
|||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "idefics"
|
||||
is_composition = True
|
||||
|
||||
|
|
|
@ -123,10 +123,10 @@ def expand_inputs_for_generation(
|
|||
raise ValueError(
|
||||
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
|
||||
)
|
||||
encoder_outputs[
|
||||
"last_hidden_state"
|
||||
] = encoder_outputs.last_hidden_state.index_select(
|
||||
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
|
||||
encoder_outputs["last_hidden_state"] = (
|
||||
encoder_outputs.last_hidden_state.index_select(
|
||||
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
|
||||
)
|
||||
)
|
||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||
return input_ids, model_kwargs
|
||||
|
|
|
@ -133,6 +133,7 @@ class IdeficsProcessor(ProcessorMixin):
|
|||
An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
|
||||
image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "IdeficsImageProcessor"
|
||||
tokenizer_class = "LlamaTokenizerFast"
|
||||
|
|
|
@ -19,10 +19,12 @@ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference."""
|
||||
|
||||
max_seqlen: int
|
||||
max_batch_size: int
|
||||
conv_states: torch.Tensor
|
||||
|
@ -137,13 +139,28 @@ class MambaBlock(nn.Module):
|
|||
def step(self, hidden_states, conv_state, ssm_state):
|
||||
xz = self.in_proj(hidden_states.squeeze(1))
|
||||
x, z = xz.chunk(2, dim=-1) # (B D)
|
||||
x = causal_conv1d_update(x, conv_state, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation)
|
||||
x = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
self.conv1d.weight.squeeze(1),
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
||||
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||
dt = F.linear(dt, self.dt_proj.weight)
|
||||
A = self.negA
|
||||
y = selective_state_update(
|
||||
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
||||
ssm_state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
self.D,
|
||||
z=z,
|
||||
dt_bias=self.dt_proj.bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
out = self.out_proj(y)
|
||||
return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
|
|
|
@ -35,7 +35,6 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
|
|||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashCausalLMBatch(Batch):
|
||||
batch_id: int
|
||||
|
@ -1213,8 +1212,9 @@ class FlashCausalLM(Model):
|
|||
# accept each new token for this specific request since we may
|
||||
# have more than one new token per request with speculative decoding
|
||||
for next_token_id in _next_token_ids:
|
||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id)
|
||||
|
||||
batch.next_token_chooser = (
|
||||
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
|
||||
)
|
||||
|
||||
# Update values
|
||||
batch.input_lengths[i] = input_length + n_accepted_ids
|
||||
|
|
|
@ -92,7 +92,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
requests_idx_mapping[r.id] = i
|
||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||
)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
|
|
@ -114,7 +114,9 @@ class IdeficsCausalLMBatch(Batch):
|
|||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||
)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
@ -401,9 +403,9 @@ class IdeficsCausalLMBatch(Batch):
|
|||
pixel_values = batch.pixel_values.new_zeros(
|
||||
(total_batch_size, max_num_images, 3, 224, 224)
|
||||
)
|
||||
pixel_values[
|
||||
start_index:end_index, :curr_batch_max_num_images
|
||||
] = batch.pixel_values
|
||||
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
|
||||
batch.pixel_values
|
||||
)
|
||||
|
||||
if image_attention_mask is None:
|
||||
image_attention_mask = batch.image_attention_mask.new_zeros(
|
||||
|
@ -500,14 +502,14 @@ class IdeficsCausalLMBatch(Batch):
|
|||
# We slice the keys to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
if batch.keys_head_dim_last:
|
||||
padded_past_keys[
|
||||
start_index:end_index, :, -past_seq_len:, :
|
||||
] = past_keys[:, :, -past_seq_len:, :]
|
||||
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
|
||||
past_keys[:, :, -past_seq_len:, :]
|
||||
)
|
||||
else:
|
||||
# BLOOM case
|
||||
padded_past_keys[
|
||||
start_index:end_index, :, :, -past_seq_len:
|
||||
] = past_keys[:, :, :, -past_seq_len:]
|
||||
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
|
||||
past_keys[:, :, :, -past_seq_len:]
|
||||
)
|
||||
del past_keys
|
||||
|
||||
start_index = end_index
|
||||
|
@ -525,9 +527,9 @@ class IdeficsCausalLMBatch(Batch):
|
|||
end_index = start_index + len(batch)
|
||||
# We slice the past values to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
padded_past_values[
|
||||
start_index:end_index, :, -past_seq_len:, :
|
||||
] = past_values[:, :, -past_seq_len:, :]
|
||||
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
||||
past_values[:, :, -past_seq_len:, :]
|
||||
)
|
||||
del past_values
|
||||
|
||||
# Update values
|
||||
|
@ -603,9 +605,11 @@ class IdeficsCausalLM(Model):
|
|||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None,
|
||||
device_map=(
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -836,9 +840,9 @@ class IdeficsCausalLM(Model):
|
|||
|
||||
# Update attention_mask as we added a new token to input_ids
|
||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||
batch.image_attention_mask[
|
||||
:, -batch.padding_right_offset, :
|
||||
] = batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
|
||||
batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
|
||||
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
|
||||
)
|
||||
# Decrease right offset
|
||||
batch.padding_right_offset -= 1
|
||||
|
||||
|
|
|
@ -15,7 +15,10 @@ from text_generation_server.utils import (
|
|||
)
|
||||
from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
|
||||
import time
|
||||
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams
|
||||
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||
MambaModel,
|
||||
InferenceParams,
|
||||
)
|
||||
from text_generation_server.models import Model
|
||||
from typing import Any, List, Optional, Tuple, Type, Dict
|
||||
from text_generation_server.models.types import (
|
||||
|
@ -28,21 +31,35 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
|||
from dataclasses import dataclass
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
||||
def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: int, d_state: int, seqlen_offset: int, dtype: torch.dtype, device: torch.device):
|
||||
|
||||
def new_inference_params(
|
||||
n_blocks: int,
|
||||
batch_size: int,
|
||||
d_inner: int,
|
||||
d_conv: int,
|
||||
d_state: int,
|
||||
seqlen_offset: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
max_seqlen = 0
|
||||
conv_states = torch.zeros(
|
||||
(n_blocks,
|
||||
batch_size,
|
||||
d_inner,
|
||||
d_conv,),
|
||||
(
|
||||
n_blocks,
|
||||
batch_size,
|
||||
d_inner,
|
||||
d_conv,
|
||||
),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
ssm_states = torch.zeros(
|
||||
(n_blocks,
|
||||
batch_size,
|
||||
d_inner,
|
||||
d_state,),
|
||||
(
|
||||
n_blocks,
|
||||
batch_size,
|
||||
d_inner,
|
||||
d_state,
|
||||
),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -52,7 +69,6 @@ def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: i
|
|||
seqlen_offset=seqlen_offset,
|
||||
conv_states=conv_states,
|
||||
ssm_states=ssm_states,
|
||||
|
||||
)
|
||||
return inference_params
|
||||
|
||||
|
@ -124,7 +140,9 @@ class MambaBatch(Batch):
|
|||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||
)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
@ -251,7 +269,9 @@ class MambaBatch(Batch):
|
|||
|
||||
# TODO
|
||||
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
|
||||
self.inference_params.conv_states = self.inference_params.conv_states[:, indices]
|
||||
self.inference_params.conv_states = self.inference_params.conv_states[
|
||||
:, indices
|
||||
]
|
||||
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
|
||||
return self
|
||||
|
||||
|
@ -280,13 +300,20 @@ class MambaBatch(Batch):
|
|||
max_seqlen = 0
|
||||
seqlen_offset = 0
|
||||
|
||||
(n_blocks, _, d_inner, d_conv) = (
|
||||
batches[0].inference_params.conv_states.shape
|
||||
)
|
||||
(n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
|
||||
(_, _, _, d_state) = batches[0].inference_params.ssm_states.shape
|
||||
dtype = batches[0].inference_params.conv_states.dtype
|
||||
device = batches[0].inference_params.conv_states.device
|
||||
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=total_batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=device, dtype=dtype)
|
||||
inference_params = new_inference_params(
|
||||
n_blocks=n_blocks,
|
||||
batch_size=total_batch_size,
|
||||
d_state=d_state,
|
||||
d_conv=d_conv,
|
||||
d_inner=d_inner,
|
||||
seqlen_offset=seqlen_offset,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Batch tensors
|
||||
input_ids = None
|
||||
|
@ -334,13 +361,20 @@ class MambaBatch(Batch):
|
|||
max_input_length - batch.max_input_length
|
||||
) * len(batch)
|
||||
|
||||
inference_params.max_seqlen = max(inference_params.max_seqlen, batch.inference_params.max_seqlen)
|
||||
inference_params.max_seqlen = max(
|
||||
inference_params.max_seqlen, batch.inference_params.max_seqlen
|
||||
)
|
||||
assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset"
|
||||
inference_params.seqlen_offset = max(inference_params.seqlen_offset, batch.inference_params.seqlen_offset)
|
||||
inference_params.seqlen_offset = max(
|
||||
inference_params.seqlen_offset, batch.inference_params.seqlen_offset
|
||||
)
|
||||
|
||||
|
||||
inference_params.conv_states[:, start_index:end_index] = batch.inference_params.conv_states
|
||||
inference_params.ssm_states[:, start_index:end_index] = batch.inference_params.ssm_states
|
||||
inference_params.conv_states[:, start_index:end_index] = (
|
||||
batch.inference_params.conv_states
|
||||
)
|
||||
inference_params.ssm_states[:, start_index:end_index] = (
|
||||
batch.inference_params.ssm_states
|
||||
)
|
||||
|
||||
start_index = end_index
|
||||
|
||||
|
@ -452,36 +486,39 @@ class Mamba(Model):
|
|||
|
||||
# Important seqlen_offset to go through the update mecanism with the state
|
||||
seqlen_offset = 1
|
||||
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype)
|
||||
inference_params = new_inference_params(
|
||||
n_blocks=n_blocks,
|
||||
batch_size=batch_size,
|
||||
d_state=d_state,
|
||||
d_conv=d_conv,
|
||||
d_inner=d_inner,
|
||||
seqlen_offset=seqlen_offset,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# Run once outside to warmup
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
inference_params=inference_params
|
||||
)
|
||||
self.model.forward(input_ids=input_ids, inference_params=inference_params)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
inference_params=inference_params
|
||||
input_ids=input_ids, inference_params=inference_params
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
graph_dict = {
|
||||
"input_ids": input_ids,
|
||||
"inference_params": inference_params,
|
||||
"graph": graph,
|
||||
"logits": logits
|
||||
"logits": logits,
|
||||
}
|
||||
self.cuda_graphs[batch_size] = graph_dict
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inference_params: Any
|
||||
self, input_ids: torch.Tensor, inference_params: Any
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
|
@ -504,15 +541,21 @@ class Mamba(Model):
|
|||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: bs] = input_ids
|
||||
cuda_graph["inference_params"].conv_states[:, : bs] = inference_params.conv_states
|
||||
cuda_graph["inference_params"].ssm_states[:, : bs] = inference_params.ssm_states
|
||||
cuda_graph["input_ids"][:bs] = input_ids
|
||||
cuda_graph["inference_params"].conv_states[
|
||||
:, :bs
|
||||
] = inference_params.conv_states
|
||||
cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
inference_params.conv_states.copy_(cuda_graph["inference_params"].conv_states[:, :bs])
|
||||
inference_params.ssm_states.copy_(cuda_graph["inference_params"].ssm_states[:, :bs])
|
||||
inference_params.conv_states.copy_(
|
||||
cuda_graph["inference_params"].conv_states[:, :bs]
|
||||
)
|
||||
inference_params.ssm_states.copy_(
|
||||
cuda_graph["inference_params"].ssm_states[:, :bs]
|
||||
)
|
||||
|
||||
# Slice output to the correct shape
|
||||
return cuda_graph["logits"][:bs]
|
||||
|
@ -533,14 +576,20 @@ class Mamba(Model):
|
|||
d_state = self.model.config.d_state
|
||||
d_conv = self.model.config.d_conv
|
||||
d_inner = self.model.config.d_inner
|
||||
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype)
|
||||
inference_params = new_inference_params(
|
||||
n_blocks=n_blocks,
|
||||
batch_size=batch_size,
|
||||
d_state=d_state,
|
||||
d_conv=d_conv,
|
||||
d_inner=d_inner,
|
||||
seqlen_offset=seqlen_offset,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
batch.inference_params = inference_params
|
||||
|
||||
# Forward pass
|
||||
logits = self.forward(
|
||||
input_ids, inference_params=batch.inference_params
|
||||
)
|
||||
|
||||
logits = self.forward(input_ids, inference_params=batch.inference_params)
|
||||
|
||||
# batch.inference_params = new_inference_params
|
||||
# Results
|
||||
|
@ -694,9 +743,9 @@ class Mamba(Model):
|
|||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||
next_token_id_squeezed.item()
|
||||
)
|
||||
batch.next_token_choosers[i] = batch.next_token_choosers[
|
||||
i
|
||||
].advance_grammar(next_token_id_squeezed.item())
|
||||
batch.input_ids[i, 0] = next_token_id
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
|
|
|
@ -36,9 +36,11 @@ class RW(CausalLM):
|
|||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None,
|
||||
device_map=(
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
|
|
@ -96,7 +96,9 @@ class Seq2SeqLMBatch(Batch):
|
|||
inputs.append(r.inputs)
|
||||
requests_idx_mapping[r.id] = i
|
||||
decoder_input_lengths.append(1)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||
)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
@ -351,9 +353,9 @@ class Seq2SeqLMBatch(Batch):
|
|||
(total_batch_size, max_input_length),
|
||||
)
|
||||
# Copy to correct indices
|
||||
attention_mask[
|
||||
start_index:end_index, -batch.max_input_length :
|
||||
] = batch.attention_mask[:, -batch.max_input_length :]
|
||||
attention_mask[start_index:end_index, -batch.max_input_length :] = (
|
||||
batch.attention_mask[:, -batch.max_input_length :]
|
||||
)
|
||||
|
||||
# Create padded tensor
|
||||
if decoder_input_ids is None:
|
||||
|
@ -547,9 +549,11 @@ class Seq2SeqLM(Model):
|
|||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None,
|
||||
device_map=(
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -750,7 +754,7 @@ class Seq2SeqLM(Model):
|
|||
|
||||
if top_n_tokens > 0:
|
||||
all_top_tokens = []
|
||||
for (top_token_ids, top_token_logprobs) in zip(
|
||||
for top_token_ids, top_token_logprobs in zip(
|
||||
top_token_ids, top_token_logprobs
|
||||
):
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
|
|
|
@ -88,14 +88,16 @@ class Generation:
|
|||
def to_pb(self) -> generate_pb2.Generation:
|
||||
return generate_pb2.Generation(
|
||||
request_id=self.request_id,
|
||||
prefill_tokens=self.prefill_tokens.to_pb()
|
||||
if self.prefill_tokens is not None
|
||||
else None,
|
||||
prefill_tokens=(
|
||||
self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None
|
||||
),
|
||||
tokens=self.tokens.to_pb(),
|
||||
generated_text=self.generated_text.to_pb()
|
||||
if self.generated_text is not None
|
||||
else None,
|
||||
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens]
|
||||
if self.top_tokens is not None
|
||||
else None,
|
||||
generated_text=(
|
||||
self.generated_text.to_pb() if self.generated_text is not None else None
|
||||
),
|
||||
top_tokens=(
|
||||
[top_tokens.to_pb() for top_tokens in self.top_tokens]
|
||||
if self.top_tokens is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
|
|
@ -182,7 +182,7 @@ try:
|
|||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1) & maxq # eventually avoid overflow
|
||||
zeros = (zeros + 1) & maxq # eventually avoid overflow
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
|
|
@ -355,7 +355,9 @@ def get_linear(weight, bias, quantize):
|
|||
"to use Exllama/GPTQ kernels for AWQ inference."
|
||||
)
|
||||
if not HAS_AWQ:
|
||||
raise NotImplementedError("You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly")
|
||||
raise NotImplementedError(
|
||||
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
||||
)
|
||||
linear = WQLinear(
|
||||
w_bit=bits,
|
||||
group_size=groupsize,
|
||||
|
|
|
@ -516,7 +516,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
||||
schema = build_regex_from_object(schema)
|
||||
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
||||
pass # schema is already a regex just here for clarity
|
||||
pass # schema is already a regex just here for clarity
|
||||
fsm = RegexFSM(schema, tokenizer)
|
||||
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||
return fsm
|
||||
|
|
|
@ -409,8 +409,12 @@ class HeterogeneousNextTokenChooser:
|
|||
|
||||
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
|
||||
if self.grammar_processor is not None:
|
||||
self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index(
|
||||
next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index
|
||||
self.fsm_grammar_states[grammar_state_index] = (
|
||||
self.grammar_processor.advance_at_index(
|
||||
next_id,
|
||||
self.fsm_grammar_states[grammar_state_index],
|
||||
grammar_state_index,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
|
|
Loading…
Reference in New Issue