Annotate `log_function` decorator (#10943)

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
reivilibre 2021-10-27 17:27:23 +01:00 committed by GitHub
parent 4e393af52f
commit 75ca0a6168
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 58 additions and 18 deletions

1
changelog.d/10943.misc Normal file
View File

@ -0,0 +1 @@
Add type annotations for the `log_function` decorator.

View File

@ -227,7 +227,7 @@ class FederationClient(FederationBase):
) )
async def backfill( async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str] self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> Optional[List[EventBase]]: ) -> Optional[List[EventBase]]:
"""Requests some more historic PDUs for the given room from the """Requests some more historic PDUs for the given room from the
given destination server. given destination server.
@ -237,6 +237,8 @@ class FederationClient(FederationBase):
room_id: The room_id to backfill. room_id: The room_id to backfill.
limit: The maximum number of events to return. limit: The maximum number of events to return.
extremities: our current backwards extremities, to backfill from extremities: our current backwards extremities, to backfill from
Must be a Collection that is falsy when empty.
(Iterable is not enough here!)
""" """
logger.debug("backfill extrem=%s", extremities) logger.debug("backfill extrem=%s", extremities)
@ -250,11 +252,22 @@ class FederationClient(FederationBase):
logger.debug("backfill transaction_data=%r", transaction_data) logger.debug("backfill transaction_data=%r", transaction_data)
if not isinstance(transaction_data, dict):
# TODO we probably want an exception type specific to federation
# client validation.
raise TypeError("Backfill transaction_data is not a dict.")
transaction_data_pdus = transaction_data.get("pdus")
if not isinstance(transaction_data_pdus, list):
# TODO we probably want an exception type specific to federation
# client validation.
raise TypeError("transaction_data.pdus is not a list.")
room_version = await self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
pdus = [ pdus = [
event_from_pdu_json(p, room_version, outlier=False) event_from_pdu_json(p, room_version, outlier=False)
for p in transaction_data["pdus"] for p in transaction_data_pdus
] ]
# Check signatures and hash of pdus, removing any from the list that fail checks # Check signatures and hash of pdus, removing any from the list that fail checks

View File

@ -295,14 +295,16 @@ class FederationServer(FederationBase):
Returns: Returns:
HTTP response code and body HTTP response code and body
""" """
response = await self.transaction_actions.have_responded(origin, transaction) existing_response = await self.transaction_actions.have_responded(
origin, transaction
)
if response: if existing_response:
logger.debug( logger.debug(
"[%s] We've already responded to this request", "[%s] We've already responded to this request",
transaction.transaction_id, transaction.transaction_id,
) )
return response return existing_response
logger.debug("[%s] Transaction is new", transaction.transaction_id) logger.debug("[%s] Transaction is new", transaction.transaction_id)
@ -632,7 +634,7 @@ class FederationServer(FederationBase):
async def on_make_knock_request( async def on_make_knock_request(
self, origin: str, room_id: str, user_id: str, supported_versions: List[str] self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
) -> Dict[str, Union[EventBase, str]]: ) -> JsonDict:
"""We've received a /make_knock/ request, so we create a partial knock """We've received a /make_knock/ request, so we create a partial knock
event for the room and hand that back, along with the room version, to the knocking event for the room and hand that back, along with the room version, to the knocking
homeserver. We do *not* persist or process this event until the other server has homeserver. We do *not* persist or process this event until the other server has

View File

@ -149,7 +149,6 @@ class TransactionManager:
) )
except HttpResponseException as e: except HttpResponseException as e:
code = e.code code = e.code
response = e.response
set_tag(tags.ERROR, True) set_tag(tags.ERROR, True)

View File

@ -15,7 +15,19 @@
import logging import logging
import urllib import urllib
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union from typing import (
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
)
import attr import attr
import ijson import ijson
@ -100,7 +112,7 @@ class TransportLayerClient:
@log_function @log_function
async def backfill( async def backfill(
self, destination: str, room_id: str, event_tuples: Iterable[str], limit: int self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
) -> Optional[JsonDict]: ) -> Optional[JsonDict]:
"""Requests `limit` previous PDUs in a given context before list of """Requests `limit` previous PDUs in a given context before list of
PDUs. PDUs.
@ -108,7 +120,9 @@ class TransportLayerClient:
Args: Args:
destination destination
room_id room_id
event_tuples event_tuples:
Must be a Collection that is falsy when empty.
(Iterable is not enough here!)
limit limit
Returns: Returns:
@ -786,7 +800,7 @@ class TransportLayerClient:
@log_function @log_function
def join_group( def join_group(
self, destination: str, group_id: str, user_id: str, content: JsonDict self, destination: str, group_id: str, user_id: str, content: JsonDict
) -> JsonDict: ) -> Awaitable[JsonDict]:
"""Attempts to join a group""" """Attempts to join a group"""
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)

