Merge pull request #1009 from matrix-org/erikj/event_split

Split out /events to synchrotron
This commit is contained in:
Erik Johnston 2016-08-15 15:39:05 +01:00 committed by GitHub
commit 75299af4fc
24 changed files with 367 additions and 92 deletions

View File

@ -80,11 +80,6 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__ DataStore.get_profile_displayname.__func__
) )
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = ( who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"] RoomMemberStore.__dict__["who_forgot_in_room"]
) )
@ -168,7 +163,6 @@ class PusherServer(HomeServer):
store = self.get_datastore() store = self.get_datastore()
replication_url = self.config.worker_replication_url replication_url = self.config.worker_replication_url
pusher_pool = self.get_pusherpool() pusher_pool = self.get_pusherpool()
clock = self.get_clock()
def stop_pusher(user_id, app_id, pushkey): def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey) key = "%s:%s" % (app_id, pushkey)
@ -220,21 +214,11 @@ class PusherServer(HomeServer):
min_stream_id, max_stream_id, affected_room_ids min_stream_id, max_stream_id, affected_room_ids
) )
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
next_expire_broken_caches_ms = 0
while True: while True:
try: try:
args = store.stream_positions() args = store.stream_positions()
args["timeout"] = 30000 args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args) result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result) yield store.process_replication(result)
poke_pushers(result) poke_pushers(result)
except: except:

View File

@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v1 import events
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@ -74,11 +75,6 @@ class SynchrotronSlavedStore(
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
): ):
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = ( who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"] RoomMemberStore.__dict__["who_forgot_in_room"]
) )
@ -89,17 +85,23 @@ class SynchrotronSlavedStore(
get_presence_list_accepted = PresenceStore.__dict__[ get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted" "get_presence_list_accepted"
] ]
get_presence_list_observers_accepted = PresenceStore.__dict__[
"get_presence_list_observers_accepted"
]
UPDATE_SYNCING_USERS_MS = 10 * 1000 UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object): class SynchrotronPresence(object):
def __init__(self, hs): def __init__(self, hs):
self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.user_to_num_current_syncs = {} self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
active_presence = self.store.take_presence_startup_info() active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = { self.user_to_current_state = {
@ -124,6 +126,8 @@ class SynchrotronPresence(object):
pass pass
get_states = PresenceHandler.get_states.__func__ get_states = PresenceHandler.get_states.__func__
get_state = PresenceHandler.get_state.__func__
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks @defer.inlineCallbacks
@ -194,19 +198,39 @@ class SynchrotronPresence(object):
self._need_to_send_sync = False self._need_to_send_sync = False
yield self._send_syncing_users_now() yield self._send_syncing_users_now()
@defer.inlineCallbacks
def notify_from_replication(self, states, stream_id):
parties = yield self._get_interested_parties(
states, calculate_remote_hosts=False
)
room_ids_to_states, users_to_states, _ = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=users_to_states.keys()
)
@defer.inlineCallbacks
def process_replication(self, result): def process_replication(self, result):
stream = result.get("presence", {"rows": []}) stream = result.get("presence", {"rows": []})
states = []
for row in stream["rows"]: for row in stream["rows"]:
( (
position, user_id, state, last_active_ts, position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg, last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active currently_active
) = row ) = row
self.user_to_current_state[user_id] = UserPresenceState( state = UserPresenceState(
user_id, state, last_active_ts, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg, last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active currently_active
) )
self.user_to_current_state[user_id] = state
states.append(state)
if states and "position" in stream:
stream_id = int(stream["position"])
yield self.notify_from_replication(states, stream_id)
class SynchrotronTyping(object): class SynchrotronTyping(object):
@ -266,10 +290,12 @@ class SynchrotronServer(HomeServer):
elif name == "client": elif name == "client":
resource = JsonResource(self, canonical_json=False) resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource) sync.register_servlets(self, resource)
events.register_servlets(self, resource)
resources.update({ resources.update({
"/_matrix/client/r0": resource, "/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource, "/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource, "/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
@ -307,15 +333,10 @@ class SynchrotronServer(HomeServer):
http_client = self.get_simple_http_client() http_client = self.get_simple_http_client()
store = self.get_datastore() store = self.get_datastore()
replication_url = self.config.worker_replication_url replication_url = self.config.worker_replication_url
clock = self.get_clock()
notifier = self.get_notifier() notifier = self.get_notifier()
presence_handler = self.get_presence_handler() presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler() typing_handler = self.get_typing_handler()
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
store.get_presence_list_accepted.invalidate_all()
def notify_from_stream( def notify_from_stream(
result, stream_name, stream_key, room=None, user=None result, stream_name, stream_key, room=None, user=None
): ):
@ -377,22 +398,15 @@ class SynchrotronServer(HomeServer):
result, "typing", "typing_key", room="room_id" result, "typing", "typing_key", room="room_id"
) )
next_expire_broken_caches_ms = 0
while True: while True:
try: try:
args = store.stream_positions() args = store.stream_positions()
args.update(typing_handler.stream_positions()) args.update(typing_handler.stream_positions())
args["timeout"] = 30000 args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args) result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result) yield store.process_replication(result)
typing_handler.process_replication(result) typing_handler.process_replication(result)
presence_handler.process_replication(result) yield presence_handler.process_replication(result)
notify(result) notify(result)
except: except:
logger.exception("Error replicating from %r", replication_url) logger.exception("Error replicating from %r", replication_url)

