diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 401082c5..66f2055a 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -23,6 +23,8 @@ pub enum ClientError { Connection(String), #[error("Server error: {0}")] Generation(String), + #[error("Sharded results are empty")] + EmptyResults, } impl From for ClientError { diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 73524740..60b81fe6 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,6 +1,6 @@ /// Multi shard Client -use crate::Result; use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; +use crate::{ClientError, Result}; use futures::future::join_all; use tonic::transport::Uri; use tracing::instrument; @@ -98,8 +98,9 @@ impl ShardedClient { .iter_mut() .map(|client| Box::pin(client.prefill(batch.clone()))) .collect(); - // all shards return the same message - join_all(futures).await.pop().unwrap() + let results: Result, Option)>> = + join_all(futures).await.into_iter().collect(); + merge_generations(results?) } /// Generate one token for each request in the given cached batches @@ -116,7 +117,20 @@ impl ShardedClient { .iter_mut() .map(|client| Box::pin(client.decode(batches.clone()))) .collect(); - // all shards return the same message - join_all(futures).await.pop().unwrap() + let results: Result, Option)>> = + join_all(futures).await.into_iter().collect(); + merge_generations(results?) } } + +/// Merge generations from the different model shards +fn merge_generations( + mut results: Vec<(Vec, Option)>, +) -> Result<(Vec, Option)> { + let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?; + + for (mut shard_generations, _) in results.into_iter() { + generations.append(&mut shard_generations); + } + Ok((generations, next_batch)) +} diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index f528a430..25ec8cb8 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -63,10 +63,10 @@ class BLOOMSharded(BLOOM): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): - self.process_group, self.rank, self.world_size = initialize_torch_distributed() - self.master = self.rank == 0 + self.process_group, rank, world_size = initialize_torch_distributed() + self.master = rank == 0 if torch.cuda.is_available(): - device = torch.device(f"cuda:{self.rank}") + device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") @@ -94,8 +94,8 @@ class BLOOMSharded(BLOOM): quantize=quantize, device=device, dtype=dtype, - rank=self.rank, - world_size=self.world_size, + rank=rank, + world_size=world_size, ) self.model = model.eval() torch.distributed.barrier(group=self.process_group) @@ -105,6 +105,8 @@ class BLOOMSharded(BLOOM): dtype=dtype, device=device, decode_buffer=1, + rank=rank, + world_size=world_size, ) @staticmethod diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 26a9a661..610dc4e2 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -549,7 +549,7 @@ class CausalLM(Model): ) in enumerate(iterator): # Select next token next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits + all_input_ids.view(1, -1), logits[-1:, :] ) # Append next token to all tokens @@ -569,54 +569,60 @@ class CausalLM(Model): next_token_text, ) - if stop: - # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - # Keep request in the batch - generated_text = None + if not stop: stopped = False - # Prefill - if stopping_criteria.current_tokens == 1: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + logprobs.gather( - 1, all_input_ids[1:] - ).squeeze(1)[-new_input_length:-1].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = PrefillTokens( - prefill_token_ids, prefill_logprobs, prefill_texts - ) - else: - prefill_tokens = None + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text = self.decode( + all_input_ids[-stopping_criteria.current_tokens :, 0] + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None - generation = Generation( - request.id, - prefill_tokens, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - next_token_id_squeezed.item() in self.all_special_ids, - generated_text, - ) + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None - generations.append(generation) + # Prefill + if stopping_criteria.current_tokens == 1: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, prefill_logprobs, prefill_texts + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id_squeezed, + next_token_logprob, + next_token_text, + next_token_id_squeezed.item() in self.all_special_ids, + generated_text, + ) + + generations.append(generation) # Update values batch.input_ids[i, 0] = next_token_id diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 1293124a..6ae869db 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -622,10 +622,8 @@ class FlashLlamaForCausalLM(torch.nn.Module): self.process_group = process_group if self.process_group is not None: self.world_size = self.process_group.size() - self.rank = self.process_group.rank() else: self.world_size = 1 - self.rank = 0 self.model = FlashLlamaModel(config, process_group) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index ae1465ab..d706df33 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -685,10 +685,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): self.process_group = process_group if self.process_group is not None: self.world_size = self.process_group.size() - self.rank = self.process_group.rank() else: self.world_size = 1 - self.rank = 0 self.gpt_neox = FlashGPTNeoXModel(config, process_group) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b51a3dc6..e862cfeb 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -687,53 +687,59 @@ class FlashCausalLM(Model): next_token_text, ) - if stop: - # Decode generated tokens - output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :] - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: + if not stop: stopped = False - generated_text = None - # Prefill - if prefill: - # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - start_index : end_index - 1 - ] - prefill_token_ids = all_input_ids[:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text = self.decode( + all_input_ids[-stopping_criteria.current_tokens :] + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None + + # Prefill + if prefill: + # Remove generated token to only have prefill and add nan for first prompt token + request_prefill_logprobs = [float("nan")] + prefill_logprobs[ + start_index : end_index - 1 + ] + prefill_token_ids = all_input_ids[:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, request_prefill_logprobs, prefill_texts + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id, + next_token_logprob, + next_token_text, + next_token_id in self.all_special_ids, + generated_text, ) - prefill_tokens = PrefillTokens( - prefill_token_ids, request_prefill_logprobs, prefill_texts - ) - else: - prefill_tokens = None - generation = Generation( - request.id, - prefill_tokens, - next_token_id, - next_token_logprob, - next_token_text, - next_token_id in self.all_special_ids, - generated_text, - ) + generations.append(generation) - generations.append(generation) new_input_length = input_length + 1 # Update values diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index e4426771..a3ba2084 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -157,10 +157,10 @@ class FlashLlamaSharded(FlashLlama): self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): self.past_pad = None - self.process_group, self.rank, self.world_size = initialize_torch_distributed() - self.master = self.rank == 0 + self.process_group, rank, world_size = initialize_torch_distributed() + self.master = rank == 0 if torch.cuda.is_available(): - device = torch.device(f"cuda:{self.rank}") + device = torch.device(f"cuda:{rank}") dtype = torch.float16 else: raise NotImplementedError("FlashLlama is only available on GPU") @@ -190,8 +190,8 @@ class FlashLlamaSharded(FlashLlama): quantize=quantize, device=device, dtype=dtype, - rank=self.rank, - world_size=self.world_size, + rank=rank, + world_size=world_size, ) self.model = model.eval().to(device) torch.distributed.barrier(group=self.process_group) @@ -200,6 +200,8 @@ class FlashLlamaSharded(FlashLlama): requires_padding=False, dtype=dtype, device=device, + rank=rank, + world_size=world_size, ) @staticmethod diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index f439e812..b3e1876f 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -34,10 +34,10 @@ class FlashNeoXSharded(FlashNeoX): self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): self.past_pad = None - self.process_group, self.rank, self.world_size = initialize_torch_distributed() - self.master = self.rank == 0 + self.process_group, rank, world_size = initialize_torch_distributed() + self.master = rank == 0 if torch.cuda.is_available(): - device = torch.device(f"cuda:{self.rank}") + device = torch.device(f"cuda:{rank}") dtype = torch.float16 else: raise NotImplementedError("FlashNeoX is only available on GPU") @@ -64,8 +64,8 @@ class FlashNeoXSharded(FlashNeoX): quantize=quantize, device=device, dtype=dtype, - rank=self.rank, - world_size=self.world_size, + rank=rank, + world_size=world_size, ) self.model = model.eval().to(device) torch.distributed.barrier(group=self.process_group) @@ -74,6 +74,8 @@ class FlashNeoXSharded(FlashNeoX): requires_padding=False, dtype=dtype, device=device, + rank=rank, + world_size=world_size, ) @staticmethod diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index f0825ab9..6e5c231e 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -174,10 +174,10 @@ class FlashSantacoderSharded(FlashSantacoder): self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): self.past_pad = None - self.process_group, self.rank, self.world_size = initialize_torch_distributed() - self.master = self.rank == 0 + self.process_group, rank, world_size = initialize_torch_distributed() + self.master = rank == 0 if torch.cuda.is_available(): - device = torch.device(f"cuda:{self.rank}") + device = torch.device(f"cuda:{rank}") dtype = torch.float16 else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") @@ -204,8 +204,8 @@ class FlashSantacoderSharded(FlashSantacoder): quantize=quantize, device=device, dtype=dtype, - rank=self.rank, - world_size=self.world_size, + rank=rank, + world_size=world_size, transpose=config.architectures[0].startswith("GPT2"), ) self.model = model.eval().to(device) @@ -215,6 +215,8 @@ class FlashSantacoderSharded(FlashSantacoder): requires_padding=False, dtype=dtype, device=device, + rank=rank, + world_size=world_size, ) @staticmethod diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 2577f1b1..57df0bab 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -195,10 +195,10 @@ class GalacticaSharded(Galactica): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): - self.process_group, self.rank, self.world_size = initialize_torch_distributed() - self.master = self.rank == 0 + self.process_group, rank, world_size = initialize_torch_distributed() + self.master = rank == 0 if torch.cuda.is_available(): - device = torch.device(f"cuda:{self.rank}") + device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") @@ -226,8 +226,8 @@ class GalacticaSharded(Galactica): quantize=quantize, device=device, dtype=dtype, - rank=self.rank, - world_size=self.world_size, + rank=rank, + world_size=world_size, ) self.model = model.eval() torch.distributed.barrier(group=self.process_group) @@ -236,6 +236,8 @@ class GalacticaSharded(Galactica): requires_padding=True, dtype=dtype, device=device, + rank=rank, + world_size=world_size, ) @staticmethod diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index e73a3c82..c71bf366 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -34,10 +34,10 @@ class GPTNeoxSharded(CausalLM): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): - self.process_group, self.rank, self.world_size = initialize_torch_distributed() - self.master = self.rank == 0 + self.process_group, rank, world_size = initialize_torch_distributed() + self.master = rank == 0 if torch.cuda.is_available(): - device = torch.device(f"cuda:{self.rank}") + device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") @@ -65,8 +65,8 @@ class GPTNeoxSharded(CausalLM): quantize=quantize, device=device, dtype=dtype, - rank=self.rank, - world_size=self.world_size, + rank=rank, + world_size=world_size, ) self.model = model.eval() torch.distributed.barrier(group=self.process_group) @@ -75,6 +75,8 @@ class GPTNeoxSharded(CausalLM): requires_padding=True, dtype=dtype, device=device, + rank=rank, + world_size=world_size, ) @staticmethod diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index bfae829d..4c85c952 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -18,6 +18,8 @@ class Model(ABC): dtype: torch.dtype, device: torch.device, decode_buffer: int = 3, + rank: int = 0, + world_size: int = 1, ): if decode_buffer < 1: raise ValueError("decode_buffer must be >= 1") @@ -28,6 +30,8 @@ class Model(ABC): self.dtype = dtype self.device = device self.decode_buffer = decode_buffer + self.rank = rank + self.world_size = world_size @property def info(self) -> InfoResponse: diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 50e5271e..cdc32c56 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -50,10 +50,10 @@ class OPTSharded(OPT): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): - self.process_group, self.rank, self.world_size = initialize_torch_distributed() - self.master = self.rank == 0 + self.process_group, rank, world_size = initialize_torch_distributed() + self.master = rank == 0 if torch.cuda.is_available(): - device = torch.device(f"cuda:{self.rank}") + device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") @@ -81,8 +81,8 @@ class OPTSharded(OPT): quantize=quantize, device=device, dtype=dtype, - rank=self.rank, - world_size=self.world_size, + rank=rank, + world_size=world_size, ) self.model = model.eval() torch.distributed.barrier(group=self.process_group) @@ -91,6 +91,8 @@ class OPTSharded(OPT): requires_padding=True, dtype=dtype, device=device, + rank=rank, + world_size=world_size, ) @staticmethod diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 4ac5ed3c..d4a0ddcc 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -631,7 +631,7 @@ class Seq2SeqLM(Model): ) in enumerate(iterator): # Select next token next_token_id, logprobs = next_token_chooser( - all_decoder_input_ids.view(1, -1), logits + all_decoder_input_ids.view(1, -1), logits[-1:, :] ) # Append next token to decoder tokens @@ -650,46 +650,52 @@ class Seq2SeqLM(Model): # Evaluate stopping criteria stop, reason = stopping_criteria(next_token_id, next_token_text) - if stop: - # Slice with decoder_input_length to remove padding - # Decode all tokens - output_text = self.decode(all_decoder_input_ids[-decoder_input_length:]) - - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - # Keep request in the batch - generated_text = None + if not stop: stopped = False - # Prefill - if stopping_criteria.current_tokens == 1: - prefill_tokens = PrefillTokens( - [self.tokenizer.bos_token_id], - [float("nan")], - [self.tokenizer.bos_token], + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Slice with decoder_input_length to remove padding + # Decode all tokens + output_text = self.decode( + all_decoder_input_ids[-decoder_input_length:] + ) + + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None + + # Prefill + if stopping_criteria.current_tokens == 1: + prefill_tokens = PrefillTokens( + [self.tokenizer.bos_token_id], + [float("nan")], + [self.tokenizer.bos_token], + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id_squeezed, + next_token_logprob, + next_token_text, + next_token_id_squeezed.item() in self.all_special_ids, + generated_text, ) - else: - prefill_tokens = None - generation = Generation( - request.id, - prefill_tokens, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - next_token_id_squeezed.item() in self.all_special_ids, - generated_text, - ) - - generations.append(generation) + generations.append(generation) # Update values batch.decoder_input_ids[i] = next_token_id diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 9e8c3c4c..5691c005 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -34,10 +34,10 @@ class T5Sharded(Seq2SeqLM): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): - self.process_group, self.rank, self.world_size = initialize_torch_distributed() - self.master = self.rank == 0 + self.process_group, rank, world_size = initialize_torch_distributed() + self.master = rank == 0 if torch.cuda.is_available(): - device = torch.device(f"cuda:{self.rank}") + device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") @@ -65,8 +65,8 @@ class T5Sharded(Seq2SeqLM): quantize=quantize, device=device, dtype=dtype, - rank=self.rank, - world_size=self.world_size, + rank=rank, + world_size=world_size, ) self.model = model.eval() torch.distributed.barrier(group=self.process_group) @@ -75,6 +75,8 @@ class T5Sharded(Seq2SeqLM): requires_padding=True, dtype=dtype, device=device, + rank=rank, + world_size=world_size, ) @staticmethod diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 23f504c6..d5a77170 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -75,11 +75,7 @@ class NextTokenChooser: def __call__(self, input_ids, scores): # Warp logits - if scores.shape[0] > 1: - # only warp the last token logits - scores[-1:, :] = self.warpers(input_ids, scores[-1:, :]) - else: - scores = self.warpers(input_ids, scores) + scores = self.warpers(input_ids, scores) # Compute logprobs logprobs = torch.log_softmax(scores, -1)