Fix (flash) Gemma prefix and enable tests

This commit is contained in:
Daniël de Kok 2024-05-24 15:34:42 +00:00 committed by Daniël de Kok
parent d32e33bd48
commit 9231098f3a
3 changed files with 3 additions and 6 deletions

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module")
def flash_gemma_handle(launcher):
with launcher("gg-hf/gemma-2b", num_shard=1) as handle:
with launcher("google/gemma-2b", num_shard=1) as handle:
yield handle
@ -13,7 +13,6 @@ async def flash_gemma(flash_gemma_handle):
return flash_gemma_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma(flash_gemma, response_snapshot):
@ -25,7 +24,6 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
@ -49,7 +47,6 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):

View File

@ -423,7 +423,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
super().__init__()
embed_norm = config.hidden_size**0.5
if prefix is None:
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"

View File

@ -57,7 +57,7 @@ class FlashGemma(FlashCausalLM):
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
prefix = "language_model"
prefix = ""
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group)