feat(server): shard token decode (#303)

This commit is contained in:
OlivierDehaene 2023-05-10 15:48:21 +02:00 committed by GitHub
parent 1585404464
commit 68e9d6ab33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 224 additions and 178 deletions

View File

@ -23,6 +23,8 @@ pub enum ClientError {
Connection(String),
#[error("Server error: {0}")]
Generation(String),
#[error("Sharded results are empty")]
EmptyResults,
}
impl From<Status> for ClientError {

View File

@ -1,6 +1,6 @@
/// Multi shard Client
use crate::Result;
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
use crate::{ClientError, Result};
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
@ -98,8 +98,9 @@ impl ShardedClient {
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
join_all(futures).await.into_iter().collect();
merge_generations(results?)
}
/// Generate one token for each request in the given cached batches
@ -116,7 +117,20 @@ impl ShardedClient {
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
join_all(futures).await.into_iter().collect();
merge_generations(results?)
}
}
/// Merge generations from the different model shards
fn merge_generations(
mut results: Vec<(Vec<Generation>, Option<Batch>)>,
) -> Result<(Vec<Generation>, Option<Batch>)> {
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
for (mut shard_generations, _) in results.into_iter() {
generations.append(&mut shard_generations);
}
Ok((generations, next_batch))
}

View File

@ -63,10 +63,10 @@ class BLOOMSharded(BLOOM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
self.process_group, rank, world_size = initialize_torch_distributed()
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
@ -94,8 +94,8 @@ class BLOOMSharded(BLOOM):
quantize=quantize,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
rank=rank,
world_size=world_size,
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
@ -105,6 +105,8 @@ class BLOOMSharded(BLOOM):
dtype=dtype,
device=device,
decode_buffer=1,
rank=rank,
world_size=world_size,
)
@staticmethod

View File

@ -549,7 +549,7 @@ class CausalLM(Model):
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits
all_input_ids.view(1, -1), logits[-1:, :]
)
# Append next token to all tokens
@ -569,54 +569,60 @@ class CausalLM(Model):
next_token_text,
)
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
if not stop:
stopped = False
# Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids[1:]
).squeeze(1)[-new_input_length:-1].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
)
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
generations.append(generation)
# Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
).gather(1, all_input_ids[1:]).squeeze(1)[
-new_input_length:-1
].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
)
generations.append(generation)
# Update values
batch.input_ids[i, 0] = next_token_id

View File

@ -622,10 +622,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
else:
self.world_size = 1
self.rank = 0
self.model = FlashLlamaModel(config, process_group)

View File

@ -685,10 +685,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
else:
self.world_size = 1
self.rank = 0
self.gpt_neox = FlashGPTNeoXModel(config, process_group)

View File

@ -687,53 +687,59 @@ class FlashCausalLM(Model):
next_token_text,
)
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
if not stop:
stopped = False
generated_text = None
# Prefill
if prefill:
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
start_index : end_index - 1
]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
# Prefill
if prefill:
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
start_index : end_index - 1
]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id,
next_token_logprob,
next_token_text,
next_token_id in self.all_special_ids,
generated_text,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id,
next_token_logprob,
next_token_text,
next_token_id in self.all_special_ids,
generated_text,
)
generations.append(generation)
generations.append(generation)
new_input_length = input_length + 1
# Update values

View File

@ -157,10 +157,10 @@ class FlashLlamaSharded(FlashLlama):
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
self.process_group, rank, world_size = initialize_torch_distributed()
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
else:
raise NotImplementedError("FlashLlama is only available on GPU")
@ -190,8 +190,8 @@ class FlashLlamaSharded(FlashLlama):
quantize=quantize,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
rank=rank,
world_size=world_size,
)
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group)
@ -200,6 +200,8 @@ class FlashLlamaSharded(FlashLlama):
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@staticmethod

View File

