From c396c5423142681f076c602d9f2795a2627246dc Mon Sep 17 00:00:00 2001 From: David Holtz Date: Wed, 9 Oct 2024 20:02:29 +0000 Subject: [PATCH] fix: adjust test to only run on cuda --- server/tests/models/test_flash_causal_lm.py | 24 +++++++++---------- server/text_generation_server/utils/tokens.py | 2 +- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/server/tests/models/test_flash_causal_lm.py b/server/tests/models/test_flash_causal_lm.py index 82b9ab80..1e27a859 100644 --- a/server/tests/models/test_flash_causal_lm.py +++ b/server/tests/models/test_flash_causal_lm.py @@ -1,32 +1,30 @@ -import pytest import torch from transformers import AutoTokenizer from text_generation_server.pb import generate_pb2 -from text_generation_server.models.flash_causal_lm import ( - FlashCausalLMBatch, - FlashCausalLM, -) -from text_generation_server.models.custom_modeling.flash_llama_modeling import ( - FlashLlamaForCausalLM, -) from text_generation_server.models.globals import set_adapter_to_index from text_generation_server.utils.import_utils import SYSTEM, empty_cache, synchronize from unittest.mock import Mock import base64 +if SYSTEM == "cuda": + from text_generation_server.models.flash_causal_lm import ( + FlashCausalLMBatch, + FlashCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, + ) + model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" set_adapter_to_index({}) -def test_flash_causal_lm_warmup(): - if SYSTEM == "cuda": +if SYSTEM == "cuda": + def test_flash_causal_lm_warmup(): flash_causal_lm_warmup() - else: - pytest.skip("Test only runs on CUDA") - def flash_causal_lm_warmup(): revision = "main" diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 9ab49665..1fdc4167 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -398,7 +398,7 @@ class HeterogeneousNextTokenChooser: next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) - if speculate > 0: + if speculate and speculate > 0: if speculative_scores is not None: # Medusa provided some scores speculative_ids = Greedy()(speculative_scores)