Add type hints to misc. files. (#9676)
This commit is contained in:
parent
7e8dc9934e
commit
af387cf52a
|
@ -0,0 +1 @@
|
|||
Add type hints to third party event rules and visibility modules.
|
5
mypy.ini
5
mypy.ini
|
@ -20,8 +20,9 @@ files =
|
|||
synapse/crypto,
|
||||
synapse/event_auth.py,
|
||||
synapse/events/builder.py,
|
||||
synapse/events/validator.py,
|
||||
synapse/events/spamcheck.py,
|
||||
synapse/events/third_party_rules.py,
|
||||
synapse/events/validator.py,
|
||||
synapse/federation,
|
||||
synapse/groups,
|
||||
synapse/handlers,
|
||||
|
@ -38,6 +39,7 @@ files =
|
|||
synapse/push,
|
||||
synapse/replication,
|
||||
synapse/rest,
|
||||
synapse/secrets.py,
|
||||
synapse/server.py,
|
||||
synapse/server_notices,
|
||||
synapse/spam_checker_api,
|
||||
|
@ -71,6 +73,7 @@ files =
|
|||
synapse/util/metrics.py,
|
||||
synapse/util/macaroons.py,
|
||||
synapse/util/stringutils.py,
|
||||
synapse/visibility.py,
|
||||
tests/replication,
|
||||
tests/test_utils,
|
||||
tests/handlers/test_password_providers.py,
|
||||
|
|
|
@ -13,12 +13,15 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.types import Requester, StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class ThirdPartyEventRules:
|
||||
"""Allows server admins to provide a Python module implementing an extra
|
||||
|
@ -28,7 +31,7 @@ class ThirdPartyEventRules:
|
|||
behaviours.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.third_party_rules = None
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -95,10 +98,9 @@ class ThirdPartyEventRules:
|
|||
if self.third_party_rules is None:
|
||||
return True
|
||||
|
||||
ret = await self.third_party_rules.on_create_room(
|
||||
return await self.third_party_rules.on_create_room(
|
||||
requester, config, is_requester_admin
|
||||
)
|
||||
return ret
|
||||
|
||||
async def check_threepid_can_be_invited(
|
||||
self, medium: str, address: str, room_id: str
|
||||
|
@ -119,10 +121,9 @@ class ThirdPartyEventRules:
|
|||
|
||||
state_events = await self._get_state_map_for_room(room_id)
|
||||
|
||||
ret = await self.third_party_rules.check_threepid_can_be_invited(
|
||||
return await self.third_party_rules.check_threepid_can_be_invited(
|
||||
medium, address, state_events
|
||||
)
|
||||
return ret
|
||||
|
||||
async def check_visibility_can_be_modified(
|
||||
self, room_id: str, new_visibility: str
|
||||
|
@ -143,7 +144,7 @@ class ThirdPartyEventRules:
|
|||
check_func = getattr(
|
||||
self.third_party_rules, "check_visibility_can_be_modified", None
|
||||
)
|
||||
if not check_func or not isinstance(check_func, Callable):
|
||||
if not check_func or not callable(check_func):
|
||||
return True
|
||||
|
||||
state_events = await self._get_state_map_for_room(room_id)
|
||||
|
|
|
@ -26,10 +26,10 @@ if sys.version_info[0:2] >= (3, 6):
|
|||
import secrets
|
||||
|
||||
class Secrets:
|
||||
def token_bytes(self, nbytes=32):
|
||||
def token_bytes(self, nbytes: int = 32) -> bytes:
|
||||
return secrets.token_bytes(nbytes)
|
||||
|
||||
def token_hex(self, nbytes=32):
|
||||
def token_hex(self, nbytes: int = 32) -> str:
|
||||
return secrets.token_hex(nbytes)
|
||||
|
||||
|
||||
|
@ -38,8 +38,8 @@ else:
|
|||
import os
|
||||
|
||||
class Secrets:
|
||||
def token_bytes(self, nbytes=32):
|
||||
def token_bytes(self, nbytes: int = 32) -> bytes:
|
||||
return os.urandom(nbytes)
|
||||
|
||||
def token_hex(self, nbytes=32):
|
||||
def token_hex(self, nbytes: int = 32) -> str:
|
||||
return binascii.hexlify(self.token_bytes(nbytes)).decode("ascii")
|
||||
|
|
|
@ -449,7 +449,7 @@ class StateGroupStorage:
|
|||
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
||||
|
||||
async def get_state_for_events(
|
||||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
||||
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
|
||||
) -> Dict[str, StateMap[EventBase]]:
|
||||
"""Given a list of event_ids and type tuples, return a list of state
|
||||
dicts for each event.
|
||||
|
@ -485,7 +485,7 @@ class StateGroupStorage:
|
|||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
async def get_state_ids_for_events(
|
||||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
||||
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
|
||||
) -> Dict[str, StateMap[str]]:
|
||||
"""
|
||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import operator
|
||||
from typing import Dict, FrozenSet, List, Optional
|
||||
|
||||
from synapse.api.constants import (
|
||||
AccountDataTypes,
|
||||
|
@ -21,10 +21,11 @@ from synapse.api.constants import (
|
|||
HistoryVisibility,
|
||||
Membership,
|
||||
)
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.storage import Storage
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.types import StateMap, get_domain_from_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -48,32 +49,32 @@ MEMBERSHIP_PRIORITY = (
|
|||
|
||||
async def filter_events_for_client(
|
||||
storage: Storage,
|
||||
user_id,
|
||||
events,
|
||||
is_peeking=False,
|
||||
always_include_ids=frozenset(),
|
||||
filter_send_to_client=True,
|
||||
):
|
||||
user_id: str,
|
||||
events: List[EventBase],
|
||||
is_peeking: bool = False,
|
||||
always_include_ids: FrozenSet[str] = frozenset(),
|
||||
filter_send_to_client: bool = True,
|
||||
) -> List[EventBase]:
|
||||
"""
|
||||
Check which events a user is allowed to see. If the user can see the event but its
|
||||
sender asked for their data to be erased, prune the content of the event.
|
||||
|
||||
Args:
|
||||
storage
|
||||
user_id(str): user id to be checked
|
||||
events(list[synapse.events.EventBase]): sequence of events to be checked
|
||||
is_peeking(bool): should be True if:
|
||||
user_id: user id to be checked
|
||||
events: sequence of events to be checked
|
||||
is_peeking: should be True if:
|
||||
* the user is not currently a member of the room, and:
|
||||
* the user has not been a member of the room since the given
|
||||
events
|
||||
always_include_ids (set(event_id)): set of event ids to specifically
|
||||
always_include_ids: set of event ids to specifically
|
||||
include (unless sender is ignored)
|
||||
filter_send_to_client (bool): Whether we're checking an event that's going to be
|
||||
filter_send_to_client: Whether we're checking an event that's going to be
|
||||
sent to a client. This might not always be the case since this function can
|
||||
also be called to check whether a user can see the state at a given point.
|
||||
|
||||
Returns:
|
||||
list[synapse.events.EventBase]
|
||||
The filtered events.
|
||||
"""
|
||||
# Filter out events that have been soft failed so that we don't relay them
|
||||
# to clients.
|
||||
|
@ -90,7 +91,7 @@ async def filter_events_for_client(
|
|||
AccountDataTypes.IGNORED_USER_LIST, user_id
|
||||
)
|
||||
|
||||
ignore_list = frozenset()
|
||||
ignore_list = frozenset() # type: FrozenSet[str]
|
||||
if ignore_dict_content:
|
||||
ignored_users_dict = ignore_dict_content.get("ignored_users", {})
|
||||
if isinstance(ignored_users_dict, dict):
|
||||
|
@ -107,19 +108,18 @@ async def filter_events_for_client(
|
|||
room_id
|
||||
] = await storage.main.get_retention_policy_for_room(room_id)
|
||||
|
||||
def allowed(event):
|
||||
def allowed(event: EventBase) -> Optional[EventBase]:
|
||||
"""
|
||||
Args:
|
||||
event (synapse.events.EventBase): event to check
|
||||
event: event to check
|
||||
|
||||
Returns:
|
||||
None|EventBase:
|
||||
None if the user cannot see this event at all
|
||||
None if the user cannot see this event at all
|
||||
|
||||
a redacted copy of the event if they can only see a redacted
|
||||
version
|
||||
a redacted copy of the event if they can only see a redacted
|
||||
version
|
||||
|
||||
the original event if they can see it as normal.
|
||||
the original event if they can see it as normal.
|
||||
"""
|
||||
# Only run some checks if these events aren't about to be sent to clients. This is
|
||||
# because, if this is not the case, we're probably only checking if the users can
|
||||
|
@ -252,48 +252,46 @@ async def filter_events_for_client(
|
|||
|
||||
return event
|
||||
|
||||
# check each event: gives an iterable[None|EventBase]
|
||||
# Check each event: gives an iterable of None or (a potentially modified)
|
||||
# EventBase.
|
||||
filtered_events = map(allowed, events)
|
||||
|
||||
# remove the None entries
|
||||
filtered_events = filter(operator.truth, filtered_events)
|
||||
|
||||
# we turn it into a list before returning it.
|
||||
return list(filtered_events)
|
||||
# Turn it into a list and remove None entries before returning.
|
||||
return [ev for ev in filtered_events if ev]
|
||||
|
||||
|
||||
async def filter_events_for_server(
|
||||
storage: Storage,
|
||||
server_name,
|
||||
events,
|
||||
redact=True,
|
||||
check_history_visibility_only=False,
|
||||
):
|
||||
server_name: str,
|
||||
events: List[EventBase],
|
||||
redact: bool = True,
|
||||
check_history_visibility_only: bool = False,
|
||||
) -> List[EventBase]:
|
||||
"""Filter a list of events based on whether given server is allowed to
|
||||
see them.
|
||||
|
||||
Args:
|
||||
storage
|
||||
server_name (str)
|
||||
events (iterable[FrozenEvent])
|
||||
redact (bool): Whether to return a redacted version of the event, or
|
||||
server_name
|
||||
events
|
||||
redact: Whether to return a redacted version of the event, or
|
||||
to filter them out entirely.
|
||||
check_history_visibility_only (bool): Whether to only check the
|
||||
check_history_visibility_only: Whether to only check the
|
||||
history visibility, rather than things like if the sender has been
|
||||
erased. This is used e.g. during pagination to decide whether to
|
||||
backfill or not.
|
||||
|
||||
Returns
|
||||
list[FrozenEvent]
|
||||
The filtered events.
|
||||
"""
|
||||
|
||||
def is_sender_erased(event, erased_senders):
|
||||
def is_sender_erased(event: EventBase, erased_senders: Dict[str, bool]) -> bool:
|
||||
if erased_senders and erased_senders[event.sender]:
|
||||
logger.info("Sender of %s has been erased, redacting", event.event_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_event_is_visible(event, state):
|
||||
def check_event_is_visible(event: EventBase, state: StateMap[EventBase]) -> bool:
|
||||
history = state.get((EventTypes.RoomHistoryVisibility, ""), None)
|
||||
if history:
|
||||
visibility = history.content.get(
|
||||
|
|
Loading…
Reference in New Issue