107 lines
2.8 KiB
Python
107 lines
2.8 KiB
Python
from typing import Dict
|
|
|
|
|
|
# Text Generation Inference Errors
|
|
class ValidationError(Exception):
|
|
def __init__(self, message: str):
|
|
super().__init__(message)
|
|
|
|
|
|
class GenerationError(Exception):
|
|
def __init__(self, message: str):
|
|
super().__init__(message)
|
|
|
|
|
|
class OverloadedError(Exception):
|
|
def __init__(self, message: str):
|
|
super().__init__(message)
|
|
|
|
|
|
class IncompleteGenerationError(Exception):
|
|
def __init__(self, message: str):
|
|
super().__init__(message)
|
|
|
|
|
|
# API Inference Errors
|
|
class BadRequestError(Exception):
|
|
def __init__(self, message: str):
|
|
super().__init__(message)
|
|
|
|
|
|
class ShardNotReadyError(Exception):
|
|
def __init__(self, message: str):
|
|
super().__init__(message)
|
|
|
|
|
|
class ShardTimeoutError(Exception):
|
|
def __init__(self, message: str):
|
|
super().__init__(message)
|
|
|
|
|
|
class NotFoundError(Exception):
|
|
def __init__(self, message: str):
|
|
super().__init__(message)
|
|
|
|
|
|
class RateLimitExceededError(Exception):
|
|
def __init__(self, message: str):
|
|
super().__init__(message)
|
|
|
|
|
|
class NotSupportedError(Exception):
|
|
def __init__(self, model_id: str):
|
|
message = (
|
|
f"Model `{model_id}` is not available for inference with this client. \n"
|
|
"Use `huggingface_hub.inference_api.InferenceApi` instead."
|
|
)
|
|
super(NotSupportedError, self).__init__(message)
|
|
|
|
|
|
# Unknown error
|
|
class UnknownError(Exception):
|
|
def __init__(self, message: str):
|
|
super().__init__(message)
|
|
|
|
|
|
def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
|
|
"""
|
|
Parse error given an HTTP status code and a json payload
|
|
|
|
Args:
|
|
status_code (`int`):
|
|
HTTP status code
|
|
payload (`Dict[str, str]`):
|
|
Json payload
|
|
|
|
Returns:
|
|
Exception: parsed exception
|
|
|
|
"""
|
|
# Try to parse a Text Generation Inference error
|
|
message = payload["error"]
|
|
if "error_type" in payload:
|
|
error_type = payload["error_type"]
|
|
if error_type == "generation":
|
|
return GenerationError(message)
|
|
if error_type == "incomplete_generation":
|
|
return IncompleteGenerationError(message)
|
|
if error_type == "overloaded":
|
|
return OverloadedError(message)
|
|
if error_type == "validation":
|
|
return ValidationError(message)
|
|
|
|
# Try to parse a APIInference error
|
|
if status_code == 400:
|
|
return BadRequestError(message)
|
|
if status_code == 403 or status_code == 424:
|
|
return ShardNotReadyError(message)
|
|
if status_code == 504:
|
|
return ShardTimeoutError(message)
|
|
if status_code == 404:
|
|
return NotFoundError(message)
|
|
if status_code == 429:
|
|
return RateLimitExceededError(message)
|
|
|
|
# Fallback to an unknown error
|
|
return UnknownError(message)
|