fix: adjust test to only run on cuda
This commit is contained in:
parent
541c476492
commit
c396c54231
|
@ -1,32 +1,30 @@
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.flash_causal_lm import (
|
|
||||||
FlashCausalLMBatch,
|
|
||||||
FlashCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
|
||||||
FlashLlamaForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.globals import set_adapter_to_index
|
from text_generation_server.models.globals import set_adapter_to_index
|
||||||
from text_generation_server.utils.import_utils import SYSTEM, empty_cache, synchronize
|
from text_generation_server.utils.import_utils import SYSTEM, empty_cache, synchronize
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
|
FlashCausalLMBatch,
|
||||||
|
FlashCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
|
FlashLlamaForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
set_adapter_to_index({})
|
set_adapter_to_index({})
|
||||||
|
|
||||||
|
|
||||||
def test_flash_causal_lm_warmup():
|
if SYSTEM == "cuda":
|
||||||
if SYSTEM == "cuda":
|
def test_flash_causal_lm_warmup():
|
||||||
flash_causal_lm_warmup()
|
flash_causal_lm_warmup()
|
||||||
else:
|
|
||||||
pytest.skip("Test only runs on CUDA")
|
|
||||||
|
|
||||||
|
|
||||||
def flash_causal_lm_warmup():
|
def flash_causal_lm_warmup():
|
||||||
revision = "main"
|
revision = "main"
|
||||||
|
|
|
@ -398,7 +398,7 @@ class HeterogeneousNextTokenChooser:
|
||||||
|
|
||||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
|
|
||||||
if speculate > 0:
|
if speculate and speculate > 0:
|
||||||
if speculative_scores is not None:
|
if speculative_scores is not None:
|
||||||
# Medusa provided some scores
|
# Medusa provided some scores
|
||||||
speculative_ids = Greedy()(speculative_scores)
|
speculative_ids = Greedy()(speculative_scores)
|
||||||
|
|
Loading…
Reference in New Issue