fix(batching): Avoid theoretical hang in batcher loop (#5)
- Avoid theoretical hang in batcher loop - Avoid a couple of clones in the router generate method - Keep attention mask tensors as integers - Remove num_heads attribute Co-authored-by: OlivierDehaene <Olivier.dehaene@gmail.com>
This commit is contained in:
parent
daa1d81d5e
commit
31d76e238d
|
@ -104,10 +104,9 @@ async fn batching_task(
|
|||
// Get the next batch from the DB
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the DB
|
||||
let mut waiting_tokens = 0;
|
||||
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
|
||||
while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
|
||||
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
|
||||
waiting_tokens += 1;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
|
@ -131,11 +130,11 @@ async fn batching_task(
|
|||
if let Some((new_request_ids, new_batch)) =
|
||||
db.next_batch(min_size, max_batch_size)
|
||||
{
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 0;
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch =
|
||||
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
||||
|
|
|
@ -90,11 +90,7 @@ async fn generate(
|
|||
// Validate request
|
||||
let (input_length, validated_request) = state
|
||||
.validation
|
||||
// FIXME: can't we get rid of the cloning here??
|
||||
.validate(GenerateRequest {
|
||||
inputs: req.inputs.clone(),
|
||||
parameters: req.parameters.clone(),
|
||||
})
|
||||
.validate(req.0)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
tracing::error!("{}", err.to_string());
|
||||
|
|
|
@ -155,7 +155,7 @@ type ValidationRequest = (
|
|||
pub enum ValidationError {
|
||||
#[error("temperature must be strictly positive")]
|
||||
Temperature,
|
||||
#[error("top_p must be >= 0.0 or < 1.0")]
|
||||
#[error("top_p must be > 0.0 and <= 1.0")]
|
||||
TopP,
|
||||
#[error("top_k must be strictly positive")]
|
||||
TopK,
|
||||
|
|
|
@ -82,7 +82,6 @@ class BLOOMSharded(CausalLM):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
num_heads=config.n_head // self.process_group.size(),
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
|
|
@ -251,7 +251,6 @@ class CausalLM(Model):
|
|||
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
num_heads=self.model.config.num_attention_heads,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
@ -358,7 +357,7 @@ class CausalLM(Model):
|
|||
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
|
||||
next_batch_past_key_values = [
|
||||
[
|
||||
t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices]
|
||||
t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
|
||||
for t in layer
|
||||
]
|
||||
for layer in past
|
||||
|
@ -381,7 +380,7 @@ class CausalLM(Model):
|
|||
next_batch_attention_mask = torch.cat(
|
||||
[
|
||||
next_batch_attention_mask,
|
||||
torch.ones((next_batch_size, 1)).to(self.device),
|
||||
next_batch_attention_mask.new_ones(next_batch_size, 1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
|
|
@ -185,7 +185,6 @@ class GalacticaSharded(Galactica):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
num_heads=config.num_attention_heads // self.process_group.size(),
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
|
|
@ -10,9 +10,8 @@ B = TypeVar("B", bound=Batch)
|
|||
|
||||
|
||||
class Model(ABC):
|
||||
def __init__(self, tokenizer: Tokenizer, num_heads: int, device: torch.device):
|
||||
def __init__(self, tokenizer: Tokenizer, device: torch.device):
|
||||
self.tokenizer = tokenizer
|
||||
self.num_heads = num_heads
|
||||
self.device = device
|
||||
|
||||
@property
|
||||
|
|
|
@ -87,7 +87,7 @@ class Seq2SeqLMBatch:
|
|||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
||||
).to(device)
|
||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||
decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1)
|
||||
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
|
@ -319,7 +319,6 @@ class Seq2SeqLM(Model):
|
|||
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
num_heads=self.model.config.num_attention_heads,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
@ -499,7 +498,7 @@ class Seq2SeqLM(Model):
|
|||
next_batch_decoder_attention_mask = torch.cat(
|
||||
[
|
||||
next_batch_decoder_attention_mask,
|
||||
torch.ones((next_batch_size, 1)).to(self.device),
|
||||
next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue