Implement remote key lookup api

This commit is contained in:
Mark Haines 2015-04-22 14:21:08 +01:00
parent 3ba522bb23
commit f30d47c876
7 changed files with 252 additions and 61 deletions

View File

@ -156,5 +156,5 @@ class ServerConfig(Config):
args.old_signing_key_path = base_key_name + ".old.signing.keys" args.old_signing_key_path = base_key_name + ".old.signing.keys"
if not os.path.exists(args.old_signing_key_path): if not os.path.exists(args.old_signing_key_path):
with open(args.old_signing_key_path, "w") as old_signing_key_file: with open(args.old_signing_key_path, "w"):
pass pass

View File

@ -26,7 +26,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
KEY_API_V1 = b"/_matrix/key/v1/" KEY_API_V1 = b"/_matrix/key/v1/"
KEY_API_V2 = b"/_matrix/key/v2/local"
@defer.inlineCallbacks @defer.inlineCallbacks
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1): def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
@ -94,8 +94,8 @@ class SynapseKeyClientProtocol(HTTPClient):
if status != b"200": if status != b"200":
# logger.info("Non-200 response from %s: %s %s", # logger.info("Non-200 response from %s: %s %s",
# self.transport.getHost(), status, message) # self.transport.getHost(), status, message)
error = SynapseKeyClientError("Non-200 response %r from %r" % error = SynapseKeyClientError(
(status, self.host) "Non-200 response %r from %r" % (status, self.host)
) )
error.status = status error.status = status
self.errback(error) self.errback(error)

View File

