feat(server): shard token decode (#303)

This commit is contained in:
OlivierDehaene 2023-05-10 15:48:21 +02:00 committed by GitHub
parent 1585404464
commit 68e9d6ab33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 224 additions and 178 deletions

View File

@ -23,6 +23,8 @@ pub enum ClientError {
Connection(String), Connection(String),
#[error("Server error: {0}")] #[error("Server error: {0}")]
Generation(String), Generation(String),
#[error("Sharded results are empty")]
EmptyResults,
} }
impl From<Status> for ClientError { impl From<Status> for ClientError {

View File

@ -1,6 +1,6 @@
/// Multi shard Client /// Multi shard Client
use crate::Result;
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
use crate::{ClientError, Result};
use futures::future::join_all; use futures::future::join_all;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
@ -98,8 +98,9 @@ impl ShardedClient {
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone()))) .map(|client| Box::pin(client.prefill(batch.clone())))
.collect(); .collect();
// all shards return the same message let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
join_all(futures).await.pop().unwrap() join_all(futures).await.into_iter().collect();
merge_generations(results?)
} }
/// Generate one token for each request in the given cached batches /// Generate one token for each request in the given cached batches
@ -116,7 +117,20 @@ impl ShardedClient {
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.decode(batches.clone()))) .map(|client| Box::pin(client.decode(batches.clone())))
.collect(); .collect();
// all shards return the same message let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
join_all(futures).await.pop().unwrap() join_all(futures).await.into_iter().collect();
merge_generations(results?)
} }
} }
/// Merge generations from the different model shards
fn merge_generations(
mut results: Vec<(Vec<Generation>, Option<Batch>)>,
) -> Result<(Vec<Generation>, Option<Batch>)> {
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
for (mut shard_generations, _) in results.into_iter() {
generations.append(&mut shard_generations);
}
Ok((generations, next_batch))
}

View File

@ -63,10 +63,10 @@ class BLOOMSharded(BLOOM):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
device = torch.device("cpu") device = torch.device("cpu")
@ -94,8 +94,8 @@ class BLOOMSharded(BLOOM):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval() self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -105,6 +105,8 @@ class BLOOMSharded(BLOOM):
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=1, decode_buffer=1,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -549,7 +549,7 @@ class CausalLM(Model):
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits all_input_ids.view(1, -1), logits[-1:, :]
) )
# Append next token to all tokens # Append next token to all tokens
@ -569,54 +569,60 @@ class CausalLM(Model):
next_token_text, next_token_text,
) )
if stop: if not stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
stopped = False stopped = False
# Prefill # Shard generations
if stopping_criteria.current_tokens == 1: # All generations will be appended in the rust sharded client
# Remove generated token to only have prefill and add nan for first prompt token if i % self.world_size == self.rank:
prefill_logprobs = [float("nan")] + logprobs.gather( if stop:
1, all_input_ids[1:] # Decode generated tokens
).squeeze(1)[-new_input_length:-1].tolist() output_text = self.decode(
prefill_token_ids = all_input_ids[-new_input_length:-1] all_input_ids[-stopping_criteria.current_tokens :, 0]
prefill_texts = self.tokenizer.batch_decode( )
prefill_token_ids, # Get seed
clean_up_tokenization_spaces=False, if isinstance(next_token_chooser.choice, Sampling):
skip_special_tokens=False, seed = next_token_chooser.choice.seed
) else:
prefill_tokens = PrefillTokens( seed = None
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation( generated_text = GeneratedText(
request.id, output_text, stopping_criteria.current_tokens, reason, seed
prefill_tokens, )
next_token_id_squeezed, else:
next_token_logprob, generated_text = None
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
)
generations.append(generation) # Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
).gather(1, all_input_ids[1:]).squeeze(1)[
-new_input_length:-1
].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
)
generations.append(generation)
# Update values # Update values
batch.input_ids[i, 0] = next_token_id batch.input_ids[i, 0] = next_token_id

View File

@ -622,10 +622,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
self.process_group = process_group self.process_group = process_group
if self.process_group is not None: if self.process_group is not None:
self.world_size = self.process_group.size() self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
else: else:
self.world_size = 1 self.world_size = 1
self.rank = 0
self.model = FlashLlamaModel(config, process_group) self.model = FlashLlamaModel(config, process_group)

View File

@ -685,10 +685,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
self.process_group = process_group self.process_group = process_group
if self.process_group is not None: if self.process_group is not None:
self.world_size = self.process_group.size() self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
else: else:
self.world_size = 1 self.world_size = 1
self.rank = 0
self.gpt_neox = FlashGPTNeoXModel(config, process_group) self.gpt_neox = FlashGPTNeoXModel(config, process_group)

View File

@ -687,53 +687,59 @@ class FlashCausalLM(Model):
next_token_text, next_token_text,
) )
if stop: if not stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
stopped = False stopped = False
generated_text = None
# Prefill # Shard generations
if prefill: # All generations will be appended in the rust sharded client
# Remove generated token to only have prefill and add nan for first prompt token if i % self.world_size == self.rank:
request_prefill_logprobs = [float("nan")] + prefill_logprobs[ if stop:
start_index : end_index - 1 # Decode generated tokens
] output_text = self.decode(
prefill_token_ids = all_input_ids[:-1] all_input_ids[-stopping_criteria.current_tokens :]
prefill_texts = self.tokenizer.batch_decode( )
prefill_token_ids, # Get seed
clean_up_tokenization_spaces=False, if isinstance(next_token_chooser.choice, Sampling):
skip_special_tokens=False, seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
# Prefill
if prefill:
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
start_index : end_index - 1
]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id,
next_token_logprob,
next_token_text,
next_token_id in self.all_special_ids,
generated_text,
) )
prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation( generations.append(generation)
request.id,
prefill_tokens,
next_token_id,
next_token_logprob,
next_token_text,
next_token_id in self.all_special_ids,
generated_text,
)
generations.append(generation)
new_input_length = input_length + 1 new_input_length = input_length + 1
# Update values # Update values

View File

@ -157,10 +157,10 @@ class FlashLlamaSharded(FlashLlama):
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.past_pad = None self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")
@ -190,8 +190,8 @@ class FlashLlamaSharded(FlashLlama):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval().to(device) self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -200,6 +200,8 @@ class FlashLlamaSharded(FlashLlama):
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -34,10 +34,10 @@ class FlashNeoXSharded(FlashNeoX):
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.past_pad = None self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16
else: else:
raise NotImplementedError("FlashNeoX is only available on GPU") raise NotImplementedError("FlashNeoX is only available on GPU")
@ -64,8 +64,8 @@ class FlashNeoXSharded(FlashNeoX):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval().to(device) self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -74,6 +74,8 @@ class FlashNeoXSharded(FlashNeoX):
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -174,10 +174,10 @@ class FlashSantacoderSharded(FlashSantacoder):
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.past_pad = None self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16
else: else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU") raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
@ -204,8 +204,8 @@ class FlashSantacoderSharded(FlashSantacoder):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
transpose=config.architectures[0].startswith("GPT2"), transpose=config.architectures[0].startswith("GPT2"),
) )
self.model = model.eval().to(device) self.model = model.eval().to(device)
@ -215,6 +215,8 @@ class FlashSantacoderSharded(FlashSantacoder):
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -195,10 +195,10 @@ class GalacticaSharded(Galactica):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
device = torch.device("cpu") device = torch.device("cpu")
@ -226,8 +226,8 @@ class GalacticaSharded(Galactica):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval() self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -236,6 +236,8 @@ class GalacticaSharded(Galactica):
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -34,10 +34,10 @@ class GPTNeoxSharded(CausalLM):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
device = torch.device("cpu") device = torch.device("cpu")
@ -65,8 +65,8 @@ class GPTNeoxSharded(CausalLM):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval() self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -75,6 +75,8 @@ class GPTNeoxSharded(CausalLM):
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -18,6 +18,8 @@ class Model(ABC):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
decode_buffer: int = 3, decode_buffer: int = 3,
rank: int = 0,
world_size: int = 1,
): ):
if decode_buffer < 1: if decode_buffer < 1:
raise ValueError("decode_buffer must be >= 1") raise ValueError("decode_buffer must be >= 1")
@ -28,6 +30,8 @@ class Model(ABC):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.decode_buffer = decode_buffer self.decode_buffer = decode_buffer
self.rank = rank
self.world_size = world_size
@property @property
def info(self) -> InfoResponse: def info(self) -> InfoResponse:

View File

@ -50,10 +50,10 @@ class OPTSharded(OPT):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
device = torch.device("cpu") device = torch.device("cpu")
@ -81,8 +81,8 @@ class OPTSharded(OPT):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval() self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -91,6 +91,8 @@ class OPTSharded(OPT):
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -631,7 +631,7 @@ class Seq2SeqLM(Model):
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
all_decoder_input_ids.view(1, -1), logits all_decoder_input_ids.view(1, -1), logits[-1:, :]
) )
# Append next token to decoder tokens # Append next token to decoder tokens
@ -650,46 +650,52 @@ class Seq2SeqLM(Model):
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria(next_token_id, next_token_text) stop, reason = stopping_criteria(next_token_id, next_token_text)
if stop: if not stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
stopped = False stopped = False
# Prefill # Shard generations
if stopping_criteria.current_tokens == 1: # All generations will be appended in the rust sharded client
prefill_tokens = PrefillTokens( if i % self.world_size == self.rank:
[self.tokenizer.bos_token_id], if stop:
[float("nan")], # Slice with decoder_input_length to remove padding
[self.tokenizer.bos_token], # Decode all tokens
output_text = self.decode(
all_decoder_input_ids[-decoder_input_length:]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_tokens = PrefillTokens(
[self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
) )
else:
prefill_tokens = None
generation = Generation( generations.append(generation)
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
)
generations.append(generation)
# Update values # Update values
batch.decoder_input_ids[i] = next_token_id batch.decoder_input_ids[i] = next_token_id

View File

@ -34,10 +34,10 @@ class T5Sharded(Seq2SeqLM):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
device = torch.device("cpu") device = torch.device("cpu")
@ -65,8 +65,8 @@ class T5Sharded(Seq2SeqLM):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval() self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -75,6 +75,8 @@ class T5Sharded(Seq2SeqLM):
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -75,11 +75,7 @@ class NextTokenChooser:
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
# Warp logits # Warp logits
if scores.shape[0] > 1: scores = self.warpers(input_ids, scores)
# only warp the last token logits
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
else:
scores = self.warpers(input_ids, scores)
# Compute logprobs # Compute logprobs
logprobs = torch.log_softmax(scores, -1) logprobs = torch.log_softmax(scores, -1)