feat: add ruff and resolve issue (#2262)

* feat: add ruff and resolve issue

* fix: update client exports and adjust after rebase

* fix: adjust syntax to avoid circular import

* fix: adjust client ruff settings

* fix: lint and refactor import check and avoid model enum as global names

* fix: improve fbgemm_gpu check and lints

* fix: update lints

* fix: prefer comparing model enum over str

* fix: adjust lints and ignore specific rules

* fix: avoid unneeded quantize check
This commit is contained in:
drbh 2024-07-26 10:29:09 -04:00 committed by GitHub
parent 4b49c50f4c
commit bab02ff2bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
74 changed files with 267 additions and 302 deletions

View File

@ -16,3 +16,8 @@ repos:
- id: fmt - id: fmt
- id: cargo-check - id: cargo-check
- id: clippy - id: clippy
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

View File

@ -19,5 +19,15 @@ DEPRECATION_WARNING = (
"Please use the `InferenceClient` from the `huggingface_hub` package instead." "Please use the `InferenceClient` from the `huggingface_hub` package instead."
) )
from text_generation.client import Client, AsyncClient from text_generation.client import Client, AsyncClient # noqa E402
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient from text_generation.inference_api import ( # noqa E402
InferenceAPIClient,
InferenceAPIAsyncClient,
)
__all__ = [
"Client",
"AsyncClient",
"InferenceAPIClient",
"InferenceAPIAsyncClient",
]

View File

@ -21,7 +21,7 @@ def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]:
List[DeployedModel]: list of all currently deployed models List[DeployedModel]: list of all currently deployed models
""" """
resp = requests.get( resp = requests.get(
f"https://api-inference.huggingface.co/framework/text-generation-inference", "https://api-inference.huggingface.co/framework/text-generation-inference",
headers=headers, headers=headers,
timeout=5, timeout=5,
) )

View File

@ -4,7 +4,6 @@ import json
import math import math
import os import os
import random import random
import re
import shutil import shutil
import subprocess import subprocess
import sys import sys
@ -271,7 +270,7 @@ class LauncherHandle:
try: try:
await self.client.generate("test") await self.client.generate("test")
return return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e: except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
time.sleep(1) time.sleep(1)
raise RuntimeError("Health check failed") raise RuntimeError("Health check failed")

View File

@ -1,7 +1,4 @@
import pytest import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -1,6 +1,4 @@
import pytest import pytest
import requests
import io
import base64 import base64

View File

@ -74,9 +74,7 @@ async def test_idefics_load(idefics, generate_load, response_snapshot):
generated_texts = [r.generated_text for r in responses] generated_texts = [r.generated_text for r in responses]
assert ( assert generated_texts[0] == " \nAssistant: A rooster stands"
generated_texts[0] == " \nAssistant: A rooster stands"
), f"{response.generated_text}"
assert len(generated_texts) == 4 assert len(generated_texts) == 4
assert generated_texts, all( assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts] [text == generated_texts[0] for text in generated_texts]

View File

