diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 5cb26ad6db..befeb55b25 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -64,7 +64,11 @@ class SQLBaseStore(object): def interaction(txn): cursor = txn.execute(query, args) - return decoder(cursor) + if decoder: + return decoder(cursor) + else: + return cursor + return self._db_pool.runInteraction(interaction) def _execut_query(self, query, *args): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index ef73be4af4..60296380e6 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -31,6 +31,38 @@ logger = logging.getLogger(__name__) class RoomMemberStore(SQLBaseStore): + @defer.inlineCallbacks + def _store_room_member(self, event): + """Store a room member in the database. + """ + domain = self.hs.parse_userid(event.target_user_id).domain + + yield self._simple_insert( + "room_memberships", + { + "event_id": event.event_id, + "user_id": event.target_user_id, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + } + ) + + # Update room hosts table + if event.membership == Membership.JOIN: + sql = ( + "INSERT OR IGNORE INTO room_hosts (room_id, host) " + "VALUES (?, ?)" + ) + yield self._execute(None, sql, room_id, domain) + else: + sql = ( + "DELETE FROM room_hosts WHERE room_id = ? AND host = ?" + ) + + yield self._execute(None, sql, room_id, domain) + + def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. @@ -38,36 +70,13 @@ class RoomMemberStore(SQLBaseStore): user_id (str): The member's user ID. room_id (str): The room the member is in. Returns: - namedtuple: The room member from the database, or None if this - member does not exist. + Deferred: Results in a MembershipEvent or None. """ - query = RoomMemberTable.select_statement( - "room_id = ? AND user_id = ? ORDER BY id DESC LIMIT 1") - return self._execute( - RoomMemberTable.decode_single_result, - query, room_id, user_id, + return self._get_members_by_dict( + room_id=room_id, + user_id=user_id ) - def store_room_member(self, user_id, sender, room_id, membership, content): - """Store a room member in the database. - - Args: - user_id (str): The member's user ID. - room_id (str): The room in relation to the member. - membership (synapse.api.constants.Membership): The new membership - state. - content (dict): The content of the membership (JSON). - """ - content_json = json.dumps(content) - return self._simple_insert(RoomMemberTable.table_name, dict( - user_id=user_id, - sender=sender, - room_id=room_id, - membership=membership, - content=content_json, - )) - - @defer.inlineCallbacks def get_room_members(self, room_id, membership=None): """Retrieve the current room member list for a room. @@ -79,17 +88,12 @@ class RoomMemberStore(SQLBaseStore): Returns: list of namedtuples representing the members in this room. """ - query = RoomMemberTable.select_statement( - "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name - + " WHERE room_id = ? GROUP BY user_id)" - ) - res = yield self._execute( - RoomMemberTable.decode_results, query, room_id, - ) - # strip memberships which don't match + + where = {"room_id": room_id} if membership: - res = [entry for entry in res if entry.membership == membership] - defer.returnValue(res) + where["membership"] = membership + + return self._get_members_by_dict(**membership) def get_rooms_for_user_where_membership_is(self, user_id, membership_list): """ Get all the rooms for this user where the membership for this user @@ -106,67 +110,37 @@ class RoomMemberStore(SQLBaseStore): return defer.succeed(None) args = [user_id] - membership_placeholder = ["membership=?"] * len(membership_list) - where_membership = "(" + " OR ".join(membership_placeholder) + ")" - for membership in membership_list: - args.append(membership) + args.extend(membership_list) - query = ("SELECT room_id, membership FROM room_memberships" - + " WHERE user_id=? AND " + where_membership - + " GROUP BY room_id ORDER BY id DESC") - return self._execute( - self.cursor_to_dict, query, *args + where_clause "user_id = ? AND (%s)" % ( + " OR ".join(["membership = ?" for _ in membership_list]), ) + return self._get_members_query(where_clause, args) + + def get_joined_hosts_for_room(self, room_id): + return self._simple_select_onecol( + "room_hosts", + {"room_id": room_id}, + "host" + ) + + def _get_members_by_dict(self, where_dict): + clause = " AND ".join("%s = ?" % k for k in where.keys()) + vals = where.values() + return self._get_members_query(clause, vals) + @defer.inlineCallbacks - def get_joined_hosts_for_room(self, room_id): - query = RoomMemberTable.select_statement( - "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name - + " WHERE room_id = ? GROUP BY user_id)" - ) + def _get_members_query(self, where_clause, where_values): + sql = ( + "SELECT e.* FROM events as e " + "INNER JOIN room_memberships as m " + "ON e.event_id = m.event_id " + "INNER JOIN current_state as c " + "ON m.event_id = c.event_id " + "WHERE %s " + ) % (where_clause,) - res = yield self._execute( - RoomMemberTable.decode_results, query, room_id, - ) - - def host_from_user_id_string(user_id): - domain = UserID.from_string(entry.user_id, self.hs).domain - return domain - - # strip memberships which don't match - hosts = [ - host_from_user_id_string(entry.user_id) - for entry in res - if entry.membership == Membership.JOIN - ] - - logger.debug("Returning hosts: %s from results: %s", hosts, res) - - defer.returnValue(hosts) - - def get_max_room_member_id(self): - return self._simple_max_id(RoomMemberTable.table_name) - - -class RoomMemberTable(Table): - table_name = "room_memberships" - - fields = [ - "id", - "user_id", - "sender", - "room_id", - "membership", - "content" - ] - - class EntryType(collections.namedtuple("RoomMemberEntry", fields)): - - def as_event(self, event_factory): - return event_factory.create_event( - etype=RoomMemberEvent.TYPE, - room_id=self.room_id, - target_user_id=self.user_id, - user_id=self.sender, - content=json.loads(self.content), - ) + rows = yield self._execute_query(sql, where_values) + results = [self._parse_event_from_row(r) for r in rows] + defer.returnValue(results) diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql index 37b7c6c74f..7f564c8540 100644 --- a/synapse/storage/schema/im.sql +++ b/synapse/storage/schema/im.sql @@ -17,7 +17,6 @@ CREATE TABLE IF NOT EXISTS events( ordering INTEGER PRIMARY KEY AUTOINCREMENT, event_id TEXT NOT NULL, type TEXT NOT NULL, --- sender TEXT, room_id TEXT, content TEXT, unrecognized_keys TEXT @@ -57,3 +56,8 @@ CREATE TABLE IF NOT EXISTS rooms( is_public INTEGER, creator TEXT ); + +CREATE TABLE IF NOT EXISTS room_hosts( + room_id TEXT NOT NULL, + host TEXT NOT NULL +);