Merge pull request #199 from matrix-org/erikj/receipts
Implement read receipts.
This commit is contained in:
commit
b6d4a4c6d8
|
@ -32,6 +32,7 @@ from .appservice import ApplicationServicesHandler
|
|||
from .sync import SyncHandler
|
||||
from .auth import AuthHandler
|
||||
from .identity import IdentityHandler
|
||||
from .receipts import ReceiptsHandler
|
||||
|
||||
|
||||
class Handlers(object):
|
||||
|
@ -57,6 +58,7 @@ class Handlers(object):
|
|||
self.directory_handler = DirectoryHandler(hs)
|
||||
self.typing_notification_handler = TypingNotificationHandler(hs)
|
||||
self.admin_handler = AdminHandler(hs)
|
||||
self.receipts_handler = ReceiptsHandler(hs)
|
||||
asapi = ApplicationServiceApi(hs)
|
||||
self.appservice_handler = ApplicationServicesHandler(
|
||||
hs, asapi, AppServiceScheduler(
|
||||
|
|
|
@ -334,6 +334,11 @@ class MessageHandler(BaseHandler):
|
|||
user, pagination_config.get_source_config("presence"), None
|
||||
)
|
||||
|
||||
receipt_stream = self.hs.get_event_sources().sources["receipt"]
|
||||
receipt, _ = yield receipt_stream.get_pagination_rows(
|
||||
user, pagination_config.get_source_config("receipt"), None
|
||||
)
|
||||
|
||||
public_room_ids = yield self.store.get_public_room_ids()
|
||||
|
||||
limit = pagin_config.limit
|
||||
|
@ -404,7 +409,8 @@ class MessageHandler(BaseHandler):
|
|||
ret = {
|
||||
"rooms": rooms_ret,
|
||||
"presence": presence,
|
||||
"end": now_token.to_string()
|
||||
"receipts": receipt,
|
||||
"end": now_token.to_string(),
|
||||
}
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
@ -465,9 +471,12 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
defer.returnValue([p for success, p in presence_defs if success])
|
||||
|
||||
presence, (messages, token) = yield defer.gatherResults(
|
||||
receipts_handler = self.hs.get_handlers().receipts_handler
|
||||
|
||||
presence, receipts, (messages, token) = yield defer.gatherResults(
|
||||
[
|
||||
get_presence(),
|
||||
receipts_handler.get_receipts_for_room(room_id, now_token.receipt_key),
|
||||
self.store.get_recent_events_for_room(
|
||||
room_id,
|
||||
limit=limit,
|
||||
|
@ -495,5 +504,6 @@ class MessageHandler(BaseHandler):
|
|||
"end": end_token.to_string(),
|
||||
},
|
||||
"state": state,
|
||||
"presence": presence
|
||||
"presence": presence,
|
||||
"receipts": receipts,
|
||||
})
|
||||
|
|
|
@ -992,7 +992,7 @@ class PresenceHandler(BaseHandler):
|
|||
room_ids([str]): List of room_ids to notify.
|
||||
"""
|
||||
with PreserveLoggingContext():
|
||||
self.notifier.on_new_user_event(
|
||||
self.notifier.on_new_event(
|
||||
"presence_key",
|
||||
self._user_cachemap_latest_serial,
|
||||
users_to_push,
|
||||
|
|
|
@ -0,0 +1,207 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReceiptsHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(ReceiptsHandler, self).__init__(hs)
|
||||
|
||||
self.hs = hs
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.federation.register_edu_handler(
|
||||
"m.receipt", self._received_remote_receipt
|
||||
)
|
||||
self.clock = self.hs.get_clock()
|
||||
|
||||
self._receipt_cache = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def received_client_receipt(self, room_id, receipt_type, user_id,
|
||||
event_id):
|
||||
"""Called when a client tells us a local user has read up to the given
|
||||
event_id in the room.
|
||||
"""
|
||||
receipt = {
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
"event_ids": [event_id],
|
||||
"data": {
|
||||
"ts": int(self.clock.time_msec()),
|
||||
}
|
||||
}
|
||||
|
||||
is_new = yield self._handle_new_receipts([receipt])
|
||||
|
||||
if is_new:
|
||||
self._push_remotes([receipt])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _received_remote_receipt(self, origin, content):
|
||||
"""Called when we receive an EDU of type m.receipt from a remote HS.
|
||||
"""
|
||||
receipts = [
|
||||
{
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
"event_ids": user_values["event_ids"],
|
||||
"data": user_values.get("data", {}),
|
||||
}
|
||||
for room_id, room_values in content.items()
|
||||
for receipt_type, users in room_values.items()
|
||||
for user_id, user_values in users.items()
|
||||
]
|
||||
|
||||
yield self._handle_new_receipts(receipts)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_new_receipts(self, receipts):
|
||||
"""Takes a list of receipts, stores them and informs the notifier.
|
||||
"""
|
||||
for receipt in receipts:
|
||||
room_id = receipt["room_id"]
|
||||
receipt_type = receipt["receipt_type"]
|
||||
user_id = receipt["user_id"]
|
||||
event_ids = receipt["event_ids"]
|
||||
data = receipt["data"]
|
||||
|
||||
res = yield self.store.insert_receipt(
|
||||
room_id, receipt_type, user_id, event_ids, data
|
||||
)
|
||||
|
||||
if not res:
|
||||
# res will be None if this read receipt is 'old'
|
||||
defer.returnValue(False)
|
||||
|
||||
stream_id, max_persisted_id = res
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self.notifier.on_new_event(
|
||||
"receipt_key", max_persisted_id, rooms=[room_id]
|
||||
)
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _push_remotes(self, receipts):
|
||||
"""Given a list of receipts, works out which remote servers should be
|
||||
poked and pokes them.
|
||||
"""
|
||||
# TODO: Some of this stuff should be coallesced.
|
||||
for receipt in receipts:
|
||||
room_id = receipt["room_id"]
|
||||
receipt_type = receipt["receipt_type"]
|
||||
user_id = receipt["user_id"]
|
||||
event_ids = receipt["event_ids"]
|
||||
data = receipt["data"]
|
||||
|
||||
remotedomains = set()
|
||||
|
||||
rm_handler = self.hs.get_handlers().room_member_handler
|
||||
yield rm_handler.fetch_room_distributions_into(
|
||||
room_id, localusers=None, remotedomains=remotedomains
|
||||
)
|
||||
|
||||
logger.debug("Sending receipt to: %r", remotedomains)
|
||||
|
||||
for domain in remotedomains:
|
||||
self.federation.send_edu(
|
||||
destination=domain,
|
||||
edu_type="m.receipt",
|
||||
content={
|
||||
room_id: {
|
||||
receipt_type: {
|
||||
user_id: {
|
||||
"event_ids": event_ids,
|
||||
"data": data,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_receipts_for_room(self, room_id, to_key):
|
||||
"""Gets all receipts for a room, upto the given key.
|
||||
"""
|
||||
result = yield self.store.get_linearized_receipts_for_room(
|
||||
room_id,
|
||||
to_key=to_key,
|
||||
)
|
||||
|
||||
if not result:
|
||||
defer.returnValue([])
|
||||
|
||||
event = {
|
||||
"type": "m.receipt",
|
||||
"room_id": room_id,
|
||||
"content": result,
|
||||
}
|
||||
|
||||
defer.returnValue([event])
|
||||
|
||||
|
||||
class ReceiptEventSource(object):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_new_events_for_user(self, user, from_key, limit):
|
||||
from_key = int(from_key)
|
||||
to_key = yield self.get_current_key()
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(user.to_string())
|
||||
rooms = [room.room_id for room in rooms]
|
||||
events = yield self.store.get_linearized_receipts_for_rooms(
|
||||
rooms,
|
||||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
)
|
||||
|
||||
defer.returnValue((events, to_key))
|
||||
|
||||
def get_current_key(self, direction='f'):
|
||||
return self.store.get_max_receipt_stream_id()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_pagination_rows(self, user, config, key):
|
||||
to_key = int(config.from_key)
|
||||
|
||||
if config.to_key:
|
||||
from_key = int(config.to_key)
|
||||
else:
|
||||
from_key = None
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(user.to_string())
|
||||
rooms = [room.room_id for room in rooms]
|
||||
events = yield self.store.get_linearized_receipts_for_rooms(
|
||||
rooms,
|
||||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
)
|
||||
|
||||
defer.returnValue((events, to_key))
|
|
@ -218,7 +218,7 @@ class TypingNotificationHandler(BaseHandler):
|
|||
self._room_serials[room_id] = self._latest_room_serial
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self.notifier.on_new_user_event(
|
||||
self.notifier.on_new_event(
|
||||
"typing_key", self._latest_room_serial, rooms=[room_id]
|
||||
)
|
||||
|
||||
|
|
|
@ -221,16 +221,7 @@ class Notifier(object):
|
|||
event
|
||||
)
|
||||
|
||||
room_id = event.room_id
|
||||
|
||||
room_user_streams = self.room_to_user_streams.get(room_id, set())
|
||||
|
||||
user_streams = room_user_streams.copy()
|
||||
|
||||
for user in extra_users:
|
||||
user_stream = self.user_to_user_stream.get(str(user))
|
||||
if user_stream is not None:
|
||||
user_streams.add(user_stream)
|
||||
app_streams = set()
|
||||
|
||||
for appservice in self.appservice_to_user_streams:
|
||||
# TODO (kegan): Redundant appservice listener checks?
|
||||
|
@ -242,24 +233,20 @@ class Notifier(object):
|
|||
app_user_streams = self.appservice_to_user_streams.get(
|
||||
appservice, set()
|
||||
)
|
||||
user_streams |= app_user_streams
|
||||
app_streams |= app_user_streams
|
||||
|
||||
logger.debug("on_new_room_event listeners %s", user_streams)
|
||||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
for user_stream in user_streams:
|
||||
try:
|
||||
user_stream.notify(
|
||||
"room_key", "s%d" % (room_stream_id,), time_now_ms
|
||||
)
|
||||
except:
|
||||
logger.exception("Failed to notify listener")
|
||||
self.on_new_event(
|
||||
"room_key", room_stream_id,
|
||||
users=extra_users,
|
||||
rooms=[event.room_id],
|
||||
extra_streams=app_streams,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]):
|
||||
""" Used to inform listeners that something has happend
|
||||
presence/user event wise.
|
||||
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
|
||||
extra_streams=set()):
|
||||
""" Used to inform listeners that something has happend event wise.
|
||||
|
||||
Will wake up all listeners for the given users and rooms.
|
||||
"""
|
||||
|
@ -283,7 +270,7 @@ class Notifier(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_events(self, user, rooms, timeout, callback,
|
||||
from_token=StreamToken("s0", "0", "0")):
|
||||
from_token=StreamToken("s0", "0", "0", "0")):
|
||||
"""Wait until the callback returns a non empty response or the
|
||||
timeout fires.
|
||||
"""
|
||||
|
|
|
@ -31,6 +31,7 @@ REQUIREMENTS = {
|
|||
"pillow": ["PIL"],
|
||||
"pydenticon": ["pydenticon"],
|
||||
"ujson": ["ujson"],
|
||||
"blist": ["blist"],
|
||||
"pysaml2": ["saml2"],
|
||||
}
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
|
|
|
@ -19,6 +19,7 @@ from . import (
|
|||
account,
|
||||
register,
|
||||
auth,
|
||||
receipts,
|
||||
keys,
|
||||
)
|
||||
|
||||
|
@ -39,4 +40,5 @@ class ClientV2AlphaRestResource(JsonResource):
|
|||
account.register_servlets(hs, client_resource)
|
||||
register.register_servlets(hs, client_resource)
|
||||
auth.register_servlets(hs, client_resource)
|
||||
receipts.register_servlets(hs, client_resource)
|
||||
keys.register_servlets(hs, client_resource)
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet
|
||||
from ._base import client_v2_pattern
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReceiptRestServlet(RestServlet):
|
||||
PATTERN = client_v2_pattern(
|
||||
"/rooms/(?P<room_id>[^/]*)"
|
||||
"/receipt/(?P<receipt_type>[^/]*)"
|
||||
"/(?P<event_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ReceiptRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.receipts_handler = hs.get_handlers().receipts_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, receipt_type, event_id):
|
||||
user, client = yield self.auth.get_user_by_req(request)
|
||||
|
||||
yield self.receipts_handler.received_client_receipt(
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id=user.to_string(),
|
||||
event_id=event_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ReceiptRestServlet(hs).register(http_server)
|
|
@ -39,6 +39,8 @@ from .signatures import SignatureStore
|
|||
from .filtering import FilteringStore
|
||||
from .end_to_end_keys import EndToEndKeyStore
|
||||
|
||||
from .receipts import ReceiptsStore
|
||||
|
||||
|
||||
import fnmatch
|
||||
import imp
|
||||
|
@ -75,6 +77,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
PushRuleStore,
|
||||
ApplicationServiceTransactionStore,
|
||||
EventsStore,
|
||||
ReceiptsStore,
|
||||
EndToEndKeyStore,
|
||||
):
|
||||
|
||||
|
|
|
@ -329,13 +329,14 @@ class SQLBaseStore(object):
|
|||
|
||||
self.database_engine = hs.database_engine
|
||||
|
||||
self._stream_id_gen = StreamIdGenerator()
|
||||
self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
|
||||
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
||||
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
||||
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
||||
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
||||
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
||||
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
||||
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
|
||||
|
||||
def start_profiling(self):
|
||||
self._previous_loop_ts = self._clock.time_msec()
|
||||
|
|
|
@ -0,0 +1,348 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014, 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore, cached
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util import unwrapFirstError
|
||||
|
||||
from blist import sorteddict
|
||||
import logging
|
||||
import ujson as json
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReceiptsStore(SQLBaseStore):
|
||||
def __init__(self, hs):
|
||||
super(ReceiptsStore, self).__init__(hs)
|
||||
|
||||
self._receipts_stream_cache = _RoomStreamChangeCache()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
"""Get receipts for multiple rooms for sending to clients.
|
||||
|
||||
Args:
|
||||
room_ids (list): List of room_ids.
|
||||
to_key (int): Max stream id to fetch receipts upto.
|
||||
from_key (int): Min stream id to fetch receipts from. None fetches
|
||||
from the start.
|
||||
|
||||
Returns:
|
||||
list: A list of receipts.
|
||||
"""
|
||||
room_ids = set(room_ids)
|
||||
|
||||
if from_key:
|
||||
room_ids = yield self._receipts_stream_cache.get_rooms_changed(
|
||||
self, room_ids, from_key
|
||||
)
|
||||
|
||||
results = yield defer.gatherResults(
|
||||
[
|
||||
self.get_linearized_receipts_for_room(
|
||||
room_id, to_key, from_key=from_key
|
||||
)
|
||||
for room_id in room_ids
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
defer.returnValue([ev for res in results for ev in res])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||
"""Get receipts for a single room for sending to clients.
|
||||
|
||||
Args:
|
||||
room_ids (str): The room id.
|
||||
to_key (int): Max stream id to fetch receipts upto.
|
||||
from_key (int): Min stream id to fetch receipts from. None fetches
|
||||
from the start.
|
||||
|
||||
Returns:
|
||||
list: A list of receipts.
|
||||
"""
|
||||
def f(txn):
|
||||
if from_key:
|
||||
sql = (
|
||||
"SELECT * FROM receipts_linearized WHERE"
|
||||
" room_id = ? AND stream_id > ? AND stream_id <= ?"
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(room_id, from_key, to_key)
|
||||
)
|
||||
else:
|
||||
sql = (
|
||||
"SELECT * FROM receipts_linearized WHERE"
|
||||
" room_id = ? AND stream_id <= ?"
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(room_id, to_key)
|
||||
)
|
||||
|
||||
rows = self.cursor_to_dict(txn)
|
||||
|
||||
return rows
|
||||
|
||||
rows = yield self.runInteraction(
|
||||
"get_linearized_receipts_for_room", f
|
||||
)
|
||||
|
||||
if not rows:
|
||||
defer.returnValue([])
|
||||
|
||||
content = {}
|
||||
for row in rows:
|
||||
content.setdefault(
|
||||
row["event_id"], {}
|
||||
).setdefault(
|
||||
row["receipt_type"], {}
|
||||
)[row["user_id"]] = json.loads(row["data"])
|
||||
|
||||
defer.returnValue([{
|
||||
"type": "m.receipt",
|
||||
"room_id": room_id,
|
||||
"content": content,
|
||||
}])
|
||||
|
||||
def get_max_receipt_stream_id(self):
|
||||
return self._receipts_id_gen.get_max_token(self)
|
||||
|
||||
@cached
|
||||
@defer.inlineCallbacks
|
||||
def get_graph_receipts_for_room(self, room_id):
|
||||
"""Get receipts for sending to remote servers.
|
||||
"""
|
||||
rows = yield self._simple_select_list(
|
||||
table="receipts_graph",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=["receipt_type", "user_id", "event_id"],
|
||||
desc="get_linearized_receipts_for_room",
|
||||
)
|
||||
|
||||
result = {}
|
||||
for row in rows:
|
||||
result.setdefault(
|
||||
row["user_id"], {}
|
||||
).setdefault(
|
||||
row["receipt_type"], []
|
||||
).append(row["event_id"])
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
|
||||
user_id, event_id, data, stream_id):
|
||||
|
||||
# We don't want to clobber receipts for more recent events, so we
|
||||
# have to compare orderings of existing receipts
|
||||
sql = (
|
||||
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||
" INNER JOIN receipts_linearized as r USING (event_id, room_id)"
|
||||
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (room_id, receipt_type, user_id))
|
||||
results = txn.fetchall()
|
||||
|
||||
if results:
|
||||
res = self._simple_select_one_txn(
|
||||
txn,
|
||||
table="events",
|
||||
retcols=["topological_ordering", "stream_ordering"],
|
||||
keyvalues={"event_id": event_id},
|
||||
)
|
||||
topological_ordering = int(res["topological_ordering"])
|
||||
stream_ordering = int(res["stream_ordering"])
|
||||
|
||||
for to, so, _ in results:
|
||||
if int(to) > topological_ordering:
|
||||
return False
|
||||
elif int(to) == topological_ordering and int(so) >= stream_ordering:
|
||||
return False
|
||||
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="receipts_linearized",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="receipts_linearized",
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
"event_id": event_id,
|
||||
"data": json.dumps(data),
|
||||
}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
|
||||
"""Insert a receipt, either from local client or remote server.
|
||||
|
||||
Automatically does conversion between linearized and graph
|
||||
representations.
|
||||
"""
|
||||
if not event_ids:
|
||||
return
|
||||
|
||||
if len(event_ids) == 1:
|
||||
linearized_event_id = event_ids[0]
|
||||
else:
|
||||
# we need to points in graph -> linearized form.
|
||||
# TODO: Make this better.
|
||||
def graph_to_linear(txn):
|
||||
query = (
|
||||
"SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
|
||||
" SELECT max(stream_ordering) WHERE event_id IN (%s)"
|
||||
")"
|
||||
) % (",".join(["?"] * len(event_ids)))
|
||||
|
||||
txn.execute(query, [room_id] + event_ids)
|
||||
rows = txn.fetchall()
|
||||
if rows:
|
||||
return rows[0][0]
|
||||
else:
|
||||
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
|
||||
|
||||
linearized_event_id = yield self.runInteraction(
|
||||
"insert_receipt_conv", graph_to_linear
|
||||
)
|
||||
|
||||
stream_id_manager = yield self._receipts_id_gen.get_next(self)
|
||||
with stream_id_manager as stream_id:
|
||||
yield self._receipts_stream_cache.room_has_changed(
|
||||
self, room_id, stream_id
|
||||
)
|
||||
have_persisted = yield self.runInteraction(
|
||||
"insert_linearized_receipt",
|
||||
self.insert_linearized_receipt_txn,
|
||||
room_id, receipt_type, user_id, linearized_event_id,
|
||||
data,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
|
||||
if not have_persisted:
|
||||
defer.returnValue(None)
|
||||
|
||||
yield self.insert_graph_receipt(
|
||||
room_id, receipt_type, user_id, event_ids, data
|
||||
)
|
||||
|
||||
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
|
||||
defer.returnValue((stream_id, max_persisted_id))
|
||||
|
||||
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
|
||||
data):
|
||||
return self.runInteraction(
|
||||
"insert_graph_receipt",
|
||||
self.insert_graph_receipt_txn,
|
||||
room_id, receipt_type, user_id, event_ids, data
|
||||
)
|
||||
|
||||
def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
|
||||
user_id, event_ids, data):
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="receipts_graph",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="receipts_graph",
|
||||
values={
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
"event_ids": json.dumps(event_ids),
|
||||
"data": json.dumps(data),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class _RoomStreamChangeCache(object):
|
||||
"""Keeps track of the stream_id of the latest change in rooms.
|
||||
|
||||
Given a list of rooms and stream key, it will give a subset of rooms that
|
||||
may have changed since that key. If the key is too old then the cache
|
||||
will simply return all rooms.
|
||||
"""
|
||||
def __init__(self, size_of_cache=1000):
|
||||
self._size_of_cache = size_of_cache
|
||||
self._room_to_key = {}
|
||||
self._cache = sorteddict()
|
||||
self._earliest_key = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_changed(self, store, room_ids, key):
|
||||
"""Returns subset of room ids that have had new receipts since the
|
||||
given key. If the key is too old it will just return the given list.
|
||||
"""
|
||||
if key > (yield self._get_earliest_key(store)):
|
||||
keys = self._cache.keys()
|
||||
i = keys.bisect_right(key)
|
||||
|
||||
result = set(
|
||||
self._cache[k] for k in keys[i:]
|
||||
).intersection(room_ids)
|
||||
else:
|
||||
result = room_ids
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def room_has_changed(self, store, room_id, key):
|
||||
"""Informs the cache that the room has been changed at the given key.
|
||||
"""
|
||||
if key > (yield self._get_earliest_key(store)):
|
||||
old_key = self._room_to_key.get(room_id, None)
|
||||
if old_key:
|
||||
key = max(key, old_key)
|
||||
self._cache.pop(old_key, None)
|
||||
self._cache[key] = room_id
|
||||
|
||||
while len(self._cache) > self._size_of_cache:
|
||||
k, r = self._cache.popitem()
|
||||
self._earliest_key = max(k, self._earliest_key)
|
||||
self._room_to_key.pop(r, None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_earliest_key(self, store):
|
||||
if self._earliest_key is None:
|
||||
self._earliest_key = yield store.get_max_receipt_stream_id()
|
||||
self._earliest_key = int(self._earliest_key)
|
||||
|
||||
defer.returnValue(self._earliest_key)
|
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2015 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS receipts_graph(
|
||||
room_id TEXT NOT NULL,
|
||||
receipt_type TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
event_ids TEXT NOT NULL,
|
||||
data TEXT NOT NULL,
|
||||
CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS receipts_linearized (
|
||||
stream_id BIGINT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
receipt_type TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
data TEXT NOT NULL,
|
||||
CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id)
|
||||
);
|
||||
|
||||
CREATE INDEX receipts_linearized_id ON receipts_linearized(
|
||||
stream_id
|
||||
);
|
|
@ -72,7 +72,10 @@ class StreamIdGenerator(object):
|
|||
with stream_id_gen.get_next_txn(txn) as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
def __init__(self):
|
||||
def __init__(self, table, column):
|
||||
self.table = table
|
||||
self.column = column
|
||||
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._current_max = None
|
||||
|
@ -157,7 +160,7 @@ class StreamIdGenerator(object):
|
|||
|
||||
def _get_or_compute_current_max(self, txn):
|
||||
with self._lock:
|
||||
txn.execute("SELECT MAX(stream_ordering) FROM events")
|
||||
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
|
||||
rows = txn.fetchall()
|
||||
val, = rows[0]
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from synapse.types import StreamToken
|
|||
from synapse.handlers.presence import PresenceEventSource
|
||||
from synapse.handlers.room import RoomEventSource
|
||||
from synapse.handlers.typing import TypingNotificationEventSource
|
||||
from synapse.handlers.receipts import ReceiptEventSource
|
||||
|
||||
|
||||
class NullSource(object):
|
||||
|
@ -43,6 +44,7 @@ class EventSources(object):
|
|||
"room": RoomEventSource,
|
||||
"presence": PresenceEventSource,
|
||||
"typing": TypingNotificationEventSource,
|
||||
"receipt": ReceiptEventSource,
|
||||
}
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -62,7 +64,10 @@ class EventSources(object):
|
|||
),
|
||||
typing_key=(
|
||||
yield self.sources["typing"].get_current_key()
|
||||
)
|
||||
),
|
||||
receipt_key=(
|
||||
yield self.sources["receipt"].get_current_key()
|
||||
),
|
||||
)
|
||||
defer.returnValue(token)
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ class EventID(DomainSpecificString):
|
|||
class StreamToken(
|
||||
namedtuple(
|
||||
"Token",
|
||||
("room_key", "presence_key", "typing_key")
|
||||
("room_key", "presence_key", "typing_key", "receipt_key")
|
||||
)
|
||||
):
|
||||
_SEPARATOR = "_"
|
||||
|
@ -109,6 +109,9 @@ class StreamToken(
|
|||
def from_string(cls, string):
|
||||
try:
|
||||
keys = string.split(cls._SEPARATOR)
|
||||
if len(keys) == len(cls._fields) - 1:
|
||||
# i.e. old token from before receipt_key
|
||||
keys.append("0")
|
||||
return cls(*keys)
|
||||
except:
|
||||
raise SynapseError(400, "Invalid Token")
|
||||
|
@ -131,6 +134,7 @@ class StreamToken(
|
|||
(other_token.room_stream_id < self.room_stream_id)
|
||||
or (int(other_token.presence_key) < int(self.presence_key))
|
||||
or (int(other_token.typing_key) < int(self.typing_key))
|
||||
or (int(other_token.receipt_key) < int(self.receipt_key))
|
||||
)
|
||||
|
||||
def copy_and_advance(self, key, new_value):
|
||||
|
|
|
@ -66,8 +66,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
|
||||
self.mock_federation_resource = MockHttpResource()
|
||||
|
||||
mock_notifier = Mock(spec=["on_new_user_event"])
|
||||
self.on_new_user_event = mock_notifier.on_new_user_event
|
||||
mock_notifier = Mock(spec=["on_new_event"])
|
||||
self.on_new_event = mock_notifier.on_new_event
|
||||
|
||||
self.auth = Mock(spec=[])
|
||||
|
||||
|
@ -182,7 +182,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
timeout=20000,
|
||||
)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
self.on_new_event.assert_has_calls([
|
||||
call('typing_key', 1, rooms=[self.room_id]),
|
||||
])
|
||||
|
||||
|
@ -245,7 +245,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
self.on_new_event.assert_has_calls([
|
||||
call('typing_key', 1, rooms=[self.room_id]),
|
||||
])
|
||||
|
||||
|
@ -299,7 +299,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
room_id=self.room_id,
|
||||
)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
self.on_new_event.assert_has_calls([
|
||||
call('typing_key', 1, rooms=[self.room_id]),
|
||||
])
|
||||
|
||||
|
@ -331,10 +331,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
timeout=10000,
|
||||
)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
self.on_new_event.assert_has_calls([
|
||||
call('typing_key', 1, rooms=[self.room_id]),
|
||||
])
|
||||
self.on_new_user_event.reset_mock()
|
||||
self.on_new_event.reset_mock()
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
||||
|
@ -351,7 +351,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
|
||||
self.clock.advance_time(11)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
self.on_new_event.assert_has_calls([
|
||||
call('typing_key', 2, rooms=[self.room_id]),
|
||||
])
|
||||
|
||||
|
@ -377,10 +377,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
timeout=10000,
|
||||
)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
self.on_new_event.assert_has_calls([
|
||||
call('typing_key', 3, rooms=[self.room_id]),
|
||||
])
|
||||
self.on_new_user_event.reset_mock()
|
||||
self.on_new_event.reset_mock()
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 3)
|
||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
||||
|
|
|
@ -183,7 +183,17 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
|||
)
|
||||
self.assertEquals(200, code, msg=str(response))
|
||||
|
||||
self.assertEquals(0, len(response["chunk"]))
|
||||
# We may get a presence event for ourselves down
|
||||
self.assertEquals(
|
||||
0,
|
||||
len([
|
||||
c for c in response["chunk"]
|
||||
if not (
|
||||
c.get("type") == "m.presence"
|
||||
and c["content"].get("user_id") == self.user_id
|
||||
)
|
||||
])
|
||||
)
|
||||
|
||||
# joined room (expect all content for room)
|
||||
yield self.join(room=room_id, user=self.user_id, tok=self.token)
|
||||
|
|
|
@ -357,7 +357,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
|||
# all be ours
|
||||
|
||||
# I'll already get my own presence state change
|
||||
self.assertEquals({"start": "0_1_0", "end": "0_1_0", "chunk": []},
|
||||
self.assertEquals({"start": "0_1_0_0", "end": "0_1_0_0", "chunk": []},
|
||||
response
|
||||
)
|
||||
|
||||
|
@ -376,7 +376,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
|||
"/events?from=s0_1_0&timeout=0", None)
|
||||
|
||||
self.assertEquals(200, code)
|
||||
self.assertEquals({"start": "s0_1_0", "end": "s0_2_0", "chunk": [
|
||||
self.assertEquals({"start": "s0_1_0_0", "end": "s0_2_0_0", "chunk": [
|
||||
{"type": "m.presence",
|
||||
"content": {
|
||||
"user_id": "@banana:test",
|
||||
|
|
Loading…
Reference in New Issue