Merge pull request #789 from matrix-org/markjh/member_cleanup
Cleanup room member handler
This commit is contained in:
commit
1ed33784a6
|
@ -29,6 +29,8 @@ class ReceiptsHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ReceiptsHandler, self).__init__(hs)
|
super(ReceiptsHandler, self).__init__(hs)
|
||||||
|
|
||||||
|
self.server_name = hs.config.server_name
|
||||||
|
self.store = hs.get_datastore()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_replication_layer()
|
||||||
self.federation.register_edu_handler(
|
self.federation.register_edu_handler(
|
||||||
|
@ -131,12 +133,9 @@ class ReceiptsHandler(BaseHandler):
|
||||||
event_ids = receipt["event_ids"]
|
event_ids = receipt["event_ids"]
|
||||||
data = receipt["data"]
|
data = receipt["data"]
|
||||||
|
|
||||||
remotedomains = set()
|
remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||||
|
remotedomains = remotedomains.copy()
|
||||||
rm_handler = self.hs.get_handlers().room_member_handler
|
remotedomains.discard(self.server_name)
|
||||||
yield rm_handler.fetch_room_distributions_into(
|
|
||||||
room_id, localusers=None, remotedomains=remotedomains
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Sending receipt to: %r", remotedomains)
|
logger.debug("Sending receipt to: %r", remotedomains)
|
||||||
|
|
||||||
|
|
|
@ -55,35 +55,6 @@ class RoomMemberHandler(BaseHandler):
|
||||||
self.distributor.declare("user_joined_room")
|
self.distributor.declare("user_joined_room")
|
||||||
self.distributor.declare("user_left_room")
|
self.distributor.declare("user_left_room")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_room_members(self, room_id):
|
|
||||||
users = yield self.store.get_users_in_room(room_id)
|
|
||||||
|
|
||||||
defer.returnValue([UserID.from_string(u) for u in users])
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def fetch_room_distributions_into(self, room_id, localusers=None,
|
|
||||||
remotedomains=None, ignore_user=None):
|
|
||||||
"""Fetch the distribution of a room, adding elements to either
|
|
||||||
'localusers' or 'remotedomains', which should be a set() if supplied.
|
|
||||||
If ignore_user is set, ignore that user.
|
|
||||||
|
|
||||||
This function returns nothing; its result is performed by the
|
|
||||||
side-effect on the two passed sets. This allows easy accumulation of
|
|
||||||
member lists of multiple rooms at once if required.
|
|
||||||
"""
|
|
||||||
members = yield self.get_room_members(room_id)
|
|
||||||
for member in members:
|
|
||||||
if ignore_user is not None and member == ignore_user:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if self.hs.is_mine(member):
|
|
||||||
if localusers is not None:
|
|
||||||
localusers.add(member)
|
|
||||||
else:
|
|
||||||
if remotedomains is not None:
|
|
||||||
remotedomains.add(member.domain)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _local_membership_update(
|
def _local_membership_update(
|
||||||
self, requester, target, room_id, membership,
|
self, requester, target, room_id, membership,
|
||||||
|
@ -426,21 +397,6 @@ class RoomMemberHandler(BaseHandler):
|
||||||
if invite:
|
if invite:
|
||||||
defer.returnValue(UserID.from_string(invite.sender))
|
defer.returnValue(UserID.from_string(invite.sender))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_joined_rooms_for_user(self, user):
|
|
||||||
"""Returns a list of roomids that the user has any of the given
|
|
||||||
membership states in."""
|
|
||||||
|
|
||||||
rooms = yield self.store.get_rooms_for_user(
|
|
||||||
user.to_string(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# For some reason the list of events contains duplicates
|
|
||||||
# TODO(paul): work out why because I really don't think it should
|
|
||||||
room_ids = set(r.room_id for r in rooms)
|
|
||||||
|
|
||||||
defer.returnValue(room_ids)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_3pid_invite(
|
def do_3pid_invite(
|
||||||
self,
|
self,
|
||||||
|
@ -457,8 +413,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
if invitee:
|
if invitee:
|
||||||
handler = self.hs.get_handlers().room_member_handler
|
yield self.update_membership(
|
||||||
yield handler.update_membership(
|
|
||||||
requester,
|
requester,
|
||||||
UserID.from_string(invitee),
|
UserID.from_string(invitee),
|
||||||
room_id,
|
room_id,
|
||||||
|
|
|
@ -485,7 +485,6 @@ class SyncHandler(BaseHandler):
|
||||||
sync_config, now_token, since_token
|
sync_config, now_token, since_token
|
||||||
)
|
)
|
||||||
|
|
||||||
rm_handler = self.hs.get_handlers().room_member_handler
|
|
||||||
app_service = yield self.store.get_app_service_by_user_id(
|
app_service = yield self.store.get_app_service_by_user_id(
|
||||||
sync_config.user.to_string()
|
sync_config.user.to_string()
|
||||||
)
|
)
|
||||||
|
@ -493,9 +492,10 @@ class SyncHandler(BaseHandler):
|
||||||
rooms = yield self.store.get_app_service_rooms(app_service)
|
rooms = yield self.store.get_app_service_rooms(app_service)
|
||||||
joined_room_ids = set(r.room_id for r in rooms)
|
joined_room_ids = set(r.room_id for r in rooms)
|
||||||
else:
|
else:
|
||||||
joined_room_ids = yield rm_handler.get_joined_rooms_for_user(
|
rooms = yield self.store.get_rooms_for_user(
|
||||||
sync_config.user
|
sync_config.user.to_string()
|
||||||
)
|
)
|
||||||
|
joined_room_ids = set(r.room_id for r in rooms)
|
||||||
|
|
||||||
user_id = sync_config.user.to_string()
|
user_id = sync_config.user.to_string()
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,8 @@ class TypingNotificationHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(TypingNotificationHandler, self).__init__(hs)
|
super(TypingNotificationHandler, self).__init__(hs)
|
||||||
|
|
||||||
self.homeserver = hs
|
self.store = hs.get_datastore()
|
||||||
|
self.server_name = hs.config.server_name
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@ -157,23 +158,17 @@ class TypingNotificationHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _push_update(self, room_id, user, typing):
|
def _push_update(self, room_id, user, typing):
|
||||||
localusers = set()
|
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||||
remotedomains = set()
|
|
||||||
|
|
||||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
deferreds = []
|
||||||
yield rm_handler.fetch_room_distributions_into(
|
for domain in domains:
|
||||||
room_id, localusers=localusers, remotedomains=remotedomains
|
if domain == self.server_name:
|
||||||
)
|
|
||||||
|
|
||||||
if localusers:
|
|
||||||
self._push_update_local(
|
self._push_update_local(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
user=user,
|
user=user,
|
||||||
typing=typing
|
typing=typing
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
deferreds = []
|
|
||||||
for domain in remotedomains:
|
|
||||||
deferreds.append(self.federation.send_edu(
|
deferreds.append(self.federation.send_edu(
|
||||||
destination=domain,
|
destination=domain,
|
||||||
edu_type="m.typing",
|
edu_type="m.typing",
|
||||||
|
@ -191,14 +186,9 @@ class TypingNotificationHandler(BaseHandler):
|
||||||
room_id = content["room_id"]
|
room_id = content["room_id"]
|
||||||
user = UserID.from_string(content["user_id"])
|
user = UserID.from_string(content["user_id"])
|
||||||
|
|
||||||
localusers = set()
|
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||||
|
|
||||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
if self.server_name in domains:
|
||||||
yield rm_handler.fetch_room_distributions_into(
|
|
||||||
room_id, localusers=localusers
|
|
||||||
)
|
|
||||||
|
|
||||||
if localusers:
|
|
||||||
self._push_update_local(
|
self._push_update_local(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
user=user,
|
user=user,
|
||||||
|
@ -239,7 +229,6 @@ class TypingNotificationEventSource(object):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self._handler = None
|
self._handler = None
|
||||||
self._room_member_handler = None
|
|
||||||
|
|
||||||
def handler(self):
|
def handler(self):
|
||||||
# Avoid cyclic dependency in handler setup
|
# Avoid cyclic dependency in handler setup
|
||||||
|
@ -247,11 +236,6 @@ class TypingNotificationEventSource(object):
|
||||||
self._handler = self.hs.get_handlers().typing_notification_handler
|
self._handler = self.hs.get_handlers().typing_notification_handler
|
||||||
return self._handler
|
return self._handler
|
||||||
|
|
||||||
def room_member_handler(self):
|
|
||||||
if not self._room_member_handler:
|
|
||||||
self._room_member_handler = self.hs.get_handlers().room_member_handler
|
|
||||||
return self._room_member_handler
|
|
||||||
|
|
||||||
def _make_event_for(self, room_id):
|
def _make_event_for(self, room_id):
|
||||||
typing = self.handler()._room_typing[room_id]
|
typing = self.handler()._room_typing[room_id]
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -137,24 +137,6 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
return [r["user_id"] for r in rows]
|
return [r["user_id"] for r in rows]
|
||||||
return self.runInteraction("get_users_in_room", f)
|
return self.runInteraction("get_users_in_room", f)
|
||||||
|
|
||||||
def get_room_members(self, room_id, membership=None):
|
|
||||||
"""Retrieve the current room member list for a room.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
room_id (str): The room to get the list of members.
|
|
||||||
membership (synapse.api.constants.Membership): The filter to apply
|
|
||||||
to this list, or None to return all members with some state
|
|
||||||
associated with this room.
|
|
||||||
Returns:
|
|
||||||
list of namedtuples representing the members in this room.
|
|
||||||
"""
|
|
||||||
return self.runInteraction(
|
|
||||||
"get_room_members",
|
|
||||||
self._get_members_events_txn,
|
|
||||||
room_id,
|
|
||||||
membership=membership,
|
|
||||||
).addCallback(self._get_events)
|
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_invited_rooms_for_user(self, user_id):
|
def get_invited_rooms_for_user(self, user_id):
|
||||||
""" Get all the rooms the user is invited to
|
""" Get all the rooms the user is invited to
|
||||||
|
|
|
@ -71,6 +71,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
self.auth = Mock(spec=[])
|
self.auth = Mock(spec=[])
|
||||||
|
|
||||||
hs = yield setup_test_homeserver(
|
hs = yield setup_test_homeserver(
|
||||||
|
"test",
|
||||||
auth=self.auth,
|
auth=self.auth,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
datastore=Mock(spec=[
|
datastore=Mock(spec=[
|
||||||
|
@ -110,56 +111,16 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.room_id = "a-room"
|
self.room_id = "a-room"
|
||||||
|
|
||||||
# Mock the RoomMemberHandler
|
|
||||||
hs.handlers.room_member_handler = Mock(spec=[])
|
|
||||||
self.room_member_handler = hs.handlers.room_member_handler
|
|
||||||
|
|
||||||
self.room_members = []
|
self.room_members = []
|
||||||
|
|
||||||
def get_rooms_for_user(user):
|
|
||||||
if user in self.room_members:
|
|
||||||
return defer.succeed([self.room_id])
|
|
||||||
else:
|
|
||||||
return defer.succeed([])
|
|
||||||
self.room_member_handler.get_rooms_for_user = get_rooms_for_user
|
|
||||||
|
|
||||||
def get_room_members(room_id):
|
|
||||||
if room_id == self.room_id:
|
|
||||||
return defer.succeed(self.room_members)
|
|
||||||
else:
|
|
||||||
return defer.succeed([])
|
|
||||||
self.room_member_handler.get_room_members = get_room_members
|
|
||||||
|
|
||||||
def get_joined_rooms_for_user(user):
|
|
||||||
if user in self.room_members:
|
|
||||||
return defer.succeed([self.room_id])
|
|
||||||
else:
|
|
||||||
return defer.succeed([])
|
|
||||||
self.room_member_handler.get_joined_rooms_for_user = get_joined_rooms_for_user
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def fetch_room_distributions_into(
|
|
||||||
room_id, localusers=None, remotedomains=None, ignore_user=None
|
|
||||||
):
|
|
||||||
members = yield get_room_members(room_id)
|
|
||||||
for member in members:
|
|
||||||
if ignore_user is not None and member == ignore_user:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if hs.is_mine(member):
|
|
||||||
if localusers is not None:
|
|
||||||
localusers.add(member)
|
|
||||||
else:
|
|
||||||
if remotedomains is not None:
|
|
||||||
remotedomains.add(member.domain)
|
|
||||||
self.room_member_handler.fetch_room_distributions_into = (
|
|
||||||
fetch_room_distributions_into
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_joined_room(room_id, user_id):
|
def check_joined_room(room_id, user_id):
|
||||||
if user_id not in [u.to_string() for u in self.room_members]:
|
if user_id not in [u.to_string() for u in self.room_members]:
|
||||||
raise AuthError(401, "User is not in the room")
|
raise AuthError(401, "User is not in the room")
|
||||||
|
|
||||||
|
def get_joined_hosts_for_room(room_id):
|
||||||
|
return set(member.domain for member in self.room_members)
|
||||||
|
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
||||||
|
|
||||||
self.auth.check_joined_room = check_joined_room
|
self.auth.check_joined_room = check_joined_room
|
||||||
|
|
||||||
# Some local users to test with
|
# Some local users to test with
|
||||||
|
|
|
@ -70,12 +70,6 @@ class RoomMemberStoreTestCase(unittest.TestCase):
|
||||||
def test_one_member(self):
|
def test_one_member(self):
|
||||||
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
|
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
|
||||||
|
|
||||||
self.assertEquals(
|
|
||||||
[self.u_alice.to_string()],
|
|
||||||
[m.user_id for m in (
|
|
||||||
yield self.store.get_room_members(self.room.to_string())
|
|
||||||
)]
|
|
||||||
)
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
[self.room.to_string()],
|
[self.room.to_string()],
|
||||||
[m.room_id for m in (
|
[m.room_id for m in (
|
||||||
|
@ -85,18 +79,6 @@ class RoomMemberStoreTestCase(unittest.TestCase):
|
||||||
)]
|
)]
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_two_members(self):
|
|
||||||
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
|
|
||||||
yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
|
|
||||||
|
|
||||||
self.assertEquals(
|
|
||||||
{self.u_alice.to_string(), self.u_bob.to_string()},
|
|
||||||
{m.user_id for m in (
|
|
||||||
yield self.store.get_room_members(self.room.to_string())
|
|
||||||
)}
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_room_hosts(self):
|
def test_room_hosts(self):
|
||||||
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
|
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
|
||||||
|
|
|
@ -50,7 +50,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||||
config.enable_registration = True
|
config.enable_registration = True
|
||||||
config.macaroon_secret_key = "not even a little secret"
|
config.macaroon_secret_key = "not even a little secret"
|
||||||
config.expire_access_token = False
|
config.expire_access_token = False
|
||||||
config.server_name = "server.under.test"
|
config.server_name = name
|
||||||
config.trusted_third_party_id_servers = []
|
config.trusted_third_party_id_servers = []
|
||||||
config.room_invite_state_types = []
|
config.room_invite_state_types = []
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue