Stop user directory from failing if it encounters users not in the `users` table. (#11053)

The following scenarios would halt the user directory updater:

- user joins room
- user leaves room
- user present in room which switches from private to public, or vice versa.

for two classes of users:

- appservice senders
- users missing from the user table.

If this happened, the user directory would be stuck, unable to make forward progress.

Exclude both cases from the user directory, so that we ignore them.

Co-authored-by: Eric Eastwood <erice@element.io>
Co-authored-by: reivilibre <oliverw@matrix.org>
Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com>
Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
This commit is contained in:
David Robertson 2021-10-13 10:38:22 +01:00 committed by GitHub
parent 1db9282dfa
commit b83e822556
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 918 additions and 90 deletions

1
changelog.d/10825.misc Normal file
View File

@ -0,0 +1 @@
Add an 'approximate difference' method to `StateFilter`.

1
changelog.d/10970.misc Normal file
View File

@ -0,0 +1 @@
Fix inconsistent behavior of `get_last_client_by_ip` when reporting data that has not been stored in the database yet.

1
changelog.d/10996.misc Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in Synapse 1.21.0 that causes opentracing and Prometheus metrics for replication requests to be measured incorrectly.

2
changelog.d/11053.bugfix Normal file
View File

@ -0,0 +1,2 @@
Fix a bug introduced in Synapse 1.45.0rc1 where the user directory would stop updating if it processed an event from a
user not in the `users` table.

View File

@ -807,6 +807,14 @@ def trace(func=None, opname=None):
result.addCallbacks(call_back, err_back) result.addCallbacks(call_back, err_back)
else: else:
if inspect.isawaitable(result):
logger.error(
"@trace may not have wrapped %s correctly! "
"The function is not async but returned a %s.",
func.__qualname__,
type(result).__name__,
)
scope.__exit__(None, None, None) scope.__exit__(None, None, None)
return result return result

View File

@ -182,85 +182,87 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
) )
@trace(opname="outgoing_replication_request") @trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress()
async def send_request(*, instance_name="master", **kwargs): async def send_request(*, instance_name="master", **kwargs):
if instance_name == local_instance_name: with outgoing_gauge.track_inprogress():
raise Exception("Trying to send HTTP request to self") if instance_name == local_instance_name:
if instance_name == "master": raise Exception("Trying to send HTTP request to self")
host = master_host if instance_name == "master":
port = master_port host = master_host
elif instance_name in instance_map: port = master_port
host = instance_map[instance_name].host elif instance_name in instance_map:
port = instance_map[instance_name].port host = instance_map[instance_name].host
else: port = instance_map[instance_name].port
raise Exception( else:
"Instance %r not in 'instance_map' config" % (instance_name,) raise Exception(
"Instance %r not in 'instance_map' config" % (instance_name,)
)
data = await cls._serialize_payload(**kwargs)
url_args = [
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
]
if cls.CACHE:
txn_id = random_string(10)
url_args.append(txn_id)
if cls.METHOD == "POST":
request_func = client.post_json_get_json
elif cls.METHOD == "PUT":
request_func = client.put_json
elif cls.METHOD == "GET":
request_func = client.get_json
else:
# We have already asserted in the constructor that a
# compatible was picked, but lets be paranoid.
raise Exception(
"Unknown METHOD on %s replication endpoint" % (cls.NAME,)
)
uri = "http://%s:%s/_synapse/replication/%s/%s" % (
host,
port,
cls.NAME,
"/".join(url_args),
) )
data = await cls._serialize_payload(**kwargs) try:
# We keep retrying the same request for timeouts. This is so that we
# have a good idea that the request has either succeeded or failed
# on the master, and so whether we should clean up or not.
while True:
headers: Dict[bytes, List[bytes]] = {}
# Add an authorization header, if configured.
if replication_secret:
headers[b"Authorization"] = [
b"Bearer " + replication_secret
]
opentracing.inject_header_dict(headers, check_destination=False)
try:
result = await request_func(uri, data, headers=headers)
break
except RequestTimedOutError:
if not cls.RETRY_ON_TIMEOUT:
raise
url_args = [ logger.warning("%s request timed out; retrying", cls.NAME)
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
]
if cls.CACHE: # If we timed out we probably don't need to worry about backing
txn_id = random_string(10) # off too much, but lets just wait a little anyway.
url_args.append(txn_id) await clock.sleep(1)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the main process that we should send to the client. (And
# importantly, not stack traces everywhere)
_outgoing_request_counter.labels(cls.NAME, e.code).inc()
raise e.to_synapse_error()
except Exception as e:
_outgoing_request_counter.labels(cls.NAME, "ERR").inc()
raise SynapseError(502, "Failed to talk to main process") from e
if cls.METHOD == "POST": _outgoing_request_counter.labels(cls.NAME, 200).inc()
request_func = client.post_json_get_json return result
elif cls.METHOD == "PUT":
request_func = client.put_json
elif cls.METHOD == "GET":
request_func = client.get_json
else:
# We have already asserted in the constructor that a
# compatible was picked, but lets be paranoid.
raise Exception(
"Unknown METHOD on %s replication endpoint" % (cls.NAME,)
)
uri = "http://%s:%s/_synapse/replication/%s/%s" % (
host,
port,
cls.NAME,
"/".join(url_args),
)
try:
# We keep retrying the same request for timeouts. This is so that we
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
headers: Dict[bytes, List[bytes]] = {}
# Add an authorization header, if configured.
if replication_secret:
headers[b"Authorization"] = [b"Bearer " + replication_secret]
opentracing.inject_header_dict(headers, check_destination=False)
try:
result = await request_func(uri, data, headers=headers)
break
except RequestTimedOutError:
if not cls.RETRY_ON_TIMEOUT:
raise
logger.warning("%s request timed out; retrying", cls.NAME)
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
await clock.sleep(1)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the main process that we should send to the client. (And
# importantly, not stack traces everywhere)
_outgoing_request_counter.labels(cls.NAME, e.code).inc()
raise e.to_synapse_error()
except Exception as e:
_outgoing_request_counter.labels(cls.NAME, "ERR").inc()
raise SynapseError(502, "Failed to talk to main process") from e
_outgoing_request_counter.labels(cls.NAME, 200).inc()
return result
return send_request return send_request

View File

@ -538,15 +538,20 @@ class ClientIpStore(ClientIpWorkerStore):
""" """
ret = await super().get_last_client_ip_by_device(user_id, device_id) ret = await super().get_last_client_ip_by_device(user_id, device_id)
# Update what is retrieved from the database with data which is pending insertion. # Update what is retrieved from the database with data which is pending
# insertion, as if it has already been stored in the database.
for key in self._batch_row_update: for key in self._batch_row_update:
uid, access_token, ip = key uid, _access_token, ip = key
if uid == user_id: if uid == user_id:
user_agent, did, last_seen = self._batch_row_update[key] user_agent, did, last_seen = self._batch_row_update[key]
if did is None:
# These updates don't make it to the `devices` table
continue
if not device_id or did == device_id: if not device_id or did == device_id:
ret[(user_id, device_id)] = { ret[(user_id, did)] = {
"user_id": user_id, "user_id": user_id,
"access_token": access_token,
"ip": ip, "ip": ip,
"user_agent": user_agent, "user_agent": user_agent,
"device_id": did, "device_id": did,

View File

@ -26,6 +26,8 @@ from typing import (
cast, cast,
) )
from synapse.api.errors import StoreError
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -383,7 +385,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"""Certain classes of local user are omitted from the user directory. """Certain classes of local user are omitted from the user directory.
Is this user one of them? Is this user one of them?
""" """
# App service users aren't usually contactable, so exclude them. # We're opting to exclude the appservice sender (user defined by the
# `sender_localpart` in the appservice registration) even though
# technically it could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice sender can be
# contacted.
if self.get_app_service_by_user_id(user) is not None:
return False
# We're opting to exclude appservice users (anyone matching the user
# namespace regex in the appservice registration) even though technically
# they could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice users can be
# contacted.
if self.get_if_app_services_interested_in_user(user): if self.get_if_app_services_interested_in_user(user):
# TODO we might want to make this configurable for each app service # TODO we might want to make this configurable for each app service
return False return False
@ -393,8 +407,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return False return False
# Deactivated users aren't contactable, so should not appear in the user directory. # Deactivated users aren't contactable, so should not appear in the user directory.
if await self.get_user_deactivated_status(user): try:
if await self.get_user_deactivated_status(user):
return False
except StoreError:
# No such user in the users table. No need to do this when calling
# is_support_user---that returns False if the user is missing.
return False return False
return True return True
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool: async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:

View File

@ -15,9 +15,11 @@ import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Awaitable, Awaitable,
Collection,
Dict, Dict,
Iterable, Iterable,
List, List,
Mapping,
Optional, Optional,
Set, Set,
Tuple, Tuple,
@ -29,7 +31,7 @@ from frozendict import frozendict
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateKey, StateMap
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
@ -134,6 +136,23 @@ class StateFilter:
include_others=True, include_others=True,
) )
@staticmethod
def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
"""
Returns a (frozen) StateFilter with the same contents as the parameters
specified here, which can be made of mutable types.
"""
types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {}
for state_types, state_keys in types.items():
if state_keys is not None:
types_with_frozen_values[state_types] = frozenset(state_keys)
else:
types_with_frozen_values[state_types] = None
return StateFilter(
frozendict(types_with_frozen_values), include_others=include_others
)
def return_expanded(self) -> "StateFilter": def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed """Creates a new StateFilter where type wild cards have been removed
(except for memberships). The returned filter is a superset of the (except for memberships). The returned filter is a superset of the
@ -356,6 +375,157 @@ class StateFilter:
return member_filter, non_member_filter return member_filter, non_member_filter
def _decompose_into_four_parts(
self,
) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]:
"""
Decomposes this state filter into 4 constituent parts, which can be
thought of as this:
all? - minus_wildcards + plus_wildcards + plus_state_keys
where
* all represents ALL state
* minus_wildcards represents entire state types to remove
* plus_wildcards represents entire state types to add
* plus_state_keys represents individual state keys to add
See `recompose_from_four_parts` for the other direction of this
correspondence.
"""
is_all = self.include_others
excluded_types: Set[str] = {t for t in self.types if is_all}
wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None}
concrete_keys: Set[StateKey] = set(self.concrete_types())
return (is_all, excluded_types), (wildcard_types, concrete_keys)
@staticmethod
def _recompose_from_four_parts(
all_part: bool,
minus_wildcards: Set[str],
plus_wildcards: Set[str],
plus_state_keys: Set[StateKey],
) -> "StateFilter":
"""
Recomposes a state filter from 4 parts.
See `decompose_into_four_parts` (the other direction of this
correspondence) for descriptions on each of the parts.
"""
# {state type -> set of state keys OR None for wildcard}
# (The same structure as that of a StateFilter.)
new_types: Dict[str, Optional[Set[str]]] = {}
# if we start with all, insert the excluded statetypes as empty sets
# to prevent them from being included
if all_part:
new_types.update({state_type: set() for state_type in minus_wildcards})
# insert the plus wildcards
new_types.update({state_type: None for state_type in plus_wildcards})
# insert the specific state keys
for state_type, state_key in plus_state_keys:
if state_type in new_types:
entry = new_types[state_type]
if entry is not None:
entry.add(state_key)
elif not all_part:
# don't insert if the entire type is already included by
# include_others as this would actually shrink the state allowed
# by this filter.
new_types[state_type] = {state_key}
return StateFilter.freeze(new_types, include_others=all_part)
def approx_difference(self, other: "StateFilter") -> "StateFilter":
"""
Returns a state filter which represents `self - other`.
This is useful for determining what state remains to be pulled out of the
database if we want the state included by `self` but already have the state
included by `other`.
The returned state filter
- MUST include all state events that are included by this filter (`self`)
unless they are included by `other`;
- MUST NOT include state events not included by this filter (`self`); and
- MAY be an over-approximation: the returned state filter
MAY additionally include some state events from `other`.
This implementation attempts to return the narrowest such state filter.
In the case that `self` contains wildcards for state types where
`other` contains specific state keys, an approximation must be made:
the returned state filter keeps the wildcard, as state filters are not
able to express 'all state keys except some given examples'.
e.g.
StateFilter(m.room.member -> None (wildcard))
minus
StateFilter(m.room.member -> {'@wombat:example.org'})
is approximated as
StateFilter(m.room.member -> None (wildcard))
"""
# We first transform self and other into an alternative representation:
# - whether or not they include all events to begin with ('all')
# - if so, which event types are excluded? ('excludes')
# - which entire event types to include ('wildcards')
# - which concrete state keys to include ('concrete state keys')
(self_all, self_excludes), (
self_wildcards,
self_concrete_keys,
) = self._decompose_into_four_parts()
(other_all, other_excludes), (
other_wildcards,
other_concrete_keys,
) = other._decompose_into_four_parts()
# Start with an estimate of the difference based on self
new_all = self_all
# Wildcards from the other can be added to the exclusion filter
new_excludes = self_excludes | other_wildcards
# We remove wildcards that appeared as wildcards in the other
new_wildcards = self_wildcards - other_wildcards
# We filter out the concrete state keys that appear in the other
# as wildcards or concrete state keys.
new_concrete_keys = {
(state_type, state_key)
for (state_type, state_key) in self_concrete_keys
if state_type not in other_wildcards
} - other_concrete_keys
if other_all:
if self_all:
# If self starts with all, then we add as wildcards any
# types which appear in the other's exclusion filter (but
# aren't in the self exclusion filter). This is as the other
# filter will return everything BUT the types in its exclusion, so
# we need to add those excluded types that also match the self
# filter as wildcard types in the new filter.
new_wildcards |= other_excludes.difference(self_excludes)
# If other is an `include_others` then the difference isn't.
new_all = False
# (We have no need for excludes when we don't start with all, as there
# is nothing to exclude.)
new_excludes = set()
# We also filter out all state types that aren't in the exclusion
# list of the other.
new_wildcards &= other_excludes
new_concrete_keys = {
(state_type, state_key)
for (state_type, state_key) in new_concrete_keys
if state_type in other_excludes
}
# Transform our newly-constructed state filter from the alternative
# representation back into the normal StateFilter representation.
return StateFilter._recompose_from_four_parts(
new_all, new_excludes, new_wildcards, new_concrete_keys
)
class StateGroupStorage: class StateGroupStorage:
"""High level interface to fetching state for event.""" """High level interface to fetching state for event."""

View File

@ -63,7 +63,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
hostname="test", hostname="test",
id="1234", id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test", # Note: this user does not match the regex above, so that tests
# can distinguish the sender from the AS user.
sender="@as_main:test",
) )
mock_load_appservices = Mock(return_value=[self.appservice]) mock_load_appservices = Mock(return_value=[self.appservice])
@ -122,7 +124,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
{(alice, bob, private), (bob, alice, private)}, {(alice, bob, private), (bob, alice, private)},
) )
# The next three tests (test_population_excludes_*) all setup # The next four tests (test_excludes_*) all setup
# - A normal user included in the user dir # - A normal user included in the user dir
# - A public and private room created by that user # - A public and private room created by that user
# - A user excluded from the room dir, belonging to both rooms # - A user excluded from the room dir, belonging to both rooms
@ -179,6 +181,34 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
) )
self._check_only_one_user_in_directory(user, public) self._check_only_one_user_in_directory(user, public)
def test_excludes_appservice_sender(self) -> None:
user = self.register_user("user", "pass")
token = self.login(user, "pass")
room = self.helper.create_room_as(user, is_public=True, tok=token)
self.helper.join(room, self.appservice.sender, tok=self.appservice.token)
self._check_only_one_user_in_directory(user, room)
def test_user_not_in_users_table(self) -> None:
"""Unclear how it happens, but on matrix.org we've seen join events
for users who aren't in the users table. Test that we don't fall over
when processing such a user.
"""
user1 = self.register_user("user1", "pass")
token1 = self.login(user1, "pass")
room = self.helper.create_room_as(user1, is_public=True, tok=token1)
# Inject a join event for a user who doesn't exist
self.get_success(inject_member_event(self.hs, room, "@not-a-user:test", "join"))
# Another new user registers and joins the room
user2 = self.register_user("user2", "pass")
token2 = self.login(user2, "pass")
self.helper.join(room, user2, tok=token2)
# The dodgy event should not have stopped us from processing user2's join.
in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
self.assertEqual(set(in_public), {(user1, room), (user2, room)})
def _create_rooms_and_inject_memberships( def _create_rooms_and_inject_memberships(
self, creator: str, token: str, joiner: str self, creator: str, token: str, joiner: str
) -> Tuple[str, str]: ) -> Tuple[str, str]:
@ -230,7 +260,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
) )
) )
profile = self.get_success(self.store.get_user_in_directory(support_user_id)) profile = self.get_success(self.store.get_user_in_directory(support_user_id))
self.assertTrue(profile is None) self.assertIsNone(profile)
display_name = "display_name" display_name = "display_name"
profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name) profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
@ -264,7 +294,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is not in directory # profile is not in directory
profile = self.get_success(self.store.get_user_in_directory(r_user_id)) profile = self.get_success(self.store.get_user_in_directory(r_user_id))
self.assertTrue(profile is None) self.assertIsNone(profile)
# update profile after deactivation # update profile after deactivation
self.get_success( self.get_success(
@ -273,7 +303,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is furthermore not in directory # profile is furthermore not in directory
profile = self.get_success(self.store.get_user_in_directory(r_user_id)) profile = self.get_success(self.store.get_user_in_directory(r_user_id))
self.assertTrue(profile is None) self.assertIsNone(profile)
def test_handle_local_profile_change_with_appservice_user(self) -> None: def test_handle_local_profile_change_with_appservice_user(self) -> None:
# create user # create user
@ -283,7 +313,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is not in directory # profile is not in directory
profile = self.get_success(self.store.get_user_in_directory(as_user_id)) profile = self.get_success(self.store.get_user_in_directory(as_user_id))
self.assertTrue(profile is None) self.assertIsNone(profile)
# update profile # update profile
profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3") profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3")
@ -293,7 +323,28 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is still not in directory # profile is still not in directory
profile = self.get_success(self.store.get_user_in_directory(as_user_id)) profile = self.get_success(self.store.get_user_in_directory(as_user_id))
self.assertTrue(profile is None) self.assertIsNone(profile)
def test_handle_local_profile_change_with_appservice_sender(self) -> None:
# profile is not in directory
profile = self.get_success(
self.store.get_user_in_directory(self.appservice.sender)
)
self.assertIsNone(profile)
# update profile
profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3")
self.get_success(
self.handler.handle_local_profile_change(
self.appservice.sender, profile_info
)
)
# profile is still not in directory
profile = self.get_success(
self.store.get_user_in_directory(self.appservice.sender)
)
self.assertIsNone(profile)
def test_handle_user_deactivated_support_user(self) -> None: def test_handle_user_deactivated_support_user(self) -> None:
s_user_id = "@support:test" s_user_id = "@support:test"

View File

@ -146,6 +146,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
], ],
) )
@parameterized.expand([(False,), (True,)])
def test_get_last_client_ip_by_device(self, after_persisting: bool):
"""Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
self.reactor.advance(12345678)
user_id = "@user:id"
device_id = "MY_DEVICE"
# Insert a user IP
self.get_success(
self.store.store_device(
user_id,
device_id,
"display name",
)
)
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", device_id
)
)
if after_persisting:
# Trigger the storage loop
self.reactor.advance(10)
result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id)
)
self.assertEqual(
result,
{
(user_id, device_id): {
"user_id": user_id,
"device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 12345678000,
},
},
)
@parameterized.expand([(False,), (True,)]) @parameterized.expand([(False,), (True,)])
def test_get_user_ip_and_agents(self, after_persisting: bool): def test_get_user_ip_and_agents(self, after_persisting: bool):
"""Test `get_user_ip_and_agents` for persisted and unpersisted data""" """Test `get_user_ip_and_agents` for persisted and unpersisted data"""

View File

@ -21,7 +21,7 @@ from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase, TestCase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -105,7 +105,6 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
def test_get_state_for_event(self): def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever # this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room. # forward extremities are currently in the DB for this room.
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@ -483,3 +482,513 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
class StateFilterDifferenceTestCase(TestCase):
def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
):
self.assertEqual(
minuend.approx_difference(subtrahend),
expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
)
def test_state_filter_difference_no_include_other_minus_no_include_other(self):
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the
include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze({EventTypes.Create: None}, include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=False,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.CanonicalAlias: {""}},
include_others=False,
),
)
# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)
def test_state_filter_difference_include_other_minus_no_include_other(self):
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Create: None,
EventTypes.Member: set(),
EventTypes.CanonicalAlias: set(),
},
include_others=True,
),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
# This also shows that the resultant state filter is normalised.
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
StateFilter(types=frozendict(), include_others=True),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter(
types=frozendict(),
include_others=True,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.CanonicalAlias: {""},
EventTypes.Member: set(),
},
include_others=True,
),
)
# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)
def test_state_filter_difference_include_other_minus_include_other(self):
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others
flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)
# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
EventTypes.Create: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)
def test_state_filter_difference_no_include_other_minus_include_other(self):
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=True,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)
# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)
def test_state_filter_difference_simple_cases(self):
"""
Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests.
"""
self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
self.assert_difference(
StateFilter.all(),
StateFilter.none(),
StateFilter.all(),
)

View File

@ -256,7 +256,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
self.assertEqual(users, {u1, u2, u3}) self.assertEqual(users, {u1, u2, u3})
# The next three tests (test_population_excludes_*) all set up # The next four tests (test_population_excludes_*) all set up
# - A normal user included in the user dir # - A normal user included in the user dir
# - A public and private room created by that user # - A public and private room created by that user
# - A user excluded from the room dir, belonging to both rooms # - A user excluded from the room dir, belonging to both rooms
@ -364,6 +364,21 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
# Check the AS user is not in the directory. # Check the AS user is not in the directory.
self._check_room_sharing_tables(user, public, private) self._check_room_sharing_tables(user, public, private)
def test_population_excludes_appservice_sender(self) -> None:
user = self.register_user("user", "pass")
token = self.login(user, "pass")
# Join the AS sender to rooms owned by the normal user.
public, private = self._create_rooms_and_inject_memberships(
user, token, self.appservice.sender
)
# Rebuild the directory.
self._purge_and_rebuild_user_dir()
# Check the AS sender is not in the directory.
self._check_room_sharing_tables(user, public, private)
def test_population_conceals_private_nickname(self) -> None: def test_population_conceals_private_nickname(self) -> None:
# Make a private room, and set a nickname within # Make a private room, and set a nickname within
user = self.register_user("aaaa", "pass") user = self.register_user("aaaa", "pass")