Implement device key caching over federation
This commit is contained in:
parent
51e9fe36e4
commit
c974116f19
|
@ -126,6 +126,16 @@ class FederationClient(FederationBase):
|
|||
destination, content, timeout
|
||||
)
|
||||
|
||||
@log_function
|
||||
def query_user_devices(self, destination, user_id, timeout=30000):
|
||||
"""Query the device keys for a list of user ids hosted on a remote
|
||||
server.
|
||||
"""
|
||||
sent_queries_counter.inc("user_devices")
|
||||
return self.transport_layer.query_user_devices(
|
||||
destination, user_id, timeout
|
||||
)
|
||||
|
||||
@log_function
|
||||
def claim_client_keys(self, destination, content, timeout):
|
||||
"""Claims one-time keys for a device hosted on a remote server.
|
||||
|
|
|
@ -416,6 +416,9 @@ class FederationServer(FederationBase):
|
|||
def on_query_client_keys(self, origin, content):
|
||||
return self.on_query_request("client_keys", content)
|
||||
|
||||
def on_query_user_devices(self, origin, user_id):
|
||||
return self.on_query_request("user_devices", user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_claim_client_keys(self, origin, content):
|
||||
|
|
|
@ -346,6 +346,32 @@ class TransportLayerClient(object):
|
|||
)
|
||||
defer.returnValue(content)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def query_user_devices(self, destination, user_id, timeout):
|
||||
"""Query the devices for a user id hosted on a remote server.
|
||||
|
||||
Response:
|
||||
{
|
||||
"stream_id": "...",
|
||||
"devices": [ { ... } ]
|
||||
}
|
||||
|
||||
Args:
|
||||
destination(str): The server to query.
|
||||
query_content(dict): The user ids to query.
|
||||
Returns:
|
||||
A dict containg the device keys.
|
||||
"""
|
||||
path = PREFIX + "/user/devices/" + user_id
|
||||
|
||||
content = yield self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
timeout=timeout,
|
||||
)
|
||||
defer.returnValue(content)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def claim_client_keys(self, destination, query_content, timeout):
|
||||
|
|
|
@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet):
|
|||
return self.handler.on_query_client_keys(origin, content)
|
||||
|
||||
|
||||
class FederationUserDevicesQueryServlet(BaseFederationServlet):
|
||||
PATH = "/user/devices/(?P<user_id>[^/]*)"
|
||||
|
||||
def on_GET(self, origin, content, query, user_id):
|
||||
return self.handler.on_query_user_devices(origin, user_id)
|
||||
|
||||
|
||||
class FederationClientKeysClaimServlet(BaseFederationServlet):
|
||||
PATH = "/user/keys/claim"
|
||||
|
||||
|
@ -613,6 +620,7 @@ SERVLET_CLASSES = (
|
|||
FederationGetMissingEventsServlet,
|
||||
FederationEventAuthServlet,
|
||||
FederationClientKeysQueryServlet,
|
||||
FederationUserDevicesQueryServlet,
|
||||
FederationClientKeysClaimServlet,
|
||||
FederationThirdPartyInviteExchangeServlet,
|
||||
On3pidBindServlet,
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
from synapse.api import errors
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.types import get_domain_from_id
|
||||
from twisted.internet import defer
|
||||
from ._base import BaseHandler
|
||||
|
@ -28,8 +29,18 @@ class DeviceHandler(BaseHandler):
|
|||
def __init__(self, hs):
|
||||
super(DeviceHandler, self).__init__(hs)
|
||||
|
||||
self.hs = hs
|
||||
self.state = hs.get_state_handler()
|
||||
self.federation = hs.get_federation_sender()
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
self.federation = hs.get_replication_layer()
|
||||
self._remote_edue_linearizer = Linearizer(name="remote_device_list")
|
||||
|
||||
self.federation.register_edu_handler(
|
||||
"m.device_list_update", self._incoming_device_list_update,
|
||||
)
|
||||
self.federation.register_query_handler(
|
||||
"user_devices", self.on_federation_query_user_devices,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_device_registered(self, user_id, device_id,
|
||||
|
@ -55,7 +66,7 @@ class DeviceHandler(BaseHandler):
|
|||
initial_device_display_name=initial_device_display_name,
|
||||
)
|
||||
if new_device:
|
||||
yield self.notify_device_update(user_id, device_id)
|
||||
yield self.notify_device_update(user_id, [device_id])
|
||||
defer.returnValue(device_id)
|
||||
|
||||
# if the device id is not specified, we'll autogen one, but loop a few
|
||||
|
@ -69,7 +80,7 @@ class DeviceHandler(BaseHandler):
|
|||
initial_device_display_name=initial_device_display_name,
|
||||
)
|
||||
if new_device:
|
||||
yield self.notify_device_update(user_id, device_id)
|
||||
yield self.notify_device_update(user_id, [device_id])
|
||||
defer.returnValue(device_id)
|
||||
attempts += 1
|
||||
|
||||
|
@ -151,7 +162,7 @@ class DeviceHandler(BaseHandler):
|
|||
user_id=user_id, device_id=device_id
|
||||
)
|
||||
|
||||
yield self.notify_device_update(user_id, device_id)
|
||||
yield self.notify_device_update(user_id, [device_id])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_device(self, user_id, device_id, content):
|
||||
|
@ -172,7 +183,7 @@ class DeviceHandler(BaseHandler):
|
|||
device_id,
|
||||
new_display_name=content.get("display_name")
|
||||
)
|
||||
yield self.notify_device_update(user_id, device_id)
|
||||
yield self.notify_device_update(user_id, [device_id])
|
||||
except errors.StoreError, e:
|
||||
if e.code == 404:
|
||||
raise errors.NotFoundError()
|
||||
|
@ -180,26 +191,28 @@ class DeviceHandler(BaseHandler):
|
|||
raise
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify_device_update(self, user_id, device_id):
|
||||
def notify_device_update(self, user_id, device_ids):
|
||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||
room_ids = [r.room_id for r in rooms]
|
||||
|
||||
hosts = set()
|
||||
if self.hs.is_mine_id(user_id):
|
||||
for room_id in room_ids:
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
hosts.update(get_domain_from_id(u) for u in users)
|
||||
hosts.discard(self.server_name)
|
||||
|
||||
position = yield self.store.add_device_change_to_streams(
|
||||
user_id, device_id, list(hosts)
|
||||
user_id, device_ids, list(hosts)
|
||||
)
|
||||
|
||||
yield self.notifier.on_new_event(
|
||||
"device_list_key", position, rooms=room_ids,
|
||||
)
|
||||
|
||||
logger.info("Sending device list update notif to: %r", hosts)
|
||||
for host in hosts:
|
||||
self.federation.send_device_messages(host)
|
||||
self.federation_sender.send_device_messages(host)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_device_list_changes(self, user_id, room_ids, from_key):
|
||||
|
@ -214,6 +227,54 @@ class DeviceHandler(BaseHandler):
|
|||
|
||||
defer.returnValue(user_ids_changed)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _incoming_device_list_update(self, origin, edu_content):
|
||||
user_id = edu_content["user_id"]
|
||||
device_id = edu_content["device_id"]
|
||||
stream_id = edu_content["stream_id"]
|
||||
prev_ids = edu_content.get("prev_id", [])
|
||||
|
||||
if get_domain_from_id(user_id) != origin:
|
||||
# TODO: Raise?
|
||||
return
|
||||
|
||||
logger.info("Got edu: %r", edu_content)
|
||||
|
||||
with (yield self._remote_edue_linearizer.queue(user_id)):
|
||||
resync = True
|
||||
if len(prev_ids) == 1:
|
||||
extremity = yield self.store.get_device_list_remote_extremity(user_id)
|
||||
logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
|
||||
if str(extremity) == str(prev_ids[0]):
|
||||
resync = False
|
||||
|
||||
if resync:
|
||||
result = yield self.federation.query_user_devices(origin, user_id)
|
||||
stream_id = result["stream_id"]
|
||||
devices = result["devices"]
|
||||
yield self.store.update_remote_device_list_cache(
|
||||
user_id, devices, stream_id,
|
||||
)
|
||||
device_ids = [device["device_id"] for device in devices]
|
||||
yield self.notify_device_update(user_id, device_ids)
|
||||
else:
|
||||
content = dict(edu_content)
|
||||
for key in ("user_id", "device_id", "stream_id", "prev_ids"):
|
||||
content.pop(key, None)
|
||||
yield self.store.update_remote_device_list_cache_entry(
|
||||
user_id, device_id, content, stream_id,
|
||||
)
|
||||
yield self.notify_device_update(user_id, [device_id])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_federation_query_user_devices(self, user_id):
|
||||
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
|
||||
defer.returnValue({
|
||||
"user_id": user_id,
|
||||
"stream_id": stream_id,
|
||||
"devices": devices,
|
||||
})
|
||||
|
||||
|
||||
def _update_device_from_client_ips(device, client_ips):
|
||||
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
||||
|
|
|
@ -73,8 +73,7 @@ class E2eKeysHandler(object):
|
|||
if self.is_mine_id(user_id):
|
||||
local_query[user_id] = device_ids
|
||||
else:
|
||||
domain = get_domain_from_id(user_id)
|
||||
remote_queries.setdefault(domain, {})[user_id] = device_ids
|
||||
remote_queries[user_id] = device_ids
|
||||
|
||||
# do the queries
|
||||
failures = {}
|
||||
|
@ -85,9 +84,40 @@ class E2eKeysHandler(object):
|
|||
if user_id in local_query:
|
||||
results[user_id] = keys
|
||||
|
||||
remote_queries_not_in_cache = {}
|
||||
if remote_queries:
|
||||
query_list = []
|
||||
for user_id, device_ids in remote_queries.iteritems():
|
||||
if device_ids:
|
||||
query_list.extend((user_id, device_id) for device_id in device_ids)
|
||||
else:
|
||||
query_list.append((user_id, None))
|
||||
|
||||
user_ids_not_in_cache, remote_results = (
|
||||
yield self.store.get_user_devices_from_cache(
|
||||
query_list
|
||||
)
|
||||
)
|
||||
for user_id, devices in remote_results.iteritems():
|
||||
user_devices = results.setdefault(user_id, {})
|
||||
for device_id, device in devices.iteritems():
|
||||
keys = device.get("keys", None)
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if keys:
|
||||
result = dict(keys)
|
||||
unsigned = result.setdefault("unsigned", {})
|
||||
if device_display_name:
|
||||
unsigned["device_display_name"] = device_display_name
|
||||
user_devices[device_id] = result
|
||||
|
||||
for user_id in user_ids_not_in_cache:
|
||||
domain = get_domain_from_id(user_id)
|
||||
r = remote_queries_not_in_cache.setdefault(domain, {})
|
||||
r[user_id] = remote_queries[user_id]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_remote_query(destination):
|
||||
destination_query = remote_queries[destination]
|
||||
destination_query = remote_queries_not_in_cache[destination]
|
||||
try:
|
||||
limiter = yield get_retry_limiter(
|
||||
destination, self.clock, self.store
|
||||
|
@ -119,7 +149,7 @@ class E2eKeysHandler(object):
|
|||
|
||||
yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(do_remote_query)(destination)
|
||||
for destination in remote_queries
|
||||
for destination in remote_queries_not_in_cache
|
||||
]))
|
||||
|
||||
defer.returnValue({
|
||||
|
@ -259,7 +289,7 @@ class E2eKeysHandler(object):
|
|||
user_id, device_id, time_now,
|
||||
encode_canonical_json(device_keys)
|
||||
)
|
||||
yield self.device_handler.notify_device_update(user_id, device_id)
|
||||
yield self.device_handler.notify_device_update(user_id, [device_id])
|
||||
|
||||
one_time_keys = keys.get("one_time_keys", None)
|
||||
if one_time_keys:
|
||||
|
|
|
@ -138,6 +138,89 @@ class DeviceStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue({d["device_id"]: d for d in devices})
|
||||
|
||||
def get_device_list_remote_extremity(self, user_id):
|
||||
return self._simple_select_one_onecol(
|
||||
table="device_lists_remote_extremeties",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="stream_id",
|
||||
desc="get_device_list_remote_extremity",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
def update_remote_device_list_cache_entry(self, user_id, device_id, content,
|
||||
stream_id):
|
||||
return self.runInteraction(
|
||||
"update_remote_device_list_cache_entry",
|
||||
self._update_remote_device_list_cache_entry_txn,
|
||||
user_id, device_id, content, stream_id,
|
||||
)
|
||||
|
||||
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
|
||||
content, stream_id):
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
values={
|
||||
"content": json.dumps(content),
|
||||
}
|
||||
)
|
||||
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="device_lists_remote_extremeties",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
)
|
||||
|
||||
def update_remote_device_list_cache(self, user_id, devices, stream_id):
|
||||
return self.runInteraction(
|
||||
"update_remote_device_list_cache",
|
||||
self._update_remote_device_list_cache_txn,
|
||||
user_id, devices, stream_id,
|
||||
)
|
||||
|
||||
def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
|
||||
stream_id):
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="device_lists_remote_cache",
|
||||
values=[
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": content["device_id"],
|
||||
"content": json.dumps(content),
|
||||
}
|
||||
for content in devices
|
||||
]
|
||||
)
|
||||
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="device_lists_remote_extremeties",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
)
|
||||
|
||||
def get_devices_by_remote(self, destination, from_stream_id):
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
|
||||
|
@ -184,7 +267,7 @@ class DeviceStore(SQLBaseStore):
|
|||
txn.execute(prev_sent_id_sql, (destination, user_id, True))
|
||||
rows = txn.fetchall()
|
||||
prev_id = rows[0][0]
|
||||
for device_id, result in user_devices.iteritems():
|
||||
for device_id, device in user_devices.iteritems():
|
||||
stream_id = query_map[(user_id, device_id)]
|
||||
result = {
|
||||
"user_id": user_id,
|
||||
|
@ -195,10 +278,10 @@ class DeviceStore(SQLBaseStore):
|
|||
|
||||
prev_id = stream_id
|
||||
|
||||
key_json = result.get("key_json", None)
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
result["keys"] = json.loads(key_json)
|
||||
device_display_name = result.get("device_display_name", None)
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
|
||||
|
@ -206,6 +289,96 @@ class DeviceStore(SQLBaseStore):
|
|||
|
||||
return (now_stream_id, results)
|
||||
|
||||
def get_user_devices_from_cache(self, query_list):
|
||||
return self.runInteraction(
|
||||
"get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
|
||||
query_list,
|
||||
)
|
||||
|
||||
def _get_user_devices_from_cache_txn(self, txn, query_list):
|
||||
user_ids = {user_id for user_id, _ in query_list}
|
||||
|
||||
user_ids_in_cache = set()
|
||||
for user_id in user_ids:
|
||||
stream_ids = self._simple_select_onecol_txn(
|
||||
txn,
|
||||
table="device_lists_remote_extremeties",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
retcol="stream_id",
|
||||
)
|
||||
if stream_ids:
|
||||
user_ids_in_cache.add(user_id)
|
||||
|
||||
user_ids_not_in_cache = user_ids - user_ids_in_cache
|
||||
|
||||
results = {}
|
||||
for user_id, device_id in query_list:
|
||||
if user_id not in user_ids_in_cache:
|
||||
continue
|
||||
|
||||
if device_id:
|
||||
content = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
retcol="content",
|
||||
)
|
||||
results.setdefault(user_id, {})[device_id] = json.loads(content)
|
||||
else:
|
||||
devices = self._simple_select_list_txn(
|
||||
txn,
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
retcols=("device_id", "content"),
|
||||
)
|
||||
results[user_id] = {
|
||||
device["device_id"]: json.loads(device["content"])
|
||||
for device in devices
|
||||
}
|
||||
user_ids_in_cache.discard(user_id)
|
||||
|
||||
return user_ids_not_in_cache, results
|
||||
|
||||
def get_devices_with_keys_by_user(self, user_id):
|
||||
return self.runInteraction(
|
||||
"get_devices_with_keys_by_user",
|
||||
self._get_devices_with_keys_by_user_txn, user_id,
|
||||
)
|
||||
|
||||
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
|
||||
devices = self._get_e2e_device_keys_txn(
|
||||
txn, [(user_id, None)], include_all_devices=True
|
||||
)
|
||||
|
||||
for user_id, user_devices in devices.iteritems():
|
||||
results = []
|
||||
for device_id, device in user_devices.iteritems():
|
||||
result = {
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
result["keys"] = json.loads(key_json)
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
|
||||
results.append(result)
|
||||
|
||||
return now_stream_id, results
|
||||
|
||||
return now_stream_id, []
|
||||
|
||||
def mark_as_sent_devices_by_remote(self, destination, stream_id):
|
||||
return self.runInteraction(
|
||||
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
|
||||
|
@ -242,17 +415,17 @@ class DeviceStore(SQLBaseStore):
|
|||
defer.returnValue(set(row["user_id"] for row in rows))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_device_change_to_streams(self, user_id, device_id, hosts):
|
||||
def add_device_change_to_streams(self, user_id, device_ids, hosts):
|
||||
# device_lists_stream
|
||||
# device_lists_outbound_pokes
|
||||
with self._device_list_id_gen.get_next() as stream_id:
|
||||
yield self.runInteraction(
|
||||
"add_device_change_to_streams", self._add_device_change_txn,
|
||||
user_id, device_id, hosts, stream_id,
|
||||
user_id, device_ids, hosts, stream_id,
|
||||
)
|
||||
defer.returnValue(stream_id)
|
||||
|
||||
def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id):
|
||||
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
|
||||
txn.call_after(
|
||||
self._device_list_stream_cache.entity_has_changed,
|
||||
user_id, stream_id,
|
||||
|
@ -263,14 +436,17 @@ class DeviceStore(SQLBaseStore):
|
|||
host, stream_id,
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="device_lists_stream",
|
||||
values={
|
||||
values=[
|
||||
{
|
||||
"stream_id": stream_id,
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
}
|
||||
for device_id in device_ids
|
||||
]
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
|
@ -285,6 +461,7 @@ class DeviceStore(SQLBaseStore):
|
|||
"sent": False,
|
||||
}
|
||||
for destination in hosts
|
||||
for device_id in device_ids
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -52,11 +52,11 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||
query_params = []
|
||||
|
||||
for (user_id, device_id) in query_list:
|
||||
query_clause = "k.user_id = ?"
|
||||
query_clause = "user_id = ?"
|
||||
query_params.append(user_id)
|
||||
|
||||
if device_id:
|
||||
query_clause += " AND k.device_id = ?"
|
||||
query_clause += " AND device_id = ?"
|
||||
query_params.append(device_id)
|
||||
|
||||
query_clauses.append(query_clause)
|
||||
|
|
|
@ -13,18 +13,6 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE device_list_streams_remote (
|
||||
list_id TEXT NOT NULL,
|
||||
origin TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
is_full BOOLEAN NOT NULL,
|
||||
ts BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX device_list_streams_remote_id_origin ON device_list_streams_remote(
|
||||
origin, list_id, user_id
|
||||
);
|
||||
|
||||
|
||||
CREATE TABLE device_lists_remote_cache (
|
||||
user_id TEXT NOT NULL,
|
||||
|
@ -35,6 +23,14 @@ CREATE TABLE device_lists_remote_cache (
|
|||
CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
|
||||
|
||||
|
||||
CREATE TABLE device_lists_remote_extremeties (
|
||||
user_id TEXT NOT NULL,
|
||||
stream_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
|
||||
|
||||
|
||||
CREATE TABLE device_lists_stream (
|
||||
stream_id BIGINT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
|
|
|
@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
hs = yield utils.setup_test_homeserver(handlers=None)
|
||||
self.handler = synapse.handlers.device.DeviceHandler(hs)
|
||||
hs = yield utils.setup_test_homeserver()
|
||||
self.handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_device_is_created_if_doesnt_exist(self):
|
||||
res = yield self.handler.check_device_registered(
|
||||
user_id="boris",
|
||||
user_id="@boris:foo",
|
||||
device_id="fco",
|
||||
initial_device_display_name="display name"
|
||||
)
|
||||
self.assertEqual(res, "fco")
|
||||
|
||||
dev = yield self.handler.store.get_device("boris", "fco")
|
||||
dev = yield self.handler.store.get_device("@boris:foo", "fco")
|
||||
self.assertEqual(dev["display_name"], "display name")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_device_is_preserved_if_exists(self):
|
||||
res1 = yield self.handler.check_device_registered(
|
||||
user_id="boris",
|
||||
user_id="@boris:foo",
|
||||
device_id="fco",
|
||||
initial_device_display_name="display name"
|
||||
)
|
||||
self.assertEqual(res1, "fco")
|
||||
|
||||
res2 = yield self.handler.check_device_registered(
|
||||
user_id="boris",
|
||||
user_id="@boris:foo",
|
||||
device_id="fco",
|
||||
initial_device_display_name="new display name"
|
||||
)
|
||||
self.assertEqual(res2, "fco")
|
||||
|
||||
dev = yield self.handler.store.get_device("boris", "fco")
|
||||
dev = yield self.handler.store.get_device("@boris:foo", "fco")
|
||||
self.assertEqual(dev["display_name"], "display name")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_device_id_is_made_up_if_unspecified(self):
|
||||
device_id = yield self.handler.check_device_registered(
|
||||
user_id="theresa",
|
||||
user_id="@theresa:foo",
|
||||
device_id=None,
|
||||
initial_device_display_name="display"
|
||||
)
|
||||
|
||||
dev = yield self.handler.store.get_device("theresa", device_id)
|
||||
dev = yield self.handler.store.get_device("@theresa:foo", device_id)
|
||||
self.assertEqual(dev["display_name"], "display")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self.mock_federation = Mock(spec=[
|
||||
"make_query",
|
||||
"register_edu_handler",
|
||||
])
|
||||
|
||||
self.query_handlers = {}
|
||||
|
|
|
@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self.mock_federation = Mock(spec=[
|
||||
"make_query",
|
||||
"register_edu_handler",
|
||||
])
|
||||
|
||||
self.query_handlers = {}
|
||||
|
|
|
@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
|||
event_cache_size=1,
|
||||
password_providers=[],
|
||||
)
|
||||
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
|
||||
hs = yield setup_test_homeserver(
|
||||
config=config,
|
||||
federation_sender=Mock(),
|
||||
replication_layer=Mock(),
|
||||
)
|
||||
|
||||
self.as_token = "token1"
|
||||
self.as_url = "some_url"
|
||||
|
@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||
event_cache_size=1,
|
||||
password_providers=[],
|
||||
)
|
||||
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
|
||||
hs = yield setup_test_homeserver(
|
||||
config=config,
|
||||
federation_sender=Mock(),
|
||||
replication_layer=Mock(),
|
||||
)
|
||||
self.db_pool = hs.get_db_pool()
|
||||
|
||||
self.as_list = [
|
||||
|
@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||
hs = yield setup_test_homeserver(
|
||||
config=config,
|
||||
datastore=Mock(),
|
||||
federation_sender=Mock()
|
||||
federation_sender=Mock(),
|
||||
replication_layer=Mock(),
|
||||
)
|
||||
|
||||
ApplicationServiceStore(hs)
|
||||
|
@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||
hs = yield setup_test_homeserver(
|
||||
config=config,
|
||||
datastore=Mock(),
|
||||
federation_sender=Mock()
|
||||
federation_sender=Mock(),
|
||||
replication_layer=Mock(),
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigError) as cm:
|
||||
|
@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||
hs = yield setup_test_homeserver(
|
||||
config=config,
|
||||
datastore=Mock(),
|
||||
federation_sender=Mock()
|
||||
federation_sender=Mock(),
|
||||
replication_layer=Mock(),
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigError) as cm:
|
||||
|
|
Loading…
Reference in New Issue