chore: add pre-commit (#1569)

This commit is contained in:
OlivierDehaene 2024-02-16 11:58:58 +01:00 committed by GitHub
parent 142cdabed3
commit 9946165ee0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
78 changed files with 346 additions and 234 deletions

View File

@ -71,12 +71,11 @@ jobs:
pip install pytest
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv server/tests
- name: Run Rust fmt
- name: Pre-commit checks
run: |
cargo fmt --check
- name: Run Rust clippy
run: |
cargo clippy
pip install pre-commit
pre-commit install
pre-commit run --all-files
- name: Run Rust tests
run: |
cargo test

1
.gitignore vendored
View File

@ -11,4 +11,3 @@ server/exllama_kernels/exllama_kernels/hip_func/
*_hip.cuh
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp

18
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,18 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
exclude: docs/source/basic_tutorials/launcher.md
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/doublify/pre-commit-rust
rev: v1.0
hooks:
- id: fmt
- id: cargo-check
- id: clippy

View File

@ -29,4 +29,3 @@ tui = {package = "ratatui", version = "0.23", default-features = false, features
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
hf-hub = "0.3.1"

View File

@ -134,6 +134,7 @@ class Parameters(BaseModel):
raise ValidationError("`value` cannot be empty for `json` grammar")
return v
class Request(BaseModel):
# Prompt
inputs: str

View File

@ -9,4 +9,3 @@ Standard attention mechanism uses High Bandwidth Memory (HBM) to store, read and
It is implemented for supported models. You can check out the complete list of models that support Flash Attention [here](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models), for models with flash prefix.
You can learn more about Flash Attention by reading the paper in this [link](https://arxiv.org/abs/2205.14135).

View File

@ -54,7 +54,9 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba_load(fused_kernel_mamba, generate_load, generous_response_snapshot):
async def test_mamba_load(
fused_kernel_mamba, generate_load, generous_response_snapshot
):
responses = await generate_load(
fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4
)

View File

@ -27,6 +27,7 @@ pub struct Validation {
}
impl Validation {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
workers: usize,
tokenizer: Option<Tokenizer>,

View File

@ -40,5 +40,3 @@ __forceinline__ __device__ void dequant_6bit_16
#endif
#endif

View File

@ -251,9 +251,9 @@ class LlamaMLP(nn.Module):
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj

View File

@ -255,9 +255,9 @@ class MistralMLP(nn.Module):
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj

View File

@ -344,9 +344,9 @@ class BlockSparseMoE(nn.Module):
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu
@ -600,9 +600,9 @@ class DenseMoE(nn.Module):
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu

View File

@ -187,9 +187,9 @@ class FlashMLP(nn.Module):
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)

View File

@ -225,9 +225,9 @@ class PhiMLP(nn.Module):
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)

View File

@ -69,7 +69,12 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device)
bits, groupsize, _, quant_method, = weights._get_gptq_params()
(
bits,
groupsize,
_,
quant_method,
) = weights._get_gptq_params()
if quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
@ -306,9 +311,9 @@ class MLP(nn.Module):
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)

View File

@ -66,6 +66,7 @@ class IdeficsVisionConfig(PretrainedConfig):
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
"""
model_type = "idefics"
attribute_map = {
"hidden_size": "embed_dim",
@ -125,6 +126,7 @@ class IdeficsPerceiverConfig(PretrainedConfig):
qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):
Whether or not to use qk layer norms in perceiver
"""
model_type = "idefics"
def __init__(
@ -219,6 +221,7 @@ class IdeficsConfig(PretrainedConfig):
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "idefics"
is_composition = True

View File

@ -123,10 +123,10 @@ def expand_inputs_for_generation(
raise ValueError(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
encoder_outputs[
"last_hidden_state"
] = encoder_outputs.last_hidden_state.index_select(
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
encoder_outputs["last_hidden_state"] = (
encoder_outputs.last_hidden_state.index_select(
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
)
)
model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs

View File

@ -133,6 +133,7 @@ class IdeficsProcessor(ProcessorMixin):
An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "IdeficsImageProcessor"
tokenizer_class = "LlamaTokenizerFast"

View File

@ -19,10 +19,12 @@ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import math
from dataclasses import dataclass
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
conv_states: torch.Tensor
@ -137,13 +139,28 @@ class MambaBlock(nn.Module):
def step(self, hidden_states, conv_state, ssm_state):
xz = self.in_proj(hidden_states.squeeze(1))
x, z = xz.chunk(2, dim=-1) # (B D)
x = causal_conv1d_update(x, conv_state, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation)
x = causal_conv1d_update(
x,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = F.linear(dt, self.dt_proj.weight)
A = self.negA
y = selective_state_update(
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
ssm_state,
x,
dt,
A,
B,
C,
self.D,
z=z,
dt_bias=self.dt_proj.bias,
dt_softplus=True,
)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()

View File

@ -2,6 +2,7 @@
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
"""
import math
import os
import warnings

View File

@ -35,7 +35,6 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__)
@dataclass
class FlashCausalLMBatch(Batch):
batch_id: int
@ -1213,8 +1212,9 @@ class FlashCausalLM(Model):
# accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids:
batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id)
batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
)
# Update values
batch.input_lengths[i] = input_length + n_accepted_ids

View File

@ -92,7 +92,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)

View File

@ -114,7 +114,9 @@ class IdeficsCausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
@ -401,9 +403,9 @@ class IdeficsCausalLMBatch(Batch):
pixel_values = batch.pixel_values.new_zeros(
(total_batch_size, max_num_images, 3, 224, 224)
)
pixel_values[
start_index:end_index, :curr_batch_max_num_images
] = batch.pixel_values
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
batch.pixel_values
)
if image_attention_mask is None:
image_attention_mask = batch.image_attention_mask.new_zeros(
@ -500,14 +502,14 @@ class IdeficsCausalLMBatch(Batch):
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last:
padded_past_keys[
start_index:end_index, :, -past_seq_len:, :
] = past_keys[:, :, -past_seq_len:, :]
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
past_keys[:, :, -past_seq_len:, :]
)
else:
# BLOOM case
padded_past_keys[
start_index:end_index, :, :, -past_seq_len:
] = past_keys[:, :, :, -past_seq_len:]
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
past_keys[:, :, :, -past_seq_len:]
)
del past_keys
start_index = end_index
@ -525,9 +527,9 @@ class IdeficsCausalLMBatch(Batch):
end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
padded_past_values[
start_index:end_index, :, -past_seq_len:, :
] = past_values[:, :, -past_seq_len:, :]
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :]
)
del past_values
# Update values
@ -603,9 +605,11 @@ class IdeficsCausalLM(Model):
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
@ -836,9 +840,9 @@ class IdeficsCausalLM(Model):
# Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1
batch.image_attention_mask[
:, -batch.padding_right_offset, :
] = batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
)
# Decrease right offset
batch.padding_right_offset -= 1

View File

@ -15,7 +15,10 @@ from text_generation_server.utils import (
)
from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
import time
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaModel,
InferenceParams,
)
from text_generation_server.models import Model
from typing import Any, List, Optional, Tuple, Type, Dict
from text_generation_server.models.types import (
@ -28,21 +31,35 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: int, d_state: int, seqlen_offset: int, dtype: torch.dtype, device: torch.device):
def new_inference_params(
n_blocks: int,
batch_size: int,
d_inner: int,
d_conv: int,
d_state: int,
seqlen_offset: int,
dtype: torch.dtype,
device: torch.device,
):
max_seqlen = 0
conv_states = torch.zeros(
(n_blocks,
batch_size,
d_inner,
d_conv,),
(
n_blocks,
batch_size,
d_inner,
d_conv,
),
device=device,
dtype=dtype,
)
ssm_states = torch.zeros(
(n_blocks,
batch_size,
d_inner,
d_state,),
(
n_blocks,
batch_size,
d_inner,
d_state,
),
device=device,
dtype=dtype,
)
@ -52,7 +69,6 @@ def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: i
seqlen_offset=seqlen_offset,
conv_states=conv_states,
ssm_states=ssm_states,
)
return inference_params
@ -124,7 +140,9 @@ class MambaBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
@ -251,7 +269,9 @@ class MambaBatch(Batch):
# TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
self.inference_params.conv_states = self.inference_params.conv_states[:, indices]
self.inference_params.conv_states = self.inference_params.conv_states[
:, indices
]
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
return self
@ -280,13 +300,20 @@ class MambaBatch(Batch):
max_seqlen = 0
seqlen_offset = 0
(n_blocks, _, d_inner, d_conv) = (
batches[0].inference_params.conv_states.shape
)
(n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
(_, _, _, d_state) = batches[0].inference_params.ssm_states.shape
dtype = batches[0].inference_params.conv_states.dtype
device = batches[0].inference_params.conv_states.device
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=total_batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=device, dtype=dtype)
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=total_batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=device,
dtype=dtype,
)
# Batch tensors
input_ids = None
@ -334,13 +361,20 @@ class MambaBatch(Batch):
max_input_length - batch.max_input_length
) * len(batch)
inference_params.max_seqlen = max(inference_params.max_seqlen, batch.inference_params.max_seqlen)
inference_params.max_seqlen = max(
inference_params.max_seqlen, batch.inference_params.max_seqlen
)
assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset"
inference_params.seqlen_offset = max(inference_params.seqlen_offset, batch.inference_params.seqlen_offset)
inference_params.seqlen_offset = max(
inference_params.seqlen_offset, batch.inference_params.seqlen_offset
)
inference_params.conv_states[:, start_index:end_index] = batch.inference_params.conv_states
inference_params.ssm_states[:, start_index:end_index] = batch.inference_params.ssm_states
inference_params.conv_states[:, start_index:end_index] = (
batch.inference_params.conv_states
)
inference_params.ssm_states[:, start_index:end_index] = (
batch.inference_params.ssm_states
)
start_index = end_index
@ -452,36 +486,39 @@ class Mamba(Model):
# Important seqlen_offset to go through the update mecanism with the state
seqlen_offset = 1
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype)
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
graph = torch.cuda.CUDAGraph()
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(
input_ids=input_ids,
inference_params=inference_params
)
self.model.forward(input_ids=input_ids, inference_params=inference_params)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
logits = self.model.forward(
input_ids=input_ids,
inference_params=inference_params
input_ids=input_ids, inference_params=inference_params
)
torch.cuda.synchronize()
graph_dict = {
"input_ids": input_ids,
"inference_params": inference_params,
"graph": graph,
"logits": logits
"logits": logits,
}
self.cuda_graphs[batch_size] = graph_dict
def forward(
self,
input_ids: torch.Tensor,
inference_params: Any
self, input_ids: torch.Tensor, inference_params: Any
) -> Tuple[torch.Tensor, torch.Tensor]:
bs = input_ids.shape[0]
padded_bs = bs
@ -504,15 +541,21 @@ class Mamba(Model):
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: bs] = input_ids
cuda_graph["inference_params"].conv_states[:, : bs] = inference_params.conv_states
cuda_graph["inference_params"].ssm_states[:, : bs] = inference_params.ssm_states
cuda_graph["input_ids"][:bs] = input_ids
cuda_graph["inference_params"].conv_states[
:, :bs
] = inference_params.conv_states
cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states
# Replay the graph
cuda_graph["graph"].replay()
inference_params.conv_states.copy_(cuda_graph["inference_params"].conv_states[:, :bs])
inference_params.ssm_states.copy_(cuda_graph["inference_params"].ssm_states[:, :bs])
inference_params.conv_states.copy_(
cuda_graph["inference_params"].conv_states[:, :bs]
)
inference_params.ssm_states.copy_(
cuda_graph["inference_params"].ssm_states[:, :bs]
)
# Slice output to the correct shape
return cuda_graph["logits"][:bs]
@ -533,14 +576,20 @@ class Mamba(Model):
d_state = self.model.config.d_state
d_conv = self.model.config.d_conv
d_inner = self.model.config.d_inner
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype)
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
batch.inference_params = inference_params
# Forward pass
logits = self.forward(
input_ids, inference_params=batch.inference_params
)
logits = self.forward(input_ids, inference_params=batch.inference_params)
# batch.inference_params = new_inference_params
# Results
@ -694,9 +743,9 @@ class Mamba(Model):
generations.append(generation)
# Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.next_token_choosers[i] = batch.next_token_choosers[
i
].advance_grammar(next_token_id_squeezed.item())
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length

View File

@ -36,9 +36,11 @@ class RW(CausalLM):
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)

View File

@ -96,7 +96,9 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs)
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
@ -351,9 +353,9 @@ class Seq2SeqLMBatch(Batch):
(total_batch_size, max_input_length),
)
# Copy to correct indices
attention_mask[
start_index:end_index, -batch.max_input_length :
] = batch.attention_mask[:, -batch.max_input_length :]
attention_mask[start_index:end_index, -batch.max_input_length :] = (
batch.attention_mask[:, -batch.max_input_length :]
)
# Create padded tensor
if decoder_input_ids is None:
@ -547,9 +549,11 @@ class Seq2SeqLM(Model):
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
@ -750,7 +754,7 @@ class Seq2SeqLM(Model):
if top_n_tokens > 0:
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(
for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode(

View File

@ -88,14 +88,16 @@ class Generation:
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
request_id=self.request_id,
prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None
else None,
prefill_tokens=(
self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None
),
tokens=self.tokens.to_pb(),
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens]
if self.top_tokens is not None
else None,
generated_text=(
self.generated_text.to_pb() if self.generated_text is not None else None
),
top_tokens=(
[top_tokens.to_pb() for top_tokens in self.top_tokens]
if self.top_tokens is not None
else None
),
)

View File

@ -182,7 +182,7 @@ try:
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) & maxq # eventually avoid overflow
zeros = (zeros + 1) & maxq # eventually avoid overflow
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated

View File

@ -355,7 +355,9 @@ def get_linear(weight, bias, quantize):
"to use Exllama/GPTQ kernels for AWQ inference."
)
if not HAS_AWQ:
raise NotImplementedError("You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly")
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
linear = WQLinear(
w_bit=bits,
group_size=groupsize,

View File

@ -516,7 +516,7 @@ class GrammarLogitProcessor(LogitsProcessor):
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity
pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer)
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm

View File

@ -409,8 +409,12 @@ class HeterogeneousNextTokenChooser:
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
if self.grammar_processor is not None:
self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index(
next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index
self.fsm_grammar_states[grammar_state_index] = (
self.grammar_processor.advance_at_index(
next_id,
self.fsm_grammar_states[grammar_state_index],
grammar_state_index,
)
)
return self