feat(server): shard token decode (#303)
This commit is contained in:
parent
1585404464
commit
68e9d6ab33
|
@ -23,6 +23,8 @@ pub enum ClientError {
|
||||||
Connection(String),
|
Connection(String),
|
||||||
#[error("Server error: {0}")]
|
#[error("Server error: {0}")]
|
||||||
Generation(String),
|
Generation(String),
|
||||||
|
#[error("Sharded results are empty")]
|
||||||
|
EmptyResults,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Status> for ClientError {
|
impl From<Status> for ClientError {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::Result;
|
|
||||||
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
|
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
|
||||||
|
use crate::{ClientError, Result};
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
@ -98,8 +98,9 @@ impl ShardedClient {
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
// all shards return the same message
|
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
|
||||||
join_all(futures).await.pop().unwrap()
|
join_all(futures).await.into_iter().collect();
|
||||||
|
merge_generations(results?)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given cached batches
|
/// Generate one token for each request in the given cached batches
|
||||||
|
@ -116,7 +117,20 @@ impl ShardedClient {
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
// all shards return the same message
|
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
|
||||||
join_all(futures).await.pop().unwrap()
|
join_all(futures).await.into_iter().collect();
|
||||||
|
merge_generations(results?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Merge generations from the different model shards
|
||||||
|
fn merge_generations(
|
||||||
|
mut results: Vec<(Vec<Generation>, Option<Batch>)>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||||
|
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
for (mut shard_generations, _) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch))
|
||||||
|
}
|
||||||
|
|
|
@ -63,10 +63,10 @@ class BLOOMSharded(BLOOM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{self.rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -94,8 +94,8 @@ class BLOOMSharded(BLOOM):
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
rank=self.rank,
|
rank=rank,
|
||||||
world_size=self.world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -105,6 +105,8 @@ class BLOOMSharded(BLOOM):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
decode_buffer=1,
|
decode_buffer=1,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -549,7 +549,7 @@ class CausalLM(Model):
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(
|
next_token_id, logprobs = next_token_chooser(
|
||||||
all_input_ids.view(1, -1), logits
|
all_input_ids.view(1, -1), logits[-1:, :]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
|
@ -569,54 +569,60 @@ class CausalLM(Model):
|
||||||
next_token_text,
|
next_token_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop:
|
if not stop:
|
||||||
# Decode generated tokens
|
|
||||||
output_text = self.decode(
|
|
||||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
|
||||||
)
|
|
||||||
# Get seed
|
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
|
||||||
seed = next_token_chooser.choice.seed
|
|
||||||
else:
|
|
||||||
seed = None
|
|
||||||
|
|
||||||
generated_text = GeneratedText(
|
|
||||||
output_text, stopping_criteria.current_tokens, reason, seed
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Keep request in the batch
|
|
||||||
generated_text = None
|
|
||||||
stopped = False
|
stopped = False
|
||||||
|
|
||||||
# Prefill
|
# Shard generations
|
||||||
if stopping_criteria.current_tokens == 1:
|
# All generations will be appended in the rust sharded client
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
if i % self.world_size == self.rank:
|
||||||
prefill_logprobs = [float("nan")] + logprobs.gather(
|
if stop:
|
||||||
1, all_input_ids[1:]
|
# Decode generated tokens
|
||||||
).squeeze(1)[-new_input_length:-1].tolist()
|
output_text = self.decode(
|
||||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
)
|
||||||
prefill_token_ids,
|
# Get seed
|
||||||
clean_up_tokenization_spaces=False,
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
skip_special_tokens=False,
|
seed = next_token_chooser.choice.seed
|
||||||
)
|
else:
|
||||||
prefill_tokens = PrefillTokens(
|
seed = None
|
||||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prefill_tokens = None
|
|
||||||
|
|
||||||
generation = Generation(
|
generated_text = GeneratedText(
|
||||||
request.id,
|
output_text, stopping_criteria.current_tokens, reason, seed
|
||||||
prefill_tokens,
|
)
|
||||||
next_token_id_squeezed,
|
else:
|
||||||
next_token_logprob,
|
generated_text = None
|
||||||
next_token_text,
|
|
||||||
next_token_id_squeezed.item() in self.all_special_ids,
|
|
||||||
generated_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
generations.append(generation)
|
# Prefill
|
||||||
|
if stopping_criteria.current_tokens == 1:
|
||||||
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
|
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||||
|
logits, -1
|
||||||
|
).gather(1, all_input_ids[1:]).squeeze(1)[
|
||||||
|
-new_input_length:-1
|
||||||
|
].tolist()
|
||||||
|
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||||
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
|
prefill_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
prefill_tokens = PrefillTokens(
|
||||||
|
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefill_tokens = None
|
||||||
|
|
||||||
|
generation = Generation(
|
||||||
|
request.id,
|
||||||
|
prefill_tokens,
|
||||||
|
next_token_id_squeezed,
|
||||||
|
next_token_logprob,
|
||||||
|
next_token_text,
|
||||||
|
next_token_id_squeezed.item() in self.all_special_ids,
|
||||||
|
generated_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_ids[i, 0] = next_token_id
|
batch.input_ids[i, 0] = next_token_id
|
||||||
|
|
|
@ -622,10 +622,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
if self.process_group is not None:
|
if self.process_group is not None:
|
||||||
self.world_size = self.process_group.size()
|
self.world_size = self.process_group.size()
|
||||||
self.rank = self.process_group.rank()
|
|
||||||
else:
|
else:
|
||||||
self.world_size = 1
|
self.world_size = 1
|
||||||
self.rank = 0
|
|
||||||
|
|
||||||
self.model = FlashLlamaModel(config, process_group)
|
self.model = FlashLlamaModel(config, process_group)
|
||||||
|
|
||||||
|
|
|
@ -685,10 +685,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
if self.process_group is not None:
|
if self.process_group is not None:
|
||||||
self.world_size = self.process_group.size()
|
self.world_size = self.process_group.size()
|
||||||
self.rank = self.process_group.rank()
|
|
||||||
else:
|
else:
|
||||||
self.world_size = 1
|
self.world_size = 1
|
||||||
self.rank = 0
|
|
||||||
|
|
||||||
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
|
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
|
||||||
|
|
||||||
|
|
|
@ -687,53 +687,59 @@ class FlashCausalLM(Model):
|
||||||
next_token_text,
|
next_token_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop:
|
if not stop:
|
||||||
# Decode generated tokens
|
|
||||||
output_text = self.decode(
|
|
||||||
all_input_ids[-stopping_criteria.current_tokens :]
|
|
||||||
)
|
|
||||||
# Get seed
|
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
|
||||||
seed = next_token_chooser.choice.seed
|
|
||||||
else:
|
|
||||||
seed = None
|
|
||||||
|
|
||||||
generated_text = GeneratedText(
|
|
||||||
output_text, stopping_criteria.current_tokens, reason, seed
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
stopped = False
|
stopped = False
|
||||||
generated_text = None
|
|
||||||
|
|
||||||
# Prefill
|
# Shard generations
|
||||||
if prefill:
|
# All generations will be appended in the rust sharded client
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
if i % self.world_size == self.rank:
|
||||||
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
if stop:
|
||||||
start_index : end_index - 1
|
# Decode generated tokens
|
||||||
]
|
output_text = self.decode(
|
||||||
prefill_token_ids = all_input_ids[:-1]
|
all_input_ids[-stopping_criteria.current_tokens :]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
)
|
||||||
prefill_token_ids,
|
# Get seed
|
||||||
clean_up_tokenization_spaces=False,
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
skip_special_tokens=False,
|
seed = next_token_chooser.choice.seed
|
||||||
|
else:
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
generated_text = GeneratedText(
|
||||||
|
output_text, stopping_criteria.current_tokens, reason, seed
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
generated_text = None
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if prefill:
|
||||||
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
|
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
||||||
|
start_index : end_index - 1
|
||||||
|
]
|
||||||
|
prefill_token_ids = all_input_ids[:-1]
|
||||||
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
|
prefill_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
prefill_tokens = PrefillTokens(
|
||||||
|
prefill_token_ids, request_prefill_logprobs, prefill_texts
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefill_tokens = None
|
||||||
|
|
||||||
|
generation = Generation(
|
||||||
|
request.id,
|
||||||
|
prefill_tokens,
|
||||||
|
next_token_id,
|
||||||
|
next_token_logprob,
|
||||||
|
next_token_text,
|
||||||
|
next_token_id in self.all_special_ids,
|
||||||
|
generated_text,
|
||||||
)
|
)
|
||||||
prefill_tokens = PrefillTokens(
|
|
||||||
prefill_token_ids, request_prefill_logprobs, prefill_texts
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prefill_tokens = None
|
|
||||||
|
|
||||||
generation = Generation(
|
generations.append(generation)
|
||||||
request.id,
|
|
||||||
prefill_tokens,
|
|
||||||
next_token_id,
|
|
||||||
next_token_logprob,
|
|
||||||
next_token_text,
|
|
||||||
next_token_id in self.all_special_ids,
|
|
||||||
generated_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
generations.append(generation)
|
|
||||||
new_input_length = input_length + 1
|
new_input_length = input_length + 1
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
|
|
@ -157,10 +157,10 @@ class FlashLlamaSharded(FlashLlama):
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
self.past_pad = None
|
self.past_pad = None
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{self.rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
@ -190,8 +190,8 @@ class FlashLlamaSharded(FlashLlama):
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
rank=self.rank,
|
rank=rank,
|
||||||
world_size=self.world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval().to(device)
|
self.model = model.eval().to(device)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -200,6 +200,8 @@ class FlashLlamaSharded(FlashLlama):
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -34,10 +34,10 @@ class FlashNeoXSharded(FlashNeoX):
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
self.past_pad = None
|
self.past_pad = None
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{self.rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||||
|
@ -64,8 +64,8 @@ class FlashNeoXSharded(FlashNeoX):
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
rank=self.rank,
|
rank=rank,
|
||||||
world_size=self.world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval().to(device)
|
self.model = model.eval().to(device)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -74,6 +74,8 @@ class FlashNeoXSharded(FlashNeoX):
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -174,10 +174,10 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
self.past_pad = None
|
self.past_pad = None
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{self.rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||||
|
@ -204,8 +204,8 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
rank=self.rank,
|
rank=rank,
|
||||||
world_size=self.world_size,
|
world_size=world_size,
|
||||||
transpose=config.architectures[0].startswith("GPT2"),
|
transpose=config.architectures[0].startswith("GPT2"),
|
||||||
)
|
)
|
||||||
self.model = model.eval().to(device)
|
self.model = model.eval().to(device)
|
||||||
|
@ -215,6 +215,8 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -195,10 +195,10 @@ class GalacticaSharded(Galactica):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{self.rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -226,8 +226,8 @@ class GalacticaSharded(Galactica):
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
rank=self.rank,
|
rank=rank,
|
||||||
world_size=self.world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -236,6 +236,8 @@ class GalacticaSharded(Galactica):
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -34,10 +34,10 @@ class GPTNeoxSharded(CausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{self.rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -65,8 +65,8 @@ class GPTNeoxSharded(CausalLM):
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
rank=self.rank,
|
rank=rank,
|
||||||
world_size=self.world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -75,6 +75,8 @@ class GPTNeoxSharded(CausalLM):
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -18,6 +18,8 @@ class Model(ABC):
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
decode_buffer: int = 3,
|
decode_buffer: int = 3,
|
||||||
|
rank: int = 0,
|
||||||
|
world_size: int = 1,
|
||||||
):
|
):
|
||||||
if decode_buffer < 1:
|
if decode_buffer < 1:
|
||||||
raise ValueError("decode_buffer must be >= 1")
|
raise ValueError("decode_buffer must be >= 1")
|
||||||
|
@ -28,6 +30,8 @@ class Model(ABC):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.decode_buffer = decode_buffer
|
self.decode_buffer = decode_buffer
|
||||||
|
self.rank = rank
|
||||||
|
self.world_size = world_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def info(self) -> InfoResponse:
|
def info(self) -> InfoResponse:
|
||||||
|
|
|
@ -50,10 +50,10 @@ class OPTSharded(OPT):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{self.rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -81,8 +81,8 @@ class OPTSharded(OPT):
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
rank=self.rank,
|
rank=rank,
|
||||||
world_size=self.world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -91,6 +91,8 @@ class OPTSharded(OPT):
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -631,7 +631,7 @@ class Seq2SeqLM(Model):
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(
|
next_token_id, logprobs = next_token_chooser(
|
||||||
all_decoder_input_ids.view(1, -1), logits
|
all_decoder_input_ids.view(1, -1), logits[-1:, :]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append next token to decoder tokens
|
# Append next token to decoder tokens
|
||||||
|
@ -650,46 +650,52 @@ class Seq2SeqLM(Model):
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
||||||
|
|
||||||
if stop:
|
if not stop:
|
||||||
# Slice with decoder_input_length to remove padding
|
|
||||||
# Decode all tokens
|
|
||||||
output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
|
|
||||||
|
|
||||||
# Get seed
|
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
|
||||||
seed = next_token_chooser.choice.seed
|
|
||||||
else:
|
|
||||||
seed = None
|
|
||||||
|
|
||||||
generated_text = GeneratedText(
|
|
||||||
output_text, stopping_criteria.current_tokens, reason, seed
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Keep request in the batch
|
|
||||||
generated_text = None
|
|
||||||
stopped = False
|
stopped = False
|
||||||
|
|
||||||
# Prefill
|
# Shard generations
|
||||||
if stopping_criteria.current_tokens == 1:
|
# All generations will be appended in the rust sharded client
|
||||||
prefill_tokens = PrefillTokens(
|
if i % self.world_size == self.rank:
|
||||||
[self.tokenizer.bos_token_id],
|
if stop:
|
||||||
[float("nan")],
|
# Slice with decoder_input_length to remove padding
|
||||||
[self.tokenizer.bos_token],
|
# Decode all tokens
|
||||||
|
output_text = self.decode(
|
||||||
|
all_decoder_input_ids[-decoder_input_length:]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get seed
|
||||||
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
seed = next_token_chooser.choice.seed
|
||||||
|
else:
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
generated_text = GeneratedText(
|
||||||
|
output_text, stopping_criteria.current_tokens, reason, seed
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
generated_text = None
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if stopping_criteria.current_tokens == 1:
|
||||||
|
prefill_tokens = PrefillTokens(
|
||||||
|
[self.tokenizer.bos_token_id],
|
||||||
|
[float("nan")],
|
||||||
|
[self.tokenizer.bos_token],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefill_tokens = None
|
||||||
|
|
||||||
|
generation = Generation(
|
||||||
|
request.id,
|
||||||
|
prefill_tokens,
|
||||||
|
next_token_id_squeezed,
|
||||||
|
next_token_logprob,
|
||||||
|
next_token_text,
|
||||||
|
next_token_id_squeezed.item() in self.all_special_ids,
|
||||||
|
generated_text,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
prefill_tokens = None
|
|
||||||
|
|
||||||
generation = Generation(
|
generations.append(generation)
|
||||||
request.id,
|
|
||||||
prefill_tokens,
|
|
||||||
next_token_id_squeezed,
|
|
||||||
next_token_logprob,
|
|
||||||
next_token_text,
|
|
||||||
next_token_id_squeezed.item() in self.all_special_ids,
|
|
||||||
generated_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
generations.append(generation)
|
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.decoder_input_ids[i] = next_token_id
|
batch.decoder_input_ids[i] = next_token_id
|
||||||
|
|
|
@ -34,10 +34,10 @@ class T5Sharded(Seq2SeqLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{self.rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -65,8 +65,8 @@ class T5Sharded(Seq2SeqLM):
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
rank=self.rank,
|
rank=rank,
|
||||||
world_size=self.world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -75,6 +75,8 @@ class T5Sharded(Seq2SeqLM):
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -75,11 +75,7 @@ class NextTokenChooser:
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
# Warp logits
|
# Warp logits
|
||||||
if scores.shape[0] > 1:
|
scores = self.warpers(input_ids, scores)
|
||||||
# only warp the last token logits
|
|
||||||
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
|
|
||||||
else:
|
|
||||||
scores = self.warpers(input_ids, scores)
|
|
||||||
|
|
||||||
# Compute logprobs
|
# Compute logprobs
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
|
|
Loading…
Reference in New Issue