@ -34,10 +34,10 @@ class FlashNeoXSharded(FlashNeoX):
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
self.process_group, rank, world_size = initialize_torch_distributed()
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
else:
raise NotImplementedError("FlashNeoX is only available on GPU")
@ -64,8 +64,8 @@ class FlashNeoXSharded(FlashNeoX):
quantize=quantize,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
rank=rank,
world_size=world_size,
)
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group)
@ -74,6 +74,8 @@ class FlashNeoXSharded(FlashNeoX):
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@staticmethod

View File

@ -174,10 +174,10 @@ class FlashSantacoderSharded(FlashSantacoder):
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
self.process_group, rank, world_size = initialize_torch_distributed()
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
@ -204,8 +204,8 @@ class FlashSantacoderSharded(FlashSantacoder):
quantize=quantize,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
rank=rank,
world_size=world_size,
transpose=config.architectures[0].startswith("GPT2"),
)
self.model = model.eval().to(device)
@ -215,6 +215,8 @@ class FlashSantacoderSharded(FlashSantacoder):
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@staticmethod

View File

@ -195,10 +195,10 @@ class GalacticaSharded(Galactica):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
self.process_group, rank, world_size = initialize_torch_distributed()
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
@ -226,8 +226,8 @@ class GalacticaSharded(Galactica):
quantize=quantize,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
rank=rank,
world_size=world_size,
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
@ -236,6 +236,8 @@ class GalacticaSharded(Galactica):
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@staticmethod

View File

@ -34,10 +34,10 @@ class GPTNeoxSharded(CausalLM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
self.process_group, rank, world_size = initialize_torch_distributed()
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
@ -65,8 +65,8 @@ class GPTNeoxSharded(CausalLM):
quantize=quantize,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
rank=rank,
world_size=world_size,
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
@ -75,6 +75,8 @@ class GPTNeoxSharded(CausalLM):
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@staticmethod

View File

@ -18,6 +18,8 @@ class Model(ABC):
dtype: torch.dtype,
device: torch.device,
decode_buffer: int = 3,
rank: int = 0,
world_size: int = 1,
):
if decode_buffer < 1:
raise ValueError("decode_buffer must be >= 1")
@ -28,6 +30,8 @@ class Model(ABC):
self.dtype = dtype
self.device = device
self.decode_buffer = decode_buffer
self.rank = rank
self.world_size = world_size
@property
def info(self) -> InfoResponse:

View File

@ -50,10 +50,10 @@ class OPTSharded(OPT):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
self.process_group, rank, world_size = initialize_torch_distributed()
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
@ -81,8 +81,8 @@ class OPTSharded(OPT):
quantize=quantize,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
rank=rank,
world_size=world_size,
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
@ -91,6 +91,8 @@ class OPTSharded(OPT):
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@staticmethod

View File

@ -631,7 +631,7 @@ class Seq2SeqLM(Model):
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_decoder_input_ids.view(1, -1), logits
all_decoder_input_ids.view(1, -1), logits[-1:, :]
)
# Append next token to decoder tokens
@ -650,46 +650,52 @@ class Seq2SeqLM(Model):
# Evaluate stopping criteria
stop, reason = stopping_criteria(next_token_id, next_token_text)
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
generated_text = None
if not stop:
stopped = False
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_tokens = PrefillTokens(
[self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text = self.decode(
all_decoder_input_ids[-decoder_input_length:]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_tokens = PrefillTokens(
[self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
)
generations.append(generation)
generations.append(generation)
# Update values
batch.decoder_input_ids[i] = next_token_id

View File

@ -34,10 +34,10 @@ class T5Sharded(Seq2SeqLM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
self.process_group, rank, world_size = initialize_torch_distributed()
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
@ -65,8 +65,8 @@ class T5Sharded(Seq2SeqLM):
quantize=quantize,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
rank=rank,
world_size=world_size,
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
@ -75,6 +75,8 @@ class T5Sharded(Seq2SeqLM):
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@staticmethod

View File

@ -75,11 +75,7 @@ class NextTokenChooser:
def __call__(self, input_ids, scores):
# Warp logits
if scores.shape[0] > 1:
# only warp the last token logits
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
else:
scores = self.warpers(input_ids, scores)
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)