View File

@ -245,7 +245,7 @@ class DirectoryHandler:
servers = result.servers servers = result.servers
else: else:
try: try:
fed_result = await self.federation.make_query( fed_result: Optional[JsonDict] = await self.federation.make_query(
destination=room_alias.domain, destination=room_alias.domain,
query_type="directory", query_type="directory",
args={"room_alias": room_alias.to_string()}, args={"room_alias": room_alias.to_string()},

View File

@ -477,7 +477,7 @@ class FederationEventHandler:
@log_function @log_function
async def backfill( async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str] self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> None: ) -> None:
"""Trigger a backfill request to `dest` for the given `room_id` """Trigger a backfill request to `dest` for the given `room_id`

View File

@ -52,6 +52,7 @@ import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.appservice import ApplicationService
from synapse.events.presence_router import PresenceRouter from synapse.events.presence_router import PresenceRouter
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
@ -1551,6 +1552,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
is_guest: bool = False, is_guest: bool = False,
explicit_room_id: Optional[str] = None, explicit_room_id: Optional[str] = None,
include_offline: bool = True, include_offline: bool = True,
service: Optional[ApplicationService] = None,
) -> Tuple[List[UserPresenceState], int]: ) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are: # The process for getting presence events are:
# 1. Get the rooms the user is in. # 1. Get the rooms the user is in.

View File

@ -456,7 +456,11 @@ class ProfileHandler:
continue continue
new_name = profile.get("displayname") new_name = profile.get("displayname")
if not isinstance(new_name, str):
new_name = None
new_avatar = profile.get("avatar_url") new_avatar = profile.get("avatar_url")
if not isinstance(new_avatar, str):
new_avatar = None
# We always hit update to update the last_check timestamp # We always hit update to update the last_check timestamp
await self.store.update_remote_profile_cache(user_id, new_name, new_avatar) await self.store.update_remote_profile_cache(user_id, new_name, new_avatar)

View File

@ -16,6 +16,7 @@
import logging import logging
from functools import wraps from functools import wraps
from inspect import getcallargs from inspect import getcallargs
from typing import Callable, TypeVar, cast
_TIME_FUNC_ID = 0 _TIME_FUNC_ID = 0
@ -41,7 +42,10 @@ def _log_debug_as_f(f, msg, msg_args):
logger.handle(record) logger.handle(record)
def log_function(f): F = TypeVar("F", bound=Callable)
def log_function(f: F) -> F:
"""Function decorator that logs every call to that function.""" """Function decorator that logs every call to that function."""
func_name = f.__name__ func_name = f.__name__
@ -69,4 +73,4 @@ def log_function(f):
return f(*args, **kwargs) return f(*args, **kwargs)
wrapped.__name__ = func_name wrapped.__name__ = func_name
return wrapped return cast(F, wrapped)

View File

@ -26,6 +26,7 @@ from typing import (
FrozenSet, FrozenSet,
Iterable, Iterable,
List, List,
Mapping,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@ -519,7 +520,7 @@ class StateResolutionHandler:
self, self,
room_id: str, room_id: str,
room_version: str, room_version: str,
state_groups_ids: Dict[int, StateMap[str]], state_groups_ids: Mapping[int, StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore", state_res_store: "StateResolutionStore",
) -> _StateCacheEntry: ) -> _StateCacheEntry:
@ -703,7 +704,7 @@ class StateResolutionHandler:
def _make_state_cache_entry( def _make_state_cache_entry(
new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]] new_state: StateMap[str], state_groups_ids: Mapping[int, StateMap[str]]
) -> _StateCacheEntry: ) -> _StateCacheEntry:
"""Given a resolved state, and a set of input state groups, pick one to base """Given a resolved state, and a set of input state groups, pick one to base
a new state group on (if any), and return an appropriately-constructed a new state group on (if any), and return an appropriately-constructed

View File

@ -91,7 +91,7 @@ class ProfileWorkerStore(SQLBaseStore):
) )
async def update_remote_profile_cache( async def update_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str self, user_id: str, displayname: Optional[str], avatar_url: Optional[str]
) -> int: ) -> int:
return await self.db_pool.simple_update( return await self.db_pool.simple_update(
table="remote_profile_cache", table="remote_profile_cache",