Merge pull request #489 from matrix-org/markjh/replication
Add a /replication API for extracting the updates that happened on synapse.
This commit is contained in:
commit
a612ce6659
|
@ -0,0 +1,67 @@
|
|||
import requests
|
||||
import collections
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
|
||||
Entry = collections.namedtuple("Entry", "name position rows")
|
||||
|
||||
ROW_TYPES = {}
|
||||
|
||||
|
||||
def row_type_for_columns(name, column_names):
|
||||
column_names = tuple(column_names)
|
||||
row_type = ROW_TYPES.get((name, column_names))
|
||||
if row_type is None:
|
||||
row_type = collections.namedtuple(name, column_names)
|
||||
ROW_TYPES[(name, column_names)] = row_type
|
||||
return row_type
|
||||
|
||||
|
||||
def parse_response(content):
|
||||
streams = json.loads(content)
|
||||
result = {}
|
||||
for name, value in streams.items():
|
||||
row_type = row_type_for_columns(name, value["field_names"])
|
||||
position = value["position"]
|
||||
rows = [row_type(*row) for row in value["rows"]]
|
||||
result[name] = Entry(name, position, rows)
|
||||
return result
|
||||
|
||||
|
||||
def replicate(server, streams):
|
||||
return parse_response(requests.get(
|
||||
server + "/_synapse/replication",
|
||||
verify=False,
|
||||
params=streams
|
||||
).content)
|
||||
|
||||
|
||||
def main():
|
||||
server = sys.argv[1]
|
||||
|
||||
streams = None
|
||||
while not streams:
|
||||
try:
|
||||
streams = {
|
||||
row.name: row.position
|
||||
for row in replicate(server, {"streams":"-1"})["streams"].rows
|
||||
}
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
time.sleep(0.1)
|
||||
|
||||
while True:
|
||||
try:
|
||||
results = replicate(server, streams)
|
||||
except:
|
||||
sys.stdout.write("connection_lost("+ repr(streams) + ")\n")
|
||||
break
|
||||
for update in results.values():
|
||||
for row in update.rows:
|
||||
sys.stdout.write(repr(row) + "\n")
|
||||
streams[update.name] = update.position
|
||||
|
||||
|
||||
|
||||
if __name__=='__main__':
|
||||
main()
|
|
@ -63,6 +63,7 @@ from synapse.config.homeserver import HomeServerConfig
|
|||
from synapse.crypto import context_factory
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
|
||||
from synapse.federation.transport.server import TransportLayerServer
|
||||
|
||||
from synapse import events
|
||||
|
@ -169,6 +170,9 @@ class SynapseHomeServer(HomeServer):
|
|||
if name == "metrics" and self.get_config().enable_metrics:
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
|
||||
if name == "replication":
|
||||
resources[REPLICATION_PREFIX] = ReplicationResource(self)
|
||||
|
||||
root_resource = create_resource_tree(resources)
|
||||
if tls:
|
||||
reactor.listenSSL(
|
||||
|
|
|
@ -774,6 +774,25 @@ class PresenceHandler(BaseHandler):
|
|||
|
||||
defer.returnValue(observer_user.to_string() in accepted_observers)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_all_presence_updates(self, last_id, current_id):
|
||||
"""
|
||||
Gets a list of presence update rows from between the given stream ids.
|
||||
Each row has:
|
||||
- stream_id(str)
|
||||
- user_id(str)
|
||||
- state(str)
|
||||
- last_active_ts(int)
|
||||
- last_federation_update_ts(int)
|
||||
- last_user_sync_ts(int)
|
||||
- status_msg(int)
|
||||
- currently_active(int)
|
||||
"""
|
||||
# TODO(markjh): replicate the unpersisted changes.
|
||||
# This could use the in-memory stores for recent changes.
|
||||
rows = yield self.store.get_all_presence_updates(last_id, current_id)
|
||||
defer.returnValue(rows)
|
||||
|
||||
|
||||
def should_notify(old_state, new_state):
|
||||
"""Decides if a presence state change should be sent to interested parties.
|
||||
|
|
|
@ -25,6 +25,7 @@ from synapse.types import UserID
|
|||
import logging
|
||||
|
||||
from collections import namedtuple
|
||||
import ujson as json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -219,6 +220,19 @@ class TypingNotificationHandler(BaseHandler):
|
|||
"typing_key", self._latest_room_serial, rooms=[room_id]
|
||||
)
|
||||
|
||||
def get_all_typing_updates(self, last_id, current_id):
|
||||
# TODO: Work out a way to do this without scanning the entire state.
|
||||
rows = []
|
||||
for room_id, serial in self._room_serials.items():
|
||||
if last_id < serial and serial <= current_id:
|
||||
typing = self._room_typing[room_id]
|
||||
typing_bytes = json.dumps([
|
||||
u.to_string() for u in typing
|
||||
], ensure_ascii=False)
|
||||
rows.append((serial, room_id, typing_bytes))
|
||||
rows.sort()
|
||||
return rows
|
||||
|
||||
|
||||
class TypingNotificationEventSource(object):
|
||||
def __init__(self, hs):
|
||||
|
|
|
@ -159,6 +159,8 @@ class Notifier(object):
|
|||
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
|
||||
)
|
||||
|
||||
self.replication_deferred = ObservableDeferred(defer.Deferred())
|
||||
|
||||
# This is not a very cheap test to perform, but it's only executed
|
||||
# when rendering the metrics page, which is likely once per minute at
|
||||
# most when scraping it.
|
||||
|
@ -207,6 +209,8 @@ class Notifier(object):
|
|||
))
|
||||
self._notify_pending_new_room_events(max_room_stream_id)
|
||||
|
||||
self.notify_replication()
|
||||
|
||||
def _notify_pending_new_room_events(self, max_room_stream_id):
|
||||
"""Notify for the room events that were queued waiting for a previous
|
||||
event to be persisted.
|
||||
|
@ -276,6 +280,8 @@ class Notifier(object):
|
|||
except:
|
||||
logger.exception("Failed to notify listener")
|
||||
|
||||
self.notify_replication()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_events(self, user_id, timeout, callback, room_ids=None,
|
||||
from_token=StreamToken("s0", "0", "0", "0", "0")):
|
||||
|
@ -479,3 +485,45 @@ class Notifier(object):
|
|||
room_streams = self.room_to_user_streams.setdefault(room_id, set())
|
||||
room_streams.add(new_user_stream)
|
||||
new_user_stream.rooms.add(room_id)
|
||||
|
||||
def notify_replication(self):
|
||||
"""Notify the any replication listeners that there's a new event"""
|
||||
with PreserveLoggingContext():
|
||||
deferred = self.replication_deferred
|
||||
self.replication_deferred = ObservableDeferred(defer.Deferred())
|
||||
deferred.callback(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_replication(self, callback, timeout):
|
||||
"""Wait for an event to happen.
|
||||
|
||||
:param callback:
|
||||
Gets called whenever an event happens. If this returns a truthy
|
||||
value then ``wait_for_replication`` returns, otherwise it waits
|
||||
for another event.
|
||||
:param int timeout:
|
||||
How many milliseconds to wait for callback return a truthy value.
|
||||
:returns:
|
||||
A deferred that resolves with the value returned by the callback.
|
||||
"""
|
||||
listener = _NotificationListener(None)
|
||||
|
||||
def timed_out():
|
||||
listener.deferred.cancel()
|
||||
|
||||
timer = self.clock.call_later(timeout / 1000., timed_out)
|
||||
while True:
|
||||
listener.deferred = self.replication_deferred.observe()
|
||||
result = yield callback()
|
||||
if result:
|
||||
break
|
||||
|
||||
try:
|
||||
with PreserveLoggingContext():
|
||||
yield listener.deferred
|
||||
except defer.CancelledError:
|
||||
break
|
||||
|
||||
self.clock.cancel_call_later(timer, ignore_errs=True)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -0,0 +1,320 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.http.server import request_handler, finish_request
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
|
||||
import ujson as json
|
||||
|
||||
import collections
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REPLICATION_PREFIX = "/_synapse/replication"
|
||||
|
||||
STREAM_NAMES = (
|
||||
("events",),
|
||||
("presence",),
|
||||
("typing",),
|
||||
("receipts",),
|
||||
("user_account_data", "room_account_data", "tag_account_data",),
|
||||
("backfill",),
|
||||
)
|
||||
|
||||
|
||||
class ReplicationResource(Resource):
|
||||
"""
|
||||
HTTP endpoint for extracting data from synapse.
|
||||
|
||||
The streams of data returned by the endpoint are controlled by the
|
||||
parameters given to the API. To return a given stream pass a query
|
||||
parameter with a position in the stream to return data from or the
|
||||
special value "-1" to return data from the start of the stream.
|
||||
|
||||
If there is no data for any of the supplied streams after the given
|
||||
position then the request will block until there is data for one
|
||||
of the streams. This allows clients to long-poll this API.
|
||||
|
||||
The possible streams are:
|
||||
|
||||
* "streams": A special stream returing the positions of other streams.
|
||||
* "events": The new events seen on the server.
|
||||
* "presence": Presence updates.
|
||||
* "typing": Typing updates.
|
||||
* "receipts": Receipt updates.
|
||||
* "user_account_data": Top-level per user account data.
|
||||
* "room_account_data: Per room per user account data.
|
||||
* "tag_account_data": Per room per user tags.
|
||||
* "backfill": Old events that have been backfilled from other servers.
|
||||
|
||||
The API takes two additional query parameters:
|
||||
|
||||
* "timeout": How long to wait before returning an empty response.
|
||||
* "limit": The maximum number of rows to return for the selected streams.
|
||||
|
||||
The response is a JSON object with keys for each stream with updates. Under
|
||||
each key is a JSON object with:
|
||||
|
||||
* "postion": The current position of the stream.
|
||||
* "field_names": The names of the fields in each row.
|
||||
* "rows": The updates as an array of arrays.
|
||||
|
||||
There are a number of ways this API could be used:
|
||||
|
||||
1) To replicate the contents of the backing database to another database.
|
||||
2) To be notified when the contents of a shared backing database changes.
|
||||
3) To "tail" the activity happening on a server for debugging.
|
||||
|
||||
In the first case the client would track all of the streams and store it's
|
||||
own copy of the data.
|
||||
|
||||
In the second case the client might theoretically just be able to follow
|
||||
the "streams" stream to track where the other streams are. However in
|
||||
practise it will probably need to get the contents of the streams in
|
||||
order to expire the any in-memory caches. Whether it gets the contents
|
||||
of the streams from this replication API or directly from the backing
|
||||
store is a matter of taste.
|
||||
|
||||
In the third case the client would use the "streams" stream to find what
|
||||
streams are available and their current positions. Then it can start
|
||||
long-polling this replication API for new data on those streams.
|
||||
"""
|
||||
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs):
|
||||
Resource.__init__(self) # Resource is old-style, so no super()
|
||||
|
||||
self.version_string = hs.version_string
|
||||
self.store = hs.get_datastore()
|
||||
self.sources = hs.get_event_sources()
|
||||
self.presence_handler = hs.get_handlers().presence_handler
|
||||
self.typing_handler = hs.get_handlers().typing_notification_handler
|
||||
self.notifier = hs.notifier
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def current_replication_token(self):
|
||||
stream_token = yield self.sources.get_current_token()
|
||||
backfill_token = yield self.store.get_current_backfill_token()
|
||||
|
||||
defer.returnValue(_ReplicationToken(
|
||||
stream_token.room_stream_id,
|
||||
int(stream_token.presence_key),
|
||||
int(stream_token.typing_key),
|
||||
int(stream_token.receipt_key),
|
||||
int(stream_token.account_data_key),
|
||||
backfill_token,
|
||||
))
|
||||
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
limit = parse_integer(request, "limit", 100)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
|
||||
request.setHeader(b"Content-Type", b"application/json")
|
||||
writer = _Writer(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def replicate():
|
||||
current_token = yield self.current_replication_token()
|
||||
logger.info("Replicating up to %r", current_token)
|
||||
|
||||
yield self.account_data(writer, current_token, limit)
|
||||
yield self.events(writer, current_token, limit)
|
||||
yield self.presence(writer, current_token) # TODO: implement limit
|
||||
yield self.typing(writer, current_token) # TODO: implement limit
|
||||
yield self.receipts(writer, current_token, limit)
|
||||
self.streams(writer, current_token)
|
||||
|
||||
logger.info("Replicated %d rows", writer.total)
|
||||
defer.returnValue(writer.total)
|
||||
|
||||
yield self.notifier.wait_for_replication(replicate, timeout)
|
||||
|
||||
writer.finish()
|
||||
|
||||
def streams(self, writer, current_token):
|
||||
request_token = parse_string(writer.request, "streams")
|
||||
|
||||
streams = []
|
||||
|
||||
if request_token is not None:
|
||||
if request_token == "-1":
|
||||
for names, position in zip(STREAM_NAMES, current_token):
|
||||
streams.extend((name, position) for name in names)
|
||||
else:
|
||||
items = zip(
|
||||
STREAM_NAMES,
|
||||
current_token,
|
||||
_ReplicationToken(request_token)
|
||||
)
|
||||
for names, current_id, last_id in items:
|
||||
if last_id < current_id:
|
||||
streams.extend((name, current_id) for name in names)
|
||||
|
||||
if streams:
|
||||
writer.write_header_and_rows(
|
||||
"streams", streams, ("name", "position"),
|
||||
position=str(current_token)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def events(self, writer, current_token, limit):
|
||||
request_events = parse_integer(writer.request, "events")
|
||||
request_backfill = parse_integer(writer.request, "backfill")
|
||||
|
||||
if request_events is not None or request_backfill is not None:
|
||||
if request_events is None:
|
||||
request_events = current_token.events
|
||||
if request_backfill is None:
|
||||
request_backfill = current_token.backfill
|
||||
events_rows, backfill_rows = yield self.store.get_all_new_events(
|
||||
request_backfill, request_events,
|
||||
current_token.backfill, current_token.events,
|
||||
limit
|
||||
)
|
||||
writer.write_header_and_rows(
|
||||
"events", events_rows, ("position", "internal", "json")
|
||||
)
|
||||
writer.write_header_and_rows(
|
||||
"backfill", backfill_rows, ("position", "internal", "json")
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def presence(self, writer, current_token):
|
||||
current_position = current_token.presence
|
||||
|
||||
request_presence = parse_integer(writer.request, "presence")
|
||||
|
||||
if request_presence is not None:
|
||||
presence_rows = yield self.presence_handler.get_all_presence_updates(
|
||||
request_presence, current_position
|
||||
)
|
||||
writer.write_header_and_rows("presence", presence_rows, (
|
||||
"position", "user_id", "state", "last_active_ts",
|
||||
"last_federation_update_ts", "last_user_sync_ts",
|
||||
"status_msg", "currently_active",
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def typing(self, writer, current_token):
|
||||
current_position = current_token.presence
|
||||
|
||||
request_typing = parse_integer(writer.request, "typing")
|
||||
|
||||
if request_typing is not None:
|
||||
typing_rows = yield self.typing_handler.get_all_typing_updates(
|
||||
request_typing, current_position
|
||||
)
|
||||
writer.write_header_and_rows("typing", typing_rows, (
|
||||
"position", "room_id", "typing"
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def receipts(self, writer, current_token, limit):
|
||||
current_position = current_token.receipts
|
||||
|
||||
request_receipts = parse_integer(writer.request, "receipts")
|
||||
|
||||
if request_receipts is not None:
|
||||
receipts_rows = yield self.store.get_all_updated_receipts(
|
||||
request_receipts, current_position, limit
|
||||
)
|
||||
writer.write_header_and_rows("receipts", receipts_rows, (
|
||||
"position", "room_id", "receipt_type", "user_id", "event_id", "data"
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def account_data(self, writer, current_token, limit):
|
||||
current_position = current_token.account_data
|
||||
|
||||
user_account_data = parse_integer(writer.request, "user_account_data")
|
||||
room_account_data = parse_integer(writer.request, "room_account_data")
|
||||
tag_account_data = parse_integer(writer.request, "tag_account_data")
|
||||
|
||||
if user_account_data is not None or room_account_data is not None:
|
||||
if user_account_data is None:
|
||||
user_account_data = current_position
|
||||
if room_account_data is None:
|
||||
room_account_data = current_position
|
||||
user_rows, room_rows = yield self.store.get_all_updated_account_data(
|
||||
user_account_data, room_account_data, current_position, limit
|
||||
)
|
||||
writer.write_header_and_rows("user_account_data", user_rows, (
|
||||
"position", "user_id", "type", "content"
|
||||
))
|
||||
writer.write_header_and_rows("room_account_data", room_rows, (
|
||||
"position", "user_id", "room_id", "type", "content"
|
||||
))
|
||||
|
||||
if tag_account_data is not None:
|
||||
tag_rows = yield self.store.get_all_updated_tags(
|
||||
tag_account_data, current_position, limit
|
||||
)
|
||||
writer.write_header_and_rows("tag_account_data", tag_rows, (
|
||||
"position", "user_id", "room_id", "tags"
|
||||
))
|
||||
|
||||
|
||||
class _Writer(object):
|
||||
"""Writes the streams as a JSON object as the response to the request"""
|
||||
def __init__(self, request):
|
||||
self.streams = {}
|
||||
self.request = request
|
||||
self.total = 0
|
||||
|
||||
def write_header_and_rows(self, name, rows, fields, position=None):
|
||||
if not rows:
|
||||
return
|
||||
|
||||
if position is None:
|
||||
position = rows[-1][0]
|
||||
|
||||
self.streams[name] = {
|
||||
"position": str(position),
|
||||
"field_names": fields,
|
||||
"rows": rows,
|
||||
}
|
||||
|
||||
self.total += len(rows)
|
||||
|
||||
def finish(self):
|
||||
self.request.write(json.dumps(self.streams, ensure_ascii=False))
|
||||
finish_request(self.request)
|
||||
|
||||
|
||||
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
|
||||
"events", "presence", "typing", "receipts", "account_data", "backfill",
|
||||
))):
|
||||
__slots__ = []
|
||||
|
||||
def __new__(cls, *args):
|
||||
if len(args) == 1:
|
||||
return cls(*(int(value) for value in args[0].split("_")))
|
||||
else:
|
||||
return super(_ReplicationToken, cls).__new__(cls, *args)
|
||||
|
||||
def __str__(self):
|
||||
return "_".join(str(value) for value in self)
|
|
@ -83,8 +83,40 @@ class AccountDataStore(SQLBaseStore):
|
|||
"get_account_data_for_room", get_account_data_for_room_txn
|
||||
)
|
||||
|
||||
def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None):
|
||||
"""Get all the client account_data for a that's changed.
|
||||
def get_all_updated_account_data(self, last_global_id, last_room_id,
|
||||
current_id, limit):
|
||||
"""Get all the client account_data that has changed on the server
|
||||
Args:
|
||||
last_global_id(int): The position to fetch from for top level data
|
||||
last_room_id(int): The position to fetch from for per room data
|
||||
current_id(int): The position to fetch up to.
|
||||
Returns:
|
||||
A deferred pair of lists of tuples of stream_id int, user_id string,
|
||||
room_id string, type string, and content string.
|
||||
"""
|
||||
def get_updated_account_data_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, account_data_type, content"
|
||||
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_global_id, current_id, limit))
|
||||
global_results = txn.fetchall()
|
||||
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, room_id, account_data_type, content"
|
||||
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_room_id, current_id, limit))
|
||||
room_results = txn.fetchall()
|
||||
return (global_results, room_results)
|
||||
return self.runInteraction(
|
||||
"get_all_updated_account_data_txn", get_updated_account_data_txn
|
||||
)
|
||||
|
||||
def get_updated_account_data_for_user(self, user_id, stream_id):
|
||||
"""Get all the client account_data for a that's changed for a user
|
||||
|
||||
Args:
|
||||
user_id(str): The user to get the account_data for.
|
||||
|
|
|
@ -1064,3 +1064,48 @@ class EventsStore(SQLBaseStore):
|
|||
yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
def get_current_backfill_token(self):
|
||||
"""The current minimum token that backfilled events have reached"""
|
||||
|
||||
# TODO: Fix race with the persit_event txn by using one of the
|
||||
# stream id managers
|
||||
return -self.min_stream_token
|
||||
|
||||
def get_all_new_events(self, last_backfill_id, last_forward_id,
|
||||
current_backfill_id, current_forward_id, limit):
|
||||
"""Get all the new events that have arrived at the server either as
|
||||
new events or as backfilled events"""
|
||||
def get_all_new_events_txn(txn):
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, ej.internal_metadata, ej.json"
|
||||
" FROM events as e"
|
||||
" JOIN event_json as ej"
|
||||
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
|
||||
" WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
|
||||
" ORDER BY e.stream_ordering ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
if last_forward_id != current_forward_id:
|
||||
txn.execute(sql, (last_forward_id, current_forward_id, limit))
|
||||
new_forward_events = txn.fetchall()
|
||||
else:
|
||||
new_forward_events = []
|
||||
|
||||
sql = (
|
||||
"SELECT -e.stream_ordering, ej.internal_metadata, ej.json"
|
||||
" FROM events as e"
|
||||
" JOIN event_json as ej"
|
||||
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
|
||||
" WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
|
||||
" ORDER BY e.stream_ordering DESC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
if last_backfill_id != current_backfill_id:
|
||||
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
|
||||
new_backfill_events = txn.fetchall()
|
||||
else:
|
||||
new_backfill_events = []
|
||||
|
||||
return (new_forward_events, new_backfill_events)
|
||||
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
|
||||
|
|
|
@ -115,6 +115,22 @@ class PresenceStore(SQLBaseStore):
|
|||
args
|
||||
)
|
||||
|
||||
def get_all_presence_updates(self, last_id, current_id):
|
||||
def get_all_presence_updates_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, state, last_active_ts,"
|
||||
" last_federation_update_ts, last_user_sync_ts, status_msg,"
|
||||
" currently_active"
|
||||
" FROM presence_stream"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, current_id))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.runInteraction(
|
||||
"get_all_presence_updates", get_all_presence_updates_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_presence_for_users(self, user_ids):
|
||||
rows = yield self._simple_select_many_batch(
|
||||
|
|
|
@ -390,3 +390,19 @@ class ReceiptsStore(SQLBaseStore):
|
|||
"data": json.dumps(data),
|
||||
}
|
||||
)
|
||||
|
||||
def get_all_updated_receipts(self, last_id, current_id, limit):
|
||||
def get_all_updated_receipts_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
|
||||
" FROM receipts_linearized"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
|
||||
return txn.fetchall()
|
||||
return self.runInteraction(
|
||||
"get_all_updated_receipts", get_all_updated_receipts_txn
|
||||
)
|
||||
|
|
|
@ -58,6 +58,59 @@ class TagsStore(SQLBaseStore):
|
|||
|
||||
return deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_all_updated_tags(self, last_id, current_id, limit):
|
||||
"""Get all the client tags that have changed on the server
|
||||
Args:
|
||||
last_id(int): The position to fetch from.
|
||||
current_id(int): The position to fetch up to.
|
||||
Returns:
|
||||
A deferred list of tuples of stream_id int, user_id string,
|
||||
room_id string, tag string and content string.
|
||||
"""
|
||||
def get_all_updated_tags_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, room_id"
|
||||
" FROM room_tags_revisions as r"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
tag_ids = yield self.runInteraction(
|
||||
"get_all_updated_tags", get_all_updated_tags_txn
|
||||
)
|
||||
|
||||
def get_tag_content(txn, tag_ids):
|
||||
sql = (
|
||||
"SELECT tag, content"
|
||||
" FROM room_tags"
|
||||
" WHERE user_id=? AND room_id=?"
|
||||
)
|
||||
results = []
|
||||
for stream_id, user_id, room_id in tag_ids:
|
||||
txn.execute(sql, (user_id, room_id))
|
||||
tags = []
|
||||
for tag, content in txn.fetchall():
|
||||
tags.append(json.dumps(tag) + ":" + content)
|
||||
tag_json = "{" + ",".join(tags) + "}"
|
||||
results.append((stream_id, user_id, room_id, tag_json))
|
||||
|
||||
return results
|
||||
|
||||
batch_size = 50
|
||||
results = []
|
||||
for i in xrange(0, len(tag_ids), batch_size):
|
||||
tags = yield self.runInteraction(
|
||||
"get_all_updated_tag_content",
|
||||
get_tag_content,
|
||||
tag_ids[i:i + batch_size],
|
||||
)
|
||||
results.extend(tags)
|
||||
|
||||
defer.returnValue(results)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_updated_tags(self, user_id, stream_id):
|
||||
"""Get all the tags for the rooms where the tags have changed since the
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -0,0 +1,179 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.resource import ReplicationResource
|
||||
from synapse.types import Requester, UserID
|
||||
|
||||
from twisted.internet import defer
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
from mock import Mock, NonCallableMock
|
||||
import json
|
||||
import contextlib
|
||||
|
||||
|
||||
class ReplicationResourceCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hs = yield setup_test_homeserver(
|
||||
"red",
|
||||
http_client=None,
|
||||
replication_layer=Mock(),
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
)
|
||||
self.user = UserID.from_string("@seeing:red")
|
||||
|
||||
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
|
||||
|
||||
self.resource = ReplicationResource(self.hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_streams(self):
|
||||
# Passing "-1" returns the current stream positions
|
||||
code, body = yield self.get(streams="-1")
|
||||
self.assertEquals(code, 200)
|
||||
self.assertEquals(body["streams"]["field_names"], ["name", "position"])
|
||||
position = body["streams"]["position"]
|
||||
# Passing the current position returns an empty response after the
|
||||
# timeout
|
||||
get = self.get(streams=str(position), timeout="0")
|
||||
self.hs.clock.advance_time_msec(1)
|
||||
code, body = yield get
|
||||
self.assertEquals(code, 200)
|
||||
self.assertEquals(body, {})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_events(self):
|
||||
get = self.get(events="-1", timeout="0")
|
||||
yield self.hs.get_handlers().room_creation_handler.create_room(
|
||||
Requester(self.user, "", False), {}
|
||||
)
|
||||
code, body = yield get
|
||||
self.assertEquals(code, 200)
|
||||
self.assertEquals(body["events"]["field_names"], [
|
||||
"position", "internal", "json"
|
||||
])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_presence(self):
|
||||
get = self.get(presence="-1")
|
||||
yield self.hs.get_handlers().presence_handler.set_state(
|
||||
self.user, {"presence": "online"}
|
||||
)
|
||||
code, body = yield get
|
||||
self.assertEquals(code, 200)
|
||||
self.assertEquals(body["presence"]["field_names"], [
|
||||
"position", "user_id", "state", "last_active_ts",
|
||||
"last_federation_update_ts", "last_user_sync_ts",
|
||||
"status_msg", "currently_active",
|
||||
])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_typing(self):
|
||||
room_id = yield self.create_room()
|
||||
get = self.get(typing="-1")
|
||||
yield self.hs.get_handlers().typing_notification_handler.started_typing(
|
||||
self.user, self.user, room_id, timeout=2
|
||||
)
|
||||
code, body = yield get
|
||||
self.assertEquals(code, 200)
|
||||
self.assertEquals(body["typing"]["field_names"], [
|
||||
"position", "room_id", "typing"
|
||||
])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_receipts(self):
|
||||
room_id = yield self.create_room()
|
||||
event_id = yield self.send_text_message(room_id, "Hello, World")
|
||||
get = self.get(receipts="-1")
|
||||
yield self.hs.get_handlers().receipts_handler.received_client_receipt(
|
||||
room_id, "m.read", self.user.to_string(), event_id
|
||||
)
|
||||
code, body = yield get
|
||||
self.assertEquals(code, 200)
|
||||
self.assertEquals(body["receipts"]["field_names"], [
|
||||
"position", "room_id", "receipt_type", "user_id", "event_id", "data"
|
||||
])
|
||||
|
||||
def _test_timeout(stream):
|
||||
"""Check that a request for the given stream timesout"""
|
||||
@defer.inlineCallbacks
|
||||
def test_timeout(self):
|
||||
get = self.get(**{stream: "-1", "timeout": "0"})
|
||||
self.hs.clock.advance_time_msec(1)
|
||||
code, body = yield get
|
||||
self.assertEquals(code, 200)
|
||||
self.assertEquals(body, {})
|
||||
test_timeout.__name__ = "test_timeout_%s" % (stream)
|
||||
return test_timeout
|
||||
|
||||
test_timeout_events = _test_timeout("events")
|
||||
test_timeout_presence = _test_timeout("presence")
|
||||
test_timeout_typing = _test_timeout("typing")
|
||||
test_timeout_receipts = _test_timeout("receipts")
|
||||
test_timeout_user_account_data = _test_timeout("user_account_data")
|
||||
test_timeout_room_account_data = _test_timeout("room_account_data")
|
||||
test_timeout_tag_account_data = _test_timeout("tag_account_data")
|
||||
test_timeout_backfill = _test_timeout("backfill")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_text_message(self, room_id, message):
|
||||
handler = self.hs.get_handlers().message_handler
|
||||
event = yield handler.create_and_send_nonmember_event({
|
||||
"type": "m.room.message",
|
||||
"content": {"body": "message", "msgtype": "m.text"},
|
||||
"room_id": room_id,
|
||||
"sender": self.user.to_string(),
|
||||
})
|
||||
defer.returnValue(event.event_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_room(self):
|
||||
result = yield self.hs.get_handlers().room_creation_handler.create_room(
|
||||
Requester(self.user, "", False), {}
|
||||
)
|
||||
defer.returnValue(result["room_id"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get(self, **params):
|
||||
request = NonCallableMock(spec_set=[
|
||||
"write", "finish", "setResponseCode", "setHeader", "args",
|
||||
"method", "processing"
|
||||
])
|
||||
|
||||
request.method = "GET"
|
||||
request.args = {k: [v] for k, v in params.items()}
|
||||
|
||||
@contextlib.contextmanager
|
||||
def processing():
|
||||
yield
|
||||
request.processing = processing
|
||||
|
||||
yield self.resource._async_render_GET(request)
|
||||
self.assertTrue(request.finish.called)
|
||||
|
||||
if request.setResponseCode.called:
|
||||
response_code = request.setResponseCode.call_args[0][0]
|
||||
else:
|
||||
response_code = 200
|
||||
|
||||
response_json = "".join(
|
||||
call[0][0] for call in request.write.call_args_list
|
||||
)
|
||||
response_body = json.loads(response_json)
|
||||
|
||||
defer.returnValue((response_code, response_body))
|
|
@ -239,9 +239,10 @@ class MockClock(object):
|
|||
def looping_call(self, function, interval):
|
||||
pass
|
||||
|
||||
def cancel_call_later(self, timer):
|
||||
def cancel_call_later(self, timer, ignore_errs=False):
|
||||
if timer[2]:
|
||||
raise Exception("Cannot cancel an expired timer")
|
||||
if not ignore_errs:
|
||||
raise Exception("Cannot cancel an expired timer")
|
||||
|
||||
timer[2] = True
|
||||
self.timers = [t for t in self.timers if t != timer]
|
||||
|
|
Loading…
Reference in New Issue