Correctly exclude users when making a room public or private (#11075)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
parent
5573133348
commit
e09be0c87a
|
@ -0,0 +1 @@
|
||||||
|
Fix a long-standing bug where users excluded from the user directory were added into the directory if they belonged to a room which became public or private.
|
|
@ -266,14 +266,17 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||||
for user_id in users_in_room:
|
for user_id in users_in_room:
|
||||||
await self.store.remove_user_who_share_room(user_id, room_id)
|
await self.store.remove_user_who_share_room(user_id, room_id)
|
||||||
|
|
||||||
# Then, re-add them to the tables.
|
# Then, re-add all remote users and some local users to the tables.
|
||||||
# NOTE: this is not the most efficient method, as _track_user_joined_room sets
|
# NOTE: this is not the most efficient method, as _track_user_joined_room sets
|
||||||
# up local_user -> other_user and other_user_whos_local -> local_user,
|
# up local_user -> other_user and other_user_whos_local -> local_user,
|
||||||
# which when ran over an entire room, will result in the same values
|
# which when ran over an entire room, will result in the same values
|
||||||
# being added multiple times. The batching upserts shouldn't make this
|
# being added multiple times. The batching upserts shouldn't make this
|
||||||
# too bad, though.
|
# too bad, though.
|
||||||
for user_id in users_in_room:
|
for user_id in users_in_room:
|
||||||
await self._track_user_joined_room(room_id, user_id)
|
if not self.is_mine_id(
|
||||||
|
user_id
|
||||||
|
) or await self.store.should_include_local_user_in_dir(user_id):
|
||||||
|
await self._track_user_joined_room(room_id, user_id)
|
||||||
|
|
||||||
async def _handle_room_membership_event(
|
async def _handle_room_membership_event(
|
||||||
self,
|
self,
|
||||||
|
@ -364,8 +367,8 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||||
"""Someone's just joined a room. Update `users_in_public_rooms` or
|
"""Someone's just joined a room. Update `users_in_public_rooms` or
|
||||||
`users_who_share_private_rooms` as appropriate.
|
`users_who_share_private_rooms` as appropriate.
|
||||||
|
|
||||||
The caller is responsible for ensuring that the given user is not excluded
|
The caller is responsible for ensuring that the given user should be
|
||||||
from the user directory.
|
included in the user directory.
|
||||||
"""
|
"""
|
||||||
is_public = await self.store.is_room_world_readable_or_publicly_joinable(
|
is_public = await self.store.is_room_world_readable_or_publicly_joinable(
|
||||||
room_id
|
room_id
|
||||||
|
|
|
@ -109,18 +109,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
tok=alice_token,
|
tok=alice_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
|
# The user directory should reflect the room memberships above.
|
||||||
in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
|
users, in_public, in_private = self.get_success(
|
||||||
in_private = self.get_success(
|
self.user_dir_helper.get_tables()
|
||||||
self.user_dir_helper.get_users_who_share_private_rooms()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(users, {alice, bob})
|
self.assertEqual(users, {alice, bob})
|
||||||
|
self.assertEqual(in_public, {(alice, public), (bob, public), (alice, public2)})
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(in_public), {(alice, public), (bob, public), (alice, public2)}
|
in_private,
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
self.user_dir_helper._compress_shared(in_private),
|
|
||||||
{(alice, bob, private), (bob, alice, private)},
|
{(alice, bob, private), (bob, alice, private)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -209,6 +205,88 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
|
in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
|
||||||
self.assertEqual(set(in_public), {(user1, room), (user2, room)})
|
self.assertEqual(set(in_public), {(user1, room), (user2, room)})
|
||||||
|
|
||||||
|
def test_excludes_users_when_making_room_public(self) -> None:
|
||||||
|
# Create a regular user and a support user.
|
||||||
|
alice = self.register_user("alice", "pass")
|
||||||
|
alice_token = self.login(alice, "pass")
|
||||||
|
support = "@support1:test"
|
||||||
|
self.get_success(
|
||||||
|
self.store.register_user(
|
||||||
|
user_id=support, password_hash=None, user_type=UserTypes.SUPPORT
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make a public and private room containing Alice and the support user
|
||||||
|
public, initially_private = self._create_rooms_and_inject_memberships(
|
||||||
|
alice, alice_token, support
|
||||||
|
)
|
||||||
|
self._check_only_one_user_in_directory(alice, public)
|
||||||
|
|
||||||
|
# Alice makes the private room public.
|
||||||
|
self.helper.send_state(
|
||||||
|
initially_private,
|
||||||
|
"m.room.join_rules",
|
||||||
|
{"join_rule": "public"},
|
||||||
|
tok=alice_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
users, in_public, in_private = self.get_success(
|
||||||
|
self.user_dir_helper.get_tables()
|
||||||
|
)
|
||||||
|
self.assertEqual(users, {alice})
|
||||||
|
self.assertEqual(in_public, {(alice, public), (alice, initially_private)})
|
||||||
|
self.assertEqual(in_private, set())
|
||||||
|
|
||||||
|
def test_switching_from_private_to_public_to_private(self) -> None:
|
||||||
|
"""Check we update the room sharing tables when switching a room
|
||||||
|
from private to public, then back again to private."""
|
||||||
|
# Alice and Bob share a private room.
|
||||||
|
alice = self.register_user("alice", "pass")
|
||||||
|
alice_token = self.login(alice, "pass")
|
||||||
|
bob = self.register_user("bob", "pass")
|
||||||
|
bob_token = self.login(bob, "pass")
|
||||||
|
room = self.helper.create_room_as(alice, is_public=False, tok=alice_token)
|
||||||
|
self.helper.invite(room, alice, bob, tok=alice_token)
|
||||||
|
self.helper.join(room, bob, tok=bob_token)
|
||||||
|
|
||||||
|
# The user directory should reflect this.
|
||||||
|
def check_user_dir_for_private_room() -> None:
|
||||||
|
users, in_public, in_private = self.get_success(
|
||||||
|
self.user_dir_helper.get_tables()
|
||||||
|
)
|
||||||
|
self.assertEqual(users, {alice, bob})
|
||||||
|
self.assertEqual(in_public, set())
|
||||||
|
self.assertEqual(in_private, {(alice, bob, room), (bob, alice, room)})
|
||||||
|
|
||||||
|
check_user_dir_for_private_room()
|
||||||
|
|
||||||
|
# Alice makes the room public.
|
||||||
|
self.helper.send_state(
|
||||||
|
room,
|
||||||
|
"m.room.join_rules",
|
||||||
|
{"join_rule": "public"},
|
||||||
|
tok=alice_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The user directory should be updated accordingly
|
||||||
|
users, in_public, in_private = self.get_success(
|
||||||
|
self.user_dir_helper.get_tables()
|
||||||
|
)
|
||||||
|
self.assertEqual(users, {alice, bob})
|
||||||
|
self.assertEqual(in_public, {(alice, room), (bob, room)})
|
||||||
|
self.assertEqual(in_private, set())
|
||||||
|
|
||||||
|
# Alice makes the room private.
|
||||||
|
self.helper.send_state(
|
||||||
|
room,
|
||||||
|
"m.room.join_rules",
|
||||||
|
{"join_rule": "invite"},
|
||||||
|
tok=alice_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The user directory should be updated accordingly
|
||||||
|
check_user_dir_for_private_room()
|
||||||
|
|
||||||
def _create_rooms_and_inject_memberships(
|
def _create_rooms_and_inject_memberships(
|
||||||
self, creator: str, token: str, joiner: str
|
self, creator: str, token: str, joiner: str
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
|
@ -232,15 +310,18 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
return public_room, private_room
|
return public_room, private_room
|
||||||
|
|
||||||
def _check_only_one_user_in_directory(self, user: str, public: str) -> None:
|
def _check_only_one_user_in_directory(self, user: str, public: str) -> None:
|
||||||
users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
|
"""Check that the user directory DB tables show that:
|
||||||
in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
|
|
||||||
in_private = self.get_success(
|
|
||||||
self.user_dir_helper.get_users_who_share_private_rooms()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
- only one user is in the user directory
|
||||||
|
- they belong to exactly one public room
|
||||||
|
- they don't share a private room with anyone.
|
||||||
|
"""
|
||||||
|
users, in_public, in_private = self.get_success(
|
||||||
|
self.user_dir_helper.get_tables()
|
||||||
|
)
|
||||||
self.assertEqual(users, {user})
|
self.assertEqual(users, {user})
|
||||||
self.assertEqual(set(in_public), {(user, public)})
|
self.assertEqual(in_public, {(user, public)})
|
||||||
self.assertEqual(in_private, [])
|
self.assertEqual(in_private, set())
|
||||||
|
|
||||||
def test_handle_local_profile_change_with_support_user(self) -> None:
|
def test_handle_local_profile_change_with_support_user(self) -> None:
|
||||||
support_user_id = "@support:test"
|
support_user_id = "@support:test"
|
||||||
|
@ -581,11 +662,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
self.user_dir_helper.get_users_in_public_rooms()
|
self.user_dir_helper.get_users_in_public_rooms()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
|
||||||
self.user_dir_helper._compress_shared(shares_private),
|
self.assertEqual(public_users, set())
|
||||||
{(u1, u2, room), (u2, u1, room)},
|
|
||||||
)
|
|
||||||
self.assertEqual(public_users, [])
|
|
||||||
|
|
||||||
# We get one search result when searching for user2 by user1.
|
# We get one search result when searching for user2 by user1.
|
||||||
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
||||||
|
@ -610,8 +688,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
self.user_dir_helper.get_users_in_public_rooms()
|
self.user_dir_helper.get_users_in_public_rooms()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
|
self.assertEqual(shares_private, set())
|
||||||
self.assertEqual(public_users, [])
|
self.assertEqual(public_users, set())
|
||||||
|
|
||||||
# User1 now gets no search results for any of the other users.
|
# User1 now gets no search results for any of the other users.
|
||||||
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
||||||
|
@ -645,11 +723,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
self.user_dir_helper.get_users_in_public_rooms()
|
self.user_dir_helper.get_users_in_public_rooms()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
|
||||||
self.user_dir_helper._compress_shared(shares_private),
|
self.assertEqual(public_users, set())
|
||||||
{(u1, u2, room), (u2, u1, room)},
|
|
||||||
)
|
|
||||||
self.assertEqual(public_users, [])
|
|
||||||
|
|
||||||
# We get one search result when searching for user2 by user1.
|
# We get one search result when searching for user2 by user1.
|
||||||
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
||||||
|
@ -704,11 +779,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
self.user_dir_helper.get_users_in_public_rooms()
|
self.user_dir_helper.get_users_in_public_rooms()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
|
||||||
self.user_dir_helper._compress_shared(shares_private),
|
self.assertEqual(public_users, set())
|
||||||
{(u1, u2, room), (u2, u1, room)},
|
|
||||||
)
|
|
||||||
self.assertEqual(public_users, [])
|
|
||||||
|
|
||||||
# Configure a spam checker.
|
# Configure a spam checker.
|
||||||
spam_checker = self.hs.get_spam_checker()
|
spam_checker = self.hs.get_spam_checker()
|
||||||
|
@ -740,8 +812,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# No users share rooms
|
# No users share rooms
|
||||||
self.assertEqual(public_users, [])
|
self.assertEqual(public_users, set())
|
||||||
self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
|
self.assertEqual(shares_private, set())
|
||||||
|
|
||||||
# Despite not sharing a room, search_all_users means we get a search
|
# Despite not sharing a room, search_all_users means we get a search
|
||||||
# result.
|
# result.
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict, List, Set, Tuple
|
from typing import Any, Dict, Set, Tuple
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
@ -42,18 +42,7 @@ class GetUserDirectoryTables:
|
||||||
def __init__(self, store: DataStore):
|
def __init__(self, store: DataStore):
|
||||||
self.store = store
|
self.store = store
|
||||||
|
|
||||||
def _compress_shared(
|
async def get_users_in_public_rooms(self) -> Set[Tuple[str, str]]:
|
||||||
self, shared: List[Dict[str, str]]
|
|
||||||
) -> Set[Tuple[str, str, str]]:
|
|
||||||
"""
|
|
||||||
Compress a list of users who share rooms dicts to a list of tuples.
|
|
||||||
"""
|
|
||||||
r = set()
|
|
||||||
for i in shared:
|
|
||||||
r.add((i["user_id"], i["other_user_id"], i["room_id"]))
|
|
||||||
return r
|
|
||||||
|
|
||||||
async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
|
|
||||||
"""Fetch the entire `users_in_public_rooms` table.
|
"""Fetch the entire `users_in_public_rooms` table.
|
||||||
|
|
||||||
Returns a list of tuples (user_id, room_id) where room_id is public and
|
Returns a list of tuples (user_id, room_id) where room_id is public and
|
||||||
|
@ -63,24 +52,27 @@ class GetUserDirectoryTables:
|
||||||
"users_in_public_rooms", None, ("user_id", "room_id")
|
"users_in_public_rooms", None, ("user_id", "room_id")
|
||||||
)
|
)
|
||||||
|
|
||||||
retval = []
|
retval = set()
|
||||||
for i in r:
|
for i in r:
|
||||||
retval.append((i["user_id"], i["room_id"]))
|
retval.add((i["user_id"], i["room_id"]))
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]:
|
async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
|
||||||
"""Fetch the entire `users_who_share_private_rooms` table.
|
"""Fetch the entire `users_who_share_private_rooms` table.
|
||||||
|
|
||||||
Returns a dict containing "user_id", "other_user_id" and "room_id" keys.
|
Returns a set of tuples (user_id, other_user_id, room_id) corresponding
|
||||||
The dicts can be flattened to Tuples with the `_compress_shared` method.
|
to the rows of `users_who_share_private_rooms`.
|
||||||
(This seems a little awkward---maybe we could clean this up.)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return await self.store.db_pool.simple_select_list(
|
rows = await self.store.db_pool.simple_select_list(
|
||||||
"users_who_share_private_rooms",
|
"users_who_share_private_rooms",
|
||||||
None,
|
None,
|
||||||
["user_id", "other_user_id", "room_id"],
|
["user_id", "other_user_id", "room_id"],
|
||||||
)
|
)
|
||||||
|
rv = set()
|
||||||
|
for row in rows:
|
||||||
|
rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
|
||||||
|
return rv
|
||||||
|
|
||||||
async def get_users_in_user_directory(self) -> Set[str]:
|
async def get_users_in_user_directory(self) -> Set[str]:
|
||||||
"""Fetch the set of users in the `user_directory` table.
|
"""Fetch the set of users in the `user_directory` table.
|
||||||
|
@ -113,6 +105,16 @@ class GetUserDirectoryTables:
|
||||||
for row in rows
|
for row in rows
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def get_tables(
|
||||||
|
self,
|
||||||
|
) -> Tuple[Set[str], Set[Tuple[str, str]], Set[Tuple[str, str, str]]]:
|
||||||
|
"""Multiple tests want to inspect these tables, so expose them together."""
|
||||||
|
return (
|
||||||
|
await self.get_users_in_user_directory(),
|
||||||
|
await self.get_users_in_public_rooms(),
|
||||||
|
await self.get_users_who_share_private_rooms(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
|
class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
|
||||||
"""Ensure that rebuilding the directory writes the correct data to the DB.
|
"""Ensure that rebuilding the directory writes the correct data to the DB.
|
||||||
|
@ -166,8 +168,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Nothing updated yet
|
# Nothing updated yet
|
||||||
self.assertEqual(shares_private, [])
|
self.assertEqual(shares_private, set())
|
||||||
self.assertEqual(public_users, [])
|
self.assertEqual(public_users, set())
|
||||||
|
|
||||||
# Ugh, have to reset this flag
|
# Ugh, have to reset this flag
|
||||||
self.store.db_pool.updates._all_done = False
|
self.store.db_pool.updates._all_done = False
|
||||||
|
@ -236,24 +238,15 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
|
||||||
# Do the initial population of the user directory via the background update
|
# Do the initial population of the user directory via the background update
|
||||||
self._purge_and_rebuild_user_dir()
|
self._purge_and_rebuild_user_dir()
|
||||||
|
|
||||||
shares_private = self.get_success(
|
users, in_public, in_private = self.get_success(
|
||||||
self.user_dir_helper.get_users_who_share_private_rooms()
|
self.user_dir_helper.get_tables()
|
||||||
)
|
|
||||||
public_users = self.get_success(
|
|
||||||
self.user_dir_helper.get_users_in_public_rooms()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# User 1 and User 2 are in the same public room
|
# User 1 and User 2 are in the same public room
|
||||||
self.assertEqual(set(public_users), {(u1, room), (u2, room)})
|
self.assertEqual(in_public, {(u1, room), (u2, room)})
|
||||||
|
|
||||||
# User 1 and User 3 share private rooms
|
# User 1 and User 3 share private rooms
|
||||||
self.assertEqual(
|
self.assertEqual(in_private, {(u1, u3, private_room), (u3, u1, private_room)})
|
||||||
self.user_dir_helper._compress_shared(shares_private),
|
|
||||||
{(u1, u3, private_room), (u3, u1, private_room)},
|
|
||||||
)
|
|
||||||
|
|
||||||
# All three should have entries in the directory
|
# All three should have entries in the directory
|
||||||
users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
|
|
||||||
self.assertEqual(users, {u1, u2, u3})
|
self.assertEqual(users, {u1, u2, u3})
|
||||||
|
|
||||||
# The next four tests (test_population_excludes_*) all set up
|
# The next four tests (test_population_excludes_*) all set up
|
||||||
|
@ -289,16 +282,12 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
|
||||||
self, normal_user: str, public_room: str, private_room: str
|
self, normal_user: str, public_room: str, private_room: str
|
||||||
) -> None:
|
) -> None:
|
||||||
# After rebuilding the directory, we should only see the normal user.
|
# After rebuilding the directory, we should only see the normal user.
|
||||||
users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
|
users, in_public, in_private = self.get_success(
|
||||||
|
self.user_dir_helper.get_tables()
|
||||||
|
)
|
||||||
self.assertEqual(users, {normal_user})
|
self.assertEqual(users, {normal_user})
|
||||||
in_public_rooms = self.get_success(
|
self.assertEqual(in_public, {(normal_user, public_room)})
|
||||||
self.user_dir_helper.get_users_in_public_rooms()
|
self.assertEqual(in_private, set())
|
||||||
)
|
|
||||||
self.assertEqual(set(in_public_rooms), {(normal_user, public_room)})
|
|
||||||
in_private_rooms = self.get_success(
|
|
||||||
self.user_dir_helper.get_users_who_share_private_rooms()
|
|
||||||
)
|
|
||||||
self.assertEqual(in_private_rooms, [])
|
|
||||||
|
|
||||||
def test_population_excludes_support_user(self) -> None:
|
def test_population_excludes_support_user(self) -> None:
|
||||||
# Create a normal and support user.
|
# Create a normal and support user.
|
||||||
|
|
Loading…
Reference in New Issue