fix: adjust test to only run on cuda

This commit is contained in:
David Holtz 2024-10-09 20:02:29 +00:00
parent 541c476492
commit c396c54231
2 changed files with 12 additions and 14 deletions

View File

@ -1,32 +1,30 @@
import pytest
import torch import torch
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.models.globals import set_adapter_to_index from text_generation_server.models.globals import set_adapter_to_index
from text_generation_server.utils.import_utils import SYSTEM, empty_cache, synchronize from text_generation_server.utils.import_utils import SYSTEM, empty_cache, synchronize
from unittest.mock import Mock from unittest.mock import Mock
import base64 import base64
if SYSTEM == "cuda":
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
set_adapter_to_index({}) set_adapter_to_index({})
def test_flash_causal_lm_warmup(): if SYSTEM == "cuda":
if SYSTEM == "cuda": def test_flash_causal_lm_warmup():
flash_causal_lm_warmup() flash_causal_lm_warmup()
else:
pytest.skip("Test only runs on CUDA")
def flash_causal_lm_warmup(): def flash_causal_lm_warmup():
revision = "main" revision = "main"

View File

@ -398,7 +398,7 @@ class HeterogeneousNextTokenChooser:
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0: if speculate and speculate > 0:
if speculative_scores is not None: if speculative_scores is not None:
# Medusa provided some scores # Medusa provided some scores
speculative_ids = Greedy()(speculative_scores) speculative_ids = Greedy()(speculative_scores)