fix(batching): Avoid theoretical hang in batcher loop (#5)
- Avoid theoretical hang in batcher loop - Avoid a couple of clones in the router generate method - Keep attention mask tensors as integers - Remove num_heads attribute Co-authored-by: OlivierDehaene <Olivier.dehaene@gmail.com>
This commit is contained in:
parent
daa1d81d5e
commit
31d76e238d
|
@ -104,10 +104,9 @@ async fn batching_task(
|
||||||
// Get the next batch from the DB
|
// Get the next batch from the DB
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
// waiting in the DB
|
// waiting in the DB
|
||||||
let mut waiting_tokens = 0;
|
while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
|
||||||
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
|
|
||||||
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
|
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
|
||||||
waiting_tokens += 1;
|
let mut waiting_tokens = 1;
|
||||||
|
|
||||||
// We loop until we do not receive any cached batch from the inference server (== until
|
// We loop until we do not receive any cached batch from the inference server (== until
|
||||||
// all requests have met their stopping criteria)
|
// all requests have met their stopping criteria)
|
||||||
|
@ -131,11 +130,11 @@ async fn batching_task(
|
||||||
if let Some((new_request_ids, new_batch)) =
|
if let Some((new_request_ids, new_batch)) =
|
||||||
db.next_batch(min_size, max_batch_size)
|
db.next_batch(min_size, max_batch_size)
|
||||||
{
|
{
|
||||||
// Reset waiting counter
|
|
||||||
waiting_tokens = 0;
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch =
|
let new_cached_batch =
|
||||||
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
||||||
|
// Reset waiting counter
|
||||||
|
waiting_tokens = 1;
|
||||||
// Extend current batch with the new batch
|
// Extend current batch with the new batch
|
||||||
if let Some(new_cached_batch) = new_cached_batch {
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
||||||
|
|
|
@ -90,11 +90,7 @@ async fn generate(
|
||||||
// Validate request
|
// Validate request
|
||||||
let (input_length, validated_request) = state
|
let (input_length, validated_request) = state
|
||||||
.validation
|
.validation
|
||||||
// FIXME: can't we get rid of the cloning here??
|
.validate(req.0)
|
||||||
.validate(GenerateRequest {
|
|
||||||
inputs: req.inputs.clone(),
|
|
||||||
parameters: req.parameters.clone(),
|
|
||||||
})
|
|
||||||
.await
|
.await
|
||||||
.map_err(|err| {
|
.map_err(|err| {
|
||||||
tracing::error!("{}", err.to_string());
|
tracing::error!("{}", err.to_string());
|
||||||
|
|
|
@ -155,7 +155,7 @@ type ValidationRequest = (
|
||||||
pub enum ValidationError {
|
pub enum ValidationError {
|
||||||
#[error("temperature must be strictly positive")]
|
#[error("temperature must be strictly positive")]
|
||||||
Temperature,
|
Temperature,
|
||||||
#[error("top_p must be >= 0.0 or < 1.0")]
|
#[error("top_p must be > 0.0 and <= 1.0")]
|
||||||
TopP,
|
TopP,
|
||||||
#[error("top_k must be strictly positive")]
|
#[error("top_k must be strictly positive")]
|
||||||
TopK,
|
TopK,
|
||||||
|
|
|
@ -82,7 +82,6 @@ class BLOOMSharded(CausalLM):
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_heads=config.n_head // self.process_group.size(),
|
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -251,7 +251,6 @@ class CausalLM(Model):
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_heads=self.model.config.num_attention_heads,
|
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -358,7 +357,7 @@ class CausalLM(Model):
|
||||||
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
|
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
|
||||||
next_batch_past_key_values = [
|
next_batch_past_key_values = [
|
||||||
[
|
[
|
||||||
t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices]
|
t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
|
||||||
for t in layer
|
for t in layer
|
||||||
]
|
]
|
||||||
for layer in past
|
for layer in past
|
||||||
|
@ -381,7 +380,7 @@ class CausalLM(Model):
|
||||||
next_batch_attention_mask = torch.cat(
|
next_batch_attention_mask = torch.cat(
|
||||||
[
|
[
|
||||||
next_batch_attention_mask,
|
next_batch_attention_mask,
|
||||||
torch.ones((next_batch_size, 1)).to(self.device),
|
next_batch_attention_mask.new_ones(next_batch_size, 1),
|
||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -185,7 +185,6 @@ class GalacticaSharded(Galactica):
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_heads=config.num_attention_heads // self.process_group.size(),
|
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,8 @@ B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
|
|
||||||
class Model(ABC):
|
class Model(ABC):
|
||||||
def __init__(self, tokenizer: Tokenizer, num_heads: int, device: torch.device):
|
def __init__(self, tokenizer: Tokenizer, device: torch.device):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.num_heads = num_heads
|
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -87,7 +87,7 @@ class Seq2SeqLMBatch:
|
||||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
||||||
).to(device)
|
).to(device)
|
||||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||||
decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1)
|
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
|
@ -319,7 +319,6 @@ class Seq2SeqLM(Model):
|
||||||
|
|
||||||
super(Seq2SeqLM, self).__init__(
|
super(Seq2SeqLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_heads=self.model.config.num_attention_heads,
|
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -499,7 +498,7 @@ class Seq2SeqLM(Model):
|
||||||
next_batch_decoder_attention_mask = torch.cat(
|
next_batch_decoder_attention_mask = torch.cat(
|
||||||
[
|
[
|
||||||
next_batch_decoder_attention_mask,
|
next_batch_decoder_attention_mask,
|
||||||
torch.ones((next_batch_size, 1)).to(self.device),
|
next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
|
||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue