Merge branch 'develop' of github.com:matrix-org/synapse into erikj/reindex_state_groups

This commit is contained in:
Erik Johnston 2016-09-08 16:52:09 +01:00
commit 5c688739d6
3 changed files with 63 additions and 24 deletions

View File

@ -443,6 +443,16 @@ class RoomListHandler(BaseHandler):
self.remote_list_request_cache.set((), deferred) self.remote_list_request_cache.set((), deferred)
self.remote_list_cache = yield deferred self.remote_list_cache = yield deferred
@defer.inlineCallbacks
def get_remote_public_room_list(self, server_name):
res = yield self.hs.get_replication_layer().get_public_rooms(
[server_name]
)
if server_name not in res:
raise SynapseError(404, "Server not found")
defer.returnValue(res[server_name])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_aggregated_public_room_list(self): def get_aggregated_public_room_list(self):
""" """

View File

@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request, parse_string
import logging import logging
import urllib import urllib
@ -295,14 +295,25 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
server = parse_string(request, "server", default=None)
try: try:
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
except AuthError: except AuthError as e:
# This endpoint isn't authed, but its useful to know who's hitting # We allow people to not be authed if they're just looking at our
# it if they *do* supply an access token # room list, but require auth when we proxy the request.
# In both cases we call the auth function, as that has the side
# effect of logging who issued this request if an access token was
# provided.
if server:
raise e
else:
pass pass
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server:
data = yield handler.get_remote_public_room_list(server)
else:
data = yield handler.get_aggregated_public_room_list() data = yield handler.get_aggregated_public_room_list()
defer.returnValue((200, data)) defer.returnValue((200, data))

View File

@ -130,11 +130,26 @@ class DeviceInboxStore(SQLBaseStore):
def _add_messages_to_local_device_inbox_txn(self, txn, stream_id, def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
messages_by_user_then_device): messages_by_user_then_device):
local_users_and_devices = set() local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items(): for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
devices = messages_by_device.keys() devices = messages_by_device.keys()
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = ( sql = (
"SELECT user_id, device_id FROM devices" "SELECT device_id FROM devices"
" WHERE user_id = ?"
)
txn.execute(sql, (user_id,))
message_json = ujson.dumps(messages_by_device["*"])
for row in txn.fetchall():
# Add the message for all devices for this user on this
# server.
device = row[0]
messages_json_for_user[device] = message_json
else:
sql = (
"SELECT device_id FROM devices"
" WHERE user_id = ? AND device_id IN (" " WHERE user_id = ? AND device_id IN ("
+ ",".join("?" * len(devices)) + ",".join("?" * len(devices))
+ ")" + ")"
@ -142,7 +157,14 @@ class DeviceInboxStore(SQLBaseStore):
# TODO: Maybe this needs to be done in batches if there are # TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user. # too many local devices for a given user.
txn.execute(sql, [user_id] + devices) txn.execute(sql, [user_id] + devices)
local_users_and_devices.update(map(tuple, txn.fetchall())) for row in txn.fetchall():
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
message_json = ujson.dumps(messages_by_device[device])
messages_json_for_user[device] = message_json
local_by_user_then_device[user_id] = messages_json_for_user
sql = ( sql = (
"INSERT INTO device_inbox" "INSERT INTO device_inbox"
@ -150,12 +172,8 @@ class DeviceInboxStore(SQLBaseStore):
" VALUES (?,?,?,?)" " VALUES (?,?,?,?)"
) )
rows = [] rows = []
for user_id, messages_by_device in messages_by_user_then_device.items(): for user_id, messages_by_device in local_by_user_then_device.items():
for device_id, message in messages_by_device.items(): for device_id, message_json in messages_by_device.items():
message_json = ujson.dumps(message)
# Only insert into the local inbox if the device exists on
# this server
if (user_id, device_id) in local_users_and_devices:
rows.append((user_id, device_id, stream_id, message_json)) rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows) txn.executemany(sql, rows)