parent
1da642bd0e
commit
73a4d65d26
|
@ -48,8 +48,12 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_gptq_load(flash_llama_gptq, generate_load, response_snapshot):
|
async def test_flash_llama_gptq_load(
|
||||||
responses = await generate_load(flash_llama_gptq, "Test request", max_new_tokens=10, n=4)
|
flash_llama_gptq, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_llama_gptq, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
|
@ -17,7 +17,9 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot):
|
async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot):
|
||||||
response = await flash_starcoder_gptq.generate(
|
response = await flash_starcoder_gptq.generate(
|
||||||
"def geometric_mean(L: List[float]):", max_new_tokens=20, decoder_input_details=True,
|
"def geometric_mean(L: List[float]):",
|
||||||
|
max_new_tokens=20,
|
||||||
|
decoder_input_details=True,
|
||||||
)
|
)
|
||||||
assert response.details.generated_tokens == 20
|
assert response.details.generated_tokens == 20
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
@ -25,7 +27,9 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, response_snapshot):
|
async def test_flash_starcoder_gptq_default_params(
|
||||||
|
flash_starcoder_gptq, response_snapshot
|
||||||
|
):
|
||||||
response = await flash_starcoder_gptq.generate(
|
response = await flash_starcoder_gptq.generate(
|
||||||
"def geometric_mean(L: List[float]):",
|
"def geometric_mean(L: List[float]):",
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
|
@ -40,10 +44,17 @@ async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, respons
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, response_snapshot):
|
async def test_flash_starcoder_gptq_load(
|
||||||
responses = await generate_load(flash_starcoder_gptq, "def geometric_mean(L: List[float]):", max_new_tokens=10, n=4)
|
flash_starcoder_gptq, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_starcoder_gptq,
|
||||||
|
"def geometric_mean(L: List[float]):",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
|
|
@ -245,6 +245,11 @@ struct Args {
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
disable_custom_kernels: bool,
|
disable_custom_kernels: bool,
|
||||||
|
|
||||||
|
/// Limit the CUDA available memory.
|
||||||
|
/// The allowed value equals the total visible memory multiplied by cuda-memory-fraction.
|
||||||
|
#[clap(default_value = "1.0", long, env)]
|
||||||
|
cuda_memory_fraction: f32,
|
||||||
|
|
||||||
/// Outputs the logs in JSON format (useful for telemetry)
|
/// Outputs the logs in JSON format (useful for telemetry)
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
json_output: bool,
|
json_output: bool,
|
||||||
|
@ -299,6 +304,7 @@ fn shard_manager(
|
||||||
disable_custom_kernels: bool,
|
disable_custom_kernels: bool,
|
||||||
watermark_gamma: Option<f32>,
|
watermark_gamma: Option<f32>,
|
||||||
watermark_delta: Option<f32>,
|
watermark_delta: Option<f32>,
|
||||||
|
cuda_memory_fraction: f32,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
|
@ -368,6 +374,12 @@ fn shard_manager(
|
||||||
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
||||||
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
|
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
|
||||||
|
|
||||||
|
// CUDA memory fraction
|
||||||
|
envs.push((
|
||||||
|
"CUDA_MEMORY_FRACTION".into(),
|
||||||
|
cuda_memory_fraction.to_string().into(),
|
||||||
|
));
|
||||||
|
|
||||||
// Safetensors load fast
|
// Safetensors load fast
|
||||||
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
|
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
|
||||||
|
|
||||||
|
@ -771,6 +783,7 @@ fn spawn_shards(
|
||||||
let disable_custom_kernels = args.disable_custom_kernels;
|
let disable_custom_kernels = args.disable_custom_kernels;
|
||||||
let watermark_gamma = args.watermark_gamma;
|
let watermark_gamma = args.watermark_gamma;
|
||||||
let watermark_delta = args.watermark_delta;
|
let watermark_delta = args.watermark_delta;
|
||||||
|
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -788,6 +801,7 @@ fn spawn_shards(
|
||||||
disable_custom_kernels,
|
disable_custom_kernels,
|
||||||
watermark_gamma,
|
watermark_gamma,
|
||||||
watermark_delta,
|
watermark_delta,
|
||||||
|
cuda_memory_fraction,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
|
|
|
@ -101,8 +101,12 @@ impl ShardedClient {
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
|
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
|
||||||
.collect();
|
.collect();
|
||||||
// all shards return the same message
|
// Take the minimum value
|
||||||
join_all(futures).await.pop().unwrap()
|
let results = join_all(futures)
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect::<Result<Vec<Option<u32>>>>()?;
|
||||||
|
Ok(results.into_iter().flatten().min())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
|
|
@ -11,7 +11,7 @@ setup(
|
||||||
"exllama_kernels/cuda_buffers.cu",
|
"exllama_kernels/cuda_buffers.cu",
|
||||||
"exllama_kernels/cuda_func/column_remap.cu",
|
"exllama_kernels/cuda_func/column_remap.cu",
|
||||||
"exllama_kernels/cuda_func/q4_matmul.cu",
|
"exllama_kernels/cuda_func/q4_matmul.cu",
|
||||||
"exllama_kernels/cuda_func/q4_matrix.cu"
|
"exllama_kernels/cuda_func/q4_matrix.cu",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
|
|
@ -20,6 +20,7 @@ from text_generation_server.utils.layers import (
|
||||||
)
|
)
|
||||||
from safetensors import SafetensorError
|
from safetensors import SafetensorError
|
||||||
|
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
):
|
):
|
||||||
|
@ -78,6 +79,7 @@ def _load_multi_mqa_gptq(
|
||||||
bits, groupsize = weights._get_gptq_qparams()
|
bits, groupsize = weights._get_gptq_qparams()
|
||||||
|
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
|
||||||
use_exllama = HAS_EXLLAMA
|
use_exllama = HAS_EXLLAMA
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ from text_generation_server.models.types import (
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -738,7 +739,12 @@ class FlashCausalLM(Model):
|
||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
free_memory, _ = torch.cuda.mem_get_info(self.device)
|
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||||
|
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||||
|
|
||||||
|
free_memory = max(
|
||||||
|
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
|
||||||
|
)
|
||||||
|
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
int(free_memory // total_cache_size)
|
int(free_memory // total_cache_size)
|
||||||
|
|
|
@ -10,6 +10,7 @@ from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
|
|
||||||
class Model(ABC):
|
class Model(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -21,9 +22,6 @@ class Model(ABC):
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.set_per_process_memory_fraction(1.0)
|
|
||||||
|
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
|
|
|
@ -16,7 +16,6 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
|
@ -146,7 +145,10 @@ def serve(
|
||||||
# When using GPTQ, Exllama kernels need some global kernels
|
# When using GPTQ, Exllama kernels need some global kernels
|
||||||
# For which we have the finale shapes only after the model has loaded
|
# For which we have the finale shapes only after the model has loaded
|
||||||
# This will allocate those buffers.
|
# This will allocate those buffers.
|
||||||
from text_generation_server.utils.gptq.exllama import create_exllama_buffers
|
from text_generation_server.utils.gptq.exllama import (
|
||||||
|
create_exllama_buffers,
|
||||||
|
)
|
||||||
|
|
||||||
create_exllama_buffers()
|
create_exllama_buffers()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -4,6 +4,13 @@ import torch
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
# Tensor Parallelism settings
|
||||||
|
RANK = int(os.getenv("RANK", "0"))
|
||||||
|
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
# CUDA memory fraction
|
||||||
|
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
||||||
|
|
||||||
|
|
||||||
class FakeBarrier:
|
class FakeBarrier:
|
||||||
def wait(self):
|
def wait(self):
|
||||||
|
@ -37,16 +44,14 @@ class FakeGroup:
|
||||||
|
|
||||||
|
|
||||||
def initialize_torch_distributed():
|
def initialize_torch_distributed():
|
||||||
rank = int(os.getenv("RANK", "0"))
|
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
from torch.distributed import ProcessGroupNCCL
|
from torch.distributed import ProcessGroupNCCL
|
||||||
|
|
||||||
# Set the device id.
|
# Set the device id.
|
||||||
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
|
assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
|
||||||
device = rank % torch.cuda.device_count()
|
device = RANK % torch.cuda.device_count()
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
||||||
backend = "nccl"
|
backend = "nccl"
|
||||||
options = ProcessGroupNCCL.Options()
|
options = ProcessGroupNCCL.Options()
|
||||||
options.is_high_priority_stream = True
|
options.is_high_priority_stream = True
|
||||||
|
@ -55,22 +60,22 @@ def initialize_torch_distributed():
|
||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
options = None
|
options = None
|
||||||
|
|
||||||
if world_size == 1:
|
if WORLD_SIZE == 1:
|
||||||
return FakeGroup(rank, world_size), rank, world_size
|
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
||||||
else:
|
else:
|
||||||
if os.getenv("DEBUG", None) == "1":
|
if os.getenv("DEBUG", None) == "1":
|
||||||
return FakeGroup(rank, world_size), rank, world_size
|
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
||||||
|
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
# Call the init process.
|
# Call the init process.
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
world_size=world_size,
|
world_size=WORLD_SIZE,
|
||||||
rank=rank,
|
rank=RANK,
|
||||||
timeout=timedelta(seconds=60),
|
timeout=timedelta(seconds=60),
|
||||||
pg_options=options,
|
pg_options=options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("torch.distributed is already initialized.")
|
logger.warning("torch.distributed is already initialized.")
|
||||||
|
|
||||||
return torch.distributed.group.WORLD, rank, world_size
|
return torch.distributed.group.WORLD, RANK, WORLD_SIZE
|
||||||
|
|
|
@ -1,28 +1,28 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
|
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
|
||||||
|
|
||||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||||
none_tensor = torch.empty((1, 1), device = "meta")
|
none_tensor = torch.empty((1, 1), device="meta")
|
||||||
|
|
||||||
|
|
||||||
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
|
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
|
||||||
"""Construct Q4Matrix, return handle"""
|
"""Construct Q4Matrix, return handle"""
|
||||||
return make_q4(qweight,
|
return make_q4(
|
||||||
qzeros,
|
qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device
|
||||||
scales,
|
)
|
||||||
g_idx if g_idx is not None else none_tensor,
|
|
||||||
device)
|
|
||||||
|
|
||||||
def ext_q4_matmul(x, q4, q4_width):
|
def ext_q4_matmul(x, q4, q4_width):
|
||||||
"""Matrix multiplication, returns x @ q4"""
|
"""Matrix multiplication, returns x @ q4"""
|
||||||
outshape = x.shape[:-1] + (q4_width,)
|
outshape = x.shape[:-1] + (q4_width,)
|
||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
|
output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
|
||||||
|
|
||||||
q4_matmul(x, q4, output)
|
q4_matmul(x, q4, output)
|
||||||
|
|
||||||
return output.view(outshape)
|
return output.view(outshape)
|
||||||
|
|
||||||
|
|
||||||
MAX_DQ = 1
|
MAX_DQ = 1
|
||||||
MAX_INNER = 1
|
MAX_INNER = 1
|
||||||
ACT_ORDER = False
|
ACT_ORDER = False
|
||||||
|
@ -31,9 +31,10 @@ DEVICE = None
|
||||||
TEMP_STATE = None
|
TEMP_STATE = None
|
||||||
TEMP_DQ = None
|
TEMP_DQ = None
|
||||||
|
|
||||||
|
|
||||||
def create_exllama_buffers():
|
def create_exllama_buffers():
|
||||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
|
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
|
||||||
|
|
||||||
if ACT_ORDER:
|
if ACT_ORDER:
|
||||||
# TODO: this should be set to rust side `max_total_tokens`, but TGI
|
# TODO: this should be set to rust side `max_total_tokens`, but TGI
|
||||||
# does not offer an API to expose this variable to python, as this variable
|
# does not offer an API to expose this variable to python, as this variable
|
||||||
|
@ -45,7 +46,9 @@ def create_exllama_buffers():
|
||||||
max_total_tokens = 1
|
max_total_tokens = 1
|
||||||
|
|
||||||
# This temp_state buffer is required to reorder X in the act-order case.
|
# This temp_state buffer is required to reorder X in the act-order case.
|
||||||
temp_state = torch.zeros((max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE)
|
temp_state = torch.zeros(
|
||||||
|
(max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
|
||||||
|
)
|
||||||
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
|
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
|
||||||
|
|
||||||
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||||
|
@ -56,10 +59,12 @@ def create_exllama_buffers():
|
||||||
matmul_no_half2 = False
|
matmul_no_half2 = False
|
||||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||||
|
|
||||||
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
|
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
|
||||||
|
|
||||||
|
|
||||||
class Ex4bitLinear:
|
class Ex4bitLinear:
|
||||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||||
|
|
||||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
|
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
|
||||||
assert bits == 4
|
assert bits == 4
|
||||||
|
@ -70,20 +75,24 @@ class Ex4bitLinear:
|
||||||
self.scales = scales
|
self.scales = scales
|
||||||
self.g_idx = g_idx.cpu() if g_idx is not None else None
|
self.g_idx = g_idx.cpu() if g_idx is not None else None
|
||||||
self.bias = bias if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
|
if self.g_idx is not None and (
|
||||||
|
(self.g_idx == 0).all()
|
||||||
|
or torch.equal(
|
||||||
|
g_idx.cpu(),
|
||||||
|
torch.tensor(
|
||||||
|
[i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32
|
||||||
|
),
|
||||||
|
)
|
||||||
|
):
|
||||||
self.empty_g_idx = True
|
self.empty_g_idx = True
|
||||||
self.g_idx = None
|
self.g_idx = None
|
||||||
|
|
||||||
assert self.device.type == "cuda"
|
assert self.device.type == "cuda"
|
||||||
assert self.device.index is not None
|
assert self.device.index is not None
|
||||||
|
|
||||||
self.q4 = ext_make_q4(
|
self.q4 = ext_make_q4(
|
||||||
self.qweight,
|
self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
|
||||||
self.qzeros,
|
|
||||||
self.scales,
|
|
||||||
self.g_idx,
|
|
||||||
self.device.index
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.height = qweight.shape[0] * 8
|
self.height = qweight.shape[0] * 8
|
||||||
|
@ -99,7 +108,8 @@ class Ex4bitLinear:
|
||||||
|
|
||||||
# Handle act-order matrix
|
# Handle act-order matrix
|
||||||
if self.g_idx is not None:
|
if self.g_idx is not None:
|
||||||
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
|
if self.groupsize is None:
|
||||||
|
raise ValueError("Found group index but no groupsize. What do?")
|
||||||
self.act_order = True
|
self.act_order = True
|
||||||
else:
|
else:
|
||||||
self.act_order = False
|
self.act_order = False
|
||||||
|
@ -112,7 +122,7 @@ class Ex4bitLinear:
|
||||||
MAX_INNER = max(MAX_INNER, self.height, self.width)
|
MAX_INNER = max(MAX_INNER, self.height, self.width)
|
||||||
|
|
||||||
ACT_ORDER = True
|
ACT_ORDER = True
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = ext_q4_matmul(x, self.q4, self.width)
|
out = ext_q4_matmul(x, self.q4, self.width)
|
||||||
|
|
||||||
|
|
|
@ -815,11 +815,7 @@ def load_weights_pre_hook(module_name, weights, recursive=False):
|
||||||
tensor = current_tensor.to(device=torch.device("cuda:0"))
|
tensor = current_tensor.to(device=torch.device("cuda:0"))
|
||||||
if current_tensor.requires_grad:
|
if current_tensor.requires_grad:
|
||||||
tensor = nn.Parameter(tensor)
|
tensor = nn.Parameter(tensor)
|
||||||
setdeepattr(
|
setdeepattr(module, local_param, tensor)
|
||||||
module,
|
|
||||||
local_param,
|
|
||||||
tensor
|
|
||||||
)
|
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
|
@ -17,9 +17,10 @@ except ImportError:
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
|
|
||||||
HAS_EXLLAMA = True
|
HAS_EXLLAMA = True
|
||||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||||
HAS_EXLLAMA=False
|
HAS_EXLLAMA = False
|
||||||
try:
|
try:
|
||||||
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
|
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
|
@ -146,7 +146,16 @@ class Weights:
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
if g_idx is not None:
|
if g_idx is not None:
|
||||||
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
|
if (
|
||||||
|
not torch.equal(
|
||||||
|
g_idx.cpu(),
|
||||||
|
torch.tensor(
|
||||||
|
[i // groupsize for i in range(g_idx.shape[0])],
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and not (g_idx == 0).all()
|
||||||
|
):
|
||||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||||
# it would require to reorder input activations that are split unto several GPUs
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
@ -154,18 +163,21 @@ class Weights:
|
||||||
try:
|
try:
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
raise RuntimeError(
|
||||||
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
if not HAS_EXLLAMA:
|
if not HAS_EXLLAMA:
|
||||||
logger.warning("Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True")
|
logger.warning(
|
||||||
|
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
|
||||||
|
)
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
else:
|
else:
|
||||||
logger.info("Using exllama kernels")
|
logger.info("Using exllama kernels")
|
||||||
|
|
||||||
|
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
if groupsize >= 0:
|
if groupsize >= 0:
|
||||||
# Exllama reorders the weights in advance and the activations on the fly, thus
|
# Exllama reorders the weights in advance and the activations on the fly, thus
|
||||||
|
@ -173,7 +185,9 @@ class Weights:
|
||||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Using exllama GPTQ kernel with groupsize<1 is not supported")
|
raise RuntimeError(
|
||||||
|
"Using exllama GPTQ kernel with groupsize<1 is not supported"
|
||||||
|
)
|
||||||
# qzeros = self.get_tensor(f"{prefix}.qzeros")
|
# qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||||
# scales = self.get_tensor(f"{prefix}.scales")
|
# scales = self.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue