diff --git a/integration-tests/models/test_flash_llama_gptq.py b/integration-tests/models/test_flash_llama_gptq.py index bc525f6..b87f054 100644 --- a/integration-tests/models/test_flash_llama_gptq.py +++ b/integration-tests/models/test_flash_llama_gptq.py @@ -48,8 +48,12 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): @pytest.mark.asyncio @pytest.mark.private -async def test_flash_llama_gptq_load(flash_llama_gptq, generate_load, response_snapshot): - responses = await generate_load(flash_llama_gptq, "Test request", max_new_tokens=10, n=4) +async def test_flash_llama_gptq_load( + flash_llama_gptq, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_gptq, "Test request", max_new_tokens=10, n=4 + ) assert len(responses) == 4 assert all([r.generated_text == responses[0].generated_text for r in responses]) diff --git a/integration-tests/models/test_flash_starcoder_gptq.py b/integration-tests/models/test_flash_starcoder_gptq.py index b6bed6a..608101f 100644 --- a/integration-tests/models/test_flash_starcoder_gptq.py +++ b/integration-tests/models/test_flash_starcoder_gptq.py @@ -17,7 +17,9 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle): @pytest.mark.private async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot): response = await flash_starcoder_gptq.generate( - "def geometric_mean(L: List[float]):", max_new_tokens=20, decoder_input_details=True, + "def geometric_mean(L: List[float]):", + max_new_tokens=20, + decoder_input_details=True, ) assert response.details.generated_tokens == 20 assert response == response_snapshot @@ -25,7 +27,9 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot): @pytest.mark.asyncio @pytest.mark.private -async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, response_snapshot): +async def test_flash_starcoder_gptq_default_params( + flash_starcoder_gptq, response_snapshot +): response = await flash_starcoder_gptq.generate( "def geometric_mean(L: List[float]):", max_new_tokens=20, @@ -40,10 +44,17 @@ async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, respons @pytest.mark.asyncio @pytest.mark.private -async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, response_snapshot): - responses = await generate_load(flash_starcoder_gptq, "def geometric_mean(L: List[float]):", max_new_tokens=10, n=4) +async def test_flash_starcoder_gptq_load( + flash_starcoder_gptq, generate_load, response_snapshot +): + responses = await generate_load( + flash_starcoder_gptq, + "def geometric_mean(L: List[float]):", + max_new_tokens=10, + n=4, + ) assert len(responses) == 4 assert all([r.generated_text == responses[0].generated_text for r in responses]) - assert responses == response_snapshot \ No newline at end of file + assert responses == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 53de36b..2ad788a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -245,6 +245,11 @@ struct Args { #[clap(long, env)] disable_custom_kernels: bool, + /// Limit the CUDA available memory. + /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction. + #[clap(default_value = "1.0", long, env)] + cuda_memory_fraction: f32, + /// Outputs the logs in JSON format (useful for telemetry) #[clap(long, env)] json_output: bool, @@ -299,6 +304,7 @@ fn shard_manager( disable_custom_kernels: bool, watermark_gamma: Option, watermark_delta: Option, + cuda_memory_fraction: f32, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc, @@ -368,6 +374,12 @@ fn shard_manager( envs.push(("MASTER_PORT".into(), master_port.to_string().into())); envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into())); + // CUDA memory fraction + envs.push(( + "CUDA_MEMORY_FRACTION".into(), + cuda_memory_fraction.to_string().into(), + )); + // Safetensors load fast envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); @@ -771,6 +783,7 @@ fn spawn_shards( let disable_custom_kernels = args.disable_custom_kernels; let watermark_gamma = args.watermark_gamma; let watermark_delta = args.watermark_delta; + let cuda_memory_fraction = args.cuda_memory_fraction; thread::spawn(move || { shard_manager( model_id, @@ -788,6 +801,7 @@ fn spawn_shards( disable_custom_kernels, watermark_gamma, watermark_delta, + cuda_memory_fraction, otlp_endpoint, status_sender, shutdown, diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 6d146bc..112b003 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -101,8 +101,12 @@ impl ShardedClient { .iter_mut() .map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens))) .collect(); - // all shards return the same message - join_all(futures).await.pop().unwrap() + // Take the minimum value + let results = join_all(futures) + .await + .into_iter() + .collect::>>>()?; + Ok(results.into_iter().flatten().min()) } /// Generate one token for each request in the given batch diff --git a/server/exllama_kernels/setup.py b/server/exllama_kernels/setup.py index f06a72b..987d181 100644 --- a/server/exllama_kernels/setup.py +++ b/server/exllama_kernels/setup.py @@ -11,7 +11,7 @@ setup( "exllama_kernels/cuda_buffers.cu", "exllama_kernels/cuda_func/column_remap.cu", "exllama_kernels/cuda_func/q4_matmul.cu", - "exllama_kernels/cuda_func/q4_matrix.cu" + "exllama_kernels/cuda_func/q4_matrix.cu", ], ) ], diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 04bd422..c16b2bf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -20,6 +20,7 @@ from text_generation_server.utils.layers import ( ) from safetensors import SafetensorError + def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): @@ -78,6 +79,7 @@ def _load_multi_mqa_gptq( bits, groupsize = weights._get_gptq_qparams() from text_generation_server.utils.layers import HAS_EXLLAMA + use_exllama = HAS_EXLLAMA weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 547678a..7de5135 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -19,6 +19,7 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser +from text_generation_server.utils.dist import MEMORY_FRACTION tracer = trace.get_tracer(__name__) @@ -738,7 +739,12 @@ class FlashCausalLM(Model): cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size - free_memory, _ = torch.cuda.mem_get_info(self.device) + total_free_memory, _ = torch.cuda.mem_get_info(self.device) + total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory + + free_memory = max( + 0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory + ) num_blocks = ( int(free_memory // total_cache_size) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 89e6e99..9d74247 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -10,6 +10,7 @@ from text_generation_server.pb.generate_pb2 import InfoResponse B = TypeVar("B", bound=Batch) + class Model(ABC): def __init__( self, @@ -21,9 +22,6 @@ class Model(ABC): rank: int = 0, world_size: int = 1, ): - if torch.cuda.is_available(): - torch.cuda.set_per_process_memory_fraction(1.0) - self.model = model.eval() self.tokenizer = tokenizer self.all_special_ids = set(tokenizer.all_special_ids) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index b279426..0929b46 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -16,7 +16,6 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor - class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__(self, model: Model, cache: Cache, server_urls: List[str]): self.cache = cache @@ -146,7 +145,10 @@ def serve( # When using GPTQ, Exllama kernels need some global kernels # For which we have the finale shapes only after the model has loaded # This will allocate those buffers. - from text_generation_server.utils.gptq.exllama import create_exllama_buffers + from text_generation_server.utils.gptq.exllama import ( + create_exllama_buffers, + ) + create_exllama_buffers() except ImportError: pass diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 41a8e01..d02bfc5 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -4,6 +4,13 @@ import torch from datetime import timedelta from loguru import logger +# Tensor Parallelism settings +RANK = int(os.getenv("RANK", "0")) +WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + +# CUDA memory fraction +MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) + class FakeBarrier: def wait(self): @@ -37,16 +44,14 @@ class FakeGroup: def initialize_torch_distributed(): - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - if torch.cuda.is_available(): from torch.distributed import ProcessGroupNCCL # Set the device id. - assert world_size <= torch.cuda.device_count(), "Each process is one gpu" - device = rank % torch.cuda.device_count() + assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" + device = RANK % torch.cuda.device_count() torch.cuda.set_device(device) + torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) backend = "nccl" options = ProcessGroupNCCL.Options() options.is_high_priority_stream = True @@ -55,22 +60,22 @@ def initialize_torch_distributed(): backend = "gloo" options = None - if world_size == 1: - return FakeGroup(rank, world_size), rank, world_size + if WORLD_SIZE == 1: + return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE else: if os.getenv("DEBUG", None) == "1": - return FakeGroup(rank, world_size), rank, world_size + return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE if not torch.distributed.is_initialized(): # Call the init process. torch.distributed.init_process_group( backend=backend, - world_size=world_size, - rank=rank, + world_size=WORLD_SIZE, + rank=RANK, timeout=timedelta(seconds=60), pg_options=options, ) else: logger.warning("torch.distributed is already initialized.") - return torch.distributed.group.WORLD, rank, world_size + return torch.distributed.group.WORLD, RANK, WORLD_SIZE diff --git a/server/text_generation_server/utils/gptq/exllama.py b/server/text_generation_server/utils/gptq/exllama.py index aba6679..e89b725 100644 --- a/server/text_generation_server/utils/gptq/exllama.py +++ b/server/text_generation_server/utils/gptq/exllama.py @@ -1,28 +1,28 @@ - import torch from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device = "meta") +none_tensor = torch.empty((1, 1), device="meta") + def ext_make_q4(qweight, qzeros, scales, g_idx, device): """Construct Q4Matrix, return handle""" - return make_q4(qweight, - qzeros, - scales, - g_idx if g_idx is not None else none_tensor, - device) + return make_q4( + qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device + ) + def ext_q4_matmul(x, q4, q4_width): """Matrix multiplication, returns x @ q4""" outshape = x.shape[:-1] + (q4_width,) x = x.view(-1, x.shape[-1]) - output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device) + output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device) q4_matmul(x, q4, output) return output.view(outshape) + MAX_DQ = 1 MAX_INNER = 1 ACT_ORDER = False @@ -31,9 +31,10 @@ DEVICE = None TEMP_STATE = None TEMP_DQ = None + def create_exllama_buffers(): global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ - + if ACT_ORDER: # TODO: this should be set to rust side `max_total_tokens`, but TGI # does not offer an API to expose this variable to python, as this variable @@ -45,7 +46,9 @@ def create_exllama_buffers(): max_total_tokens = 1 # This temp_state buffer is required to reorder X in the act-order case. - temp_state = torch.zeros((max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE) + temp_state = torch.zeros( + (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE + ) temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE) # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. @@ -56,10 +59,12 @@ def create_exllama_buffers(): matmul_no_half2 = False set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) - TEMP_STATE, TEMP_DQ = temp_state, temp_dq + TEMP_STATE, TEMP_DQ = temp_state, temp_dq + class Ex4bitLinear: """Linear layer implementation with per-group 4-bit quantization of the weights""" + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE assert bits == 4 @@ -70,20 +75,24 @@ class Ex4bitLinear: self.scales = scales self.g_idx = g_idx.cpu() if g_idx is not None else None self.bias = bias if bias is not None else None - - if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))): + + if self.g_idx is not None and ( + (self.g_idx == 0).all() + or torch.equal( + g_idx.cpu(), + torch.tensor( + [i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32 + ), + ) + ): self.empty_g_idx = True self.g_idx = None - + assert self.device.type == "cuda" assert self.device.index is not None self.q4 = ext_make_q4( - self.qweight, - self.qzeros, - self.scales, - self.g_idx, - self.device.index + self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index ) self.height = qweight.shape[0] * 8 @@ -99,7 +108,8 @@ class Ex4bitLinear: # Handle act-order matrix if self.g_idx is not None: - if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?") + if self.groupsize is None: + raise ValueError("Found group index but no groupsize. What do?") self.act_order = True else: self.act_order = False @@ -112,7 +122,7 @@ class Ex4bitLinear: MAX_INNER = max(MAX_INNER, self.height, self.width) ACT_ORDER = True - + def forward(self, x): out = ext_q4_matmul(x, self.q4, self.width) diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index 45b01ae..6eb44e4 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -815,11 +815,7 @@ def load_weights_pre_hook(module_name, weights, recursive=False): tensor = current_tensor.to(device=torch.device("cuda:0")) if current_tensor.requires_grad: tensor = nn.Parameter(tensor) - setdeepattr( - module, - local_param, - tensor - ) + setdeepattr(module, local_param, tensor) return inner diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 4f28016..183cf2c 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -17,9 +17,10 @@ except ImportError: from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear + HAS_EXLLAMA = True if os.getenv("DISABLE_EXLLAMA") == "True": - HAS_EXLLAMA=False + HAS_EXLLAMA = False try: from text_generation_server.utils.gptq.exllama import Ex4bitLinear except ImportError: diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 4f284ea..dae5350 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -146,7 +146,16 @@ class Weights: if self.process_group.size() > 1: g_idx = self.get_tensor(f"{prefix}.g_idx") if g_idx is not None: - if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all(): + if ( + not torch.equal( + g_idx.cpu(), + torch.tensor( + [i // groupsize for i in range(g_idx.shape[0])], + dtype=torch.int32, + ), + ) + and not (g_idx == 0).all() + ): # Exllama implementation does not support row tensor parallelism with act-order, as # it would require to reorder input activations that are split unto several GPUs use_exllama = False @@ -154,18 +163,21 @@ class Weights: try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: - raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") - + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) from text_generation_server.utils.layers import HAS_EXLLAMA + if use_exllama: if not HAS_EXLLAMA: - logger.warning("Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True") + logger.warning( + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True" + ) use_exllama = False else: logger.info("Using exllama kernels") - if use_exllama: if groupsize >= 0: # Exllama reorders the weights in advance and the activations on the fly, thus @@ -173,7 +185,9 @@ class Weights: qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0) else: - raise RuntimeError("Using exllama GPTQ kernel with groupsize<1 is not supported") + raise RuntimeError( + "Using exllama GPTQ kernel with groupsize<1 is not supported" + ) # qzeros = self.get_tensor(f"{prefix}.qzeros") # scales = self.get_tensor(f"{prefix}.scales")