from typing import Dict # Text Generation Inference Errors class ValidationError(Exception): def __init__(self, message: str): super().__init__(message) class GenerationError(Exception): def __init__(self, message: str): super().__init__(message) class OverloadedError(Exception): def __init__(self, message: str): super().__init__(message) class IncompleteGenerationError(Exception): def __init__(self, message: str): super().__init__(message) # API Inference Errors class BadRequestError(Exception): def __init__(self, message: str): super().__init__(message) class ShardNotReadyError(Exception): def __init__(self, message: str): super().__init__(message) class ShardTimeoutError(Exception): def __init__(self, message: str): super().__init__(message) class NotFoundError(Exception): def __init__(self, message: str): super().__init__(message) class RateLimitExceededError(Exception): def __init__(self, message: str): super().__init__(message) class NotSupportedError(Exception): def __init__(self, model_id: str): message = ( f"Model `{model_id}` is not available for inference with this client. \n" "Use `huggingface_hub.inference_api.InferenceApi` instead." ) super(NotSupportedError, self).__init__(message) # Unknown error class UnknownError(Exception): def __init__(self, message: str): super().__init__(message) def parse_error(status_code: int, payload: Dict[str, str]) -> Exception: """ Parse error given an HTTP status code and a json payload Args: status_code (`int`): HTTP status code payload (`Dict[str, str]`): Json payload Returns: Exception: parsed exception """ # Try to parse a Text Generation Inference error message = payload["error"] if "error_type" in payload: error_type = payload["error_type"] if error_type == "generation": return GenerationError(message) if error_type == "incomplete_generation": return IncompleteGenerationError(message) if error_type == "overloaded": return OverloadedError(message) if error_type == "validation": return ValidationError(message) # Try to parse a APIInference error if status_code == 400: return BadRequestError(message) if status_code == 403 or status_code == 424: return ShardNotReadyError(message) if status_code == 504: return ShardTimeoutError(message) if status_code == 404: return NotFoundError(message) if status_code == 429: return RateLimitExceededError(message) # Fallback to an unknown error return UnknownError(message)