Add federation support for end-to-end key requests

This commit is contained in:
Mark Haines 2015-07-23 16:03:38 +01:00
parent 4e2e67fd50
commit 62c010283d
5 changed files with 231 additions and 30 deletions

View File

@ -134,6 +134,40 @@ class FederationClient(FederationBase):
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
) )
@log_function
def query_client_keys(self, destination, content, retry_on_dns_fail=True):
"""Query device keys for a device hosted on a remote server.
Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
sent_queries_counter.inc("client_device_keys")
return self.transport_layer.query_client_keys(
destination, content, retry_on_dns_fail=retry_on_dns_fail
)
@log_function
def claim_client_keys(self, destination, content, retry_on_dns_fail=True):
"""Claims one-time keys for a device hosted on a remote server.
Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
sent_queries_counter.inc("client_one_time_keys")
return self.transport_layer.claim_client_keys(
destination, content, retry_on_dns_fail=retry_on_dns_fail
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def backfill(self, dest, context, limit, extremities): def backfill(self, dest, context, limit, extremities):

View File

@ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
import simplejson as json
import logging import logging
@ -312,6 +313,42 @@ class FederationServer(FederationBase):
(200, send_content) (200, send_content)
) )
@defer.inlineCallbacks
@log_function
def on_query_client_keys(self, origin, content):
query = []
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
defer.returnValue({"device_keys": json_result})
@defer.inlineCallbacks
@log_function
def on_claim_client_keys(self, origin, content):
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
results = yield self.store.claim_e2e_one_time_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
defer.returnValue({"one_time_keys": json_result})
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_get_missing_events(self, origin, room_id, earliest_events, def on_get_missing_events(self, origin, room_id, earliest_events,

View File

@ -222,6 +222,76 @@ class TransportLayerClient(object):
defer.returnValue(content) defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def query_client_keys(self, destination, query_content):
"""Query the device keys for a list of user ids hosted on a remote
server.
Request:
{
"device_keys": {
"<user_id>": ["<device_id>"]
} }
Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {...}
} } }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/client_keys/query"
content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def claim_client_keys(self, destination, query_content):
"""Claim one-time keys for a list of devices hosted on a remote server.
Request:
{
"one_time_keys": {
"<user_id>": {
"<device_id>": "<algorithm>"
} } }
Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
} } } }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the one-time keys.
"""
path = PREFIX + "/client_keys/claim"
content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
)
defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_missing_events(self, destination, room_id, earliest_events, def get_missing_events(self, destination, room_id, earliest_events,

View File

@ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/client_keys/query"
@defer.inlineCallbacks
def on_POST(self, origin, content):
response = yield self.handler.on_client_key_query(origin, content)
defer.returnValue((200, response))
class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/client_keys/claim"
@defer.inlineCallbacks
def on_POST(self, origin, content):
response = yield self.handler.on_client_key_claim(origin, content)
defer.returnValue((200, response))
class FederationQueryAuthServlet(BaseFederationServlet): class FederationQueryAuthServlet(BaseFederationServlet):
PATH = "/query_auth/([^/]*)/([^/]*)" PATH = "/query_auth/([^/]*)/([^/]*)"
@ -373,4 +391,6 @@ SERVLET_CLASSES = (
FederationQueryAuthServlet, FederationQueryAuthServlet,
FederationGetMissingEventsServlet, FederationGetMissingEventsServlet,
FederationEventAuthServlet, FederationEventAuthServlet,
FederationClientKeysQueryServlet,
FederationClientKeysClaimServlet,
) )

View File

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.types import UserID
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
from ._base import client_v2_pattern from ._base import client_v2_pattern
@ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet):
super(KeyQueryServlet, self).__init__() super(KeyQueryServlet, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id): def on_POST(self, request, user_id, device_id):
logger.debug("onPOST")
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
try: try:
body = json.loads(request.content.read()) body = json.loads(request.content.read())
except: except:
raise SynapseError(400, "Invalid key JSON") raise SynapseError(400, "Invalid key JSON")
query = [] result = yield self.handle_request(body)
for user_id, device_ids in body.get("device_keys", {}).items(): defer.returnValue(result)
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
defer.returnValue(self.json_result(request, results))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id): def on_GET(self, request, user_id, device_id):
auth_user, client_info = yield self.auth.get_user_by_req(request) auth_user, client_info = yield self.auth.get_user_by_req(request)
auth_user_id = auth_user.to_string() auth_user_id = auth_user.to_string()
if not user_id: user_id = user_id if user_id else auth_user_id
user_id = auth_user_id device_ids = [device_id] if device_id else []
if not device_id: result = yield self.handle_request(
device_id = None {"device_keys": {user_id: device_ids}}
# Returns a map of user_id->device_id->json_bytes. )
results = yield self.store.get_e2e_device_keys([(user_id, device_id)]) defer.returnValue(result)
defer.returnValue(self.json_result(request, results))
@defer.inlineCallbacks
def handle_request(self, body):
local_query = []
remote_queries = {}
for user_id, device_ids in body.get("device_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
else:
remote_queries.set_default(user.domain, {})[user_id] = list(
device_ids
)
results = yield self.store.get_e2e_device_keys(local_query)
def json_result(self, request, results):
json_result = {} json_result = {}
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items(): for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads( json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes json_bytes
) )
return (200, {"device_keys": json_result})
for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.query_client_keys(
destination, {"device_keys": device_keys}
)
for user_id, keys in remote_result.items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"device_keys": json_result}))
class OneTimeKeyServlet(RestServlet): class OneTimeKeyServlet(RestServlet):
@ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm): def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
results = yield self.store.claim_e2e_one_time_keys( result = yield self.handle_request(
[(user_id, device_id, algorithm)] {"one_time_keys": {user_id: {device_id: algorithm}}}
) )
defer.returnValue(self.json_result(request, results)) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm): def on_POST(self, request, user_id, device_id, algorithm):
@ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet):
body = json.loads(request.content.read()) body = json.loads(request.content.read())
except: except:
raise SynapseError(400, "Invalid key JSON") raise SynapseError(400, "Invalid key JSON")
query = [] result = yield self.handle_request(body)
for user_id, device_keys in body.get("one_time_keys", {}).items(): defer.returnValue(result)
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm)) @defer.inlineCallbacks
results = yield self.store.claim_e2e_one_time_keys(query) def handle_request(self, body):
defer.returnValue(self.json_result(request, results)) local_query = []
remote_queries = {}
for user_id, device_keys in body.get("one_time_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
remote_queries.set_default(user.domain, {})[user_id] = (
device_keys
)
results = yield self.store.claim_e2e_one_time_keys(local_query)
def json_result(self, request, results):
json_result = {} json_result = {}
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, keys in device_keys.items(): for device_id, keys in device_keys.items():
@ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet):
json_result.setdefault(user_id, {})[device_id] = { json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes) key_id: json.loads(json_bytes)
} }
return (200, {"one_time_keys": json_result})
for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.query_client_keys(
destination, {"one_time_keys": device_keys}
)
for user_id, keys in remote_result.items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"one_time_keys": json_result}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):