View File

@ -19,7 +19,6 @@ from .room import (
) )
from .room_member import RoomMemberHandler from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler from .federation import FederationHandler
from .profile import ProfileHandler from .profile import ProfileHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
@ -53,8 +52,6 @@ class Handlers(object):
self.message_handler = MessageHandler(hs) self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs) self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs) self.room_member_handler = RoomMemberHandler(hs)
self.event_stream_handler = EventStreamHandler(hs)
self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs) self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs) self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)

View File

@ -503,7 +503,7 @@ class PresenceHandler(object):
defer.returnValue(states) defer.returnValue(states)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_interested_parties(self, states): def _get_interested_parties(self, states, calculate_remote_hosts=True):
"""Given a list of states return which entities (rooms, users, servers) """Given a list of states return which entities (rooms, users, servers)
are interested in the given states. are interested in the given states.
@ -526,6 +526,7 @@ class PresenceHandler(object):
users_to_states.setdefault(state.user_id, []).append(state) users_to_states.setdefault(state.user_id, []).append(state)
hosts_to_states = {} hosts_to_states = {}
if calculate_remote_hosts:
for room_id, states in room_ids_to_states.items(): for room_id, states in room_ids_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states) local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
if not local_states: if not local_states:
@ -565,6 +566,16 @@ class PresenceHandler(object):
self._push_to_remotes(hosts_to_states) self._push_to_remotes(hosts_to_states)
@defer.inlineCallbacks
def notify_for_states(self, state, stream_id):
parties = yield self._get_interested_parties([state])
room_ids_to_states, users_to_states, hosts_to_states = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=[UserID.from_string(u) for u in users_to_states.keys()]
)
def _push_to_remotes(self, hosts_to_states): def _push_to_remotes(self, hosts_to_states):
"""Sends state updates to remote servers. """Sends state updates to remote servers.

View File

@ -41,6 +41,7 @@ STREAM_NAMES = (
("push_rules",), ("push_rules",),
("pushers",), ("pushers",),
("state",), ("state",),
("caches",),
) )
@ -70,6 +71,7 @@ class ReplicationResource(Resource):
* "backfill": Old events that have been backfilled from other servers. * "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules. * "push_rules": Per user changes to push rules.
* "pushers": Per user changes to their pushers. * "pushers": Per user changes to their pushers.
* "caches": Cache invalidations.
The API takes two additional query parameters: The API takes two additional query parameters:
@ -129,6 +131,7 @@ class ReplicationResource(Resource):
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token() pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token() state_token = self.store.get_state_stream_token()
caches_token = self.store.get_cache_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@ -140,6 +143,7 @@ class ReplicationResource(Resource):
push_rules_token, push_rules_token,
pushers_token, pushers_token,
state_token, state_token,
caches_token,
)) ))
@request_handler() @request_handler()
@ -188,6 +192,7 @@ class ReplicationResource(Resource):
yield self.push_rules(writer, current_token, limit, request_streams) yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams) yield self.state(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total) logger.info("Replicated %d rows", writer.total)
@ -379,6 +384,20 @@ class ReplicationResource(Resource):
"position", "type", "state_key", "event_id" "position", "type", "state_key", "event_id"
)) ))
@defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams):
current_position = current_token.caches
caches = request_streams.get("caches")
if caches is not None:
updated_caches = yield self.store.get_all_updated_caches(
caches, current_position, limit
)
writer.write_header_and_rows("caches", updated_caches, (
"position", "cache_func", "keys", "invalidation_ts"
))
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
@ -407,7 +426,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state" "push_rules", "pushers", "state", "caches",
))): ))):
__slots__ = [] __slots__ = []

