Prefer `type(x) is int` to `isinstance(x, int)` (#14945)

* Perfer `type(x) is int` to `isinstance(x, int)`

This covered all additional instances I could see where `x` was
user-controlled.
The remaining cases are

```
$ rg -s 'isinstance.*[^_]int'
tests/replication/_base.py
576:        if isinstance(obj, int):

synapse/util/caches/stream_change_cache.py
136:        assert isinstance(stream_pos, int)
214:        assert isinstance(stream_pos, int)
246:        assert isinstance(stream_pos, int)
267:        assert isinstance(stream_pos, int)

synapse/replication/tcp/external_cache.py
133:        if isinstance(result, int):

synapse/metrics/__init__.py
100:        if isinstance(calls, (int, float)):

synapse/handlers/appservice.py
262:        assert isinstance(new_token, int)

synapse/config/_util.py
62:        if isinstance(p, int):
```

which cover metrics, logic related to `jsonschema`, and replication and
data streams. AFAICS these are all internal to Synapse

* Changelog
This commit is contained in:
David Robertson 2023-01-31 10:33:07 +00:00 committed by GitHub
parent 510d4b06e7
commit 796a4b7482
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 75 additions and 47 deletions

1
changelog.d/14945.misc Normal file
View File

@ -0,0 +1 @@
Fix various long-standing bugs in Synapse's config, event and request handling where booleans were unintentionally accepted where an integer was expected.

View File

@ -174,15 +174,29 @@ class Config:
@staticmethod @staticmethod
def parse_size(value: Union[str, int]) -> int: def parse_size(value: Union[str, int]) -> int:
if isinstance(value, int): """Interpret `value` as a number of bytes.
If an integer is provided it is treated as bytes and is unchanged.
String byte sizes can have a suffix of 'K' or `M`, representing kibibytes and
mebibytes respectively. No suffix is understood as a plain byte count.
Raises:
TypeError, if given something other than an integer or a string
ValueError: if given a string not of the form described above.
"""
if type(value) is int:
return value return value
sizes = {"K": 1024, "M": 1024 * 1024} elif type(value) is str:
size = 1 sizes = {"K": 1024, "M": 1024 * 1024}
suffix = value[-1] size = 1
if suffix in sizes: suffix = value[-1]
value = value[:-1] if suffix in sizes:
size = sizes[suffix] value = value[:-1]
return int(value) * size size = sizes[suffix]
return int(value) * size
else:
raise TypeError(f"Bad byte size {value!r}")
@staticmethod @staticmethod
def parse_duration(value: Union[str, int]) -> int: def parse_duration(value: Union[str, int]) -> int:
@ -198,22 +212,36 @@ class Config:
Returns: Returns:
The number of milliseconds in the duration. The number of milliseconds in the duration.
Raises:
TypeError, if given something other than an integer or a string
ValueError: if given a string not of the form described above.
""" """
if isinstance(value, int): if type(value) is int:
return value return value
second = 1000 elif type(value) is str:
minute = 60 * second second = 1000
hour = 60 * minute minute = 60 * second
day = 24 * hour hour = 60 * minute
week = 7 * day day = 24 * hour
year = 365 * day week = 7 * day
sizes = {"s": second, "m": minute, "h": hour, "d": day, "w": week, "y": year} year = 365 * day
size = 1 sizes = {
suffix = value[-1] "s": second,
if suffix in sizes: "m": minute,
value = value[:-1] "h": hour,
size = sizes[suffix] "d": day,
return int(value) * size "w": week,
"y": year,
}
size = 1
suffix = value[-1]
if suffix in sizes:
value = value[:-1]
size = sizes[suffix]
return int(value) * size
else:
raise TypeError(f"Bad duration {value!r}")
@staticmethod @staticmethod
def abspath(file_path: str) -> str: def abspath(file_path: str) -> str:

View File

@ -126,7 +126,7 @@ class CacheConfig(Config):
cache_config = config.get("caches") or {} cache_config = config.get("caches") or {}
self.global_factor = cache_config.get("global_factor", _DEFAULT_FACTOR_SIZE) self.global_factor = cache_config.get("global_factor", _DEFAULT_FACTOR_SIZE)
if not isinstance(self.global_factor, (int, float)): if type(self.global_factor) not in (int, float):
raise ConfigError("caches.global_factor must be a number.") raise ConfigError("caches.global_factor must be a number.")
# Load cache factors from the config # Load cache factors from the config
@ -151,7 +151,7 @@ class CacheConfig(Config):
) )
for cache, factor in individual_factors.items(): for cache, factor in individual_factors.items():
if not isinstance(factor, (int, float)): if type(factor) not in (int, float):
raise ConfigError( raise ConfigError(
"caches.per_cache_factors.%s must be a number" % (cache,) "caches.per_cache_factors.%s must be a number" % (cache,)
) )

View File

@ -904,7 +904,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig:
raise ConfigError(DIRECT_TCP_ERROR, ("listeners", str(num), "type")) raise ConfigError(DIRECT_TCP_ERROR, ("listeners", str(num), "type"))
port = listener.get("port") port = listener.get("port")
if not isinstance(port, int): if type(port) is not int:
raise ConfigError("Listener configuration is lacking a valid 'port' option") raise ConfigError("Listener configuration is lacking a valid 'port' option")
tls = listener.get("tls", False) tls = listener.get("tls", False)

View File

@ -139,7 +139,7 @@ class EventValidator:
max_lifetime = event.content.get("max_lifetime") max_lifetime = event.content.get("max_lifetime")
if min_lifetime is not None: if min_lifetime is not None:
if not isinstance(min_lifetime, int): if type(min_lifetime) is not int:
raise SynapseError( raise SynapseError(
code=400, code=400,
msg="'min_lifetime' must be an integer", msg="'min_lifetime' must be an integer",
@ -147,7 +147,7 @@ class EventValidator:
) )
if max_lifetime is not None: if max_lifetime is not None:
if not isinstance(max_lifetime, int): if type(max_lifetime) is not int:
raise SynapseError( raise SynapseError(
code=400, code=400,
msg="'max_lifetime' must be an integer", msg="'max_lifetime' must be an integer",

View File

@ -1864,7 +1864,7 @@ class TimestampToEventResponse:
) )
origin_server_ts = d.get("origin_server_ts") origin_server_ts = d.get("origin_server_ts")
if not isinstance(origin_server_ts, int): if type(origin_server_ts) is not int:
raise ValueError( raise ValueError(
"Invalid response: 'origin_server_ts' must be a int but received %r" "Invalid response: 'origin_server_ts' must be a int but received %r"
% origin_server_ts % origin_server_ts

View File

@ -377,7 +377,7 @@ class MessageHandler:
""" """
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
if not isinstance(expiry_ts, int) or event.is_state(): if type(expiry_ts) is not int or event.is_state():
return return
# _schedule_expiry_for_event won't actually schedule anything if there's already # _schedule_expiry_for_event won't actually schedule anything if there's already

View File

@ -152,7 +152,7 @@ class PurgeHistoryRestServlet(RestServlet):
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
elif "purge_up_to_ts" in body: elif "purge_up_to_ts" in body:
ts = body["purge_up_to_ts"] ts = body["purge_up_to_ts"]
if not isinstance(ts, int): if type(ts) is not int:
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"purge_up_to_ts must be an int", "purge_up_to_ts must be an int",

View File

@ -143,7 +143,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
else: else:
# Get length of token to generate (default is 16) # Get length of token to generate (default is 16)
length = body.get("length", 16) length = body.get("length", 16)
if not isinstance(length, int): if type(length) is not int:
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"length must be an integer", "length must be an integer",
@ -163,8 +163,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
uses_allowed = body.get("uses_allowed", None) uses_allowed = body.get("uses_allowed", None)
if not ( if not (
uses_allowed is None uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0)
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
): ):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
@ -173,13 +172,13 @@ class NewRegistrationTokenRestServlet(RestServlet):
) )
expiry_time = body.get("expiry_time", None) expiry_time = body.get("expiry_time", None)
if not isinstance(expiry_time, (int, type(None))): if type(expiry_time) not in (int, type(None)):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"expiry_time must be an integer or null", "expiry_time must be an integer or null",
Codes.INVALID_PARAM, Codes.INVALID_PARAM,
) )
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): if type(expiry_time) is int and expiry_time < self.clock.time_msec():
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"expiry_time must not be in the past", "expiry_time must not be in the past",
@ -284,7 +283,7 @@ class RegistrationTokenRestServlet(RestServlet):
uses_allowed = body["uses_allowed"] uses_allowed = body["uses_allowed"]
if not ( if not (
uses_allowed is None uses_allowed is None
or (isinstance(uses_allowed, int) and uses_allowed >= 0) or (type(uses_allowed) is int and uses_allowed >= 0)
): ):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
@ -295,13 +294,13 @@ class RegistrationTokenRestServlet(RestServlet):
if "expiry_time" in body: if "expiry_time" in body:
expiry_time = body["expiry_time"] expiry_time = body["expiry_time"]
if not isinstance(expiry_time, (int, type(None))): if type(expiry_time) not in (int, type(None)):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"expiry_time must be an integer or null", "expiry_time must be an integer or null",
Codes.INVALID_PARAM, Codes.INVALID_PARAM,
) )
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): if type(expiry_time) is int and expiry_time < self.clock.time_msec():
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"expiry_time must not be in the past", "expiry_time must not be in the past",

View File

@ -973,7 +973,7 @@ class UserTokenRestServlet(RestServlet):
body = parse_json_object_from_request(request, allow_empty_body=True) body = parse_json_object_from_request(request, allow_empty_body=True)
valid_until_ms = body.get("valid_until_ms") valid_until_ms = body.get("valid_until_ms")
if valid_until_ms and not isinstance(valid_until_ms, int): if type(valid_until_ms) not in (int, type(None)):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int" HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int"
) )
@ -1125,14 +1125,14 @@ class RateLimitRestServlet(RestServlet):
messages_per_second = body.get("messages_per_second", 0) messages_per_second = body.get("messages_per_second", 0)
burst_count = body.get("burst_count", 0) burst_count = body.get("burst_count", 0)
if not isinstance(messages_per_second, int) or messages_per_second < 0: if type(messages_per_second) is not int or messages_per_second < 0:
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"%r parameter must be a positive int" % (messages_per_second,), "%r parameter must be a positive int" % (messages_per_second,),
errcode=Codes.INVALID_PARAM, errcode=Codes.INVALID_PARAM,
) )
if not isinstance(burst_count, int) or burst_count < 0: if type(burst_count) is not int or burst_count < 0:
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"%r parameter must be a positive int" % (burst_count,), "%r parameter must be a positive int" % (burst_count,),

View File

@ -54,7 +54,7 @@ class ReportEventRestServlet(RestServlet):
"Param 'reason' must be a string", "Param 'reason' must be a string",
Codes.BAD_JSON, Codes.BAD_JSON,
) )
if not isinstance(body.get("score", 0), int): if type(body.get("score", 0)) is not int:
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"Param 'score' must be an integer", "Param 'score' must be an integer",

View File

@ -200,7 +200,7 @@ class OEmbedProvider:
calc_description_and_urls(open_graph_response, oembed["html"]) calc_description_and_urls(open_graph_response, oembed["html"])
for size in ("width", "height"): for size in ("width", "height"):
val = oembed.get(size) val = oembed.get(size)
if val is not None and isinstance(val, int): if type(val) is int:
open_graph_response[f"og:video:{size}"] = val open_graph_response[f"og:video:{size}"] = val
elif oembed_type == "link": elif oembed_type == "link":

View File

@ -77,7 +77,7 @@ class Thumbnailer:
image_exif = self.image._getexif() # type: ignore image_exif = self.image._getexif() # type: ignore
if image_exif is not None: if image_exif is not None:
image_orientation = image_exif.get(EXIF_ORIENTATION_TAG) image_orientation = image_exif.get(EXIF_ORIENTATION_TAG)
assert isinstance(image_orientation, int) assert type(image_orientation) is int
self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation) self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation)
except Exception as e: except Exception as e:
# A lot of parsing errors can happen when parsing EXIF # A lot of parsing errors can happen when parsing EXIF

View File

@ -1651,7 +1651,7 @@ class PersistEventsStore:
if self._ephemeral_messages_enabled: if self._ephemeral_messages_enabled:
# If there's an expiry timestamp on the event, store it. # If there's an expiry timestamp on the event, store it.
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
if isinstance(expiry_ts, int) and not event.is_state(): if type(expiry_ts) is int and not event.is_state():
self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
# Insert into the room_memberships table. # Insert into the room_memberships table.
@ -2133,10 +2133,10 @@ class PersistEventsStore:
): ):
if ( if (
"min_lifetime" in event.content "min_lifetime" in event.content
and not isinstance(event.content.get("min_lifetime"), int) and type(event.content["min_lifetime"]) is not int
) or ( ) or (
"max_lifetime" in event.content "max_lifetime" in event.content
and not isinstance(event.content.get("max_lifetime"), int) and type(event.content["max_lifetime"]) is not int
): ):
# Ignore the event if one of the value isn't an integer. # Ignore the event if one of the value isn't an integer.
return return