chore: add pre-commit (#1569)

This commit is contained in:
OlivierDehaene 2024-02-16 11:58:58 +01:00 committed by GitHub
parent 142cdabed3
commit 9946165ee0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
78 changed files with 346 additions and 234 deletions

View File

@ -71,12 +71,11 @@ jobs:
pip install pytest pip install pytest
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv server/tests pytest -s -vv server/tests
- name: Run Rust fmt - name: Pre-commit checks
run: | run: |
cargo fmt --check pip install pre-commit
- name: Run Rust clippy pre-commit install
run: | pre-commit run --all-files
cargo clippy
- name: Run Rust tests - name: Run Rust tests
run: | run: |
cargo test cargo test

1
.gitignore vendored
View File

@ -11,4 +11,3 @@ server/exllama_kernels/exllama_kernels/hip_func/
*_hip.cuh *_hip.cuh
server/exllama_kernels/exllama_kernels/hip_buffers.cuh server/exllama_kernels/exllama_kernels/hip_buffers.cuh
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp

18
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,18 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
exclude: docs/source/basic_tutorials/launcher.md
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/doublify/pre-commit-rust
rev: v1.0
hooks:
- id: fmt
- id: cargo-check
- id: clippy

View File

@ -29,4 +29,3 @@ tui = {package = "ratatui", version = "0.23", default-features = false, features
tracing = "0.1.37" tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
hf-hub = "0.3.1" hf-hub = "0.3.1"

View File

@ -134,6 +134,7 @@ class Parameters(BaseModel):
raise ValidationError("`value` cannot be empty for `json` grammar") raise ValidationError("`value` cannot be empty for `json` grammar")
return v return v
class Request(BaseModel): class Request(BaseModel):
# Prompt # Prompt
inputs: str inputs: str

View File

@ -9,4 +9,3 @@ Standard attention mechanism uses High Bandwidth Memory (HBM) to store, read and
It is implemented for supported models. You can check out the complete list of models that support Flash Attention [here](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models), for models with flash prefix. It is implemented for supported models. You can check out the complete list of models that support Flash Attention [here](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models), for models with flash prefix.
You can learn more about Flash Attention by reading the paper in this [link](https://arxiv.org/abs/2205.14135). You can learn more about Flash Attention by reading the paper in this [link](https://arxiv.org/abs/2205.14135).

View File

@ -54,7 +54,9 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_mamba_load(fused_kernel_mamba, generate_load, generous_response_snapshot): async def test_mamba_load(
fused_kernel_mamba, generate_load, generous_response_snapshot
):
responses = await generate_load( responses = await generate_load(
fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4 fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4
) )

View File

@ -27,6 +27,7 @@ pub struct Validation {
} }
impl Validation { impl Validation {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
workers: usize, workers: usize,
tokenizer: Option<Tokenizer>, tokenizer: Option<Tokenizer>,

View File

@ -40,5 +40,3 @@ __forceinline__ __device__ void dequant_6bit_16
#endif #endif
#endif #endif

View File

@ -251,9 +251,9 @@ class LlamaMLP(nn.Module):
if "gelu" not in act if "gelu" not in act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
) )
# Fuse gate and up proj # Fuse gate and up proj

View File

@ -255,9 +255,9 @@ class MistralMLP(nn.Module):
if "gelu" not in act if "gelu" not in act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
) )
# Fuse gate and up proj # Fuse gate and up proj

View File

@ -344,9 +344,9 @@ class BlockSparseMoE(nn.Module):
if "gelu" in act: if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu( self.act = lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
elif "silu" in act: elif "silu" in act:
self.act = torch.nn.functional.silu self.act = torch.nn.functional.silu
@ -600,9 +600,9 @@ class DenseMoE(nn.Module):
if "gelu" in act: if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu( self.act = lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
elif "silu" in act: elif "silu" in act:
self.act = torch.nn.functional.silu self.act = torch.nn.functional.silu

View File

@ -187,9 +187,9 @@ class FlashMLP(nn.Module):
if "gelu" not in act if "gelu" not in act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
) )

View File

@ -225,9 +225,9 @@ class PhiMLP(nn.Module):
if "gelu" not in act if "gelu" not in act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
) )

View File

@ -69,7 +69,12 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device) qzeros = qzeros.to(device=weights.device)
bits, groupsize, _, quant_method, = weights._get_gptq_params() (
bits,
groupsize,
_,
quant_method,
) = weights._get_gptq_params()
if quant_method == "gptq": if quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device) g_idx = g_idx.to(device=weights.device)
@ -306,9 +311,9 @@ class MLP(nn.Module):
if "gelu" not in act if "gelu" not in act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
) )

View File

@ -66,6 +66,7 @@ class IdeficsVisionConfig(PretrainedConfig):
initializer_range (`float`, *optional*, defaults to 0.02): initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
""" """
model_type = "idefics" model_type = "idefics"
attribute_map = { attribute_map = {
"hidden_size": "embed_dim", "hidden_size": "embed_dim",
@ -125,6 +126,7 @@ class IdeficsPerceiverConfig(PretrainedConfig):
qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`): qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):
Whether or not to use qk layer norms in perceiver Whether or not to use qk layer norms in perceiver
""" """
model_type = "idefics" model_type = "idefics"
def __init__( def __init__(
@ -219,6 +221,7 @@ class IdeficsConfig(PretrainedConfig):
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
```""" ```"""
model_type = "idefics" model_type = "idefics"
is_composition = True is_composition = True

View File

@ -123,11 +123,11 @@ def expand_inputs_for_generation(
raise ValueError( raise ValueError(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
) )
encoder_outputs[ encoder_outputs["last_hidden_state"] = (
"last_hidden_state" encoder_outputs.last_hidden_state.index_select(
] = encoder_outputs.last_hidden_state.index_select(
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
) )
)
model_kwargs["encoder_outputs"] = encoder_outputs model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs return input_ids, model_kwargs

View File

@ -133,6 +133,7 @@ class IdeficsProcessor(ProcessorMixin):
An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input. An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image) image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
image_processor_class = "IdeficsImageProcessor" image_processor_class = "IdeficsImageProcessor"
tokenizer_class = "LlamaTokenizerFast" tokenizer_class = "LlamaTokenizerFast"

View File

@ -19,10 +19,12 @@ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import math import math
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
class InferenceParams: class InferenceParams:
"""Inference parameters that are passed to the main model in order """Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.""" to efficienly calculate and store the context during inference."""
max_seqlen: int max_seqlen: int
max_batch_size: int max_batch_size: int
conv_states: torch.Tensor conv_states: torch.Tensor
@ -137,13 +139,28 @@ class MambaBlock(nn.Module):
def step(self, hidden_states, conv_state, ssm_state): def step(self, hidden_states, conv_state, ssm_state):
xz = self.in_proj(hidden_states.squeeze(1)) xz = self.in_proj(hidden_states.squeeze(1))
x, z = xz.chunk(2, dim=-1) # (B D) x, z = xz.chunk(2, dim=-1) # (B D)
x = causal_conv1d_update(x, conv_state, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation) x = causal_conv1d_update(
x,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
x_db = self.x_proj(x) # (B dt_rank+2*d_state) x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = F.linear(dt, self.dt_proj.weight) dt = F.linear(dt, self.dt_proj.weight)
A = self.negA A = self.negA
y = selective_state_update( y = selective_state_update(
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True ssm_state,
x,
dt,
A,
B,
C,
self.D,
z=z,
dt_bias=self.dt_proj.bias,
dt_softplus=True,
) )
out = self.out_proj(y) out = self.out_proj(y)
return out.unsqueeze(1), conv_state.clone(), ssm_state.clone() return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()

View File

@ -2,6 +2,7 @@
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
""" """
import math import math
import os import os
import warnings import warnings

View File

@ -35,7 +35,6 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@dataclass @dataclass
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
batch_id: int batch_id: int
@ -1213,8 +1212,9 @@ class FlashCausalLM(Model):
# accept each new token for this specific request since we may # accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding # have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids: for next_token_id in _next_token_ids:
batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id) batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
)
# Update values # Update values
batch.input_lengths[i] = input_length + n_accepted_ids batch.input_lengths[i] = input_length + n_accepted_ids

View File

@ -92,7 +92,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )

View File

@ -114,7 +114,9 @@ class IdeficsCausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
@ -401,9 +403,9 @@ class IdeficsCausalLMBatch(Batch):
pixel_values = batch.pixel_values.new_zeros( pixel_values = batch.pixel_values.new_zeros(
(total_batch_size, max_num_images, 3, 224, 224) (total_batch_size, max_num_images, 3, 224, 224)
) )
pixel_values[ pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
start_index:end_index, :curr_batch_max_num_images batch.pixel_values
] = batch.pixel_values )
if image_attention_mask is None: if image_attention_mask is None:
image_attention_mask = batch.image_attention_mask.new_zeros( image_attention_mask = batch.image_attention_mask.new_zeros(
@ -500,14 +502,14 @@ class IdeficsCausalLMBatch(Batch):
# We slice the keys to remove the padding from previous batches # We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last: if batch.keys_head_dim_last:
padded_past_keys[ padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
start_index:end_index, :, -past_seq_len:, : past_keys[:, :, -past_seq_len:, :]
] = past_keys[:, :, -past_seq_len:, :] )
else: else:
# BLOOM case # BLOOM case
padded_past_keys[ padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
start_index:end_index, :, :, -past_seq_len: past_keys[:, :, :, -past_seq_len:]
] = past_keys[:, :, :, -past_seq_len:] )
del past_keys del past_keys
start_index = end_index start_index = end_index
@ -525,9 +527,9 @@ class IdeficsCausalLMBatch(Batch):
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches # We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
padded_past_values[ padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
start_index:end_index, :, -past_seq_len:, : past_values[:, :, -past_seq_len:, :]
] = past_values[:, :, -past_seq_len:, :] )
del past_values del past_values
# Update values # Update values
@ -603,9 +605,11 @@ class IdeficsCausalLM(Model):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1 if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None, else None
),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -836,9 +840,9 @@ class IdeficsCausalLM(Model):
# Update attention_mask as we added a new token to input_ids # Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1 batch.attention_mask[:, -batch.padding_right_offset] = 1
batch.image_attention_mask[ batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
:, -batch.padding_right_offset, : batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
] = batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] )
# Decrease right offset # Decrease right offset
batch.padding_right_offset -= 1 batch.padding_right_offset -= 1

View File

@ -15,7 +15,10 @@ from text_generation_server.utils import (
) )
from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
import time import time
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaModel,
InferenceParams,
)
from text_generation_server.models import Model from text_generation_server.models import Model
from typing import Any, List, Optional, Tuple, Type, Dict from typing import Any, List, Optional, Tuple, Type, Dict
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -28,21 +31,35 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: int, d_state: int, seqlen_offset: int, dtype: torch.dtype, device: torch.device):
def new_inference_params(
n_blocks: int,
batch_size: int,
d_inner: int,
d_conv: int,
d_state: int,
seqlen_offset: int,
dtype: torch.dtype,
device: torch.device,
):
max_seqlen = 0 max_seqlen = 0
conv_states = torch.zeros( conv_states = torch.zeros(
(n_blocks, (
n_blocks,
batch_size, batch_size,
d_inner, d_inner,
d_conv,), d_conv,
),
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
ssm_states = torch.zeros( ssm_states = torch.zeros(
(n_blocks, (
n_blocks,
batch_size, batch_size,
d_inner, d_inner,
d_state,), d_state,
),
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -52,7 +69,6 @@ def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: i
seqlen_offset=seqlen_offset, seqlen_offset=seqlen_offset,
conv_states=conv_states, conv_states=conv_states,
ssm_states=ssm_states, ssm_states=ssm_states,
) )
return inference_params return inference_params
@ -124,7 +140,9 @@ class MambaBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
@ -251,7 +269,9 @@ class MambaBatch(Batch):
# TODO # TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary. # Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
self.inference_params.conv_states = self.inference_params.conv_states[:, indices] self.inference_params.conv_states = self.inference_params.conv_states[
:, indices
]
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices] self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
return self return self
@ -280,13 +300,20 @@ class MambaBatch(Batch):
max_seqlen = 0 max_seqlen = 0
seqlen_offset = 0 seqlen_offset = 0
(n_blocks, _, d_inner, d_conv) = ( (n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
batches[0].inference_params.conv_states.shape
)
(_, _, _, d_state) = batches[0].inference_params.ssm_states.shape (_, _, _, d_state) = batches[0].inference_params.ssm_states.shape
dtype = batches[0].inference_params.conv_states.dtype dtype = batches[0].inference_params.conv_states.dtype
device = batches[0].inference_params.conv_states.device device = batches[0].inference_params.conv_states.device
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=total_batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=device, dtype=dtype) inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=total_batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=device,
dtype=dtype,
)
# Batch tensors # Batch tensors
input_ids = None input_ids = None
@ -334,13 +361,20 @@ class MambaBatch(Batch):
max_input_length - batch.max_input_length max_input_length - batch.max_input_length
) * len(batch) ) * len(batch)
inference_params.max_seqlen = max(inference_params.max_seqlen, batch.inference_params.max_seqlen) inference_params.max_seqlen = max(
inference_params.max_seqlen, batch.inference_params.max_seqlen
)
assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset" assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset"
inference_params.seqlen_offset = max(inference_params.seqlen_offset, batch.inference_params.seqlen_offset) inference_params.seqlen_offset = max(
inference_params.seqlen_offset, batch.inference_params.seqlen_offset
)
inference_params.conv_states[:, start_index:end_index] = (
inference_params.conv_states[:, start_index:end_index] = batch.inference_params.conv_states batch.inference_params.conv_states
inference_params.ssm_states[:, start_index:end_index] = batch.inference_params.ssm_states )
inference_params.ssm_states[:, start_index:end_index] = (
batch.inference_params.ssm_states
)
start_index = end_index start_index = end_index
@ -452,36 +486,39 @@ class Mamba(Model):
# Important seqlen_offset to go through the update mecanism with the state # Important seqlen_offset to go through the update mecanism with the state
seqlen_offset = 1 seqlen_offset = 1
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype) inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
torch.cuda.synchronize() torch.cuda.synchronize()
# Run once outside to warmup # Run once outside to warmup
self.model.forward( self.model.forward(input_ids=input_ids, inference_params=inference_params)
input_ids=input_ids,
inference_params=inference_params
)
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
logits = self.model.forward( logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids, inference_params=inference_params
inference_params=inference_params
) )
torch.cuda.synchronize() torch.cuda.synchronize()
graph_dict = { graph_dict = {
"input_ids": input_ids, "input_ids": input_ids,
"inference_params": inference_params, "inference_params": inference_params,
"graph": graph, "graph": graph,
"logits": logits "logits": logits,
} }
self.cuda_graphs[batch_size] = graph_dict self.cuda_graphs[batch_size] = graph_dict
def forward( def forward(
self, self, input_ids: torch.Tensor, inference_params: Any
input_ids: torch.Tensor,
inference_params: Any
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
bs = input_ids.shape[0] bs = input_ids.shape[0]
padded_bs = bs padded_bs = bs
@ -505,14 +542,20 @@ class Mamba(Model):
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded # Static inputs are potentially padded
cuda_graph["input_ids"][:bs] = input_ids cuda_graph["input_ids"][:bs] = input_ids
cuda_graph["inference_params"].conv_states[:, : bs] = inference_params.conv_states cuda_graph["inference_params"].conv_states[
:, :bs
] = inference_params.conv_states
cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
inference_params.conv_states.copy_(cuda_graph["inference_params"].conv_states[:, :bs]) inference_params.conv_states.copy_(
inference_params.ssm_states.copy_(cuda_graph["inference_params"].ssm_states[:, :bs]) cuda_graph["inference_params"].conv_states[:, :bs]
)
inference_params.ssm_states.copy_(
cuda_graph["inference_params"].ssm_states[:, :bs]
)
# Slice output to the correct shape # Slice output to the correct shape
return cuda_graph["logits"][:bs] return cuda_graph["logits"][:bs]
@ -533,14 +576,20 @@ class Mamba(Model):
d_state = self.model.config.d_state d_state = self.model.config.d_state
d_conv = self.model.config.d_conv d_conv = self.model.config.d_conv
d_inner = self.model.config.d_inner d_inner = self.model.config.d_inner
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype) inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
batch.inference_params = inference_params batch.inference_params = inference_params
# Forward pass # Forward pass
logits = self.forward( logits = self.forward(input_ids, inference_params=batch.inference_params)
input_ids, inference_params=batch.inference_params
)
# batch.inference_params = new_inference_params # batch.inference_params = new_inference_params
# Results # Results
@ -694,9 +743,9 @@ class Mamba(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( batch.next_token_choosers[i] = batch.next_token_choosers[
next_token_id_squeezed.item() i
) ].advance_grammar(next_token_id_squeezed.item())
batch.input_ids[i, 0] = next_token_id batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length

View File

@ -36,9 +36,11 @@ class RW(CausalLM):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1 if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None, else None
),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -96,7 +96,9 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs) inputs.append(r.inputs)
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
@ -351,9 +353,9 @@ class Seq2SeqLMBatch(Batch):
(total_batch_size, max_input_length), (total_batch_size, max_input_length),
) )
# Copy to correct indices # Copy to correct indices
attention_mask[ attention_mask[start_index:end_index, -batch.max_input_length :] = (
start_index:end_index, -batch.max_input_length : batch.attention_mask[:, -batch.max_input_length :]
] = batch.attention_mask[:, -batch.max_input_length :] )
# Create padded tensor # Create padded tensor
if decoder_input_ids is None: if decoder_input_ids is None:
@ -547,9 +549,11 @@ class Seq2SeqLM(Model):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1 if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None, else None
),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -750,7 +754,7 @@ class Seq2SeqLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip( for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs top_token_ids, top_token_logprobs
): ):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(

View File

@ -88,14 +88,16 @@ class Generation:
def to_pb(self) -> generate_pb2.Generation: def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation( return generate_pb2.Generation(
request_id=self.request_id, request_id=self.request_id,
prefill_tokens=self.prefill_tokens.to_pb() prefill_tokens=(
if self.prefill_tokens is not None self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None
else None, ),
tokens=self.tokens.to_pb(), tokens=self.tokens.to_pb(),
generated_text=self.generated_text.to_pb() generated_text=(
if self.generated_text is not None self.generated_text.to_pb() if self.generated_text is not None else None
else None, ),
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] top_tokens=(
[top_tokens.to_pb() for top_tokens in self.top_tokens]
if self.top_tokens is not None if self.top_tokens is not None
else None, else None
),
) )

View File

@ -355,7 +355,9 @@ def get_linear(weight, bias, quantize):
"to use Exllama/GPTQ kernels for AWQ inference." "to use Exllama/GPTQ kernels for AWQ inference."
) )
if not HAS_AWQ: if not HAS_AWQ:
raise NotImplementedError("You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly") raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
linear = WQLinear( linear = WQLinear(
w_bit=bits, w_bit=bits,
group_size=groupsize, group_size=groupsize,

View File

@ -409,8 +409,12 @@ class HeterogeneousNextTokenChooser:
def advance_grammar_single(self, grammar_state_index: int, next_id: int): def advance_grammar_single(self, grammar_state_index: int, next_id: int):
if self.grammar_processor is not None: if self.grammar_processor is not None:
self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index( self.fsm_grammar_states[grammar_state_index] = (
next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index self.grammar_processor.advance_at_index(
next_id,
self.fsm_grammar_states[grammar_state_index],
grammar_state_index,
)
) )
return self return self