feat: add ruff and resolve issue (#2262)
* feat: add ruff and resolve issue * fix: update client exports and adjust after rebase * fix: adjust syntax to avoid circular import * fix: adjust client ruff settings * fix: lint and refactor import check and avoid model enum as global names * fix: improve fbgemm_gpu check and lints * fix: update lints * fix: prefer comparing model enum over str * fix: adjust lints and ignore specific rules * fix: avoid unneeded quantize check
This commit is contained in:
parent
4b49c50f4c
commit
bab02ff2bc
|
@ -16,3 +16,8 @@ repos:
|
|||
- id: fmt
|
||||
- id: cargo-check
|
||||
- id: clippy
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.3.0
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
|
|
|
@ -19,5 +19,15 @@ DEPRECATION_WARNING = (
|
|||
"Please use the `InferenceClient` from the `huggingface_hub` package instead."
|
||||
)
|
||||
|
||||
from text_generation.client import Client, AsyncClient
|
||||
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
|
||||
from text_generation.client import Client, AsyncClient # noqa E402
|
||||
from text_generation.inference_api import ( # noqa E402
|
||||
InferenceAPIClient,
|
||||
InferenceAPIAsyncClient,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Client",
|
||||
"AsyncClient",
|
||||
"InferenceAPIClient",
|
||||
"InferenceAPIAsyncClient",
|
||||
]
|
||||
|
|
|
@ -21,7 +21,7 @@ def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]:
|
|||
List[DeployedModel]: list of all currently deployed models
|
||||
"""
|
||||
resp = requests.get(
|
||||
f"https://api-inference.huggingface.co/framework/text-generation-inference",
|
||||
"https://api-inference.huggingface.co/framework/text-generation-inference",
|
||||
headers=headers,
|
||||
timeout=5,
|
||||
)
|
||||
|
|
|
@ -4,7 +4,6 @@ import json
|
|||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -271,7 +270,7 @@ class LauncherHandle:
|
|||
try:
|
||||
await self.client.generate("test")
|
||||
return
|
||||
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
|
||||
except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
|
||||
time.sleep(1)
|
||||
raise RuntimeError("Health check failed")
|
||||
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
import pytest
|
||||
import json
|
||||
|
||||
from text_generation.types import GrammarType
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
import pytest
|
||||
import requests
|
||||
import io
|
||||
import base64
|
||||
|
||||
|
||||
|
|
|
@ -74,9 +74,7 @@ async def test_idefics_load(idefics, generate_load, response_snapshot):
|
|||
|
||||
generated_texts = [r.generated_text for r in responses]
|
||||
|
||||
assert (
|
||||
generated_texts[0] == " \nAssistant: A rooster stands"
|
||||
), f"{response.generated_text}"
|
||||
assert generated_texts[0] == " \nAssistant: A rooster stands"
|
||||
assert len(generated_texts) == 4
|
||||
assert generated_texts, all(
|
||||
[text == generated_texts[0] for text in generated_texts]
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
import pytest
|
||||
import json
|
||||
|
||||
from text_generation.types import GrammarType
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -91,7 +88,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
|||
},
|
||||
],
|
||||
)
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.content is None
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"id": 0,
|
||||
|
@ -129,7 +126,7 @@ async def test_flash_llama_grammar_tools_auto(
|
|||
},
|
||||
],
|
||||
)
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.content is None
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"id": 0,
|
||||
|
@ -168,7 +165,7 @@ async def test_flash_llama_grammar_tools_choice(
|
|||
},
|
||||
],
|
||||
)
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.content is None
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"id": 0,
|
||||
|
@ -241,7 +238,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
|||
stream=False,
|
||||
)
|
||||
|
||||
assert responses.choices[0].message.content == None
|
||||
assert responses.choices[0].message.content is None
|
||||
assert responses.choices[0].message.tool_calls == [
|
||||
{
|
||||
"function": {
|
||||
|
|
|
@ -20,7 +20,7 @@ def main():
|
|||
break
|
||||
|
||||
with open("./small.json", "w") as f:
|
||||
data = json.dump(conversations, f, indent=4)
|
||||
json.dump(conversations, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
import os
|
||||
import requests
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
import huggingface_hub.constants
|
||||
from huggingface_hub import hf_api
|
||||
|
||||
import text_generation_server.utils.hub
|
||||
from text_generation_server.utils.hub import (
|
||||
|
|
|
@ -2,7 +2,6 @@ import torch
|
|||
from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||
|
||||
|
||||
class ProcessGroup:
|
||||
|
|
|
@ -2,7 +2,6 @@ import pytest
|
|||
import torch
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
WeightsLoader,
|
||||
)
|
||||
|
@ -86,15 +85,6 @@ dummy_file_system = {
|
|||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
"weight.weight": torch.tensor(
|
||||
[
|
||||
[1, 2],
|
||||
[3, 4],
|
||||
[5, 6],
|
||||
[7, 8],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
},
|
||||
"test_get_weights_row": {
|
||||
"weight.weight": torch.tensor(
|
||||
|
@ -966,7 +956,7 @@ def test_get_multi_weights_col_exl2():
|
|||
prefix = "weight"
|
||||
|
||||
try:
|
||||
w = weights.get_multi_weights_col(
|
||||
weights.get_multi_weights_col(
|
||||
prefixes=[prefix],
|
||||
dim=0,
|
||||
)
|
||||
|
|
|
@ -4,15 +4,12 @@
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Set, Tuple
|
||||
from typing import Dict, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from text_generation_server.adapters.weights import AdapterWeights
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.models.model import Model
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleMap:
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig as _LoraConfig
|
||||
|
@ -26,9 +26,6 @@ from text_generation_server.utils.sgmv import (
|
|||
use_cutlass_shrink,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.models.model import Model
|
||||
|
||||
|
||||
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
||||
block_size = size // world_size
|
||||
|
|
|
@ -4,12 +4,11 @@ import typer
|
|||
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Optional, List, Dict
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from huggingface_hub import hf_hub_download
|
||||
from text_generation_server.utils.adapter import parse_lora_adapters
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
@ -165,7 +164,7 @@ def download_weights(
|
|||
# currently by default we don't merge the weights with the base model
|
||||
if merge_lora:
|
||||
try:
|
||||
adapter_config_filename = hf_hub_download(
|
||||
hf_hub_download(
|
||||
model_id, revision=revision, filename="adapter_config.json"
|
||||
)
|
||||
utils.download_and_unload_peft(
|
||||
|
@ -285,9 +284,9 @@ def download_weights(
|
|||
if auto_convert:
|
||||
if not trust_remote_code:
|
||||
logger.warning(
|
||||
f"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
|
||||
f"Pickle files are unsafe and can essentially contain remote code execution!"
|
||||
f"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
|
||||
"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
|
||||
"Pickle files are unsafe and can essentially contain remote code execution!"
|
||||
"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
|
@ -319,7 +318,7 @@ def download_weights(
|
|||
# Name for this varible depends on transformers version.
|
||||
discard_names = getattr(class_, "_tied_weights_keys", [])
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
discard_names = []
|
||||
# Convert pytorch weights to safetensors
|
||||
utils.convert_files(local_pt_files, local_st_files, discard_names)
|
||||
|
|
|
@ -18,3 +18,17 @@ from text_generation_server.layers.lora import (
|
|||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_linear",
|
||||
"FastLinear",
|
||||
"TensorParallelColumnLinear",
|
||||
"TensorParallelRowLinear",
|
||||
"TensorParallelEmbedding",
|
||||
"SpeculativeHead",
|
||||
"LoraLinear",
|
||||
"TensorParallelMultiAdapterLinear",
|
||||
"TensorParallelAdapterRowLinear",
|
||||
"load_layer_norm",
|
||||
"load_conv2d",
|
||||
]
|
||||
|
|
|
@ -13,3 +13,12 @@ elif SYSTEM == "ipex":
|
|||
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||
else:
|
||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"Seqlen",
|
||||
]
|
||||
|
|
|
@ -10,7 +10,6 @@ _PARTITION_SIZE = 512
|
|||
|
||||
try:
|
||||
from vllm._C import cache_ops
|
||||
from vllm._C import ops
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
|
|
|
@ -747,11 +747,8 @@ class _attention(torch.autograd.Function):
|
|||
padded_d_model = 1 << (head_size - 1).bit_length()
|
||||
padded_d_model = max(padded_d_model, 16)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
||||
nheads_q,
|
||||
batch,
|
||||
)
|
||||
def grid(META):
|
||||
return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch
|
||||
|
||||
encoded_softmax = None
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@ ENGINE = "triton" if use_triton else "ck"
|
|||
|
||||
try:
|
||||
from vllm._C import cache_ops
|
||||
from vllm._C import ops
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
|
|
@ -12,17 +12,26 @@ from text_generation_server.utils.weights import (
|
|||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.log import log_master, log_once
|
||||
import importlib.util
|
||||
|
||||
|
||||
FBGEMM_MM_AVAILABLE = False
|
||||
FBGEMM_DYN_AVAILABLE = False
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai
|
||||
|
||||
|
||||
def is_fbgemm_gpu_available():
|
||||
try:
|
||||
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
if is_fbgemm_gpu_available():
|
||||
if SYSTEM == "cuda":
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
FBGEMM_MM_AVAILABLE = major == 9
|
||||
FBGEMM_DYN_AVAILABLE = major >= 8
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
else:
|
||||
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
|
||||
|
||||
|
||||
|
|
|
@ -8,6 +8,34 @@ from text_generation_server.utils.import_utils import SYSTEM
|
|||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
except Exception:
|
||||
major = 1
|
||||
|
||||
HAS_EXLLAMA = False
|
||||
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
|
||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA = False
|
||||
elif CAN_EXLLAMA:
|
||||
try:
|
||||
if V2:
|
||||
from text_generation_server.layers.gptq.exllamav2 import (
|
||||
QuantLinear as ExllamaQuantLinear, # noqa: F401
|
||||
)
|
||||
|
||||
HAS_EXLLAMA = "2"
|
||||
else:
|
||||
from text_generation_server.layers.gptq.exllama import (
|
||||
Ex4bitLinear as ExllamaQuantLinear, # noqa: F401
|
||||
)
|
||||
|
||||
HAS_EXLLAMA = "1"
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQWeight(Weight):
|
||||
|
@ -55,7 +83,7 @@ class GPTQWeight(Weight):
|
|||
from text_generation_server.layers.gptq import ExllamaQuantLinear
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
|
||||
"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
|
||||
)
|
||||
|
||||
return ExllamaQuantLinear(self, bias)
|
||||
|
@ -73,45 +101,6 @@ class GPTQWeight(Weight):
|
|||
)
|
||||
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
except Exception:
|
||||
major = 1
|
||||
|
||||
HAS_EXLLAMA = False
|
||||
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
|
||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA = False
|
||||
elif CAN_EXLLAMA:
|
||||
try:
|
||||
if V2:
|
||||
from text_generation_server.layers.gptq.exllamav2 import (
|
||||
QuantLinear as ExllamaQuantLinear,
|
||||
)
|
||||
from text_generation_server.layers.gptq.exllamav2 import (
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
|
||||
HAS_EXLLAMA = "2"
|
||||
else:
|
||||
from text_generation_server.layers.gptq.exllama import (
|
||||
Ex4bitLinear as ExllamaQuantLinear,
|
||||
)
|
||||
from text_generation_server.layers.gptq.exllama import (
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
|
||||
HAS_EXLLAMA = "1"
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
||||
|
||||
|
||||
class GPTQWeightsLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for GPTQ- and AWQ-quantized weights.
|
||||
|
|
|
@ -206,10 +206,13 @@ def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
|||
output = torch.empty(
|
||||
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
|
||||
)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
matmul_248_kernel[grid](
|
||||
input,
|
||||
qweight,
|
||||
|
|
|
@ -15,6 +15,7 @@ from text_generation_server.utils.hub import weight_files
|
|||
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
from text_generation_server.layers.gptq.utils import torch_snr_error
|
||||
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||
|
||||
|
@ -372,7 +373,7 @@ def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):
|
|||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
@ -404,7 +405,7 @@ def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):
|
|||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
@ -448,7 +449,7 @@ def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):
|
|||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
@ -504,7 +505,7 @@ def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):
|
|||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
@ -546,7 +547,7 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
|
|||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
@ -700,6 +701,8 @@ def sequential(
|
|||
pass
|
||||
|
||||
def add_batch(name):
|
||||
nonlocal gptq
|
||||
|
||||
def tmp(_, inp, out):
|
||||
gptq[name].add_batch(inp[0].data, out.data)
|
||||
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
import torch
|
||||
|
||||
|
||||
# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
|
||||
def torch_snr_error(
|
||||
y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute SNR between y_pred(tensor) and y_real(tensor)
|
||||
|
||||
SNR can be calcualted as following equation:
|
||||
|
||||
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
|
||||
|
||||
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
|
||||
|
||||
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
|
||||
|
||||
Args:
|
||||
y_pred (torch.Tensor): _description_
|
||||
y_real (torch.Tensor): _description_
|
||||
reduction (str, optional): _description_. Defaults to 'mean'.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
torch.Tensor: _description_
|
||||
"""
|
||||
if y_pred.shape != y_real.shape:
|
||||
raise ValueError(
|
||||
f"Can not compute snr loss for tensors with different shape. "
|
||||
f"({y_pred.shape} and {y_real.shape})"
|
||||
)
|
||||
reduction = str(reduction).lower()
|
||||
|
||||
if y_pred.ndim == 1:
|
||||
y_pred = y_pred.unsqueeze(0)
|
||||
y_real = y_real.unsqueeze(0)
|
||||
|
||||
y_pred = y_pred.flatten(start_dim=1)
|
||||
y_real = y_real.flatten(start_dim=1)
|
||||
|
||||
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
|
||||
signal_power = torch.pow(y_real, 2).sum(dim=-1)
|
||||
snr = (noise_power) / (signal_power + 1e-7)
|
||||
|
||||
if reduction == "mean":
|
||||
return torch.mean(snr)
|
||||
elif reduction == "sum":
|
||||
return torch.sum(snr)
|
||||
elif reduction == "none":
|
||||
return snr
|
||||
else:
|
||||
raise ValueError("Unsupported reduction method.")
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from torch.nn import functional as F
|
||||
|
|
|
@ -1,12 +1,8 @@
|
|||
import math
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, List
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from accelerate import init_empty_weights
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from text_generation_server.utils.sgmv import (
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
|
||||
from text_generation_server.layers.marlin.gptq import (
|
||||
GPTQMarlinLinear,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -85,7 +85,7 @@ class MarlinWeightsLoader(WeightsLoader):
|
|||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||
"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
B_meta = torch.cat(
|
||||
|
@ -104,7 +104,7 @@ class MarlinWeightsLoader(WeightsLoader):
|
|||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||
"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||
)
|
||||
s = torch.cat(
|
||||
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
|
|
|
@ -2,12 +2,9 @@ import os
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm._C import ops
|
||||
|
|
|
@ -33,7 +33,7 @@ class SpeculativeHead(torch.nn.Module):
|
|||
except KeyError:
|
||||
try:
|
||||
speculator = MedusaHeadV1.load(config, prefix, weights)
|
||||
except:
|
||||
except Exception:
|
||||
speculator = MedusaHeadV2(config, prefix, weights)
|
||||
lm_head = None
|
||||
else:
|
||||
|
|
|
@ -2,7 +2,6 @@ import torch
|
|||
from torch.nn import functional as F
|
||||
from typing import Iterable, List
|
||||
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
|
@ -50,7 +49,7 @@ class TensorParallelHead(SuperLayer):
|
|||
# If the piece and LM head embeddings are shared, we have
|
||||
# non-quantized weights...
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
except:
|
||||
except Exception:
|
||||
# ...otherwise they are quantized.
|
||||
weight = weights.get_weights_col(prefix)
|
||||
should_gather = weights.process_group.size() > 1
|
||||
|
@ -67,15 +66,6 @@ class TensorParallelHead(SuperLayer):
|
|||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
|
||||
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
|
||||
if config.quantize in ["gptq", "awq", "eetq", "marlin"]:
|
||||
quantize = None
|
||||
# See above, exl2 LM head can be quantized or not.
|
||||
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
|
||||
quantize = None
|
||||
else:
|
||||
quantize = config.quantize
|
||||
|
||||
return TensorParallelHead(
|
||||
get_linear(weight, bias=None),
|
||||
process_group=weights.process_group,
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# ruff: noqa: F821
|
||||
# the above line disables the `undefined-name` rule for the model type variables
|
||||
|
||||
import torch
|
||||
import enum
|
||||
import os
|
||||
|
@ -712,6 +715,7 @@ def get_model(
|
|||
)
|
||||
|
||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
||||
print(f">>> model_type: {model_type}")
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
|
@ -856,7 +860,7 @@ def get_model(
|
|||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=RWConfig,
|
||||
)
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
|
||||
else:
|
||||
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
||||
return FlashCausalLM(
|
||||
|
|
|
@ -233,7 +233,7 @@ class CausalLMBatch(Batch):
|
|||
]
|
||||
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
if type(self.past_key_values[0]) == tuple:
|
||||
if type(self.past_key_values[0]) is tuple:
|
||||
self.past_key_values = [list(layer) for layer in self.past_key_values]
|
||||
|
||||
# Update tensors in-place to allow incremental garbage collection
|
||||
|
@ -377,7 +377,7 @@ class CausalLMBatch(Batch):
|
|||
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
|
||||
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
|
||||
# And ensure that we can update tensors in-place
|
||||
if type(batch.past_key_values[0]) == tuple:
|
||||
if isinstance(batch.past_key_values[0], tuple):
|
||||
batch.past_key_values = [
|
||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
|
||||
for layer in batch.past_key_values
|
||||
|
|
|
@ -908,7 +908,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|||
loss = None
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return (
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -9,9 +9,7 @@ from transformers.modeling_attn_mask_utils import (
|
|||
_prepare_4d_attention_mask,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
ImageClassifierOutput,
|
||||
)
|
||||
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||
|
||||
|
@ -446,11 +444,12 @@ class CLIPEncoder(nn.Module):
|
|||
|
||||
|
||||
class CLIPTextTransformer(nn.Module):
|
||||
def __init__(self, prefix: str, config: CLIPTextConfig):
|
||||
def __init__(self, prefix: str, config: CLIPTextConfig, weights=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
self.embeddings = CLIPTextEmbeddings(config)
|
||||
# Initialize weights and apply final processing with `self.post_init()`
|
||||
self.encoder = CLIPEncoder(
|
||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||
)
|
||||
|
@ -505,7 +504,7 @@ class CLIPTextTransformer(nn.Module):
|
|||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
last_hidden_state[
|
||||
torch.arange(
|
||||
last_hidden_state.shape[0], device=last_hidden_state.device
|
||||
),
|
||||
|
@ -515,7 +514,7 @@ class CLIPTextTransformer(nn.Module):
|
|||
]
|
||||
else:
|
||||
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
||||
pooled_output = last_hidden_state[
|
||||
last_hidden_state[
|
||||
torch.arange(
|
||||
last_hidden_state.shape[0], device=last_hidden_state.device
|
||||
),
|
||||
|
@ -565,9 +564,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
|
|||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
||||
```"""
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
return self.text_model(
|
||||
input_ids=input_ids,
|
||||
|
@ -580,7 +576,6 @@ class CLIPVisionTransformer(nn.Module):
|
|||
def __init__(self, prefix, config: CLIPVisionConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(
|
||||
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||
|
@ -661,9 +656,6 @@ class CLIPVisionModel(CLIPPreTrainedModel):
|
|||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
||||
```"""
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
|
@ -799,14 +791,12 @@ class CLIPModel(nn.Module):
|
|||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1]
|
||||
|
|
|
@ -30,7 +30,6 @@ from text_generation_server.layers.attention import (
|
|||
attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
|
|
|
@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import (
|
|||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
class DbrxAttentionConfig(PretrainedConfig):
|
||||
|
|
|
@ -39,6 +39,12 @@ from torch import nn
|
|||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
from vllm import _custom_C
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
|
||||
|
||||
class DeepseekV2Config(PretrainedConfig):
|
||||
def __init__(
|
||||
|
|
|
@ -46,7 +46,6 @@ from text_generation_server.layers.layernorm import (
|
|||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.weights import (
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
|
@ -277,7 +276,7 @@ class LlamaMLP(nn.Module):
|
|||
bias=bias,
|
||||
)
|
||||
else:
|
||||
prefixes = [f"gate_proj", f"up_proj"]
|
||||
prefixes = ["gate_proj", "up_proj"]
|
||||
sizes = [
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
|
|
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
|||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
|
@ -38,7 +37,6 @@ from text_generation_server.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
@ -31,7 +30,6 @@ if SYSTEM != "ipex":
|
|||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
|
||||
|
|
|
@ -15,7 +15,6 @@ from text_generation_server.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
""" PyTorch Idefics2 model."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
@ -22,10 +22,8 @@ from torch import nn
|
|||
import math
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
)
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
|||
|
||||
from PIL import Image
|
||||
|
||||
import transformers
|
||||
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from transformers.image_transforms import (
|
||||
resize,
|
||||
|
@ -293,6 +294,4 @@ class IdeficsImageProcessor(BaseImageProcessor):
|
|||
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||
|
||||
|
||||
import transformers
|
||||
|
||||
transformers.IdeficsImageProcessor = IdeficsImageProcessor
|
||||
|
|
|
@ -21,10 +21,8 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
|
@ -33,13 +31,6 @@ from transformers.modeling_outputs import (
|
|||
CausalLMOutputWithPast,
|
||||
dataclass,
|
||||
)
|
||||
from transformers.modeling_utils import PretrainedConfig
|
||||
from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
|
||||
from text_generation_server.models.custom_modeling.idefics_vision import (
|
||||
IdeficsVisionTransformer,
|
||||
|
@ -56,6 +47,7 @@ from text_generation_server.layers import (
|
|||
)
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from loguru import logger
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
import dropout_layer_norm
|
||||
|
@ -237,7 +229,7 @@ class IdeficsDecoupledPartialTPEmbedding(nn.Module):
|
|||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.additional_weight = nn.Parameter(
|
||||
weights.get_tensor(f"model.embed_tokens.additional_embedding.weight")
|
||||
weights.get_tensor("model.embed_tokens.additional_embedding.weight")
|
||||
)
|
||||
|
||||
def forward(self, input_ids):
|
||||
|
@ -499,7 +491,6 @@ class IdeficsAttention(nn.Module):
|
|||
# if not hasattr(nn.functional, "scaled_dot_product_attention"):
|
||||
# raise ValueError("this model requires pytorch 2.0 or higher")
|
||||
|
||||
process_group = weights.process_group
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
|
@ -1024,7 +1015,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||
if config.use_resampler:
|
||||
perceiver_config = config.perceiver_config
|
||||
self.perceiver_resampler = IdeficsPerceiverResampler(
|
||||
prefix=f"model.perceiver_resampler",
|
||||
prefix="model.perceiver_resampler",
|
||||
config=config,
|
||||
embed_dim=config.vision_config.embed_dim,
|
||||
depth=perceiver_config.resampler_depth,
|
||||
|
@ -1052,7 +1043,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||
# self.gradient_checkpointing = False
|
||||
|
||||
self.norm = IdeficsRMSNorm(
|
||||
prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
# self.gradient_checkpointing = False
|
||||
|
|
|
@ -169,7 +169,6 @@ class IdeficsPerceiverAttention(nn.Module):
|
|||
|
||||
self.qk_scale = self.head_dim**-0.5
|
||||
|
||||
process_group = weights.process_group
|
||||
if n_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} "
|
||||
|
|
|
@ -28,9 +28,6 @@ from transformers.tokenization_utils_base import (
|
|||
TruncationStrategy,
|
||||
)
|
||||
from transformers.utils import TensorType, is_torch_available
|
||||
from text_generation_server.models.custom_modeling.idefics_image_processing import (
|
||||
IdeficsImageProcessor,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
|
|
@ -129,7 +129,6 @@ class IdeficsVisionAttention(nn.Module):
|
|||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
process_group = weights.process_group
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
|
@ -460,7 +459,6 @@ class IdeficsVisionTransformer(nn.Module):
|
|||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = IdeficsVisionEmbeddings(
|
||||
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
""" PyTorch Llava-NeXT model."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
|
|
@ -4,7 +4,6 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
|||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
|
@ -194,7 +193,7 @@ def flash_attn_fn(
|
|||
):
|
||||
try:
|
||||
from flash_attn import bert_padding, flash_attn_interface
|
||||
except:
|
||||
except Exception:
|
||||
raise RuntimeError("Please install flash-attn==1.0.3.post0")
|
||||
check_valid_inputs(query, key, value)
|
||||
if past_key_value is not None:
|
||||
|
@ -207,7 +206,7 @@ def flash_attn_fn(
|
|||
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
||||
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
||||
if attn_bias is not None:
|
||||
raise NotImplementedError(f"attn_bias not implemented for flash attn.")
|
||||
raise NotImplementedError("attn_bias not implemented for flash attn.")
|
||||
(batch_size, seqlen) = query.shape[:2]
|
||||
if key_padding_mask is None:
|
||||
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
|
||||
|
@ -269,13 +268,13 @@ def triton_flash_attn_fn(
|
|||
):
|
||||
try:
|
||||
from .flash_attn_triton import flash_attn_func
|
||||
except:
|
||||
except Exception:
|
||||
_installed = False
|
||||
if version.parse(torch.__version__) < version.parse("2.0.0"):
|
||||
_installed = True
|
||||
try:
|
||||
from flash_attn.flash_attn_triton import flash_attn_func
|
||||
except:
|
||||
except Exception:
|
||||
_installed = False
|
||||
if not _installed:
|
||||
raise RuntimeError(
|
||||
|
@ -292,9 +291,9 @@ def triton_flash_attn_fn(
|
|||
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
||||
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
||||
if dropout_p:
|
||||
raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
|
||||
raise NotImplementedError("Dropout not implemented for attn_impl: triton.")
|
||||
if needs_weights:
|
||||
raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
|
||||
raise NotImplementedError("attn_impl: triton cannot return attn weights.")
|
||||
if key_padding_mask is not None:
|
||||
warnings.warn(
|
||||
"Propagating key_padding_mask to the attention module "
|
||||
|
@ -428,7 +427,7 @@ class MultiQueryAttention(nn.Module):
|
|||
additive bias.
|
||||
"""
|
||||
|
||||
def __init__(self, config, prefix, weights):
|
||||
def __init__(self, config, prefix, weights, verbose=False):
|
||||
super().__init__()
|
||||
attn_impl = config.attn_config.attn_impl
|
||||
self.attn_impl = config.attn_config.attn_impl
|
||||
|
@ -445,7 +444,7 @@ class MultiQueryAttention(nn.Module):
|
|||
self.Wqkv = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||
)
|
||||
fuse_splits = (d_model, d_model + self.head_dim)
|
||||
(d_model, d_model + self.head_dim)
|
||||
if self.qk_ln:
|
||||
raise NotImplementedError("qk_ln not supported")
|
||||
if self.attn_impl == "flash":
|
||||
|
@ -795,7 +794,9 @@ class MPTModel(MPTPreTrainedModel):
|
|||
self.alibi = config.attn_config.alibi
|
||||
self.alibi_bias_max = config.attn_config.alibi_bias_max
|
||||
if config.init_device == "mixed":
|
||||
if dist.get_local_rank() == 0:
|
||||
# TODO: reimplement mixed device initialization
|
||||
# dist.get_local_rank() == 0:
|
||||
if True:
|
||||
config.init_device = "cpu"
|
||||
else:
|
||||
config.init_device = "meta"
|
||||
|
@ -1016,7 +1017,7 @@ class MPTModel(MPTPreTrainedModel):
|
|||
if past_key_values is not None:
|
||||
if len(past_key_values) != self.config.n_layers:
|
||||
raise ValueError(
|
||||
f"past_key_values must provide a past_key_value for each attention "
|
||||
"past_key_values must provide a past_key_value for each attention "
|
||||
+ f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
|
||||
)
|
||||
past_position = past_key_values[0][0].size(1)
|
||||
|
@ -1182,7 +1183,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if self.transformer.prefix_lm:
|
||||
prefix_mask = torch.ones_like(attention_mask)
|
||||
if kwargs.get("use_cache") == False:
|
||||
if kwargs.get("use_cache") is False:
|
||||
raise NotImplementedError(
|
||||
"MPT with prefix_lm=True does not support use_cache=False."
|
||||
)
|
||||
|
|
|
@ -21,25 +21,14 @@ import torch
|
|||
import torch.distributed
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers import GPTNeoXConfig
|
||||
from loguru import logger
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -133,7 +122,6 @@ class GPTNeoXAttention(nn.Module):
|
|||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_attention_heads
|
||||
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
||||
max_positions = config.max_position_embeddings
|
||||
# ??? TODO
|
||||
# self.register_buffer(
|
||||
# "bias",
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.distributed
|
|||
|
||||
import math
|
||||
from torch import nn
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from typing import Optional, List, Tuple
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
|
|
@ -1,20 +1,15 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
from typing import Optional, Tuple
|
||||
import warnings
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_create_4d_causal_attention_mask,
|
||||
_prepare_4d_attention_mask,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
ImageClassifierOutput,
|
||||
)
|
||||
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
||||
from transformers import SiglipConfig, SiglipVisionConfig
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
from text_generation_server.layers.tensor_parallel import (
|
||||
TensorParallelEmbedding,
|
||||
|
@ -244,9 +239,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
|||
return hidden_state[:, 0]
|
||||
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
def _trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
|
@ -264,12 +256,12 @@ def _trunc_normal_(tensor, mean, std, a, b):
|
|||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
l = norm_cdf((a - mean) / std)
|
||||
u = norm_cdf((b - mean) / std)
|
||||
lower = norm_cdf((a - mean) / std)
|
||||
upper = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
|
@ -313,9 +305,6 @@ def trunc_normal_tf_(
|
|||
tensor.mul_(std).add_(mean)
|
||||
|
||||
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
|
||||
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
if mode == "fan_in":
|
||||
|
@ -349,9 +338,6 @@ def default_flax_embed_init(tensor):
|
|||
variance_scaling_(tensor, mode="fan_in", distribution="normal")
|
||||
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
class SiglipEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
|
@ -393,7 +379,6 @@ class SiglipVisionTransformer(nn.Module):
|
|||
def __init__(self, prefix, config: SiglipVisionConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = SiglipVisionEmbeddings(
|
||||
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||
|
|
|
@ -45,6 +45,15 @@ from text_generation_server.layers import (
|
|||
SpeculativeHead,
|
||||
)
|
||||
|
||||
# copied from https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/t5/modeling_t5.py#L1316
|
||||
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
__HEAD_MASK_WARNING_MSG = """
|
||||
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
|
||||
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
|
||||
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
|
||||
num_heads)`.
|
||||
"""
|
||||
|
||||
|
||||
class PartialTPEmbedding(nn.Module):
|
||||
def __init__(self, prefix: str, weights):
|
||||
|
@ -1132,12 +1141,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
# move labels to correct device to enable PP
|
||||
labels = labels.to(lm_logits.device)
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
||||
labels = labels.to(logits.device)
|
||||
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
|
||||
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
||||
output = (logits,) + decoder_outputs[1:] + encoder_outputs
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return (
|
||||
|
|
|
@ -42,7 +42,7 @@ def load_vision_model(prefix, config, weights):
|
|||
)
|
||||
|
||||
return SiglipVisionTransformer(
|
||||
prefix=f"vision_tower.vision_model", config=config, weights=weights
|
||||
prefix="vision_tower.vision_model", config=config, weights=weights
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
|
|
@ -1194,7 +1194,7 @@ class FlashCausalLM(Model):
|
|||
if self.speculate is None or self.speculate + 1 <= bs:
|
||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
logger.exception(f"Decode cuda graph warmup failed")
|
||||
logger.exception("Decode cuda graph warmup failed")
|
||||
else:
|
||||
log_master(
|
||||
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
|
||||
|
|
|
@ -2,23 +2,15 @@ import re
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||
from text_generation_server.utils import (
|
||||
NextTokenChooser,
|
||||
StoppingCriteria,
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
|
||||
|
|
|
@ -1,13 +1,8 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
AutoProcessor,
|
||||
)
|
||||
|
||||
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
|
||||
from text_generation_server.models.custom_modeling.idefics_processing import (
|
||||
|
|
|
@ -289,7 +289,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||
image_hidden_states = self.image_hidden_states[keep_indices]
|
||||
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
if type(self.past_key_values[0]) == tuple:
|
||||
if type(self.past_key_values[0]) is tuple:
|
||||
self.past_key_values = [list(layer) for layer in self.past_key_values]
|
||||
|
||||
# Update tensors in-place to allow incremental garbage collection
|
||||
|
@ -456,7 +456,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
|
||||
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
|
||||
# And ensure that we can update tensors in-place
|
||||
if type(batch.past_key_values[0]) == tuple:
|
||||
if isinstance(batch.past_key_values[0], tuple):
|
||||
batch.past_key_values = [
|
||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
|
||||
for layer in batch.past_key_values
|
||||
|
|
|
@ -2,7 +2,6 @@ import torch
|
|||
import torch.distributed
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
from typing import Optional
|
||||
import os
|
||||
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||
MambaConfig,
|
||||
)
|
||||
|
@ -20,7 +19,7 @@ from text_generation_server.models.custom_modeling.mamba_modeling import (
|
|||
InferenceParams,
|
||||
)
|
||||
from text_generation_server.models import Model
|
||||
from typing import Any, List, Optional, Tuple, Type, Dict
|
||||
from typing import Any, List, Tuple, Type, Dict
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
Tokens,
|
||||
|
@ -31,7 +30,7 @@ from text_generation_server.utils.chunks import concat_text_chunks
|
|||
from text_generation_server.utils.quantization import get_loader
|
||||
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria
|
||||
|
||||
|
||||
def new_inference_params(
|
||||
|
@ -299,7 +298,6 @@ class MambaBatch(Batch):
|
|||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
max_tokens = 0
|
||||
max_seqlen = 0
|
||||
seqlen_offset = 0
|
||||
|
||||
(n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
|
||||
|
@ -485,7 +483,7 @@ class Mamba(Model):
|
|||
for bs in CUDA_GRAPHS:
|
||||
self.cuda_graph_warmup(bs)
|
||||
except Exception:
|
||||
logger.exception(f"Decode cuda graph warmup failed")
|
||||
logger.exception("Decode cuda graph warmup failed")
|
||||
else:
|
||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||
|
||||
|
@ -534,7 +532,7 @@ class Mamba(Model):
|
|||
}
|
||||
self.cuda_graphs[batch_size] = graph_dict
|
||||
|
||||
def tunableop_warmup(self, seqlen: int):
|
||||
def tunableop_warmup(self, batch_size: int, seqlen: int):
|
||||
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
|
||||
n_blocks = len(self.model.blocks)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import inspect
|
|||
import torch
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
|
||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict
|
||||
from collections import defaultdict
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
|
|
|
@ -3,16 +3,11 @@ from PIL import Image
|
|||
import torch
|
||||
import torch.distributed
|
||||
from opentelemetry import trace
|
||||
from typing import Iterable, Optional, Tuple
|
||||
from typing import Iterable
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLM,
|
||||
VlmCausalLMBatch,
|
||||
image_text_replacement,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||
PaliGemmaForConditionalGeneration,
|
||||
)
|
||||
from transformers import AutoProcessor, AutoConfig
|
||||
|
||||
from text_generation_server.pb.generate_pb2 import Request
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import (
|
||||
|
@ -11,7 +10,7 @@ from transformers import (
|
|||
AutoConfig,
|
||||
)
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
|
@ -254,7 +253,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
]
|
||||
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
if type(self.past_key_values[0]) == tuple:
|
||||
if type(self.past_key_values[0]) is tuple:
|
||||
self.past_key_values = [
|
||||
[t for t in layer] for layer in self.past_key_values
|
||||
]
|
||||
|
@ -430,7 +429,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
batch.encoder_last_hidden_state = None
|
||||
|
||||
# Ensure that we can update tensors in-place
|
||||
if type(batch.past_key_values[0]) == tuple:
|
||||
if isinstance(batch.past_key_values[0], tuple):
|
||||
batch.past_key_values = [
|
||||
[t for t in layer] for layer in batch.past_key_values
|
||||
]
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from functools import total_ordering
|
||||
import torch
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
|
|
@ -9,7 +9,7 @@ from loguru import logger
|
|||
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
from typing import List, Optional
|
||||
|
||||
from text_generation_server.cache import Cache
|
||||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
import torch
|
||||
from loguru import logger
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
|
||||
import importlib.util
|
||||
|
||||
|
||||
def is_ipex_available():
|
||||
try:
|
||||
import intel_extension_for_pytorch
|
||||
except ImportError:
|
||||
return False
|
||||
return True
|
||||
return importlib.util.find_spec("intel_extension_for_pytorch") is not None
|
||||
|
||||
|
||||
def get_cuda_free_memory(device, memory_fraction):
|
||||
|
|
|
@ -2,9 +2,17 @@ import copy
|
|||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
|
||||
|
||||
from text_generation_server.utils.merges.utils import (
|
||||
calculate_majority_sign_mask,
|
||||
disjoint_merge,
|
||||
prune,
|
||||
)
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.adapters.lora import LoraConfig
|
||||
from text_generation_server.utils.adapter import ModuleMap
|
||||
|
||||
|
||||
class AdapterParameters:
|
||||
def __init__(
|
||||
|
@ -17,17 +25,6 @@ class AdapterParameters:
|
|||
self.majority_sign_method = majority_sign_method
|
||||
|
||||
|
||||
from text_generation_server.utils.merges.utils import (
|
||||
calculate_majority_sign_mask,
|
||||
disjoint_merge,
|
||||
prune,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.adapters.lora import LoraConfig
|
||||
from text_generation_server.utils.adapter import ModuleMap
|
||||
|
||||
|
||||
def _apply_weights(
|
||||
tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
|
|
|
@ -28,7 +28,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
|||
low_cpu_mem_usage=True,
|
||||
)
|
||||
logger.info("Peft model detected.")
|
||||
logger.info(f"Merging the lora weights.")
|
||||
logger.info("Merging the lora weights.")
|
||||
|
||||
base_model_id = model.peft_config["default"].base_model_name_or_path
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ from typing import Optional
|
|||
from huggingface_hub import hf_hub_download
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
WeightsLoader,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import re
|
||||
from typing import List, Optional, Tuple, Set, Union
|
||||
|
||||
import math
|
||||
import torch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
|
||||
|
|
Loading…
Reference in New Issue