Prefer `type(x) is int` to `isinstance(x, int)` (#14945)
* Perfer `type(x) is int` to `isinstance(x, int)` This covered all additional instances I could see where `x` was user-controlled. The remaining cases are ``` $ rg -s 'isinstance.*[^_]int' tests/replication/_base.py 576: if isinstance(obj, int): synapse/util/caches/stream_change_cache.py 136: assert isinstance(stream_pos, int) 214: assert isinstance(stream_pos, int) 246: assert isinstance(stream_pos, int) 267: assert isinstance(stream_pos, int) synapse/replication/tcp/external_cache.py 133: if isinstance(result, int): synapse/metrics/__init__.py 100: if isinstance(calls, (int, float)): synapse/handlers/appservice.py 262: assert isinstance(new_token, int) synapse/config/_util.py 62: if isinstance(p, int): ``` which cover metrics, logic related to `jsonschema`, and replication and data streams. AFAICS these are all internal to Synapse * Changelog
This commit is contained in:
parent
510d4b06e7
commit
796a4b7482
|
@ -0,0 +1 @@
|
||||||
|
Fix various long-standing bugs in Synapse's config, event and request handling where booleans were unintentionally accepted where an integer was expected.
|
|
@ -174,15 +174,29 @@ class Config:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_size(value: Union[str, int]) -> int:
|
def parse_size(value: Union[str, int]) -> int:
|
||||||
if isinstance(value, int):
|
"""Interpret `value` as a number of bytes.
|
||||||
|
|
||||||
|
If an integer is provided it is treated as bytes and is unchanged.
|
||||||
|
|
||||||
|
String byte sizes can have a suffix of 'K' or `M`, representing kibibytes and
|
||||||
|
mebibytes respectively. No suffix is understood as a plain byte count.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError, if given something other than an integer or a string
|
||||||
|
ValueError: if given a string not of the form described above.
|
||||||
|
"""
|
||||||
|
if type(value) is int:
|
||||||
return value
|
return value
|
||||||
sizes = {"K": 1024, "M": 1024 * 1024}
|
elif type(value) is str:
|
||||||
size = 1
|
sizes = {"K": 1024, "M": 1024 * 1024}
|
||||||
suffix = value[-1]
|
size = 1
|
||||||
if suffix in sizes:
|
suffix = value[-1]
|
||||||
value = value[:-1]
|
if suffix in sizes:
|
||||||
size = sizes[suffix]
|
value = value[:-1]
|
||||||
return int(value) * size
|
size = sizes[suffix]
|
||||||
|
return int(value) * size
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Bad byte size {value!r}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_duration(value: Union[str, int]) -> int:
|
def parse_duration(value: Union[str, int]) -> int:
|
||||||
|
@ -198,22 +212,36 @@ class Config:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of milliseconds in the duration.
|
The number of milliseconds in the duration.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError, if given something other than an integer or a string
|
||||||
|
ValueError: if given a string not of the form described above.
|
||||||
"""
|
"""
|
||||||
if isinstance(value, int):
|
if type(value) is int:
|
||||||
return value
|
return value
|
||||||
second = 1000
|
elif type(value) is str:
|
||||||
minute = 60 * second
|
second = 1000
|
||||||
hour = 60 * minute
|
minute = 60 * second
|
||||||
day = 24 * hour
|
hour = 60 * minute
|
||||||
week = 7 * day
|
day = 24 * hour
|
||||||
year = 365 * day
|
week = 7 * day
|
||||||
sizes = {"s": second, "m": minute, "h": hour, "d": day, "w": week, "y": year}
|
year = 365 * day
|
||||||
size = 1
|
sizes = {
|
||||||
suffix = value[-1]
|
"s": second,
|
||||||
if suffix in sizes:
|
"m": minute,
|
||||||
value = value[:-1]
|
"h": hour,
|
||||||
size = sizes[suffix]
|
"d": day,
|
||||||
return int(value) * size
|
"w": week,
|
||||||
|
"y": year,
|
||||||
|
}
|
||||||
|
size = 1
|
||||||
|
suffix = value[-1]
|
||||||
|
if suffix in sizes:
|
||||||
|
value = value[:-1]
|
||||||
|
size = sizes[suffix]
|
||||||
|
return int(value) * size
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Bad duration {value!r}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def abspath(file_path: str) -> str:
|
def abspath(file_path: str) -> str:
|
||||||
|
|
|
@ -126,7 +126,7 @@ class CacheConfig(Config):
|
||||||
|
|
||||||
cache_config = config.get("caches") or {}
|
cache_config = config.get("caches") or {}
|
||||||
self.global_factor = cache_config.get("global_factor", _DEFAULT_FACTOR_SIZE)
|
self.global_factor = cache_config.get("global_factor", _DEFAULT_FACTOR_SIZE)
|
||||||
if not isinstance(self.global_factor, (int, float)):
|
if type(self.global_factor) not in (int, float):
|
||||||
raise ConfigError("caches.global_factor must be a number.")
|
raise ConfigError("caches.global_factor must be a number.")
|
||||||
|
|
||||||
# Load cache factors from the config
|
# Load cache factors from the config
|
||||||
|
@ -151,7 +151,7 @@ class CacheConfig(Config):
|
||||||
)
|
)
|
||||||
|
|
||||||
for cache, factor in individual_factors.items():
|
for cache, factor in individual_factors.items():
|
||||||
if not isinstance(factor, (int, float)):
|
if type(factor) not in (int, float):
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
"caches.per_cache_factors.%s must be a number" % (cache,)
|
"caches.per_cache_factors.%s must be a number" % (cache,)
|
||||||
)
|
)
|
||||||
|
|
|
@ -904,7 +904,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig:
|
||||||
raise ConfigError(DIRECT_TCP_ERROR, ("listeners", str(num), "type"))
|
raise ConfigError(DIRECT_TCP_ERROR, ("listeners", str(num), "type"))
|
||||||
|
|
||||||
port = listener.get("port")
|
port = listener.get("port")
|
||||||
if not isinstance(port, int):
|
if type(port) is not int:
|
||||||
raise ConfigError("Listener configuration is lacking a valid 'port' option")
|
raise ConfigError("Listener configuration is lacking a valid 'port' option")
|
||||||
|
|
||||||
tls = listener.get("tls", False)
|
tls = listener.get("tls", False)
|
||||||
|
|
|
@ -139,7 +139,7 @@ class EventValidator:
|
||||||
max_lifetime = event.content.get("max_lifetime")
|
max_lifetime = event.content.get("max_lifetime")
|
||||||
|
|
||||||
if min_lifetime is not None:
|
if min_lifetime is not None:
|
||||||
if not isinstance(min_lifetime, int):
|
if type(min_lifetime) is not int:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
code=400,
|
code=400,
|
||||||
msg="'min_lifetime' must be an integer",
|
msg="'min_lifetime' must be an integer",
|
||||||
|
@ -147,7 +147,7 @@ class EventValidator:
|
||||||
)
|
)
|
||||||
|
|
||||||
if max_lifetime is not None:
|
if max_lifetime is not None:
|
||||||
if not isinstance(max_lifetime, int):
|
if type(max_lifetime) is not int:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
code=400,
|
code=400,
|
||||||
msg="'max_lifetime' must be an integer",
|
msg="'max_lifetime' must be an integer",
|
||||||
|
|
|
@ -1864,7 +1864,7 @@ class TimestampToEventResponse:
|
||||||
)
|
)
|
||||||
|
|
||||||
origin_server_ts = d.get("origin_server_ts")
|
origin_server_ts = d.get("origin_server_ts")
|
||||||
if not isinstance(origin_server_ts, int):
|
if type(origin_server_ts) is not int:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid response: 'origin_server_ts' must be a int but received %r"
|
"Invalid response: 'origin_server_ts' must be a int but received %r"
|
||||||
% origin_server_ts
|
% origin_server_ts
|
||||||
|
|
|
@ -377,7 +377,7 @@ class MessageHandler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
|
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
|
||||||
if not isinstance(expiry_ts, int) or event.is_state():
|
if type(expiry_ts) is not int or event.is_state():
|
||||||
return
|
return
|
||||||
|
|
||||||
# _schedule_expiry_for_event won't actually schedule anything if there's already
|
# _schedule_expiry_for_event won't actually schedule anything if there's already
|
||||||
|
|
|
@ -152,7 +152,7 @@ class PurgeHistoryRestServlet(RestServlet):
|
||||||
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
|
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
|
||||||
elif "purge_up_to_ts" in body:
|
elif "purge_up_to_ts" in body:
|
||||||
ts = body["purge_up_to_ts"]
|
ts = body["purge_up_to_ts"]
|
||||||
if not isinstance(ts, int):
|
if type(ts) is not int:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"purge_up_to_ts must be an int",
|
"purge_up_to_ts must be an int",
|
||||||
|
|
|
@ -143,7 +143,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
|
||||||
else:
|
else:
|
||||||
# Get length of token to generate (default is 16)
|
# Get length of token to generate (default is 16)
|
||||||
length = body.get("length", 16)
|
length = body.get("length", 16)
|
||||||
if not isinstance(length, int):
|
if type(length) is not int:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"length must be an integer",
|
"length must be an integer",
|
||||||
|
@ -163,8 +163,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
|
||||||
|
|
||||||
uses_allowed = body.get("uses_allowed", None)
|
uses_allowed = body.get("uses_allowed", None)
|
||||||
if not (
|
if not (
|
||||||
uses_allowed is None
|
uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0)
|
||||||
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
|
|
||||||
):
|
):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
@ -173,13 +172,13 @@ class NewRegistrationTokenRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
expiry_time = body.get("expiry_time", None)
|
expiry_time = body.get("expiry_time", None)
|
||||||
if not isinstance(expiry_time, (int, type(None))):
|
if type(expiry_time) not in (int, type(None)):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"expiry_time must be an integer or null",
|
"expiry_time must be an integer or null",
|
||||||
Codes.INVALID_PARAM,
|
Codes.INVALID_PARAM,
|
||||||
)
|
)
|
||||||
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
|
if type(expiry_time) is int and expiry_time < self.clock.time_msec():
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"expiry_time must not be in the past",
|
"expiry_time must not be in the past",
|
||||||
|
@ -284,7 +283,7 @@ class RegistrationTokenRestServlet(RestServlet):
|
||||||
uses_allowed = body["uses_allowed"]
|
uses_allowed = body["uses_allowed"]
|
||||||
if not (
|
if not (
|
||||||
uses_allowed is None
|
uses_allowed is None
|
||||||
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
|
or (type(uses_allowed) is int and uses_allowed >= 0)
|
||||||
):
|
):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
@ -295,13 +294,13 @@ class RegistrationTokenRestServlet(RestServlet):
|
||||||
|
|
||||||
if "expiry_time" in body:
|
if "expiry_time" in body:
|
||||||
expiry_time = body["expiry_time"]
|
expiry_time = body["expiry_time"]
|
||||||
if not isinstance(expiry_time, (int, type(None))):
|
if type(expiry_time) not in (int, type(None)):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"expiry_time must be an integer or null",
|
"expiry_time must be an integer or null",
|
||||||
Codes.INVALID_PARAM,
|
Codes.INVALID_PARAM,
|
||||||
)
|
)
|
||||||
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
|
if type(expiry_time) is int and expiry_time < self.clock.time_msec():
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"expiry_time must not be in the past",
|
"expiry_time must not be in the past",
|
||||||
|
|
|
@ -973,7 +973,7 @@ class UserTokenRestServlet(RestServlet):
|
||||||
body = parse_json_object_from_request(request, allow_empty_body=True)
|
body = parse_json_object_from_request(request, allow_empty_body=True)
|
||||||
|
|
||||||
valid_until_ms = body.get("valid_until_ms")
|
valid_until_ms = body.get("valid_until_ms")
|
||||||
if valid_until_ms and not isinstance(valid_until_ms, int):
|
if type(valid_until_ms) not in (int, type(None)):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int"
|
HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int"
|
||||||
)
|
)
|
||||||
|
@ -1125,14 +1125,14 @@ class RateLimitRestServlet(RestServlet):
|
||||||
messages_per_second = body.get("messages_per_second", 0)
|
messages_per_second = body.get("messages_per_second", 0)
|
||||||
burst_count = body.get("burst_count", 0)
|
burst_count = body.get("burst_count", 0)
|
||||||
|
|
||||||
if not isinstance(messages_per_second, int) or messages_per_second < 0:
|
if type(messages_per_second) is not int or messages_per_second < 0:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"%r parameter must be a positive int" % (messages_per_second,),
|
"%r parameter must be a positive int" % (messages_per_second,),
|
||||||
errcode=Codes.INVALID_PARAM,
|
errcode=Codes.INVALID_PARAM,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(burst_count, int) or burst_count < 0:
|
if type(burst_count) is not int or burst_count < 0:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"%r parameter must be a positive int" % (burst_count,),
|
"%r parameter must be a positive int" % (burst_count,),
|
||||||
|
|
|
@ -54,7 +54,7 @@ class ReportEventRestServlet(RestServlet):
|
||||||
"Param 'reason' must be a string",
|
"Param 'reason' must be a string",
|
||||||
Codes.BAD_JSON,
|
Codes.BAD_JSON,
|
||||||
)
|
)
|
||||||
if not isinstance(body.get("score", 0), int):
|
if type(body.get("score", 0)) is not int:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"Param 'score' must be an integer",
|
"Param 'score' must be an integer",
|
||||||
|
|
|
@ -200,7 +200,7 @@ class OEmbedProvider:
|
||||||
calc_description_and_urls(open_graph_response, oembed["html"])
|
calc_description_and_urls(open_graph_response, oembed["html"])
|
||||||
for size in ("width", "height"):
|
for size in ("width", "height"):
|
||||||
val = oembed.get(size)
|
val = oembed.get(size)
|
||||||
if val is not None and isinstance(val, int):
|
if type(val) is int:
|
||||||
open_graph_response[f"og:video:{size}"] = val
|
open_graph_response[f"og:video:{size}"] = val
|
||||||
|
|
||||||
elif oembed_type == "link":
|
elif oembed_type == "link":
|
||||||
|
|
|
@ -77,7 +77,7 @@ class Thumbnailer:
|
||||||
image_exif = self.image._getexif() # type: ignore
|
image_exif = self.image._getexif() # type: ignore
|
||||||
if image_exif is not None:
|
if image_exif is not None:
|
||||||
image_orientation = image_exif.get(EXIF_ORIENTATION_TAG)
|
image_orientation = image_exif.get(EXIF_ORIENTATION_TAG)
|
||||||
assert isinstance(image_orientation, int)
|
assert type(image_orientation) is int
|
||||||
self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation)
|
self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# A lot of parsing errors can happen when parsing EXIF
|
# A lot of parsing errors can happen when parsing EXIF
|
||||||
|
|
|
@ -1651,7 +1651,7 @@ class PersistEventsStore:
|
||||||
if self._ephemeral_messages_enabled:
|
if self._ephemeral_messages_enabled:
|
||||||
# If there's an expiry timestamp on the event, store it.
|
# If there's an expiry timestamp on the event, store it.
|
||||||
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
|
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
|
||||||
if isinstance(expiry_ts, int) and not event.is_state():
|
if type(expiry_ts) is int and not event.is_state():
|
||||||
self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
|
self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
|
||||||
|
|
||||||
# Insert into the room_memberships table.
|
# Insert into the room_memberships table.
|
||||||
|
@ -2133,10 +2133,10 @@ class PersistEventsStore:
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
"min_lifetime" in event.content
|
"min_lifetime" in event.content
|
||||||
and not isinstance(event.content.get("min_lifetime"), int)
|
and type(event.content["min_lifetime"]) is not int
|
||||||
) or (
|
) or (
|
||||||
"max_lifetime" in event.content
|
"max_lifetime" in event.content
|
||||||
and not isinstance(event.content.get("max_lifetime"), int)
|
and type(event.content["max_lifetime"]) is not int
|
||||||
):
|
):
|
||||||
# Ignore the event if one of the value isn't an integer.
|
# Ignore the event if one of the value isn't an integer.
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue