fix: pass model_id for all flash causal lms

This commit is contained in:
drbh 2024-06-06 21:02:03 +00:00
parent 73eb2ae255
commit 88bd5c2c92
13 changed files with 15 additions and 1 deletions

View File

@ -953,12 +953,13 @@ class FlashCausalLM(Model):
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
batch_num_blocks = batch.num_blocks if batch is not None else 0
num_blocks = (
# Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch.num_blocks
+ batch_num_blocks
)
del batch

View File

@ -62,6 +62,7 @@ class FlashCohere(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashCohere, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),

View File

@ -87,6 +87,7 @@ class FlashDbrx(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashDbrx, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),

View File

@ -62,6 +62,7 @@ class FlashGemma(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),

View File

@ -65,6 +65,7 @@ class FlashGPT2(FlashCausalLM):
model = FlashGPT2ForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashGPT2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),

View File

@ -79,6 +79,7 @@ class BaseFlashMistral(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=num_layers,
@ -110,6 +111,7 @@ class FlashMistral(BaseFlashMistral):
trust_remote_code: bool = False,
):
super(FlashMistral, self).__init__(
model_id=model_id,
config_cls=MistralConfig,
model_cls=FlashMistralForCausalLM,
model_id=model_id,

View File

@ -20,6 +20,7 @@ class FlashMixtral(BaseFlashMistral):
trust_remote_code: bool = False,
):
super(FlashMixtral, self).__init__(
model_id=model_id,
config_cls=MixtralConfig,
model_cls=FlashMixtralForCausalLM,
model_id=model_id,

View File

@ -65,6 +65,7 @@ class FlashNeoXSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashNeoXSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.gpt_neox.layers),

View File

@ -91,6 +91,7 @@ class FlashPhi(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashPhi, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),

View File

@ -71,6 +71,7 @@ class FlashQwen2(BaseFlashMistral):
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),

View File

@ -74,6 +74,7 @@ class FlashRWSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashRWSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),

View File

@ -76,6 +76,7 @@ class FlashSantacoderSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashSantacoderSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),

View File

@ -70,6 +70,7 @@ class FlashStarcoder2(BaseFlashMistral):
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),