feat: add ruff and resolve issue (#2262)

* feat: add ruff and resolve issue

* fix: update client exports and adjust after rebase

* fix: adjust syntax to avoid circular import

* fix: adjust client ruff settings

* fix: lint and refactor import check and avoid model enum as global names

* fix: improve fbgemm_gpu check and lints

* fix: update lints

* fix: prefer comparing model enum over str

* fix: adjust lints and ignore specific rules

* fix: avoid unneeded quantize check
This commit is contained in:
drbh 2024-07-26 10:29:09 -04:00 committed by GitHub
parent 4b49c50f4c
commit bab02ff2bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
74 changed files with 267 additions and 302 deletions

View File

@ -16,3 +16,8 @@ repos:
- id: fmt
- id: cargo-check
- id: clippy
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

View File

@ -19,5 +19,15 @@ DEPRECATION_WARNING = (
"Please use the `InferenceClient` from the `huggingface_hub` package instead."
)
from text_generation.client import Client, AsyncClient
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
from text_generation.client import Client, AsyncClient # noqa E402
from text_generation.inference_api import ( # noqa E402
InferenceAPIClient,
InferenceAPIAsyncClient,
)
__all__ = [
"Client",
"AsyncClient",
"InferenceAPIClient",
"InferenceAPIAsyncClient",
]

View File

@ -21,7 +21,7 @@ def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]:
List[DeployedModel]: list of all currently deployed models
"""
resp = requests.get(
f"https://api-inference.huggingface.co/framework/text-generation-inference",
"https://api-inference.huggingface.co/framework/text-generation-inference",
headers=headers,
timeout=5,
)

View File

@ -4,7 +4,6 @@ import json
import math
import os
import random
import re
import shutil
import subprocess
import sys
@ -271,7 +270,7 @@ class LauncherHandle:
try:
await self.client.generate("test")
return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
time.sleep(1)
raise RuntimeError("Health check failed")

View File

@ -1,7 +1,4 @@
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")

View File

@ -1,6 +1,4 @@
import pytest
import requests
import io
import base64

View File

@ -74,9 +74,7 @@ async def test_idefics_load(idefics, generate_load, response_snapshot):
generated_texts = [r.generated_text for r in responses]
assert (
generated_texts[0] == " \nAssistant: A rooster stands"
), f"{response.generated_text}"
assert generated_texts[0] == " \nAssistant: A rooster stands"
assert len(generated_texts) == 4
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]

View File

@ -1,7 +1,4 @@
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
@ -91,7 +88,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
@ -129,7 +126,7 @@ async def test_flash_llama_grammar_tools_auto(
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
@ -168,7 +165,7 @@ async def test_flash_llama_grammar_tools_choice(
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
@ -241,7 +238,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
stream=False,
)
assert responses.choices[0].message.content == None
assert responses.choices[0].message.content is None
assert responses.choices[0].message.tool_calls == [
{
"function": {

View File

@ -20,7 +20,7 @@ def main():
break
with open("./small.json", "w") as f:
data = json.dump(conversations, f, indent=4)
json.dump(conversations, f, indent=4)
if __name__ == "__main__":

View File

@ -1,11 +1,9 @@
import os
import requests
import tempfile
import pytest
import huggingface_hub.constants
from huggingface_hub import hf_api
import text_generation_server.utils.hub
from text_generation_server.utils.hub import (

View File

@ -2,7 +2,6 @@ import torch
from text_generation_server.layers import (
TensorParallelEmbedding,
)
from text_generation_server.utils.weights import DefaultWeightsLoader
class ProcessGroup:

View File

@ -2,7 +2,6 @@ import pytest
import torch
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
WeightsLoader,
)
@ -86,15 +85,6 @@ dummy_file_system = {
],
dtype=torch.float32,
),
"weight.weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
},
"test_get_weights_row": {
"weight.weight": torch.tensor(
@ -966,7 +956,7 @@ def test_get_multi_weights_col_exl2():
prefix = "weight"
try:
w = weights.get_multi_weights_col(
weights.get_multi_weights_col(
prefixes=[prefix],
dim=0,
)

View File

@ -4,15 +4,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Set, Tuple
from typing import Dict, Set, Tuple
import torch
from text_generation_server.adapters.weights import AdapterWeights
if TYPE_CHECKING:
from text_generation_server.models.model import Model
@dataclass
class ModuleMap:

View File

@ -4,7 +4,7 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch
from peft import LoraConfig as _LoraConfig
@ -26,9 +26,6 @@ from text_generation_server.utils.sgmv import (
use_cutlass_shrink,
)
if TYPE_CHECKING:
from text_generation_server.models.model import Model
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size

View File

@ -4,12 +4,11 @@ import typer
from pathlib import Path
from loguru import logger
from typing import Optional, List, Dict
from typing import Optional
from enum import Enum
from huggingface_hub import hf_hub_download
from text_generation_server.utils.adapter import parse_lora_adapters
from text_generation_server.utils.log import log_master
app = typer.Typer()
@ -165,7 +164,7 @@ def download_weights(
# currently by default we don't merge the weights with the base model
if merge_lora:
try:
adapter_config_filename = hf_hub_download(
hf_hub_download(
model_id, revision=revision, filename="adapter_config.json"
)
utils.download_and_unload_peft(
@ -285,9 +284,9 @@ def download_weights(
if auto_convert:
if not trust_remote_code:
logger.warning(
f"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
f"Pickle files are unsafe and can essentially contain remote code execution!"
f"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
"Pickle files are unsafe and can essentially contain remote code execution!"
"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
)
logger.warning(
@ -319,7 +318,7 @@ def download_weights(
# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])
except Exception as e:
except Exception:
discard_names = []
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files, discard_names)

View File

@ -18,3 +18,17 @@ from text_generation_server.layers.lora import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
__all__ = [
"get_linear",
"FastLinear",
"TensorParallelColumnLinear",
"TensorParallelRowLinear",
"TensorParallelEmbedding",
"SpeculativeHead",
"LoraLinear",
"TensorParallelMultiAdapterLinear",
"TensorParallelAdapterRowLinear",
"load_layer_norm",
"load_conv2d",
]

View File

@ -13,3 +13,12 @@ elif SYSTEM == "ipex":
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
__all__ = [
"attention",
"paged_attention",
"reshape_and_cache",
"SUPPORTS_WINDOWING",
"Seqlen",
]

View File

@ -10,7 +10,6 @@ _PARTITION_SIZE = 512
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"

View File

@ -747,11 +747,8 @@ class _attention(torch.autograd.Function):
padded_d_model = 1 << (head_size - 1).bit_length()
padded_d_model = max(padded_d_model, 16)
grid = lambda META: (
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
nheads_q,
batch,
)
def grid(META):
return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch
encoded_softmax = None

View File

@ -15,7 +15,6 @@ ENGINE = "triton" if use_triton else "ck"
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"

View File

@ -1,6 +1,5 @@
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
import math
from typing import Optional
import torch
import torch.nn as nn

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass
from functools import lru_cache
import bitsandbytes as bnb
import torch

View File

@ -12,17 +12,26 @@ from text_generation_server.utils.weights import (
Weights,
)
from text_generation_server.utils.log import log_master, log_once
import importlib.util
FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False
try:
import fbgemm_gpu.experimental.gen_ai
def is_fbgemm_gpu_available():
try:
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
except ModuleNotFoundError:
return False
if is_fbgemm_gpu_available():
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8
except (ImportError, ModuleNotFoundError):
else:
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")

View File

@ -8,6 +8,34 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
major, _minor = torch.cuda.get_device_capability()
except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
if V2:
from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear, # noqa: F401
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear, # noqa: F401
)
HAS_EXLLAMA = "1"
except ImportError:
pass
@dataclass
class GPTQWeight(Weight):
@ -55,7 +83,7 @@ class GPTQWeight(Weight):
from text_generation_server.layers.gptq import ExllamaQuantLinear
except ImportError:
raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
return ExllamaQuantLinear(self, bias)
@ -73,45 +101,6 @@ class GPTQWeight(Weight):
)
try:
major, _minor = torch.cuda.get_device_capability()
except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
if V2:
from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllamav2 import (
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllama import (
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "1"
except ImportError:
pass
from text_generation_server.layers.gptq.quant_linear import QuantLinear
class GPTQWeightsLoader(WeightsLoader):
"""
Loader for GPTQ- and AWQ-quantized weights.

