diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1fe40c0c..f1a4854f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -670,7 +670,7 @@ class FlashCausalLM(Model): self.device, ) _, batch = self.generate_token(batch) - except Exception as e: + except torch.cuda.OutOfMemoryError as e: raise RuntimeError( f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`"