@ -15,7 +15,9 @@
from synapse.crypto.keyclient import fetch_server_key from synapse.crypto.keyclient import fetch_server_key
from twisted.internet import defer from twisted.internet import defer
from syutil.crypto.jsonsign import verify_signed_json, signature_ids from syutil.crypto.jsonsign import (
verify_signed_json, signature_ids, sign_json, encode_canonical_json
)
from syutil.crypto.signing_key import ( from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes is_signing_algorithm_supported, decode_verify_key_bytes
) )
@ -26,6 +28,8 @@ from synapse.util.retryutils import get_retry_limiter
from OpenSSL import crypto from OpenSSL import crypto
import urllib
import hashlib
import logging import logging
@ -37,6 +41,7 @@ class Keyring(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_http_client() self.client = hs.get_http_client()
self.config = hs.get_config()
self.perspective_servers = {} self.perspective_servers = {}
self.hs = hs self.hs = hs
@ -127,7 +132,6 @@ class Keyring(object):
server_name, key_ids server_name, key_ids
) )
for key_id in key_ids: for key_id in key_ids:
if key_id in keys: if key_id in keys:
defer.returnValue(keys[key_id]) defer.returnValue(keys[key_id])
@ -142,13 +146,14 @@ class Keyring(object):
perspective_name, self.clock, self.store perspective_name, self.clock, self.store
) )
with limiter:
responses = yield self.client.post_json( responses = yield self.client.post_json(
destination=perspective_name, destination=perspective_name,
path=b"/_matrix/key/v2/query", path=b"/_matrix/key/v2/query",
data={u"server_keys": {server_name: list(key_ids)}}, data={u"server_keys": {server_name: list(key_ids)}},
) )
keys = dict() keys = {}
for response in responses: for response in responses:
if (u"signatures" not in response if (u"signatures" not in response
@ -181,7 +186,9 @@ class Keyring(object):
" server %r" % (perspective_name,) " server %r" % (perspective_name,)
) )
response_keys = process_v2_response(self, server_name, key_ids) response_keys = yield self.process_v2_response(
server_name, perspective_name, response
)
keys.update(response_keys) keys.update(response_keys)
@ -202,11 +209,11 @@ class Keyring(object):
if requested_key_id in keys: if requested_key_id in keys:
continue continue
(response_json, tls_certificate) = yield fetch_server_key( (response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory, server_name, self.hs.tls_context_factory,
path="/_matrix/key/v2/server/%s" % ( path=(b"/_matrix/key/v2/server/%s" % (
urllib.quote(requested_key_id), urllib.quote(requested_key_id),
), )).encode("ascii"),
) )
if (u"signatures" not in response if (u"signatures" not in response
@ -223,17 +230,18 @@ class Keyring(object):
sha256_fingerprint_b64 = encode_base64(sha256_fingerprint) sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
response_sha256_fingerprints = set() response_sha256_fingerprints = set()
for fingerprint in response_json[u"tls_fingerprints"]: for fingerprint in response[u"tls_fingerprints"]:
if u"sha256" in fingerprint: if u"sha256" in fingerprint:
response_sha256_fingerprints.add(fingerprint[u"sha256"]) response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint not in response_sha256_fingerprints: if sha256_fingerprint_b64 not in response_sha256_fingerprints:
raise ValueError("TLS certificate not allowed by fingerprints") raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
server_name=server_name, server_name=server_name,
from_server=server_name, from_server=server_name,
response_json=response_json, requested_id=requested_key_id,
response_json=response,
) )
keys.update(response_keys) keys.update(response_keys)
@ -244,19 +252,15 @@ class Keyring(object):
verify_keys=keys, verify_keys=keys,
) )
for key_id in key_ids: defer.returnValue(keys)
if key_id in verify_keys:
defer.returnValue(verify_keys[key_id])
return
raise ValueError("No verification key found for given key ids")
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response(self, server_name, from_server, json_response): def process_v2_response(self, server_name, from_server, response_json,
time_now_ms = clock.time_msec() requested_id=None):
time_now_ms = self.clock.time_msec()
response_keys = {} response_keys = {}
verify_keys = {} verify_keys = {}
for key_id, key_data in response["verify_keys"].items(): for key_id, key_data in response_json["verify_keys"].items():
if is_signing_algorithm_supported(key_id): if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"] key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64) key_bytes = decode_base64(key_base64)
@ -264,7 +268,7 @@ class Keyring(object):
verify_keys[key_id] = verify_key verify_keys[key_id] = verify_key
old_verify_keys = {} old_verify_keys = {}
for key_id, key_data in response["verify_keys"].items(): for key_id, key_data in response_json["old_verify_keys"].items():
if is_signing_algorithm_supported(key_id): if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"] key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64) key_bytes = decode_base64(key_base64)
@ -273,21 +277,21 @@ class Keyring(object):
verify_key.time_added = time_now_ms verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key old_verify_keys[key_id] = verify_key
for key_id in response["signatures"][server_name]: for key_id in response_json["signatures"][server_name]:
if key_id not in response["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise ValueError( raise ValueError(
"Key response must include verification keys for all" "Key response must include verification keys for all"
" signatures" " signatures"
) )
if key_id in verify_keys: if key_id in verify_keys:
verify_signed_json( verify_signed_json(
response, response_json,
server_name, server_name,
verify_keys[key_id] verify_keys[key_id]
) )
signed_key_json = sign_json( signed_key_json = sign_json(
response, response_json,
self.config.server_name, self.config.server_name,
self.config.signing_key[0], self.config.signing_key[0],
) )
@ -295,7 +299,9 @@ class Keyring(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json) signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until"] ts_valid_until_ms = signed_key_json[u"valid_until"]
updated_key_ids = set([requested_key_id]) updated_key_ids = set()
if requested_id is not None:
updated_key_ids.add(requested_id)
updated_key_ids.update(verify_keys) updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys) updated_key_ids.update(old_verify_keys)
@ -307,8 +313,8 @@ class Keyring(object):
server_name=server_name, server_name=server_name,
key_id=key_id, key_id=key_id,
from_server=server_name, from_server=server_name,
ts_now_ms=ts_now_ms, ts_now_ms=time_now_ms,
ts_valid_until_ms=valid_until, ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes, key_json_bytes=signed_key_json_bytes,
) )
@ -373,7 +379,6 @@ class Keyring(object):
verify_keys[key_id] verify_keys[key_id]
) )
yield self.store.store_server_certificate( yield self.store.store_server_certificate(
server_name, server_name,
server_name, server_name,

View File

@ -13,7 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.web.resource import Resource
from .local_key_resource import LocalKey from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey
class KeyApiV2Resource(LocalKey):
pass class KeyApiV2Resource(Resource):
def __init__(self, hs):
Resource.__init__(self)
self.putChild("server", LocalKey(hs))
self.putChild("query", RemoteKey(hs))

View File

@ -31,7 +31,7 @@ class LocalKey(Resource):
"""HTTP resource containing encoding the TLS X.509 certificate and NACL """HTTP resource containing encoding the TLS X.509 certificate and NACL
signature verification keys for this server:: signature verification keys for this server::
GET /_matrix/key/v2/ HTTP/1.1 GET /_matrix/key/v2/server/a.key.id HTTP/1.1
HTTP/1.1 200 OK HTTP/1.1 200 OK
Content-Type: application/json Content-Type: application/json
@ -56,6 +56,8 @@ class LocalKey(Resource):
} }
""" """
isLeaf = True
def __init__(self, hs): def __init__(self, hs):
self.version_string = hs.version_string self.version_string = hs.version_string
self.config = hs.config self.config = hs.config
@ -68,7 +70,6 @@ class LocalKey(Resource):
self.expires = int(time_now_msec + refresh_interval) self.expires = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object()) self.response_body = encode_canonical_json(self.response_json_object())
def response_json_object(self): def response_json_object(self):
verify_keys = {} verify_keys = {}
for key in self.config.signing_key: for key in self.config.signing_key:
@ -120,7 +121,3 @@ class LocalKey(Resource):
request, 200, self.response_body, request, 200, self.response_body,
version_string=self.version_string version_string=self.version_string
) )
def getChild(self, name, request):
if name == '':
return self

View File

@ -0,0 +1,174 @@
from synapse.http.server import request_handler, respond_with_json_bytes
from synapse.api.errors import SynapseError, Codes
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from io import BytesIO
import json
import logging
logger = logging.getLogger(__name__)
class RemoteKey(Resource):
"""HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
that the NACL signature for the remote server is valid. Returns a dict of
JSON signed by both the remote server and by this server.
Supports individual GET APIs and a bulk query POST API.
Requsts:
GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
GET /_matrix/key/v2/query/remote.server.example.com/a.key.id HTTP/1.1
POST /_matrix/v2/query HTTP/1.1
Content-Type: application/json
{
"server_keys": { "remote.server.example.com": ["a.key.id"] }
}
Response:
HTTP/1.1 200 OK
Content-Type: application/json
{
"server_keys": [
{
"server_name": "remote.server.example.com"
"valid_until": # posix timestamp
"verify_keys": {
"a.key.id": { # The identifier for a key.
key: "" # base64 encoded verification key.
}
}
"old_verify_keys": {
"an.old.key.id": { # The identifier for an old key.
key: "", # base64 encoded key
expired: 0, # when th e
}
}
"tls_fingerprints": [
{ "sha256": # fingerprint }
]
"signatures": {
"remote.server.example.com": {...}
"this.server.example.com": {...}
}
}
]
}
"""
isLeaf = True
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.store = hs.get_datastore()
self.version_string = hs.version_string
self.clock = hs.get_clock()
def render_GET(self, request):
self.async_render_GET(request)
return NOT_DONE_YET
@request_handler
@defer.inlineCallbacks
def async_render_GET(self, request):
if len(request.postpath) == 1:
server, = request.postpath
query = {server: [None]}
elif len(request.postpath) == 2:
server, key_id = request.postpath
query = {server: [key_id]}
else:
raise SynapseError(
404, "Not found %r" % request.postpath, Codes.NOT_FOUND
)
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
def render_POST(self, request):
self.async_render_POST(request)
return NOT_DONE_YET
@request_handler
@defer.inlineCallbacks
def async_render_POST(self, request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise ValueError()
except ValueError:
raise SynapseError(
400, "Content must be JSON object.", errcode=Codes.NOT_JSON
)
query = content["server_keys"]
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
@defer.inlineCallbacks
def query_keys(self, request, query, query_remote_on_cache_miss=False):
store_queries = []
for server_name, key_ids in query.items():
for key_id in key_ids:
store_queries.append((server_name, key_id, None))
cached = yield self.store.get_server_keys_json(store_queries)
json_results = []
time_now_ms = self.clock.time_msec()
cache_misses = dict()
for (server_name, key_id, from_server), results in cached.items():
results = [
(result["ts_added_ms"], result) for result in results
if result["ts_valid_until_ms"] > time_now_ms
]
if not results:
if key_id is not None:
cache_misses.setdefault(server_name, set()).add(key_id)
continue
if key_id is not None:
most_recent_result = max(results)
json_results.append(most_recent_result[-1]["key_json"])
else:
for result in results:
json_results.append(result[-1]["key_json"])
if cache_misses and query_remote_on_cache_miss:
for server_name, key_ids in cache_misses.items():
try:
yield self.keyring.get_server_verify_key_v2_direct(
server_name, key_ids
)
except:
logger.exception("Failed to get key for %r", server_name)
pass
yield self.query_keys(
request, query, query_remote_on_cache_miss=False
)
else:
result_io = BytesIO()
result_io.write(b"{\"server_keys\":")
sep = b"["
for json_bytes in json_results:
result_io.write(sep)
result_io.write(json_bytes)
sep = b","
if sep == b"[":
result_io.write(sep)
result_io.write(b"]}")
respond_with_json_bytes(
request, 200, result_io.getvalue(),
version_string=self.version_string
)

View File

@ -140,8 +140,8 @@ class KeyStore(SQLBaseStore):
"key_id": key_id, "key_id": key_id,
"from_server": from_server, "from_server": from_server,
"ts_added_ms": ts_now_ms, "ts_added_ms": ts_now_ms,
"ts_valid_until_ms": ts_valid_until_ms, "ts_valid_until_ms": ts_expires_ms,
"key_json": key_json_bytes, "key_json": buffer(key_json_bytes),
}, },
or_replace=True, or_replace=True,
) )
@ -149,9 +149,9 @@ class KeyStore(SQLBaseStore):
def get_server_keys_json(self, server_keys): def get_server_keys_json(self, server_keys):
"""Retrive the key json for a list of server_keys and key ids. """Retrive the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet will be missing from the that server, key_id, and source triplet entry will be an empty list.
returned dictionary. The JSON is returned as a byte array so that it The JSON is returned as a byte array so that it can be efficiently
can be efficiently used in an HTTP response. used in an HTTP response.
Args: Args:
server_keys (list): List of (server_name, key_id, source) triplets. server_keys (list): List of (server_name, key_id, source) triplets.
Returns: Returns:
@ -161,16 +161,25 @@ class KeyStore(SQLBaseStore):
def _get_server_keys_json_txn(txn): def _get_server_keys_json_txn(txn):
results = {} results = {}
for server_name, key_id, from_server in server_keys: for server_name, key_id, from_server in server_keys:
rows = _simple_select_list_txn( keyvalues = {"server_name": server_name}
keyvalues={ if key_id is not None:
"server_name": server_name, keyvalues["key_id"] = key_id
"key_id": key_id, if from_server is not None:
"from_server": from_server, keyvalues["from_server"] = from_server
}, rows = self._simple_select_list_txn(
retcols=("ts_valid_until_ms", "key_json"), txn,
"server_keys_json",
keyvalues=keyvalues,
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
) )
results[(server_name, key_id, from_server)] = rows results[(server_name, key_id, from_server)] = rows
return results return results
return runInteraction( return self.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn "get_server_keys_json", _get_server_keys_json_txn
) )