@ -1,7 +1,4 @@
import pytest import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -91,7 +88,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
}, },
], ],
) )
assert response.choices[0].message.content == None assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0, "id": 0,
@ -129,7 +126,7 @@ async def test_flash_llama_grammar_tools_auto(
}, },
], ],
) )
assert response.choices[0].message.content == None assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0, "id": 0,
@ -168,7 +165,7 @@ async def test_flash_llama_grammar_tools_choice(
}, },
], ],
) )
assert response.choices[0].message.content == None assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0, "id": 0,
@ -241,7 +238,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
stream=False, stream=False,
) )
assert responses.choices[0].message.content == None assert responses.choices[0].message.content is None
assert responses.choices[0].message.tool_calls == [ assert responses.choices[0].message.tool_calls == [
{ {
"function": { "function": {

View File

@ -20,7 +20,7 @@ def main():
break break
with open("./small.json", "w") as f: with open("./small.json", "w") as f:
data = json.dump(conversations, f, indent=4) json.dump(conversations, f, indent=4)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,11 +1,9 @@
import os import os
import requests
import tempfile import tempfile
import pytest import pytest
import huggingface_hub.constants import huggingface_hub.constants
from huggingface_hub import hf_api
import text_generation_server.utils.hub import text_generation_server.utils.hub
from text_generation_server.utils.hub import ( from text_generation_server.utils.hub import (

View File

@ -2,7 +2,6 @@ import torch
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
) )
from text_generation_server.utils.weights import DefaultWeightsLoader
class ProcessGroup: class ProcessGroup:

View File

@ -2,7 +2,6 @@ import pytest
import torch import torch
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader, DefaultWeightsLoader,
UnquantizedWeight,
Weights, Weights,
WeightsLoader, WeightsLoader,
) )
@ -86,15 +85,6 @@ dummy_file_system = {
], ],
dtype=torch.float32, dtype=torch.float32,
), ),
"weight.weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
}, },
"test_get_weights_row": { "test_get_weights_row": {
"weight.weight": torch.tensor( "weight.weight": torch.tensor(
@ -966,7 +956,7 @@ def test_get_multi_weights_col_exl2():
prefix = "weight" prefix = "weight"
try: try:
w = weights.get_multi_weights_col( weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
dim=0, dim=0,
) )

View File

@ -4,15 +4,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Set, Tuple from typing import Dict, Set, Tuple
import torch import torch
from text_generation_server.adapters.weights import AdapterWeights from text_generation_server.adapters.weights import AdapterWeights
if TYPE_CHECKING:
from text_generation_server.models.model import Model
@dataclass @dataclass
class ModuleMap: class ModuleMap:

View File

@ -4,7 +4,7 @@
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch import torch
from peft import LoraConfig as _LoraConfig from peft import LoraConfig as _LoraConfig
@ -26,9 +26,6 @@ from text_generation_server.utils.sgmv import (
use_cutlass_shrink, use_cutlass_shrink,
) )
if TYPE_CHECKING:
from text_generation_server.models.model import Model
def get_start_stop_idxs_for_rank(offset, size, rank, world_size): def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size block_size = size // world_size

View File

@ -4,12 +4,11 @@ import typer
from pathlib import Path from pathlib import Path
from loguru import logger from loguru import logger
from typing import Optional, List, Dict from typing import Optional
from enum import Enum from enum import Enum
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.adapter import parse_lora_adapters from text_generation_server.utils.adapter import parse_lora_adapters
from text_generation_server.utils.log import log_master
app = typer.Typer() app = typer.Typer()
@ -165,7 +164,7 @@ def download_weights(
# currently by default we don't merge the weights with the base model # currently by default we don't merge the weights with the base model
if merge_lora: if merge_lora:
try: try:
adapter_config_filename = hf_hub_download( hf_hub_download(
model_id, revision=revision, filename="adapter_config.json" model_id, revision=revision, filename="adapter_config.json"
) )
utils.download_and_unload_peft( utils.download_and_unload_peft(
@ -285,9 +284,9 @@ def download_weights(
if auto_convert: if auto_convert:
if not trust_remote_code: if not trust_remote_code:
logger.warning( logger.warning(
f"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because " "🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
f"Pickle files are unsafe and can essentially contain remote code execution!" "Pickle files are unsafe and can essentially contain remote code execution!"
f"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety", "Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
) )
logger.warning( logger.warning(
@ -319,7 +318,7 @@ def download_weights(
# Name for this varible depends on transformers version. # Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", []) discard_names = getattr(class_, "_tied_weights_keys", [])
except Exception as e: except Exception:
discard_names = [] discard_names = []
# Convert pytorch weights to safetensors # Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files, discard_names) utils.convert_files(local_pt_files, local_st_files, discard_names)

View File

@ -18,3 +18,17 @@ from text_generation_server.layers.lora import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
__all__ = [
"get_linear",
"FastLinear",
"TensorParallelColumnLinear",
"TensorParallelRowLinear",
"TensorParallelEmbedding",
"SpeculativeHead",
"LoraLinear",
"TensorParallelMultiAdapterLinear",
"TensorParallelAdapterRowLinear",
"load_layer_norm",
"load_conv2d",
]

View File

@ -13,3 +13,12 @@ elif SYSTEM == "ipex":
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else: else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
__all__ = [
"attention",
"paged_attention",
"reshape_and_cache",
"SUPPORTS_WINDOWING",
"Seqlen",
]

View File

@ -10,7 +10,6 @@ _PARTITION_SIZE = 512
try: try:
from vllm._C import cache_ops from vllm._C import cache_ops
from vllm._C import ops
except Exception as e: except Exception as e:
raise ImportError( raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"

View File

@ -747,11 +747,8 @@ class _attention(torch.autograd.Function):
padded_d_model = 1 << (head_size - 1).bit_length() padded_d_model = 1 << (head_size - 1).bit_length()
padded_d_model = max(padded_d_model, 16) padded_d_model = max(padded_d_model, 16)
grid = lambda META: ( def grid(META):
triton.cdiv(max_seqlens_q, META["BLOCK_M"]), return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch
nheads_q,
batch,
)
encoded_softmax = None encoded_softmax = None

View File

@ -15,7 +15,6 @@ ENGINE = "triton" if use_triton else "ck"
try: try:
from vllm._C import cache_ops from vllm._C import cache_ops
from vllm._C import ops
except Exception as e: except Exception as e:
raise ImportError( raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"

View File

@ -1,6 +1,5 @@
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py # Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
import math
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch

View File

@ -12,17 +12,26 @@ from text_generation_server.utils.weights import (
Weights, Weights,
) )
from text_generation_server.utils.log import log_master, log_once from text_generation_server.utils.log import log_master, log_once
import importlib.util
FBGEMM_MM_AVAILABLE = False FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False FBGEMM_DYN_AVAILABLE = False
try:
import fbgemm_gpu.experimental.gen_ai
def is_fbgemm_gpu_available():
try:
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
except ModuleNotFoundError:
return False
if is_fbgemm_gpu_available():
if SYSTEM == "cuda": if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability() major, _ = torch.cuda.get_device_capability()
FBGEMM_MM_AVAILABLE = major == 9 FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8 FBGEMM_DYN_AVAILABLE = major >= 8
except (ImportError, ModuleNotFoundError): else:
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")

View File

@ -8,6 +8,34 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
major, _minor = torch.cuda.get_device_capability()
except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
if V2:
from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear, # noqa: F401
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear, # noqa: F401
)
HAS_EXLLAMA = "1"
except ImportError:
pass
@dataclass @dataclass
class GPTQWeight(Weight): class GPTQWeight(Weight):
@ -55,7 +83,7 @@ class GPTQWeight(Weight):
from text_generation_server.layers.gptq import ExllamaQuantLinear from text_generation_server.layers.gptq import ExllamaQuantLinear
except ImportError: except ImportError:
raise NotImplementedError( raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" "Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
) )
return ExllamaQuantLinear(self, bias) return ExllamaQuantLinear(self, bias)
@ -73,45 +101,6 @@ class GPTQWeight(Weight):
) )
try:
major, _minor = torch.cuda.get_device_capability()
except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
if V2:
from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllamav2 import (
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllama import (
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "1"
except ImportError:
pass
from text_generation_server.layers.gptq.quant_linear import QuantLinear
class GPTQWeightsLoader(WeightsLoader): class GPTQWeightsLoader(WeightsLoader):
""" """
Loader for GPTQ- and AWQ-quantized weights. Loader for GPTQ- and AWQ-quantized weights.

View File

@ -206,10 +206,13 @@ def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
output = torch.empty( output = torch.empty(
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
) )
grid = lambda META: (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) def grid(META):
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), return (
) triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
matmul_248_kernel[grid]( matmul_248_kernel[grid](
input, input,
qweight, qweight,

View File

@ -15,6 +15,7 @@ from text_generation_server.utils.hub import weight_files
from text_generation_server.layers.gptq.quant_linear import QuantLinear from text_generation_server.layers.gptq.quant_linear import QuantLinear
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from text_generation_server.layers.gptq.utils import torch_snr_error
from text_generation_server.utils.weights import DefaultWeightsLoader from text_generation_server.utils.weights import DefaultWeightsLoader
@ -372,7 +373,7 @@ def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code model_id, use_fast=False, trust_remote_code=trust_remote_code
) )
except: except Exception:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code model_id, use_fast=True, trust_remote_code=trust_remote_code
) )
@ -404,7 +405,7 @@ def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code model_id, use_fast=False, trust_remote_code=trust_remote_code
) )
except: except Exception:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code model_id, use_fast=True, trust_remote_code=trust_remote_code
) )
@ -448,7 +449,7 @@ def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code model_id, use_fast=False, trust_remote_code=trust_remote_code
) )
except: except Exception:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code model_id, use_fast=True, trust_remote_code=trust_remote_code
) )
@ -504,7 +505,7 @@ def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code model_id, use_fast=False, trust_remote_code=trust_remote_code
) )
except: except Exception:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code model_id, use_fast=True, trust_remote_code=trust_remote_code
) )
@ -546,7 +547,7 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code model_id, use_fast=False, trust_remote_code=trust_remote_code
) )
except: except Exception:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code model_id, use_fast=True, trust_remote_code=trust_remote_code
) )
@ -700,6 +701,8 @@ def sequential(
pass pass
def add_batch(name): def add_batch(name):
nonlocal gptq
def tmp(_, inp, out): def tmp(_, inp, out):
gptq[name].add_batch(inp[0].data, out.data) gptq[name].add_batch(inp[0].data, out.data)

View File

@ -0,0 +1,56 @@
import torch
# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
def torch_snr_error(
y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean"
) -> torch.Tensor:
"""
Compute SNR between y_pred(tensor) and y_real(tensor)
SNR can be calcualted as following equation:
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
Args:
y_pred (torch.Tensor): _description_
y_real (torch.Tensor): _description_
reduction (str, optional): _description_. Defaults to 'mean'.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
torch.Tensor: _description_
"""
if y_pred.shape != y_real.shape:
raise ValueError(
f"Can not compute snr loss for tensors with different shape. "
f"({y_pred.shape} and {y_real.shape})"
)
reduction = str(reduction).lower()
if y_pred.ndim == 1:
y_pred = y_pred.unsqueeze(0)
y_real = y_real.unsqueeze(0)
y_pred = y_pred.flatten(start_dim=1)
y_real = y_real.flatten(start_dim=1)
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
signal_power = torch.pow(y_real, 2).sum(dim=-1)
snr = (noise_power) / (signal_power + 1e-7)
if reduction == "mean":
return torch.mean(snr)
elif reduction == "sum":
return torch.sum(snr)
elif reduction == "none":
return snr
else:
raise ValueError("Unsupported reduction method.")

View File

@ -1,5 +1,3 @@
from typing import Optional
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F from torch.nn import functional as F

View File

@ -1,12 +1,8 @@
import math from typing import TYPE_CHECKING, Optional, List
import os
from typing import TYPE_CHECKING, Optional, Tuple, List
import torch import torch
import torch.distributed import torch.distributed
from accelerate import init_empty_weights
from torch import nn from torch import nn
from torch.nn import functional as F
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from text_generation_server.utils.sgmv import ( from text_generation_server.utils.sgmv import (

View File

@ -1,6 +1,3 @@
from typing import List, Tuple
import torch
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
from text_generation_server.layers.marlin.gptq import ( from text_generation_server.layers.marlin.gptq import (
GPTQMarlinLinear, GPTQMarlinLinear,

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -85,7 +85,7 @@ class MarlinWeightsLoader(WeightsLoader):
) )
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized" "Cannot load `marlin` weight, make sure the model is already quantized"
) )
B_meta = torch.cat( B_meta = torch.cat(
@ -104,7 +104,7 @@ class MarlinWeightsLoader(WeightsLoader):
) )
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized" "Cannot load `marlin` weight, make sure the model is already quantized"
) )
s = torch.cat( s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1

View File

@ -2,12 +2,9 @@ import os
import math import math
import torch import torch
from torch import nn from torch import nn
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda": if SYSTEM == "cuda":
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb import rotary_emb
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from vllm._C import ops from vllm._C import ops

View File

@ -33,7 +33,7 @@ class SpeculativeHead(torch.nn.Module):
except KeyError: except KeyError:
try: try:
speculator = MedusaHeadV1.load(config, prefix, weights) speculator = MedusaHeadV1.load(config, prefix, weights)
except: except Exception:
speculator = MedusaHeadV2(config, prefix, weights) speculator = MedusaHeadV2(config, prefix, weights)
lm_head = None lm_head = None
else: else:

View File

@ -2,7 +2,6 @@ import torch
from torch.nn import functional as F from torch.nn import functional as F
from typing import Iterable, List from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex": if SYSTEM == "ipex":
@ -50,7 +49,7 @@ class TensorParallelHead(SuperLayer):
# If the piece and LM head embeddings are shared, we have # If the piece and LM head embeddings are shared, we have
# non-quantized weights... # non-quantized weights...
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
except: except Exception:
# ...otherwise they are quantized. # ...otherwise they are quantized.
weight = weights.get_weights_col(prefix) weight = weights.get_weights_col(prefix)
should_gather = weights.process_group.size() > 1 should_gather = weights.process_group.size() > 1
@ -67,15 +66,6 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False should_gather = False
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq", "marlin"]:
quantize = None
# See above, exl2 LM head can be quantized or not.
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
quantize = None
else:
quantize = config.quantize
return TensorParallelHead( return TensorParallelHead(
get_linear(weight, bias=None), get_linear(weight, bias=None),
process_group=weights.process_group, process_group=weights.process_group,

View File

@ -1,3 +1,6 @@
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables
import torch import torch
import enum import enum
import os import os
@ -712,6 +715,7 @@ def get_model(
) )
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
print(f">>> model_type: {model_type}")
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -856,7 +860,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig, config_class=RWConfig,
) )
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
else: else:
if FLASH_ATTENTION and not config_dict.get("alibi", False): if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashCausalLM( return FlashCausalLM(

View File

@ -233,7 +233,7 @@ class CausalLMBatch(Batch):
] ]
# Ensure that past_key_values tensors can be updated in-place # Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple: if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values] self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection # Update tensors in-place to allow incremental garbage collection
@ -377,7 +377,7 @@ class CausalLMBatch(Batch):
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim] # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place # And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple: if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [ batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values for layer in batch.past_key_values

