fix(batching): Avoid theoretical hang in batcher loop (#5)

- Avoid theoretical hang in batcher loop
- Avoid a couple of clones in the router generate method
- Keep attention mask tensors as integers
- Remove num_heads attribute

Co-authored-by: OlivierDehaene <Olivier.dehaene@gmail.com>
This commit is contained in:
Nick Hill 2022-12-05 01:10:59 -08:00 committed by GitHub
parent daa1d81d5e
commit 31d76e238d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 11 additions and 21 deletions

View File

@ -104,10 +104,9 @@ async fn batching_task(
// Get the next batch from the DB
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB
let mut waiting_tokens = 0;
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
waiting_tokens += 1;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
@ -131,11 +130,11 @@ async fn batching_task(
if let Some((new_request_ids, new_batch)) =
db.next_batch(min_size, max_batch_size)
{
// Reset waiting counter
waiting_tokens = 0;
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));

View File

@ -90,11 +90,7 @@ async fn generate(
// Validate request
let (input_length, validated_request) = state
.validation
// FIXME: can't we get rid of the cloning here??
.validate(GenerateRequest {
inputs: req.inputs.clone(),
parameters: req.parameters.clone(),
})
.validate(req.0)
.await
.map_err(|err| {
tracing::error!("{}", err.to_string());

View File

@ -155,7 +155,7 @@ type ValidationRequest = (
pub enum ValidationError {
#[error("temperature must be strictly positive")]
Temperature,
#[error("top_p must be >= 0.0 or < 1.0")]
#[error("top_p must be > 0.0 and <= 1.0")]
TopP,
#[error("top_k must be strictly positive")]
TopK,

View File

@ -82,7 +82,6 @@ class BLOOMSharded(CausalLM):
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
tokenizer=tokenizer,
num_heads=config.n_head // self.process_group.size(),
device=device,
)

View File

@ -251,7 +251,6 @@ class CausalLM(Model):
super(CausalLM, self).__init__(
tokenizer=tokenizer,
num_heads=self.model.config.num_attention_heads,
device=device,
)
@ -358,7 +357,7 @@ class CausalLM(Model):
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_past_key_values = [
[
t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices]
t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
for t in layer
]
for layer in past
@ -381,7 +380,7 @@ class CausalLM(Model):
next_batch_attention_mask = torch.cat(
[
next_batch_attention_mask,
torch.ones((next_batch_size, 1)).to(self.device),
next_batch_attention_mask.new_ones(next_batch_size, 1),
],
dim=1,
)

View File

@ -185,7 +185,6 @@ class GalacticaSharded(Galactica):
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
tokenizer=tokenizer,
num_heads=config.num_attention_heads // self.process_group.size(),
device=device,
)

View File

@ -10,9 +10,8 @@ B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(self, tokenizer: Tokenizer, num_heads: int, device: torch.device):
def __init__(self, tokenizer: Tokenizer, device: torch.device):
self.tokenizer = tokenizer
self.num_heads = num_heads
self.device = device
@property

View File

@ -87,7 +87,7 @@ class Seq2SeqLMBatch:
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1)
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
return cls(
batch_id=pb.id,
@ -319,7 +319,6 @@ class Seq2SeqLM(Model):
super(Seq2SeqLM, self).__init__(
tokenizer=tokenizer,
num_heads=self.model.config.num_attention_heads,
device=device,
)
@ -499,7 +498,7 @@ class Seq2SeqLM(Model):
next_batch_decoder_attention_mask = torch.cat(
[
next_batch_decoder_attention_mask,
torch.ones((next_batch_size, 1)).to(self.device),
next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
],
dim=1,
)