Fix (flash) Gemma prefix and enable tests
This commit is contained in:
parent
d32e33bd48
commit
9231098f3a
|
@ -3,7 +3,7 @@ import pytest
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_gemma_handle(launcher):
|
||||
with launcher("gg-hf/gemma-2b", num_shard=1) as handle:
|
||||
with launcher("google/gemma-2b", num_shard=1) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
@ -13,7 +13,6 @@ async def flash_gemma(flash_gemma_handle):
|
|||
return flash_gemma_handle.client
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma(flash_gemma, response_snapshot):
|
||||
|
@ -25,7 +24,6 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
|
|||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
||||
|
@ -49,7 +47,6 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
|||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
|
||||
|
|
|
@ -423,7 +423,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
embed_norm = config.hidden_size**0.5
|
||||
if prefix is None:
|
||||
if not prefix:
|
||||
prefix = "model"
|
||||
else:
|
||||
prefix = f"{prefix}.model"
|
||||
|
|
|
@ -57,7 +57,7 @@ class FlashGemma(FlashCausalLM):
|
|||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
# TODO hardcoded
|
||||
prefix = "language_model"
|
||||
prefix = ""
|
||||
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
|
Loading…
Reference in New Issue