fix: adjust test to only run on cuda
This commit is contained in:
parent
541c476492
commit
c396c54231
|
@ -1,32 +1,30 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
FlashCausalLMBatch,
|
||||
FlashCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.globals import set_adapter_to_index
|
||||
from text_generation_server.utils.import_utils import SYSTEM, empty_cache, synchronize
|
||||
from unittest.mock import Mock
|
||||
import base64
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
FlashCausalLMBatch,
|
||||
FlashCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
|
||||
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
|
||||
set_adapter_to_index({})
|
||||
|
||||
|
||||
def test_flash_causal_lm_warmup():
|
||||
if SYSTEM == "cuda":
|
||||
if SYSTEM == "cuda":
|
||||
def test_flash_causal_lm_warmup():
|
||||
flash_causal_lm_warmup()
|
||||
else:
|
||||
pytest.skip("Test only runs on CUDA")
|
||||
|
||||
|
||||
def flash_causal_lm_warmup():
|
||||
revision = "main"
|
||||
|
|
|
@ -398,7 +398,7 @@ class HeterogeneousNextTokenChooser:
|
|||
|
||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
|
||||
if speculate > 0:
|
||||
if speculate and speculate > 0:
|
||||
if speculative_scores is not None:
|
||||
# Medusa provided some scores
|
||||
speculative_ids = Greedy()(speculative_scores)
|
||||
|
|
Loading…
Reference in New Issue