View File

@ -908,7 +908,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
loss = None loss = None
if not return_dict: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return ( return (

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -9,9 +9,7 @@ from transformers.modeling_attn_mask_utils import (
_prepare_4d_attention_mask, _prepare_4d_attention_mask,
) )
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
ImageClassifierOutput,
) )
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
@ -446,11 +444,12 @@ class CLIPEncoder(nn.Module):
class CLIPTextTransformer(nn.Module): class CLIPTextTransformer(nn.Module):
def __init__(self, prefix: str, config: CLIPTextConfig): def __init__(self, prefix: str, config: CLIPTextConfig, weights=None):
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
self.embeddings = CLIPTextEmbeddings(config) self.embeddings = CLIPTextEmbeddings(config)
# Initialize weights and apply final processing with `self.post_init()`
self.encoder = CLIPEncoder( self.encoder = CLIPEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights prefix=f"{prefix}.encoder", config=config, weights=weights
) )
@ -505,7 +504,7 @@ class CLIPTextTransformer(nn.Module):
# text_embeds.shape = [batch_size, sequence_length, transformer.width] # text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence) # take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[ last_hidden_state[
torch.arange( torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device last_hidden_state.shape[0], device=last_hidden_state.device
), ),
@ -515,7 +514,7 @@ class CLIPTextTransformer(nn.Module):
] ]
else: else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[ last_hidden_state[
torch.arange( torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device last_hidden_state.shape[0], device=last_hidden_state.device
), ),
@ -565,9 +564,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
>>> last_hidden_state = outputs.last_hidden_state >>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```""" ```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.text_model( return self.text_model(
input_ids=input_ids, input_ids=input_ids,
@ -580,7 +576,6 @@ class CLIPVisionTransformer(nn.Module):
def __init__(self, prefix, config: CLIPVisionConfig, weights): def __init__(self, prefix, config: CLIPVisionConfig, weights):
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPVisionEmbeddings( self.embeddings = CLIPVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights prefix=f"{prefix}.embeddings", config=config, weights=weights
@ -661,9 +656,6 @@ class CLIPVisionModel(CLIPPreTrainedModel):
>>> last_hidden_state = outputs.last_hidden_state >>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled CLS states >>> pooled_output = outputs.pooler_output # pooled CLS states
```""" ```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.vision_model( return self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
@ -799,14 +791,12 @@ class CLIPModel(nn.Module):
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components. # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
vision_outputs = self.vision_model( vision_outputs = self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
return_dict=return_dict,
) )
text_outputs = self.text_model( text_outputs = self.text_model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
return_dict=return_dict,
) )
image_embeds = vision_outputs[1] image_embeds = vision_outputs[1]

View File

@ -30,7 +30,6 @@ from text_generation_server.layers.attention import (
attention, attention,
reshape_and_cache, reshape_and_cache,
) )
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,

View File

@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import (
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from text_generation_server.utils.log import log_once
class DbrxAttentionConfig(PretrainedConfig): class DbrxAttentionConfig(PretrainedConfig):

View File

@ -39,6 +39,12 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class DeepseekV2Config(PretrainedConfig): class DeepseekV2Config(PretrainedConfig):
def __init__( def __init__(

View File

@ -46,7 +46,6 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
UnquantizedWeight,
Weights, Weights,
) )
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
@ -277,7 +276,7 @@ class LlamaMLP(nn.Module):
bias=bias, bias=bias,
) )
else: else:
prefixes = [f"gate_proj", f"up_proj"] prefixes = ["gate_proj", "up_proj"]
sizes = [ sizes = [
config.intermediate_size, config.intermediate_size,
config.intermediate_size, config.intermediate_size,

View File

@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
Seqlen,
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
@ -38,7 +37,6 @@ from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
get_linear,
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )

View File

@ -21,7 +21,6 @@
import torch import torch
import torch.distributed import torch.distributed
import numpy as np
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -31,7 +30,6 @@ if SYSTEM != "ipex":
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from loguru import logger
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,

View File

@ -16,7 +16,6 @@
import torch import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear

View File

@ -15,7 +15,6 @@ from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
get_linear,
) )
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch Idefics2 model.""" """ PyTorch Idefics2 model."""
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -22,10 +22,8 @@ from torch import nn
import math import math
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import ( from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
load_vision_model,
) )
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask

View File

@ -19,6 +19,7 @@ import numpy as np
from PIL import Image from PIL import Image
import transformers
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_transforms import ( from transformers.image_transforms import (
resize, resize,
@ -293,6 +294,4 @@ class IdeficsImageProcessor(BaseImageProcessor):
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs) return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
import transformers
transformers.IdeficsImageProcessor = IdeficsImageProcessor transformers.IdeficsImageProcessor = IdeficsImageProcessor

View File

@ -21,10 +21,8 @@
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
@ -33,13 +31,6 @@ from transformers.modeling_outputs import (
CausalLMOutputWithPast, CausalLMOutputWithPast,
dataclass, dataclass,
) )
from transformers.modeling_utils import PretrainedConfig
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_vision import ( from text_generation_server.models.custom_modeling.idefics_vision import (
IdeficsVisionTransformer, IdeficsVisionTransformer,
@ -56,6 +47,7 @@ from text_generation_server.layers import (
) )
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from loguru import logger
if SYSTEM == "cuda": if SYSTEM == "cuda":
import dropout_layer_norm import dropout_layer_norm
@ -237,7 +229,7 @@ class IdeficsDecoupledPartialTPEmbedding(nn.Module):
prefix="model.embed_tokens", weights=weights prefix="model.embed_tokens", weights=weights
) )
self.additional_weight = nn.Parameter( self.additional_weight = nn.Parameter(
weights.get_tensor(f"model.embed_tokens.additional_embedding.weight") weights.get_tensor("model.embed_tokens.additional_embedding.weight")
) )
def forward(self, input_ids): def forward(self, input_ids):
@ -499,7 +491,6 @@ class IdeficsAttention(nn.Module):
# if not hasattr(nn.functional, "scaled_dot_product_attention"): # if not hasattr(nn.functional, "scaled_dot_product_attention"):
# raise ValueError("this model requires pytorch 2.0 or higher") # raise ValueError("this model requires pytorch 2.0 or higher")
process_group = weights.process_group
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
@ -1024,7 +1015,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
if config.use_resampler: if config.use_resampler:
perceiver_config = config.perceiver_config perceiver_config = config.perceiver_config
self.perceiver_resampler = IdeficsPerceiverResampler( self.perceiver_resampler = IdeficsPerceiverResampler(
prefix=f"model.perceiver_resampler", prefix="model.perceiver_resampler",
config=config, config=config,
embed_dim=config.vision_config.embed_dim, embed_dim=config.vision_config.embed_dim,
depth=perceiver_config.resampler_depth, depth=perceiver_config.resampler_depth,
@ -1052,7 +1043,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
# self.gradient_checkpointing = False # self.gradient_checkpointing = False
self.norm = IdeficsRMSNorm( self.norm = IdeficsRMSNorm(
prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps prefix="model.norm", weights=weights, eps=config.rms_norm_eps
) )
# self.gradient_checkpointing = False # self.gradient_checkpointing = False

View File

@ -169,7 +169,6 @@ class IdeficsPerceiverAttention(nn.Module):
self.qk_scale = self.head_dim**-0.5 self.qk_scale = self.head_dim**-0.5
process_group = weights.process_group
if n_heads % weights.process_group.size() != 0: if n_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} " f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} "

View File

@ -28,9 +28,6 @@ from transformers.tokenization_utils_base import (
TruncationStrategy, TruncationStrategy,
) )
from transformers.utils import TensorType, is_torch_available from transformers.utils import TensorType, is_torch_available
from text_generation_server.models.custom_modeling.idefics_image_processing import (
IdeficsImageProcessor,
)
if is_torch_available(): if is_torch_available():

View File

@ -129,7 +129,6 @@ class IdeficsVisionAttention(nn.Module):
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
process_group = weights.process_group
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
@ -460,7 +459,6 @@ class IdeficsVisionTransformer(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size
self.embeddings = IdeficsVisionEmbeddings( self.embeddings = IdeficsVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights prefix=f"{prefix}.embeddings", config=config, weights=weights

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch Llava-NeXT model.""" """ PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint

View File

@ -4,7 +4,6 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
""" """
import math import math
import os
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@ -194,7 +193,7 @@ def flash_attn_fn(
): ):
try: try:
from flash_attn import bert_padding, flash_attn_interface from flash_attn import bert_padding, flash_attn_interface
except: except Exception:
raise RuntimeError("Please install flash-attn==1.0.3.post0") raise RuntimeError("Please install flash-attn==1.0.3.post0")
check_valid_inputs(query, key, value) check_valid_inputs(query, key, value)
if past_key_value is not None: if past_key_value is not None:
@ -207,7 +206,7 @@ def flash_attn_fn(
_s_k = max(0, attn_bias.size(3) - key.size(1)) _s_k = max(0, attn_bias.size(3) - key.size(1))
attn_bias = attn_bias[:, :, _s_q:, _s_k:] attn_bias = attn_bias[:, :, _s_q:, _s_k:]
if attn_bias is not None: if attn_bias is not None:
raise NotImplementedError(f"attn_bias not implemented for flash attn.") raise NotImplementedError("attn_bias not implemented for flash attn.")
(batch_size, seqlen) = query.shape[:2] (batch_size, seqlen) = query.shape[:2]
if key_padding_mask is None: if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
@ -269,13 +268,13 @@ def triton_flash_attn_fn(
): ):
try: try:
from .flash_attn_triton import flash_attn_func from .flash_attn_triton import flash_attn_func
except: except Exception:
_installed = False _installed = False
if version.parse(torch.__version__) < version.parse("2.0.0"): if version.parse(torch.__version__) < version.parse("2.0.0"):
_installed = True _installed = True
try: try:
from flash_attn.flash_attn_triton import flash_attn_func from flash_attn.flash_attn_triton import flash_attn_func
except: except Exception:
_installed = False _installed = False
if not _installed: if not _installed:
raise RuntimeError( raise RuntimeError(
@ -292,9 +291,9 @@ def triton_flash_attn_fn(
_s_k = max(0, attn_bias.size(3) - key.size(1)) _s_k = max(0, attn_bias.size(3) - key.size(1))
attn_bias = attn_bias[:, :, _s_q:, _s_k:] attn_bias = attn_bias[:, :, _s_q:, _s_k:]
if dropout_p: if dropout_p:
raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.") raise NotImplementedError("Dropout not implemented for attn_impl: triton.")
if needs_weights: if needs_weights:
raise NotImplementedError(f"attn_impl: triton cannot return attn weights.") raise NotImplementedError("attn_impl: triton cannot return attn weights.")
if key_padding_mask is not None: if key_padding_mask is not None:
warnings.warn( warnings.warn(
"Propagating key_padding_mask to the attention module " "Propagating key_padding_mask to the attention module "
@ -428,7 +427,7 @@ class MultiQueryAttention(nn.Module):
additive bias. additive bias.
""" """
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights, verbose=False):
super().__init__() super().__init__()
attn_impl = config.attn_config.attn_impl attn_impl = config.attn_config.attn_impl
self.attn_impl = config.attn_config.attn_impl self.attn_impl = config.attn_config.attn_impl
@ -445,7 +444,7 @@ class MultiQueryAttention(nn.Module):
self.Wqkv = TensorParallelColumnLinear.load( self.Wqkv = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
) )
fuse_splits = (d_model, d_model + self.head_dim) (d_model, d_model + self.head_dim)
if self.qk_ln: if self.qk_ln:
raise NotImplementedError("qk_ln not supported") raise NotImplementedError("qk_ln not supported")
if self.attn_impl == "flash": if self.attn_impl == "flash":
@ -795,7 +794,9 @@ class MPTModel(MPTPreTrainedModel):
self.alibi = config.attn_config.alibi self.alibi = config.attn_config.alibi
self.alibi_bias_max = config.attn_config.alibi_bias_max self.alibi_bias_max = config.attn_config.alibi_bias_max
if config.init_device == "mixed": if config.init_device == "mixed":
if dist.get_local_rank() == 0: # TODO: reimplement mixed device initialization
# dist.get_local_rank() == 0:
if True:
config.init_device = "cpu" config.init_device = "cpu"
else: else:
config.init_device = "meta" config.init_device = "meta"
@ -1016,7 +1017,7 @@ class MPTModel(MPTPreTrainedModel):
if past_key_values is not None: if past_key_values is not None:
if len(past_key_values) != self.config.n_layers: if len(past_key_values) != self.config.n_layers:
raise ValueError( raise ValueError(
f"past_key_values must provide a past_key_value for each attention " "past_key_values must provide a past_key_value for each attention "
+ f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
) )
past_position = past_key_values[0][0].size(1) past_position = past_key_values[0][0].size(1)
@ -1182,7 +1183,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
if self.transformer.prefix_lm: if self.transformer.prefix_lm:
prefix_mask = torch.ones_like(attention_mask) prefix_mask = torch.ones_like(attention_mask)
if kwargs.get("use_cache") == False: if kwargs.get("use_cache") is False:
raise NotImplementedError( raise NotImplementedError(
"MPT with prefix_lm=True does not support use_cache=False." "MPT with prefix_lm=True does not support use_cache=False."
) )

View File

@ -21,25 +21,14 @@ import torch
import torch.distributed import torch.distributed
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
) )
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers import GPTNeoXConfig
from loguru import logger
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
@ -133,7 +122,6 @@ class GPTNeoXAttention(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_attention_heads self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct) self.rotary_ndims = int(self.head_size * config.rotary_pct)
max_positions = config.max_position_embeddings
# ??? TODO # ??? TODO
# self.register_buffer( # self.register_buffer(
# "bias", # "bias",

View File

@ -5,7 +5,7 @@ import torch.distributed
import math import math
from torch import nn from torch import nn
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast

View File

@ -1,20 +1,15 @@
from typing import Optional, Tuple, Union from typing import Optional, Tuple
import warnings
import math import math
import torch import torch
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import (
_create_4d_causal_attention_mask,
_prepare_4d_attention_mask,
)
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
ImageClassifierOutput,
) )
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig from transformers import SiglipConfig, SiglipVisionConfig
from torch.nn.init import _calculate_fan_in_and_fan_out
from text_generation_server.layers.tensor_parallel import ( from text_generation_server.layers.tensor_parallel import (
TensorParallelEmbedding, TensorParallelEmbedding,
@ -244,9 +239,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
return hidden_state[:, 0] return hidden_state[:, 0]
import warnings
def _trunc_normal_(tensor, mean, std, a, b): def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW # Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
@ -264,12 +256,12 @@ def _trunc_normal_(tensor, mean, std, a, b):
# Values are generated by using a truncated uniform distribution and # Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution. # then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values # Get upper and lower cdf values
l = norm_cdf((a - mean) / std) lower = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std) upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to # Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1]. # [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1) tensor.uniform_(2 * lower - 1, 2 * upper - 1)
# Use inverse cdf transform for normal distribution to get truncated # Use inverse cdf transform for normal distribution to get truncated
# standard normal # standard normal
@ -313,9 +305,6 @@ def trunc_normal_tf_(
tensor.mul_(std).add_(mean) tensor.mul_(std).add_(mean)
from torch.nn.init import _calculate_fan_in_and_fan_out
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in": if mode == "fan_in":
@ -349,9 +338,6 @@ def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal") variance_scaling_(tensor, mode="fan_in", distribution="normal")
from transformers import PreTrainedModel
class SiglipEncoder(nn.Module): class SiglipEncoder(nn.Module):
""" """
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
@ -393,7 +379,6 @@ class SiglipVisionTransformer(nn.Module):
def __init__(self, prefix, config: SiglipVisionConfig, weights): def __init__(self, prefix, config: SiglipVisionConfig, weights):
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings( self.embeddings = SiglipVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights prefix=f"{prefix}.embeddings", config=config, weights=weights

View File

@ -45,6 +45,15 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
) )
# copied from https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/t5/modeling_t5.py#L1316
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""
class PartialTPEmbedding(nn.Module): class PartialTPEmbedding(nn.Module):
def __init__(self, prefix: str, weights): def __init__(self, prefix: str, weights):
@ -1132,12 +1141,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100) loss_fct = CrossEntropyLoss(ignore_index=-100)
# move labels to correct device to enable PP # move labels to correct device to enable PP
labels = labels.to(lm_logits.device) labels = labels.to(logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
if not return_dict: if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs output = (logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return ( return (

View File

@ -42,7 +42,7 @@ def load_vision_model(prefix, config, weights):
) )
return SiglipVisionTransformer( return SiglipVisionTransformer(
prefix=f"vision_tower.vision_model", config=config, weights=weights prefix="vision_tower.vision_model", config=config, weights=weights
) )
else: else:
raise RuntimeError(f"Unsupported model type {config.model_type}") raise RuntimeError(f"Unsupported model type {config.model_type}")

View File

@ -1194,7 +1194,7 @@ class FlashCausalLM(Model):
if self.speculate is None or self.speculate + 1 <= bs: if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt) self.cuda_graph_warmup(bs, max_s, max_bt)
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed") logger.exception("Decode cuda graph warmup failed")
else: else:
log_master( log_master(
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."

View File

@ -2,23 +2,15 @@ import re
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional, Type
from transformers import ( from transformers import (
AutoTokenizer,
AutoConfig,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.utils import ( from text_generation_server.utils import (
NextTokenChooser, NextTokenChooser,
StoppingCriteria, StoppingCriteria,
initialize_torch_distributed,
weight_files,
Weights,
) )
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks

View File

@ -1,13 +1,8 @@
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional, Tuple from typing import Optional
from transformers import (
AutoTokenizer,
AutoConfig,
AutoProcessor,
)
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_processing import ( from text_generation_server.models.custom_modeling.idefics_processing import (

View File

@ -289,7 +289,7 @@ class IdeficsCausalLMBatch(Batch):
image_hidden_states = self.image_hidden_states[keep_indices] image_hidden_states = self.image_hidden_states[keep_indices]
# Ensure that past_key_values tensors can be updated in-place # Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple: if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values] self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection # Update tensors in-place to allow incremental garbage collection
@ -456,7 +456,7 @@ class IdeficsCausalLMBatch(Batch):
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim] # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place # And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple: if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [ batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values for layer in batch.past_key_values

View File

@ -2,7 +2,6 @@ import torch
import torch.distributed import torch.distributed
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional from typing import Optional
import os
from text_generation_server.models.custom_modeling.mamba_modeling import ( from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig, MambaConfig,
) )
@ -20,7 +19,7 @@ from text_generation_server.models.custom_modeling.mamba_modeling import (
InferenceParams, InferenceParams,
) )
from text_generation_server.models import Model from text_generation_server.models import Model
from typing import Any, List, Optional, Tuple, Type, Dict from typing import Any, List, Tuple, Type, Dict
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
Tokens, Tokens,
@ -31,7 +30,7 @@ from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens, Sampling from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria
def new_inference_params( def new_inference_params(
@ -299,7 +298,6 @@ class MambaBatch(Batch):
stopping_criterias = [] stopping_criterias = []
top_n_tokens = [] top_n_tokens = []
max_tokens = 0 max_tokens = 0
max_seqlen = 0
seqlen_offset = 0 seqlen_offset = 0
(n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape (n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
@ -485,7 +483,7 @@ class Mamba(Model):
for bs in CUDA_GRAPHS: for bs in CUDA_GRAPHS:
self.cuda_graph_warmup(bs) self.cuda_graph_warmup(bs)
except Exception: except Exception:
logger.exception(f"Decode cuda graph warmup failed") logger.exception("Decode cuda graph warmup failed")
else: else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
@ -534,7 +532,7 @@ class Mamba(Model):
} }
self.cuda_graphs[batch_size] = graph_dict self.cuda_graphs[batch_size] = graph_dict
def tunableop_warmup(self, seqlen: int): def tunableop_warmup(self, batch_size: int, seqlen: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
n_blocks = len(self.model.blocks) n_blocks = len(self.model.blocks)

View File

@ -2,7 +2,7 @@ import inspect
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict from typing import List, Tuple, Optional, TypeVar, Type, Dict
from collections import defaultdict from collections import defaultdict
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase

View File

@ -3,16 +3,11 @@ from PIL import Image
import torch import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Iterable, Optional, Tuple from typing import Iterable
from text_generation_server.models.vlm_causal_lm import ( from text_generation_server.models.vlm_causal_lm import (
VlmCausalLM,
VlmCausalLMBatch, VlmCausalLMBatch,
image_text_replacement, image_text_replacement,
) )
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from transformers import AutoProcessor, AutoConfig
from text_generation_server.pb.generate_pb2 import Request from text_generation_server.pb.generate_pb2 import Request

View File

@ -1,7 +1,6 @@
import torch import torch
import torch.distributed import torch.distributed
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import ( from transformers import (
@ -11,7 +10,7 @@ from transformers import (
AutoConfig, AutoConfig,
) )
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
@ -254,7 +253,7 @@ class Seq2SeqLMBatch(Batch):
] ]
# Ensure that past_key_values tensors can be updated in-place # Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple: if type(self.past_key_values[0]) is tuple:
self.past_key_values = [ self.past_key_values = [
[t for t in layer] for layer in self.past_key_values [t for t in layer] for layer in self.past_key_values
] ]
@ -430,7 +429,7 @@ class Seq2SeqLMBatch(Batch):
batch.encoder_last_hidden_state = None batch.encoder_last_hidden_state = None
# Ensure that we can update tensors in-place # Ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple: if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [ batch.past_key_values = [
[t for t in layer] for layer in batch.past_key_values [t for t in layer] for layer in batch.past_key_values
] ]

View File

@ -1,4 +1,3 @@
from functools import total_ordering
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod

View File

@ -9,7 +9,7 @@ from loguru import logger
from grpc_reflection.v1alpha import reflection from grpc_reflection.v1alpha import reflection
from pathlib import Path from pathlib import Path
from typing import List, Optional, Dict from typing import List, Optional
from text_generation_server.cache import Cache from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor

View File

@ -1,15 +1,13 @@
import torch import torch
from loguru import logger from loguru import logger
import subprocess
import os import os
import importlib.util
def is_ipex_available(): def is_ipex_available():
try: return importlib.util.find_spec("intel_extension_for_pytorch") is not None
import intel_extension_for_pytorch
except ImportError:
return False
return True
def get_cuda_free_memory(device, memory_fraction): def get_cuda_free_memory(device, memory_fraction):

View File

@ -2,9 +2,17 @@ import copy
from abc import ABC from abc import ABC
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
from text_generation_server.utils.merges.utils import (
calculate_majority_sign_mask,
disjoint_merge,
prune,
)
import torch import torch
if TYPE_CHECKING:
from text_generation_server.adapters.lora import LoraConfig
from text_generation_server.utils.adapter import ModuleMap
class AdapterParameters: class AdapterParameters:
def __init__( def __init__(
@ -17,17 +25,6 @@ class AdapterParameters:
self.majority_sign_method = majority_sign_method self.majority_sign_method = majority_sign_method
from text_generation_server.utils.merges.utils import (
calculate_majority_sign_mask,
disjoint_merge,
prune,
)
if TYPE_CHECKING:
from text_generation_server.adapters.lora import LoraConfig
from text_generation_server.utils.adapter import ModuleMap
def _apply_weights( def _apply_weights(
tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:

View File

@ -28,7 +28,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
) )
logger.info("Peft model detected.") logger.info("Peft model detected.")
logger.info(f"Merging the lora weights.") logger.info("Merging the lora weights.")
base_model_id = model.peft_config["default"].base_model_name_or_path base_model_id = model.peft_config["default"].base_model_name_or_path

View File

@ -6,7 +6,6 @@ from typing import Optional
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader, DefaultWeightsLoader,
UnquantizedWeight,
WeightsLoader, WeightsLoader,
) )

View File

@ -1,7 +1,6 @@
import re import re
from typing import List, Optional, Tuple, Set, Union from typing import List, Optional, Tuple, Set, Union
import math
import torch import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType