diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index 0585f1fb..d3f2d766 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -50,19 +50,39 @@ def test_batch_top_tokens(): top_n_tokens = [0, 2, 3, 4, 5] top_n_tokens_tensor = torch.tensor(top_n_tokens) inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5) + accepted_ids = torch.ones_like(top_n_tokens_tensor) topn_tok_ids, topn_tok_logprobs = batch_top_tokens( - top_n_tokens, top_n_tokens_tensor, inp_logprobs + top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids ) - assert topn_tok_ids[0] == [] - assert topn_tok_ids[1] == [0, 3] - assert topn_tok_ids[2] == [0, 3, 1, 4] - assert topn_tok_ids[3] == [0, 3, 1, 4] - assert topn_tok_ids[4] == [0, 3, 1, 4, 2] + assert topn_tok_ids[0] == [[]] + assert topn_tok_ids[1] == [[0, 3]] + assert topn_tok_ids[2] == [[0, 3, 1, 4]] + assert topn_tok_ids[3] == [[0, 3, 1, 4]] + assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]] - assert topn_tok_logprobs[0] == [] - assert topn_tok_logprobs[1] == [-1, -2] - assert topn_tok_logprobs[2] == [-1, -2, -3, -3] - assert topn_tok_logprobs[3] == [-1, -2, -3, -3] - assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4] + assert topn_tok_logprobs[0] == [[]] + assert topn_tok_logprobs[1] == [[-1, -2]] + assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]] + assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]] + assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]] + + # Now let's make second member of the batch be speculated + inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2) + accepted_ids[1] = 2 + topn_tok_ids, topn_tok_logprobs = batch_top_tokens( + top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids + ) + + assert topn_tok_ids[0] == [[]] + assert topn_tok_ids[1] == [[0, 3], [0, 3]] + assert topn_tok_ids[2] == [[0, 3, 1, 4]] + assert topn_tok_ids[3] == [[0, 3, 1, 4]] + assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]] + + assert topn_tok_logprobs[0] == [[]] + assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]] + assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]] + assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]] + assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]] diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 7b10256c..29e9f8b1 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -580,10 +580,13 @@ class CausalLM(Model): generations: List[Generation] = [] stopped = True + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, torch.log_softmax(logits[:, -1], -1), + accepted_ids, ) start_decode = time.time_ns() @@ -692,20 +695,24 @@ class CausalLM(Model): prefill_tokens = None if top_n_tokens > 0: - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) + all_top_tokens = [] + for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens else: top_tokens = None diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 930082cd..53a3d582 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -842,6 +842,8 @@ class FlashCausalLM(Model): else: next_token_logits = out + + speculate = get_speculate() ( next_input_ids, next_token_logprobs, @@ -851,16 +853,15 @@ class FlashCausalLM(Model): ) = batch.next_token_chooser( batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, - get_speculate(), + speculate, batch.speculative_ids, speculative_logits, ) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs + batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) - speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1] if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs @@ -1062,20 +1063,24 @@ class FlashCausalLM(Model): prefill_tokens = None if top_n_tokens > 0: - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) + all_top_tokens = [] + for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens else: top_tokens = None diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index f2e4cec6..8b93aecd 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -640,10 +640,13 @@ class Seq2SeqLM(Model): batch.past_key_values, ) + # Speculation is not active for seq2seq + accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0] batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, torch.log_softmax(logits[:, -1], -1), + accepted_ids, ) start_decode = time.time_ns() @@ -746,20 +749,24 @@ class Seq2SeqLM(Model): prefill_tokens = None if top_n_tokens > 0: - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) + all_top_tokens = [] + for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens else: top_tokens = None diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index f85f27e5..bc68812e 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -95,5 +95,5 @@ class Generation: generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, - top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None, + top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None, ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 04cc8d97..270a6990 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -277,7 +277,8 @@ class HeterogeneousNextTokenChooser: scores[:, j] = _scores next_ids[:, j] = _next_ids next_ids = next_ids.view(B * S) - scores = scores.view(B * S, -1) + allscores = scores.view(B * S, -1) + alllogprobs = torch.log_softmax(allscores, -1) if speculated_ids is not None: accepted_ids = [] @@ -305,16 +306,17 @@ class HeterogeneousNextTokenChooser: accepted_ids, device=input_ids.device, dtype=input_ids.dtype ) next_ids = next_ids[indices] - scores = scores[indices] + logprobs = alllogprobs[indices] indices = torch.arange(B, device=input_ids.device) * S if speculative_scores is not None: speculative_scores = speculative_scores[indices + accepted_ids - 1] else: accepted_ids = torch.ones_like(next_ids) + logprobs = alllogprobs - logprobs = torch.log_softmax(scores, -1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) + if speculate > 0: if speculative_scores is not None: # Medusa provided some scores @@ -327,7 +329,7 @@ class HeterogeneousNextTokenChooser: else: speculative_ids = None - return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids + return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids def filter(self, indices): if self.watermark_processor is not None: @@ -436,8 +438,8 @@ class HeterogeneousSampling: def batch_top_tokens( - top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor -) -> Tuple[List[List[int]], List[List[float]]]: + top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor +) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: """Find the top n most likely tokens for a batch of generations. When multiple tokens have equal probabilities and they don't all fit, the @@ -446,14 +448,19 @@ def batch_top_tokens( max_top_n = max(top_n_tokens) # Early exit when top_n_tokens is not used if max_top_n == 0: - return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens) + return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens) + + batch_size = accepted_ids.shape[0] + speculate_size = logprobs.shape[0] // batch_size + top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size) # Ensure top_n doesn't exceed vocab size - top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens] + top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)] # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 # Sorted topk is faster than torch.sort() since we only need a small subset - sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values + sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values + nth_highest = torch.gather( sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1) ) @@ -471,13 +478,33 @@ def batch_top_tokens( top_indices = top_k.indices.tolist() top_values = top_k.values.tolist() - return ( - [ - idxs[:n] if req_n > 0 else [] - for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens) - ], - [ - vals[:n] if req_n > 0 else [] - for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens) - ], - ) + batch_top_token_ids = [] + batch_top_token_logprobs = [] + accepted_ids_list = accepted_ids.tolist() + for i, n_accepted_ids in enumerate(accepted_ids_list): + start = speculate_size * i + stop = speculate_size * (i + 1) + _top_indices = top_indices[start: stop] + _top_values = top_values[start: stop] + _top_n_ishes = top_n_ishes[start: stop] + _top_n_tokens = top_n_tokens[start: stop] + + _top_indices = _top_indices[:n_accepted_ids] + _top_values = _top_values[:n_accepted_ids] + _top_n_ishes = _top_n_ishes[:n_accepted_ids] + _top_n_tokens = _top_n_tokens[:n_accepted_ids] + + row_top_token_ids = [] + row_top_token_logprobs = [] + + for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens): + indices = idxs[:n] if req_n > 0 else [] + values = vals[:n] if req_n > 0 else [] + + row_top_token_ids.append(indices) + row_top_token_logprobs.append(values) + + batch_top_token_ids.append(row_top_token_ids) + batch_top_token_logprobs.append(row_top_token_logprobs) + + return batch_top_token_ids, batch_top_token_logprobs