View File

@ -206,10 +206,13 @@ def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
output = torch.empty(
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
)
grid = lambda META: (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
def grid(META):
return (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
matmul_248_kernel[grid](
input,
qweight,

View File

@ -15,6 +15,7 @@ from text_generation_server.utils.hub import weight_files
from text_generation_server.layers.gptq.quant_linear import QuantLinear
from loguru import logger
from typing import Optional
from text_generation_server.layers.gptq.utils import torch_snr_error
from text_generation_server.utils.weights import DefaultWeightsLoader
@ -372,7 +373,7 @@ def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
@ -404,7 +405,7 @@ def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
@ -448,7 +449,7 @@ def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
@ -504,7 +505,7 @@ def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
@ -546,7 +547,7 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
@ -700,6 +701,8 @@ def sequential(
pass
def add_batch(name):
nonlocal gptq
def tmp(_, inp, out):
gptq[name].add_batch(inp[0].data, out.data)

View File

@ -0,0 +1,56 @@
import torch
# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
def torch_snr_error(
y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean"
) -> torch.Tensor:
"""
Compute SNR between y_pred(tensor) and y_real(tensor)
SNR can be calcualted as following equation:
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
Args:
y_pred (torch.Tensor): _description_
y_real (torch.Tensor): _description_
reduction (str, optional): _description_. Defaults to 'mean'.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
torch.Tensor: _description_
"""
if y_pred.shape != y_real.shape:
raise ValueError(
f"Can not compute snr loss for tensors with different shape. "
f"({y_pred.shape} and {y_real.shape})"
)
reduction = str(reduction).lower()
if y_pred.ndim == 1:
y_pred = y_pred.unsqueeze(0)
y_real = y_real.unsqueeze(0)
y_pred = y_pred.flatten(start_dim=1)
y_real = y_real.flatten(start_dim=1)
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
signal_power = torch.pow(y_real, 2).sum(dim=-1)
snr = (noise_power) / (signal_power + 1e-7)
if reduction == "mean":
return torch.mean(snr)
elif reduction == "sum":
return torch.sum(snr)
elif reduction == "none":
return snr
else:
raise ValueError("Unsupported reduction method.")

View File

@ -1,5 +1,3 @@
from typing import Optional
import torch
from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F

View File

@ -1,12 +1,8 @@
import math
import os
from typing import TYPE_CHECKING, Optional, Tuple, List
from typing import TYPE_CHECKING, Optional, List
import torch
import torch.distributed
from accelerate import init_empty_weights
from torch import nn
from torch.nn import functional as F
from torch.distributed import ProcessGroup
from text_generation_server.utils.sgmv import (

View File

@ -1,6 +1,3 @@
from typing import List, Tuple
import torch
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
from text_generation_server.layers.marlin.gptq import (
GPTQMarlinLinear,

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union
import torch
import torch.nn as nn
@ -85,7 +85,7 @@ class MarlinWeightsLoader(WeightsLoader):
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
"Cannot load `marlin` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
@ -104,7 +104,7 @@ class MarlinWeightsLoader(WeightsLoader):
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
"Cannot load `marlin` weight, make sure the model is already quantized"
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1

View File

@ -2,12 +2,9 @@ import os
import math
import torch
from torch import nn
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda":
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
elif SYSTEM == "rocm":
from vllm._C import ops

View File

@ -33,7 +33,7 @@ class SpeculativeHead(torch.nn.Module):
except KeyError:
try:
speculator = MedusaHeadV1.load(config, prefix, weights)
except:
except Exception:
speculator = MedusaHeadV2(config, prefix, weights)
lm_head = None
else:

View File

@ -2,7 +2,6 @@ import torch
from torch.nn import functional as F
from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
@ -50,7 +49,7 @@ class TensorParallelHead(SuperLayer):
# If the piece and LM head embeddings are shared, we have
# non-quantized weights...
weight = weights.get_tensor(f"{prefix}.weight")
except:
except Exception:
# ...otherwise they are quantized.
weight = weights.get_weights_col(prefix)
should_gather = weights.process_group.size() > 1
@ -67,15 +66,6 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq", "marlin"]:
quantize = None
# See above, exl2 LM head can be quantized or not.
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
quantize = None
else:
quantize = config.quantize
return TensorParallelHead(
get_linear(weight, bias=None),
process_group=weights.process_group,

View File

@ -1,3 +1,6 @@
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables
import torch
import enum
import os
@ -712,6 +715,7 @@ def get_model(
)
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
print(f">>> model_type: {model_type}")
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
@ -856,7 +860,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
else:
if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashCausalLM(

View File

@ -233,7 +233,7 @@ class CausalLMBatch(Batch):
]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple:
if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection
@ -377,7 +377,7 @@ class CausalLMBatch(Batch):
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple:
if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values

View File

@ -908,7 +908,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
loss = None
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return (

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union
from typing import Optional, Tuple
import torch
from torch import nn
@ -9,9 +9,7 @@ from transformers.modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
)
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
)
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
@ -446,11 +444,12 @@ class CLIPEncoder(nn.Module):
class CLIPTextTransformer(nn.Module):
def __init__(self, prefix: str, config: CLIPTextConfig):
def __init__(self, prefix: str, config: CLIPTextConfig, weights=None):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPTextEmbeddings(config)
# Initialize weights and apply final processing with `self.post_init()`
self.encoder = CLIPEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
@ -505,7 +504,7 @@ class CLIPTextTransformer(nn.Module):
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
last_hidden_state[
torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device
),
@ -515,7 +514,7 @@ class CLIPTextTransformer(nn.Module):
]
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[
last_hidden_state[
torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device
),
@ -565,9 +564,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.text_model(
input_ids=input_ids,
@ -580,7 +576,6 @@ class CLIPVisionTransformer(nn.Module):
def __init__(self, prefix, config: CLIPVisionConfig, weights):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
@ -661,9 +656,6 @@ class CLIPVisionModel(CLIPPreTrainedModel):
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled CLS states
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.vision_model(
pixel_values=pixel_values,
@ -799,14 +791,12 @@ class CLIPModel(nn.Module):
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
vision_outputs = self.vision_model(
pixel_values=pixel_values,
return_dict=return_dict,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]

View File

@ -30,7 +30,6 @@ from text_generation_server.layers.attention import (
attention,
reshape_and_cache,
)
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
TensorParallelRowLinear,

View File

@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import (
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.utils.log import log_once
class DbrxAttentionConfig(PretrainedConfig):

View File

@ -39,6 +39,12 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class DeepseekV2Config(PretrainedConfig):
def __init__(

View File

@ -46,7 +46,6 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import (
UnquantizedWeight,
Weights,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
@ -277,7 +276,7 @@ class LlamaMLP(nn.Module):
bias=bias,
)
else:
prefixes = [f"gate_proj", f"up_proj"]
prefixes = ["gate_proj", "up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,

View File

@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
Seqlen,
paged_attention,
attention,
reshape_and_cache,
@ -38,7 +37,6 @@ from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)

View File

@ -21,7 +21,6 @@
import torch
import torch.distributed
import numpy as np
from torch import nn
from text_generation_server.utils.import_utils import SYSTEM
@ -31,7 +30,6 @@ if SYSTEM != "ipex":
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from loguru import logger
from text_generation_server.layers.attention import (
paged_attention,

View File

@ -16,7 +16,6 @@
import torch
import torch.distributed
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear

View File

@ -15,7 +15,6 @@ from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (

View File

@ -14,7 +14,7 @@
# limitations under the License.
""" PyTorch Idefics2 model."""
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint
@ -22,10 +22,8 @@ from torch import nn
import math
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask

View File

@ -19,6 +19,7 @@ import numpy as np
from PIL import Image
import transformers
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_transforms import (
resize,
@ -293,6 +294,4 @@ class IdeficsImageProcessor(BaseImageProcessor):
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
import transformers
transformers.IdeficsImageProcessor = IdeficsImageProcessor

View File

@ -21,10 +21,8 @@
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
@ -33,13 +31,6 @@ from transformers.modeling_outputs import (
CausalLMOutputWithPast,
dataclass,
)
from transformers.modeling_utils import PretrainedConfig
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_vision import (
IdeficsVisionTransformer,
@ -56,6 +47,7 @@ from text_generation_server.layers import (
)
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.import_utils import SYSTEM
from loguru import logger
if SYSTEM == "cuda":
import dropout_layer_norm
@ -237,7 +229,7 @@ class IdeficsDecoupledPartialTPEmbedding(nn.Module):
prefix="model.embed_tokens", weights=weights
)
self.additional_weight = nn.Parameter(
weights.get_tensor(f"model.embed_tokens.additional_embedding.weight")
weights.get_tensor("model.embed_tokens.additional_embedding.weight")
)
def forward(self, input_ids):
@ -499,7 +491,6 @@ class IdeficsAttention(nn.Module):
# if not hasattr(nn.functional, "scaled_dot_product_attention"):
# raise ValueError("this model requires pytorch 2.0 or higher")
process_group = weights.process_group
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
@ -1024,7 +1015,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
if config.use_resampler:
perceiver_config = config.perceiver_config
self.perceiver_resampler = IdeficsPerceiverResampler(
prefix=f"model.perceiver_resampler",
prefix="model.perceiver_resampler",
config=config,
embed_dim=config.vision_config.embed_dim,
depth=perceiver_config.resampler_depth,
@ -1052,7 +1043,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
# self.gradient_checkpointing = False
self.norm = IdeficsRMSNorm(
prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)
# self.gradient_checkpointing = False

View File

@ -169,7 +169,6 @@ class IdeficsPerceiverAttention(nn.Module):
self.qk_scale = self.head_dim**-0.5
process_group = weights.process_group
if n_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} "

View File

@ -28,9 +28,6 @@ from transformers.tokenization_utils_base import (
TruncationStrategy,
)
from transformers.utils import TensorType, is_torch_available
from text_generation_server.models.custom_modeling.idefics_image_processing import (
IdeficsImageProcessor,
)
if is_torch_available():

View File

@ -129,7 +129,6 @@ class IdeficsVisionAttention(nn.Module):
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
process_group = weights.process_group
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
@ -460,7 +459,6 @@ class IdeficsVisionTransformer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = IdeficsVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights

View File

@ -14,7 +14,7 @@
# limitations under the License.
""" PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint

View File

@ -4,7 +4,6 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
"""
import math
import os
import warnings
from typing import List, Optional, Tuple, Union
import torch
@ -194,7 +193,7 @@ def flash_attn_fn(
):
try:
from flash_attn import bert_padding, flash_attn_interface
except:
except Exception:
raise RuntimeError("Please install flash-attn==1.0.3.post0")
check_valid_inputs(query, key, value)
if past_key_value is not None:
@ -207,7 +206,7 @@ def flash_attn_fn(
_s_k = max(0, attn_bias.size(3) - key.size(1))
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
if attn_bias is not None:
raise NotImplementedError(f"attn_bias not implemented for flash attn.")
raise NotImplementedError("attn_bias not implemented for flash attn.")
(batch_size, seqlen) = query.shape[:2]
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
@ -269,13 +268,13 @@ def triton_flash_attn_fn(
):
try:
from .flash_attn_triton import flash_attn_func
except:
except Exception:
_installed = False
if version.parse(torch.__version__) < version.parse("2.0.0"):
_installed = True
try:
from flash_attn.flash_attn_triton import flash_attn_func
except:
except Exception:
_installed = False
if not _installed:
raise RuntimeError(
@ -292,9 +291,9 @@ def triton_flash_attn_fn(
_s_k = max(0, attn_bias.size(3) - key.size(1))
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
if dropout_p:
raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
raise NotImplementedError("Dropout not implemented for attn_impl: triton.")
if needs_weights:
raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
raise NotImplementedError("attn_impl: triton cannot return attn weights.")
if key_padding_mask is not None:
warnings.warn(
"Propagating key_padding_mask to the attention module "
@ -428,7 +427,7 @@ class MultiQueryAttention(nn.Module):
additive bias.
"""
def __init__(self, config, prefix, weights):
def __init__(self, config, prefix, weights, verbose=False):
super().__init__()
attn_impl = config.attn_config.attn_impl
self.attn_impl = config.attn_config.attn_impl
@ -445,7 +444,7 @@ class MultiQueryAttention(nn.Module):
self.Wqkv = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
)
fuse_splits = (d_model, d_model + self.head_dim)
(d_model, d_model + self.head_dim)
if self.qk_ln:
raise NotImplementedError("qk_ln not supported")
if self.attn_impl == "flash":
@ -795,7 +794,9 @@ class MPTModel(MPTPreTrainedModel):
self.alibi = config.attn_config.alibi
self.alibi_bias_max = config.attn_config.alibi_bias_max
if config.init_device == "mixed":
if dist.get_local_rank() == 0:
# TODO: reimplement mixed device initialization
# dist.get_local_rank() == 0:
if True:
config.init_device = "cpu"
else:
config.init_device = "meta"
@ -1016,7 +1017,7 @@ class MPTModel(MPTPreTrainedModel):
if past_key_values is not None:
if len(past_key_values) != self.config.n_layers:
raise ValueError(
f"past_key_values must provide a past_key_value for each attention "
"past_key_values must provide a past_key_value for each attention "
+ f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
)
past_position = past_key_values[0][0].size(1)
@ -1182,7 +1183,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
input_ids = input_ids[:, -1].unsqueeze(-1)
if self.transformer.prefix_lm:
prefix_mask = torch.ones_like(attention_mask)
if kwargs.get("use_cache") == False:
if kwargs.get("use_cache") is False:
raise NotImplementedError(
"MPT with prefix_lm=True does not support use_cache=False."
)

View File

@ -21,25 +21,14 @@ import torch
import torch.distributed
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers import GPTNeoXConfig
from loguru import logger
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
@ -133,7 +122,6 @@ class GPTNeoXAttention(nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct)
max_positions = config.max_position_embeddings
# ??? TODO
# self.register_buffer(
# "bias",

View File

@ -5,7 +5,7 @@ import torch.distributed
import math
from torch import nn
from typing import Optional, List, Tuple, Any
from typing import Optional, List, Tuple
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast

View File

@ -1,20 +1,15 @@
from typing import Optional, Tuple, Union
from typing import Optional, Tuple
import warnings
import math
import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import (
_create_4d_causal_attention_mask,
_prepare_4d_attention_mask,
)
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
)
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from transformers import SiglipConfig, SiglipVisionConfig
from torch.nn.init import _calculate_fan_in_and_fan_out
from text_generation_server.layers.tensor_parallel import (
TensorParallelEmbedding,
@ -244,9 +239,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
return hidden_state[:, 0]
import warnings
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
@ -264,12 +256,12 @@ def _trunc_normal_(tensor, mean, std, a, b):
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
lower = norm_cdf((a - mean) / std)
upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
@ -313,9 +305,6 @@ def trunc_normal_tf_(
tensor.mul_(std).add_(mean)
from torch.nn.init import _calculate_fan_in_and_fan_out
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
@ -349,9 +338,6 @@ def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
from transformers import PreTrainedModel
class SiglipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
@ -393,7 +379,6 @@ class SiglipVisionTransformer(nn.Module):
def __init__(self, prefix, config: SiglipVisionConfig, weights):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights

View File

@ -45,6 +45,15 @@ from text_generation_server.layers import (
SpeculativeHead,
)
# copied from https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/t5/modeling_t5.py#L1316
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""
class PartialTPEmbedding(nn.Module):
def __init__(self, prefix: str, weights):
@ -1132,12 +1141,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
output = (logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
return (

View File

@ -42,7 +42,7 @@ def load_vision_model(prefix, config, weights):
)
return SiglipVisionTransformer(
prefix=f"vision_tower.vision_model", config=config, weights=weights
prefix="vision_tower.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")

View File

@ -1194,7 +1194,7 @@ class FlashCausalLM(Model):
if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt)
except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed")
logger.exception("Decode cuda graph warmup failed")
else:
log_master(
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."

View File

@ -2,23 +2,15 @@ import re
import torch
import torch.distributed
from typing import List, Optional, Type
from transformers import (
AutoTokenizer,
AutoConfig,
PreTrainedTokenizerBase,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.chunks import concat_text_chunks

View File

@ -1,13 +1,8 @@
import torch
import torch.distributed
from typing import List, Optional, Tuple
from typing import Optional
from transformers import (
AutoTokenizer,
AutoConfig,
AutoProcessor,
)
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_processing import (

View File

@ -289,7 +289,7 @@ class IdeficsCausalLMBatch(Batch):
image_hidden_states = self.image_hidden_states[keep_indices]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple:
if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection
@ -456,7 +456,7 @@ class IdeficsCausalLMBatch(Batch):
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple:
if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values

View File

@ -2,7 +2,6 @@ import torch
import torch.distributed
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional
import os
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig,
)
@ -20,7 +19,7 @@ from text_generation_server.models.custom_modeling.mamba_modeling import (
InferenceParams,
)
from text_generation_server.models import Model
from typing import Any, List, Optional, Tuple, Type, Dict
from typing import Any, List, Tuple, Type, Dict
from text_generation_server.models.types import (
Batch,
Tokens,
@ -31,7 +30,7 @@ from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.utils import NextTokenChooser, StoppingCriteria
def new_inference_params(
@ -299,7 +298,6 @@ class MambaBatch(Batch):
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
max_seqlen = 0
seqlen_offset = 0
(n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
@ -485,7 +483,7 @@ class Mamba(Model):
for bs in CUDA_GRAPHS:
self.cuda_graph_warmup(bs)
except Exception:
logger.exception(f"Decode cuda graph warmup failed")
logger.exception("Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
@ -534,7 +532,7 @@ class Mamba(Model):
}
self.cuda_graphs[batch_size] = graph_dict
def tunableop_warmup(self, seqlen: int):
def tunableop_warmup(self, batch_size: int, seqlen: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
n_blocks = len(self.model.blocks)

View File

@ -2,7 +2,7 @@ import inspect
import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
from typing import List, Tuple, Optional, TypeVar, Type, Dict
from collections import defaultdict
from transformers import PreTrainedTokenizerBase

View File

@ -3,16 +3,11 @@ from PIL import Image
import torch
import torch.distributed
from opentelemetry import trace
from typing import Iterable, Optional, Tuple
from typing import Iterable
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLM,
VlmCausalLMBatch,
image_text_replacement,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from transformers import AutoProcessor, AutoConfig
from text_generation_server.pb.generate_pb2 import Request

View File

@ -1,7 +1,6 @@
import torch
import torch.distributed
import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
@ -11,7 +10,7 @@ from transformers import (
AutoConfig,
)
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
@ -254,7 +253,7 @@ class Seq2SeqLMBatch(Batch):
]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple:
if type(self.past_key_values[0]) is tuple:
self.past_key_values = [
[t for t in layer] for layer in self.past_key_values
]
@ -430,7 +429,7 @@ class Seq2SeqLMBatch(Batch):
batch.encoder_last_hidden_state = None
# Ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple:
if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [
[t for t in layer] for layer in batch.past_key_values
]

View File

@ -1,4 +1,3 @@
from functools import total_ordering
import torch
from abc import ABC, abstractmethod

View File

@ -9,7 +9,7 @@ from loguru import logger
from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import List, Optional, Dict
from typing import List, Optional
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor

View File

@ -1,15 +1,13 @@
import torch
from loguru import logger
import subprocess
import os
import importlib.util
def is_ipex_available():
try:
import intel_extension_for_pytorch
except ImportError:
return False
return True
return importlib.util.find_spec("intel_extension_for_pytorch") is not None
def get_cuda_free_memory(device, memory_fraction):

View File

@ -2,9 +2,17 @@ import copy
from abc import ABC
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
from text_generation_server.utils.merges.utils import (
calculate_majority_sign_mask,
disjoint_merge,
prune,
)
import torch
if TYPE_CHECKING:
from text_generation_server.adapters.lora import LoraConfig
from text_generation_server.utils.adapter import ModuleMap
class AdapterParameters:
def __init__(
@ -17,17 +25,6 @@ class AdapterParameters:
self.majority_sign_method = majority_sign_method
from text_generation_server.utils.merges.utils import (
calculate_majority_sign_mask,
disjoint_merge,
prune,
)
if TYPE_CHECKING:
from text_generation_server.adapters.lora import LoraConfig
from text_generation_server.utils.adapter import ModuleMap
def _apply_weights(
tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor
) -> torch.Tensor:

View File

@ -28,7 +28,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
low_cpu_mem_usage=True,
)
logger.info("Peft model detected.")
logger.info(f"Merging the lora weights.")
logger.info("Merging the lora weights.")
base_model_id = model.peft_config["default"].base_model_name_or_path

View File

@ -6,7 +6,6 @@ from typing import Optional
from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
WeightsLoader,
)

View File

@ -1,7 +1,6 @@
import re
from typing import List, Optional, Tuple, Set, Union
import math
import torch
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType