fix: pass model_id for all flash causal lms
This commit is contained in:
parent
73eb2ae255
commit
88bd5c2c92
|
@ -953,12 +953,13 @@ class FlashCausalLM(Model):
|
|||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||
|
||||
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||
batch_num_blocks = batch.num_blocks if batch is not None else 0
|
||||
|
||||
num_blocks = (
|
||||
# Leave 5% for some wiggle room
|
||||
int((free_memory * 0.95) // total_cache_size)
|
||||
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
|
||||
+ batch.num_blocks
|
||||
+ batch_num_blocks
|
||||
)
|
||||
|
||||
del batch
|
||||
|
|
|
@ -62,6 +62,7 @@ class FlashCohere(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCohere, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
|
@ -87,6 +87,7 @@ class FlashDbrx(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashDbrx, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
|
@ -62,6 +62,7 @@ class FlashGemma(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGemma, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
|
@ -65,6 +65,7 @@ class FlashGPT2(FlashCausalLM):
|
|||
model = FlashGPT2ForCausalLM(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGPT2, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
|
@ -79,6 +79,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=num_layers,
|
||||
|
@ -110,6 +111,7 @@ class FlashMistral(BaseFlashMistral):
|
|||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(FlashMistral, self).__init__(
|
||||
model_id=model_id,
|
||||
config_cls=MistralConfig,
|
||||
model_cls=FlashMistralForCausalLM,
|
||||
model_id=model_id,
|
||||
|
|
|
@ -20,6 +20,7 @@ class FlashMixtral(BaseFlashMistral):
|
|||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(FlashMixtral, self).__init__(
|
||||
model_id=model_id,
|
||||
config_cls=MixtralConfig,
|
||||
model_cls=FlashMixtralForCausalLM,
|
||||
model_id=model_id,
|
||||
|
|
|
@ -65,6 +65,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashNeoXSharded, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.gpt_neox.layers),
|
||||
|
|
|
@ -91,6 +91,7 @@ class FlashPhi(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashPhi, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
|
@ -71,6 +71,7 @@ class FlashQwen2(BaseFlashMistral):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(BaseFlashMistral, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
|
@ -74,6 +74,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashRWSharded, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.transformer.h),
|
||||
|
|
|
@ -76,6 +76,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashSantacoderSharded, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.transformer.h),
|
||||
|
|
|
@ -70,6 +70,7 @@ class FlashStarcoder2(BaseFlashMistral):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(BaseFlashMistral, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
Loading…
Reference in New Issue