View File

@ -14,15 +14,43 @@
# limitations under the License. # limitations under the License.
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer from twisted.internet import defer
from ._slaved_id_tracker import SlavedIdTracker
import logging
logger = logging.getLogger(__name__)
class BaseSlavedStore(SQLBaseStore): class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs) super(BaseSlavedStore, self).__init__(hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id",
)
else:
self._cache_id_gen = None
def stream_positions(self): def stream_positions(self):
return {} pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def process_replication(self, result): def process_replication(self, result):
stream = result.get("caches")
if stream:
for row in stream["rows"]:
(
position, cache_func, keys, invalidation_ts,
) = row
try:
getattr(self, cache_func).invalidate(tuple(keys))
except AttributeError:
logger.warn("Got unexpected cache_func: %r", cache_func)
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None) return defer.succeed(None)

View File

@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore
class DirectoryStore(BaseSlavedStore): class DirectoryStore(BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[ get_aliases_for_room = DirectoryStore.__dict__[
"get_aliases_for_room" "get_aliases_for_room"
].orig ]

View File

@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class WhoisRestServlet(ClientV1RestServlet): class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(WhoisRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
) )
def __init__(self, hs):
super(PurgeHistoryRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id): def on_POST(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)

View File

@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet):
hs (synapse.server.HomeServer): hs (synapse.server.HomeServer):
""" """
self.hs = hs self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory() self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_v1auth() self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore() self.txns = HttpTransactionStore()

View File

@ -36,6 +36,10 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(ClientV1RestServlet): class ClientDirectoryServer(ClientV1RestServlet):
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
def __init__(self, hs):
super(ClientDirectoryServer, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_alias): def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ClientDirectoryListServer, self).__init__(hs) super(ClientDirectoryListServer, self).__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):

View File

@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000 DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
super(EventStreamRestServlet, self).__init__(hs)
self.event_stream_handler = hs.get_event_stream_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req( requester = yield self.auth.get_user_by_req(
@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet):
if "room_id" in request.args: if "room_id" in request.args:
room_id = request.args["room_id"][0] room_id = request.args["room_id"][0]
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args: if "timeout" in request.args:
@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
chunk = yield handler.get_stream( chunk = yield self.event_stream_handler.get_stream(
requester.user.to_string(), requester.user.to_string(),
pagin_config, pagin_config,
timeout=timeout, timeout=timeout,
@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(EventRestServlet, self).__init__(hs) super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, event_id): def on_GET(self, request, event_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler event = yield self.event_handler.get_event(requester.user, event_id)
event = yield handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:

View File

@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns
class InitialSyncRestServlet(ClientV1RestServlet): class InitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/initialSync$") PATTERNS = client_path_patterns("/initialSync$")
def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)

View File

@ -56,6 +56,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.cas_enabled = hs.config.cas_enabled self.cas_enabled = hs.config.cas_enabled
self.auth_handler = self.hs.get_auth_handler() self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler() self.device_handler = self.hs.get_device_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request): def on_GET(self, request):
flows = [] flows = []
@ -260,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(SAML2RestServlet, self).__init__(hs) super(SAML2RestServlet, self).__init__(hs)
self.sp_config = hs.config.saml2_config_path self.sp_config = hs.config.saml2_config_path
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -329,6 +331,7 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_service_url = hs.config.cas_service_url self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View File

@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request
class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
class ProfileRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)

View File

@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet):
self.sessions = {} self.sessions = {}
self.enable_registration = hs.config.enable_registration self.enable_registration = hs.config.enable_registration
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request): def on_GET(self, request):
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
super(CreateUserRestServlet, self).__init__(hs) super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):

View File

@ -35,6 +35,10 @@ logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet): class RoomCreateRestServlet(ClientV1RestServlet):
# No PATTERN; we have custom dispatch rules here # No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
PATTERNS = "/createRoom" PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events # TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet): class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /room/$roomid/state/$eventtype # /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet): class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet): class JoinRoomAliasServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /join/$room_identifier[/$txn_id] # /join/$room_identifier[/$txn_id]
@ -296,6 +311,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
def __init__(self, hs):
super(RoomMemberListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
@ -322,6 +341,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
class RoomMessageListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
def __init__(self, hs):
super(RoomMessageListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -351,6 +374,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
class RoomStateRestServlet(ClientV1RestServlet): class RoomStateRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
def __init__(self, hs):
super(RoomStateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -368,6 +395,10 @@ class RoomStateRestServlet(ClientV1RestServlet):
class RoomInitialSyncRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@ -388,6 +419,7 @@ class RoomEventContext(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomEventContext, self).__init__(hs) super(RoomEventContext, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
@ -424,6 +456,10 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@ -462,6 +498,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet): class RoomMembershipRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
@ -542,6 +582,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@ -624,6 +668,10 @@ class SearchRestServlet(ClientV1RestServlet):
"/search$" "/search$"
) )
def __init__(self, hs):
super(SearchRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)

View File

@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room import RoomListHandler from synapse.handlers.room import RoomListHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier from synapse.notifier import Notifier
@ -94,6 +95,8 @@ class HomeServer(object):
'auth_handler', 'auth_handler',
'device_handler', 'device_handler',
'e2e_keys_handler', 'e2e_keys_handler',
'event_handler',
'event_stream_handler',
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
@ -214,6 +217,12 @@ class HomeServer(object):
def build_application_service_handler(self): def build_application_service_handler(self):
return ApplicationServicesHandler(self) return ApplicationServicesHandler(self)
def build_event_handler(self):
return EventHandler(self)
def build_event_stream_handler(self):
return EventStreamHandler(self)
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)

View File

@ -50,6 +50,7 @@ from .openid import OpenIdStore
from .client_ips import ClientIpStore from .client_ips import ClientIpStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from .engines import PostgresEngine
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -123,6 +124,13 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")], extra_tables=[("deleted_pushers", "stream_id")],
) )
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
db_conn, "cache_invalidation_stream", "stream_id",
)
else:
self._cache_id_gen = None
events_max = self._stream_id_gen.get_current_token() events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict( event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events", db_conn, "events",

View File

@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.storage.engines import PostgresEngine
import synapse.metrics import synapse.metrics
@ -305,11 +306,12 @@ class SQLBaseStore(object):
func, *args, **kwargs func, *args, **kwargs
) )
try:
with PreserveLoggingContext(): with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection( result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
finally:
for after_callback, after_args in after_callbacks: for after_callback, after_args in after_callbacks:
after_callback(*after_args) after_callback(*after_args)
defer.returnValue(result) defer.returnValue(result)
@ -860,6 +862,58 @@ class SQLBaseStore(object):
return cache, min_val return cache, min_val
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
This should only be used to invalidate caches where slaves won't
otherwise know from other replication streams that the cache should
be invalidated.
"""
txn.call_after(cache_func.invalidate, keys)
if isinstance(self.database_engine, PostgresEngine):
# get_next() returns a context manager which is designed to wrap
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
txn.call_after(ctx.__exit__, None, None, None)
self._simple_insert_txn(
txn,
table="cache_invalidation_stream",
values={
"stream_id": stream_id,
"cache_func": cache_func.__name__,
"keys": list(keys),
"invalidation_ts": self.clock.time_msec(),
}
)
def get_all_updated_caches(self, last_id, current_id, limit):
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = (
"SELECT stream_id, cache_func, keys, invalidation_ts"
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit,))
return txn.fetchall()
return self.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
else:
return 0
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying """ This exception is used to rollback a transaction without implying

View File

@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore):
Returns: Returns:
Deferred Deferred
""" """
try: def alias_txn(txn):
yield self._simple_insert( self._simple_insert_txn(
txn,
"room_aliases", "room_aliases",
{ {
"room_alias": room_alias.to_string(), "room_alias": room_alias.to_string(),
"room_id": room_id, "room_id": room_id,
"creator": creator, "creator": creator,
}, },
desc="create_room_alias_association", )
self._simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[{
"room_alias": room_alias.to_string(),
"server": server,
} for server in servers],
)
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (room_id,)
)
try:
ret = yield self.runInteraction(
"create_room_alias_association", alias_txn
) )
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
raise SynapseError( raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string() 409, "Room alias %s already exists" % room_alias.to_string()
) )
defer.returnValue(ret)
for server in servers:
# TODO(erikj): Fix this to bulk insert
yield self._simple_insert(
"room_alias_servers",
{
"room_alias": room_alias.to_string(),
"server": server,
},
desc="create_room_alias_association",
)
self.get_aliases_for_room.invalidate((room_id,))
def get_room_alias_creator(self, room_alias): def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 33 SCHEMA_VERSION = 34
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -189,18 +189,30 @@ class PresenceStore(SQLBaseStore):
desc="add_presence_list_pending", desc="add_presence_list_pending",
) )
@defer.inlineCallbacks
def set_presence_list_accepted(self, observer_localpart, observed_userid): def set_presence_list_accepted(self, observer_localpart, observed_userid):
result = yield self._simple_update_one( def update_presence_list_txn(txn):
result = self._simple_update_one_txn(
txn,
table="presence_list", table="presence_list",
keyvalues={"user_id": observer_localpart, keyvalues={
"observed_user_id": observed_userid}, "user_id": observer_localpart,
"observed_user_id": observed_userid
},
updatevalues={"accepted": True}, updatevalues={"accepted": True},
desc="set_presence_list_accepted",
) )
self.get_presence_list_accepted.invalidate((observer_localpart,))
self.get_presence_list_observers_accepted.invalidate((observed_userid,)) self._invalidate_cache_and_stream(
defer.returnValue(result) txn, self.get_presence_list_accepted, (observer_localpart,)
)
self._invalidate_cache_and_stream(
txn, self.get_presence_list_observers_accepted, (observed_userid,)
)
return result
return self.runInteraction(
"set_presence_list_accepted", update_presence_list_txn,
)
def get_presence_list(self, observer_localpart, accepted=None): def get_presence_list(self, observer_localpart, accepted=None):
if accepted: if accepted:

View File

@ -277,7 +277,6 @@ class RoomMemberStore(SQLBaseStore):
user_id, membership_list=[Membership.JOIN], user_id, membership_list=[Membership.JOIN],
) )
@defer.inlineCallbacks
def forget(self, user_id, room_id): def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id.""" """Indicate that user_id wishes to discard history for room_id."""
def f(txn): def f(txn):
@ -292,10 +291,13 @@ class RoomMemberStore(SQLBaseStore):
" room_id = ?" " room_id = ?"
) )
txn.execute(sql, (user_id, room_id)) txn.execute(sql, (user_id, room_id))
yield self.runInteraction("forget_membership", f)
self.was_forgotten_at.invalidate_all() txn.call_after(self.was_forgotten_at.invalidate_all)
self.who_forgot_in_room.invalidate_all() txn.call_after(self.did_forget.invalidate, (user_id, room_id))
self.did_forget.invalidate((user_id, room_id)) self._invalidate_cache_and_stream(
txn, self.who_forgot_in_room, (room_id,)
)
return self.runInteraction("forget_membership", f)
@cachedInlineCallbacks(num_args=2) @cachedInlineCallbacks(num_args=2)
def did_forget(self, user_id, room_id): def did_forget(self, user_id, room_id):

View File

@ -0,0 +1,46 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine
import logging
logger = logging.getLogger(__name__)
# This stream is used to notify replication slaves that some caches have
# been invalidated that they cannot infer from the other streams.
CREATE_TABLE = """
CREATE TABLE cache_invalidation_stream (
stream_id BIGINT,
cache_func TEXT,
keys TEXT[],
invalidation_ts BIGINT
);
CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id);
"""
def run_create(cur, database_engine, *args, **kwargs):
if not isinstance(database_engine, PostgresEngine):
return
for statement in get_statements(CREATE_TABLE.splitlines()):
cur.execute(statement)
def run_upgrade(cur, database_engine, *args, **kwargs):
pass