Merge pull request #789 from matrix-org/markjh/member_cleanup

Cleanup room member handler
This commit is contained in:
Mark Haines 2016-05-17 10:43:19 +01:00
commit 1ed33784a6
8 changed files with 37 additions and 174 deletions

View File

@ -29,6 +29,8 @@ class ReceiptsHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(ReceiptsHandler, self).__init__(hs) super(ReceiptsHandler, self).__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_edu_handler( self.federation.register_edu_handler(
@ -131,12 +133,9 @@ class ReceiptsHandler(BaseHandler):
event_ids = receipt["event_ids"] event_ids = receipt["event_ids"]
data = receipt["data"] data = receipt["data"]
remotedomains = set() remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
remotedomains = remotedomains.copy()
rm_handler = self.hs.get_handlers().room_member_handler remotedomains.discard(self.server_name)
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=None, remotedomains=remotedomains
)
logger.debug("Sending receipt to: %r", remotedomains) logger.debug("Sending receipt to: %r", remotedomains)

View File

@ -55,35 +55,6 @@ class RoomMemberHandler(BaseHandler):
self.distributor.declare("user_joined_room") self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room") self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks @defer.inlineCallbacks
def _local_membership_update( def _local_membership_update(
self, requester, target, room_id, membership, self, requester, target, room_id, membership,
@ -426,21 +397,6 @@ class RoomMemberHandler(BaseHandler):
if invite: if invite:
defer.returnValue(UserID.from_string(invite.sender)) defer.returnValue(UserID.from_string(invite.sender))
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
room_ids = set(r.room_id for r in rooms)
defer.returnValue(room_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_3pid_invite( def do_3pid_invite(
self, self,
@ -457,8 +413,7 @@ class RoomMemberHandler(BaseHandler):
) )
if invitee: if invitee:
handler = self.hs.get_handlers().room_member_handler yield self.update_membership(
yield handler.update_membership(
requester, requester,
UserID.from_string(invitee), UserID.from_string(invitee),
room_id, room_id,

View File

@ -485,7 +485,6 @@ class SyncHandler(BaseHandler):
sync_config, now_token, since_token sync_config, now_token, since_token
) )
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id( app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string() sync_config.user.to_string()
) )
@ -493,9 +492,10 @@ class SyncHandler(BaseHandler):
rooms = yield self.store.get_app_service_rooms(app_service) rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms) joined_room_ids = set(r.room_id for r in rooms)
else: else:
joined_room_ids = yield rm_handler.get_joined_rooms_for_user( rooms = yield self.store.get_rooms_for_user(
sync_config.user sync_config.user.to_string()
) )
joined_room_ids = set(r.room_id for r in rooms)
user_id = sync_config.user.to_string() user_id = sync_config.user.to_string()

View File

@ -39,7 +39,8 @@ class TypingNotificationHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(TypingNotificationHandler, self).__init__(hs) super(TypingNotificationHandler, self).__init__(hs)
self.homeserver = hs self.store = hs.get_datastore()
self.server_name = hs.config.server_name
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -157,23 +158,17 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_update(self, room_id, user, typing): def _push_update(self, room_id, user, typing):
localusers = set() domains = yield self.store.get_joined_hosts_for_room(room_id)
remotedomains = set()
rm_handler = self.homeserver.get_handlers().room_member_handler deferreds = []
yield rm_handler.fetch_room_distributions_into( for domain in domains:
room_id, localusers=localusers, remotedomains=remotedomains if domain == self.server_name:
)
if localusers:
self._push_update_local( self._push_update_local(
room_id=room_id, room_id=room_id,
user=user, user=user,
typing=typing typing=typing
) )
else:
deferreds = []
for domain in remotedomains:
deferreds.append(self.federation.send_edu( deferreds.append(self.federation.send_edu(
destination=domain, destination=domain,
edu_type="m.typing", edu_type="m.typing",
@ -191,14 +186,9 @@ class TypingNotificationHandler(BaseHandler):
room_id = content["room_id"] room_id = content["room_id"]
user = UserID.from_string(content["user_id"]) user = UserID.from_string(content["user_id"])
localusers = set() domains = yield self.store.get_joined_hosts_for_room(room_id)
rm_handler = self.homeserver.get_handlers().room_member_handler if self.server_name in domains:
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=localusers
)
if localusers:
self._push_update_local( self._push_update_local(
room_id=room_id, room_id=room_id,
user=user, user=user,
@ -239,7 +229,6 @@ class TypingNotificationEventSource(object):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._handler = None self._handler = None
self._room_member_handler = None
def handler(self): def handler(self):
# Avoid cyclic dependency in handler setup # Avoid cyclic dependency in handler setup
@ -247,11 +236,6 @@ class TypingNotificationEventSource(object):
self._handler = self.hs.get_handlers().typing_notification_handler self._handler = self.hs.get_handlers().typing_notification_handler
return self._handler return self._handler
def room_member_handler(self):
if not self._room_member_handler:
self._room_member_handler = self.hs.get_handlers().room_member_handler
return self._room_member_handler
def _make_event_for(self, room_id): def _make_event_for(self, room_id):
typing = self.handler()._room_typing[room_id] typing = self.handler()._room_typing[room_id]
return { return {

View File

@ -137,24 +137,6 @@ class RoomMemberStore(SQLBaseStore):
return [r["user_id"] for r in rows] return [r["user_id"] for r in rows]
return self.runInteraction("get_users_in_room", f) return self.runInteraction("get_users_in_room", f)
def get_room_members(self, room_id, membership=None):
"""Retrieve the current room member list for a room.
Args:
room_id (str): The room to get the list of members.
membership (synapse.api.constants.Membership): The filter to apply
to this list, or None to return all members with some state
associated with this room.
Returns:
list of namedtuples representing the members in this room.
"""
return self.runInteraction(
"get_room_members",
self._get_members_events_txn,
room_id,
membership=membership,
).addCallback(self._get_events)
@cached() @cached()
def get_invited_rooms_for_user(self, user_id): def get_invited_rooms_for_user(self, user_id):
""" Get all the rooms the user is invited to """ Get all the rooms the user is invited to

View File

@ -71,6 +71,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.auth = Mock(spec=[]) self.auth = Mock(spec=[])
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
"test",
auth=self.auth, auth=self.auth,
clock=self.clock, clock=self.clock,
datastore=Mock(spec=[ datastore=Mock(spec=[
@ -110,56 +111,16 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.room_id = "a-room" self.room_id = "a-room"
# Mock the RoomMemberHandler
hs.handlers.room_member_handler = Mock(spec=[])
self.room_member_handler = hs.handlers.room_member_handler
self.room_members = [] self.room_members = []
def get_rooms_for_user(user):
if user in self.room_members:
return defer.succeed([self.room_id])
else:
return defer.succeed([])
self.room_member_handler.get_rooms_for_user = get_rooms_for_user
def get_room_members(room_id):
if room_id == self.room_id:
return defer.succeed(self.room_members)
else:
return defer.succeed([])
self.room_member_handler.get_room_members = get_room_members
def get_joined_rooms_for_user(user):
if user in self.room_members:
return defer.succeed([self.room_id])
else:
return defer.succeed([])
self.room_member_handler.get_joined_rooms_for_user = get_joined_rooms_for_user
@defer.inlineCallbacks
def fetch_room_distributions_into(
room_id, localusers=None, remotedomains=None, ignore_user=None
):
members = yield get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
self.room_member_handler.fetch_room_distributions_into = (
fetch_room_distributions_into
)
def check_joined_room(room_id, user_id): def check_joined_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]: if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room") raise AuthError(401, "User is not in the room")
def get_joined_hosts_for_room(room_id):
return set(member.domain for member in self.room_members)
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
self.auth.check_joined_room = check_joined_room self.auth.check_joined_room = check_joined_room
# Some local users to test with # Some local users to test with

View File

@ -70,12 +70,6 @@ class RoomMemberStoreTestCase(unittest.TestCase):
def test_one_member(self): def test_one_member(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN) yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
self.assertEquals(
[self.u_alice.to_string()],
[m.user_id for m in (
yield self.store.get_room_members(self.room.to_string())
)]
)
self.assertEquals( self.assertEquals(
[self.room.to_string()], [self.room.to_string()],
[m.room_id for m in ( [m.room_id for m in (
@ -85,18 +79,6 @@ class RoomMemberStoreTestCase(unittest.TestCase):
)] )]
) )
@defer.inlineCallbacks
def test_two_members(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.assertEquals(
{self.u_alice.to_string(), self.u_bob.to_string()},
{m.user_id for m in (
yield self.store.get_room_members(self.room.to_string())
)}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_room_hosts(self): def test_room_hosts(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN) yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)

View File

@ -50,7 +50,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.enable_registration = True config.enable_registration = True
config.macaroon_secret_key = "not even a little secret" config.macaroon_secret_key = "not even a little secret"
config.expire_access_token = False config.expire_access_token = False
config.server_name = "server.under.test" config.server_name = name
config.trusted_third_party_id_servers = [] config.trusted_third_party_id_servers = []
config.room_invite_state_types = [] config.room_invite_state_types = []