fix: pass model_id for all causal and seq2seq lms

This commit is contained in:
drbh 2024-06-06 21:13:14 +00:00
parent 88bd5c2c92
commit dc0f76553c
11 changed files with 11 additions and 0 deletions

View File

@ -90,6 +90,7 @@ class BLOOMSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,

View File

@ -538,6 +538,7 @@ class CausalLM(Model):
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,

View File

@ -212,6 +212,7 @@ class GalacticaSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,

View File

@ -65,6 +65,7 @@ class GPTNeoxSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,

View File

@ -90,6 +90,7 @@ class MPTSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,

View File

@ -63,6 +63,7 @@ class OPTSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,

View File

@ -60,6 +60,7 @@ class Phi(CausalLM):
model = PhiForCausalLM(config, weights) model = PhiForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,

View File

@ -62,6 +62,7 @@ class RW(CausalLM):
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,

View File

@ -62,6 +62,7 @@ class SantaCoder(CausalLM):
) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,

View File

@ -575,6 +575,7 @@ class Seq2SeqLM(Model):
tokenizer.bos_token_id = model.config.decoder_start_token_id tokenizer.bos_token_id = model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,

View File

@ -73,6 +73,7 @@ class T5Sharded(Seq2SeqLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,