feat(server): shard token decode (#303)
This commit is contained in:
parent
1585404464
commit
68e9d6ab33
|
@ -23,6 +23,8 @@ pub enum ClientError {
|
|||
Connection(String),
|
||||
#[error("Server error: {0}")]
|
||||
Generation(String),
|
||||
#[error("Sharded results are empty")]
|
||||
EmptyResults,
|
||||
}
|
||||
|
||||
impl From<Status> for ClientError {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/// Multi shard Client
|
||||
use crate::Result;
|
||||
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
|
||||
use crate::{ClientError, Result};
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
@ -98,8 +98,9 @@ impl ShardedClient {
|
|||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
merge_generations(results?)
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
|
@ -116,7 +117,20 @@ impl ShardedClient {
|
|||
.iter_mut()
|
||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
merge_generations(results?)
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge generations from the different model shards
|
||||
fn merge_generations(
|
||||
mut results: Vec<(Vec<Generation>, Option<Batch>)>,
|
||||
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
for (mut shard_generations, _) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
}
|
||||
Ok((generations, next_batch))
|
||||
}
|
||||
|
|
|
@ -63,10 +63,10 @@ class BLOOMSharded(BLOOM):
|
|||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
self.master = rank == 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
@ -94,8 +94,8 @@ class BLOOMSharded(BLOOM):
|
|||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
self.model = model.eval()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -105,6 +105,8 @@ class BLOOMSharded(BLOOM):
|
|||
dtype=dtype,
|
||||
device=device,
|
||||
decode_buffer=1,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -549,7 +549,7 @@ class CausalLM(Model):
|
|||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
all_input_ids.view(1, -1), logits
|
||||
all_input_ids.view(1, -1), logits[-1:, :]
|
||||
)
|
||||
|
||||
# Append next token to all tokens
|
||||
|
@ -569,54 +569,60 @@ class CausalLM(Model):
|
|||
next_token_text,
|
||||
)
|
||||
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
# Keep request in the batch
|
||||
generated_text = None
|
||||
if not stop:
|
||||
stopped = False
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + logprobs.gather(
|
||||
1, all_input_ids[1:]
|
||||
).squeeze(1)[-new_input_length:-1].tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
if i % self.world_size == self.rank:
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
generated_text,
|
||||
)
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
generated_text = None
|
||||
|
||||
generations.append(generation)
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||
logits, -1
|
||||
).gather(1, all_input_ids[1:]).squeeze(1)[
|
||||
-new_input_length:-1
|
||||
].tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
generated_text,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.input_ids[i, 0] = next_token_id
|
||||
|
|
|
@ -622,10 +622,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
self.process_group = process_group
|
||||
if self.process_group is not None:
|
||||
self.world_size = self.process_group.size()
|
||||
self.rank = self.process_group.rank()
|
||||
else:
|
||||
self.world_size = 1
|
||||
self.rank = 0
|
||||
|
||||
self.model = FlashLlamaModel(config, process_group)
|
||||
|
||||
|
|
|
@ -685,10 +685,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||
self.process_group = process_group
|
||||
if self.process_group is not None:
|
||||
self.world_size = self.process_group.size()
|
||||
self.rank = self.process_group.rank()
|
||||
else:
|
||||
self.world_size = 1
|
||||
self.rank = 0
|
||||
|
||||
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
|
||||
|
||||
|
|
|
@ -687,53 +687,59 @@ class FlashCausalLM(Model):
|
|||
next_token_text,
|
||||
)
|
||||
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :]
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
if not stop:
|
||||
stopped = False
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
||||
start_index : end_index - 1
|
||||
]
|
||||
prefill_token_ids = all_input_ids[:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
if i % self.world_size == self.rank:
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :]
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
||||
start_index : end_index - 1
|
||||
]
|
||||
prefill_token_ids = all_input_ids[:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, request_prefill_logprobs, prefill_texts
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id in self.all_special_ids,
|
||||
generated_text,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, request_prefill_logprobs, prefill_texts
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id in self.all_special_ids,
|
||||
generated_text,
|
||||
)
|
||||
generations.append(generation)
|
||||
|
||||
generations.append(generation)
|
||||
new_input_length = input_length + 1
|
||||
|
||||
# Update values
|
||||
|
|
|
@ -157,10 +157,10 @@ class FlashLlamaSharded(FlashLlama):
|
|||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.past_pad = None
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
self.master = rank == 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||
|
@ -190,8 +190,8 @@ class FlashLlamaSharded(FlashLlama):
|
|||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
self.model = model.eval().to(device)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -200,6 +200,8 @@ class FlashLlamaSharded(FlashLlama):
|
|||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -34,10 +34,10 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.past_pad = None
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
self.master = rank == 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||
|
@ -64,8 +64,8 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
self.model = model.eval().to(device)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -74,6 +74,8 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -174,10 +174,10 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.past_pad = None
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
self.master = rank == 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||
|
@ -204,8 +204,8 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
transpose=config.architectures[0].startswith("GPT2"),
|
||||
)
|
||||
self.model = model.eval().to(device)
|
||||
|
@ -215,6 +215,8 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -195,10 +195,10 @@ class GalacticaSharded(Galactica):
|
|||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
self.master = rank == 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
@ -226,8 +226,8 @@ class GalacticaSharded(Galactica):
|
|||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
self.model = model.eval()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -236,6 +236,8 @@ class GalacticaSharded(Galactica):
|
|||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -34,10 +34,10 @@ class GPTNeoxSharded(CausalLM):
|
|||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
self.master = rank == 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
@ -65,8 +65,8 @@ class GPTNeoxSharded(CausalLM):
|
|||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
self.model = model.eval()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -75,6 +75,8 @@ class GPTNeoxSharded(CausalLM):
|
|||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -18,6 +18,8 @@ class Model(ABC):
|
|||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
decode_buffer: int = 3,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
):
|
||||
if decode_buffer < 1:
|
||||
raise ValueError("decode_buffer must be >= 1")
|
||||
|
@ -28,6 +30,8 @@ class Model(ABC):
|
|||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.decode_buffer = decode_buffer
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
@property
|
||||
def info(self) -> InfoResponse:
|
||||
|
|
|
@ -50,10 +50,10 @@ class OPTSharded(OPT):
|
|||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
self.master = rank == 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
@ -81,8 +81,8 @@ class OPTSharded(OPT):
|
|||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
self.model = model.eval()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -91,6 +91,8 @@ class OPTSharded(OPT):
|
|||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -631,7 +631,7 @@ class Seq2SeqLM(Model):
|
|||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
all_decoder_input_ids.view(1, -1), logits
|
||||
all_decoder_input_ids.view(1, -1), logits[-1:, :]
|
||||
)
|
||||
|
||||
# Append next token to decoder tokens
|
||||
|
@ -650,46 +650,52 @@ class Seq2SeqLM(Model):
|
|||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
||||
|
||||
if stop:
|
||||
# Slice with decoder_input_length to remove padding
|
||||
# Decode all tokens
|
||||
output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
# Keep request in the batch
|
||||
generated_text = None
|
||||
if not stop:
|
||||
stopped = False
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
prefill_tokens = PrefillTokens(
|
||||
[self.tokenizer.bos_token_id],
|
||||
[float("nan")],
|
||||
[self.tokenizer.bos_token],
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
if i % self.world_size == self.rank:
|
||||
if stop:
|
||||
# Slice with decoder_input_length to remove padding
|
||||
# Decode all tokens
|
||||
output_text = self.decode(
|
||||
all_decoder_input_ids[-decoder_input_length:]
|
||||
)
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
prefill_tokens = PrefillTokens(
|
||||
[self.tokenizer.bos_token_id],
|
||||
[float("nan")],
|
||||
[self.tokenizer.bos_token],
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
generated_text,
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
generated_text,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.decoder_input_ids[i] = next_token_id
|
||||
|
|
|
@ -34,10 +34,10 @@ class T5Sharded(Seq2SeqLM):
|
|||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
self.master = rank == 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
@ -65,8 +65,8 @@ class T5Sharded(Seq2SeqLM):
|
|||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
self.model = model.eval()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -75,6 +75,8 @@ class T5Sharded(Seq2SeqLM):
|
|||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -75,11 +75,7 @@ class NextTokenChooser:
|
|||
|
||||
def __call__(self, input_ids, scores):
|
||||
# Warp logits
|
||||
if scores.shape[0] > 1:
|
||||
# only warp the last token logits
|
||||
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
|
||||
else:
|
||||
scores = self.warpers(input_ids, scores)
|
||||
scores = self.warpers(input_ids, scores)
|
||||
|
||||
# Compute logprobs
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
|
|
Loading…
Reference in New Issue