Implement remote key lookup api
This commit is contained in:
parent
3ba522bb23
commit
f30d47c876
|
@ -62,7 +62,7 @@ class ServerConfig(Config):
|
||||||
server_group.add_argument("--old-signing-key-path",
|
server_group.add_argument("--old-signing-key-path",
|
||||||
help="The old signing keys")
|
help="The old signing keys")
|
||||||
server_group.add_argument("--key-refresh-interval",
|
server_group.add_argument("--key-refresh-interval",
|
||||||
default=24 * 60 * 60 * 1000, # 1 Day
|
default=24 * 60 * 60 * 1000, # 1 Day
|
||||||
help="How long a key response is valid for."
|
help="How long a key response is valid for."
|
||||||
" Used to set the exipiry in /key/v2/."
|
" Used to set the exipiry in /key/v2/."
|
||||||
" Controls how frequently servers will"
|
" Controls how frequently servers will"
|
||||||
|
@ -156,5 +156,5 @@ class ServerConfig(Config):
|
||||||
args.old_signing_key_path = base_key_name + ".old.signing.keys"
|
args.old_signing_key_path = base_key_name + ".old.signing.keys"
|
||||||
|
|
||||||
if not os.path.exists(args.old_signing_key_path):
|
if not os.path.exists(args.old_signing_key_path):
|
||||||
with open(args.old_signing_key_path, "w") as old_signing_key_file:
|
with open(args.old_signing_key_path, "w"):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -26,7 +26,7 @@ import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
KEY_API_V1 = b"/_matrix/key/v1/"
|
KEY_API_V1 = b"/_matrix/key/v1/"
|
||||||
KEY_API_V2 = b"/_matrix/key/v2/local"
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
|
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
|
||||||
|
@ -94,8 +94,8 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||||
if status != b"200":
|
if status != b"200":
|
||||||
# logger.info("Non-200 response from %s: %s %s",
|
# logger.info("Non-200 response from %s: %s %s",
|
||||||
# self.transport.getHost(), status, message)
|
# self.transport.getHost(), status, message)
|
||||||
error = SynapseKeyClientError("Non-200 response %r from %r" %
|
error = SynapseKeyClientError(
|
||||||
(status, self.host)
|
"Non-200 response %r from %r" % (status, self.host)
|
||||||
)
|
)
|
||||||
error.status = status
|
error.status = status
|
||||||
self.errback(error)
|
self.errback(error)
|
||||||
|
|
|
@ -15,7 +15,9 @@
|
||||||
|
|
||||||
from synapse.crypto.keyclient import fetch_server_key
|
from synapse.crypto.keyclient import fetch_server_key
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from syutil.crypto.jsonsign import verify_signed_json, signature_ids
|
from syutil.crypto.jsonsign import (
|
||||||
|
verify_signed_json, signature_ids, sign_json, encode_canonical_json
|
||||||
|
)
|
||||||
from syutil.crypto.signing_key import (
|
from syutil.crypto.signing_key import (
|
||||||
is_signing_algorithm_supported, decode_verify_key_bytes
|
is_signing_algorithm_supported, decode_verify_key_bytes
|
||||||
)
|
)
|
||||||
|
@ -26,6 +28,8 @@ from synapse.util.retryutils import get_retry_limiter
|
||||||
|
|
||||||
from OpenSSL import crypto
|
from OpenSSL import crypto
|
||||||
|
|
||||||
|
import urllib
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,6 +41,7 @@ class Keyring(object):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.client = hs.get_http_client()
|
self.client = hs.get_http_client()
|
||||||
|
self.config = hs.get_config()
|
||||||
self.perspective_servers = {}
|
self.perspective_servers = {}
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
|
@ -127,7 +132,6 @@ class Keyring(object):
|
||||||
server_name, key_ids
|
server_name, key_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
for key_id in key_ids:
|
for key_id in key_ids:
|
||||||
if key_id in keys:
|
if key_id in keys:
|
||||||
defer.returnValue(keys[key_id])
|
defer.returnValue(keys[key_id])
|
||||||
|
@ -142,17 +146,18 @@ class Keyring(object):
|
||||||
perspective_name, self.clock, self.store
|
perspective_name, self.clock, self.store
|
||||||
)
|
)
|
||||||
|
|
||||||
responses = yield self.client.post_json(
|
with limiter:
|
||||||
destination=perspective_name,
|
responses = yield self.client.post_json(
|
||||||
path=b"/_matrix/key/v2/query",
|
destination=perspective_name,
|
||||||
data={u"server_keys": {server_name: list(key_ids)}},
|
path=b"/_matrix/key/v2/query",
|
||||||
)
|
data={u"server_keys": {server_name: list(key_ids)}},
|
||||||
|
)
|
||||||
|
|
||||||
keys = dict()
|
keys = {}
|
||||||
|
|
||||||
for response in responses:
|
for response in responses:
|
||||||
if (u"signatures" not in response
|
if (u"signatures" not in response
|
||||||
or perspective_name not in response[u"signatures"]):
|
or perspective_name not in response[u"signatures"]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Key response not signed by perspective server"
|
"Key response not signed by perspective server"
|
||||||
" %r" % (perspective_name,)
|
" %r" % (perspective_name,)
|
||||||
|
@ -181,7 +186,9 @@ class Keyring(object):
|
||||||
" server %r" % (perspective_name,)
|
" server %r" % (perspective_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
response_keys = process_v2_response(self, server_name, key_ids)
|
response_keys = yield self.process_v2_response(
|
||||||
|
server_name, perspective_name, response
|
||||||
|
)
|
||||||
|
|
||||||
keys.update(response_keys)
|
keys.update(response_keys)
|
||||||
|
|
||||||
|
@ -202,15 +209,15 @@ class Keyring(object):
|
||||||
if requested_key_id in keys:
|
if requested_key_id in keys:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
(response_json, tls_certificate) = yield fetch_server_key(
|
(response, tls_certificate) = yield fetch_server_key(
|
||||||
server_name, self.hs.tls_context_factory,
|
server_name, self.hs.tls_context_factory,
|
||||||
path="/_matrix/key/v2/server/%s" % (
|
path=(b"/_matrix/key/v2/server/%s" % (
|
||||||
urllib.quote(requested_key_id),
|
urllib.quote(requested_key_id),
|
||||||
),
|
)).encode("ascii"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if (u"signatures" not in response
|
if (u"signatures" not in response
|
||||||
or server_name not in response[u"signatures"]):
|
or server_name not in response[u"signatures"]):
|
||||||
raise ValueError("Key response not signed by remote server")
|
raise ValueError("Key response not signed by remote server")
|
||||||
|
|
||||||
if "tls_fingerprints" not in response:
|
if "tls_fingerprints" not in response:
|
||||||
|
@ -223,17 +230,18 @@ class Keyring(object):
|
||||||
sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
|
sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
|
||||||
|
|
||||||
response_sha256_fingerprints = set()
|
response_sha256_fingerprints = set()
|
||||||
for fingerprint in response_json[u"tls_fingerprints"]:
|
for fingerprint in response[u"tls_fingerprints"]:
|
||||||
if u"sha256" in fingerprint:
|
if u"sha256" in fingerprint:
|
||||||
response_sha256_fingerprints.add(fingerprint[u"sha256"])
|
response_sha256_fingerprints.add(fingerprint[u"sha256"])
|
||||||
|
|
||||||
if sha256_fingerprint not in response_sha256_fingerprints:
|
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
|
||||||
raise ValueError("TLS certificate not allowed by fingerprints")
|
raise ValueError("TLS certificate not allowed by fingerprints")
|
||||||
|
|
||||||
response_keys = yield self.process_v2_response(
|
response_keys = yield self.process_v2_response(
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
from_server=server_name,
|
from_server=server_name,
|
||||||
response_json=response_json,
|
requested_id=requested_key_id,
|
||||||
|
response_json=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
keys.update(response_keys)
|
keys.update(response_keys)
|
||||||
|
@ -244,19 +252,15 @@ class Keyring(object):
|
||||||
verify_keys=keys,
|
verify_keys=keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
for key_id in key_ids:
|
defer.returnValue(keys)
|
||||||
if key_id in verify_keys:
|
|
||||||
defer.returnValue(verify_keys[key_id])
|
|
||||||
return
|
|
||||||
|
|
||||||
raise ValueError("No verification key found for given key ids")
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def process_v2_response(self, server_name, from_server, json_response):
|
def process_v2_response(self, server_name, from_server, response_json,
|
||||||
time_now_ms = clock.time_msec()
|
requested_id=None):
|
||||||
|
time_now_ms = self.clock.time_msec()
|
||||||
response_keys = {}
|
response_keys = {}
|
||||||
verify_keys = {}
|
verify_keys = {}
|
||||||
for key_id, key_data in response["verify_keys"].items():
|
for key_id, key_data in response_json["verify_keys"].items():
|
||||||
if is_signing_algorithm_supported(key_id):
|
if is_signing_algorithm_supported(key_id):
|
||||||
key_base64 = key_data["key"]
|
key_base64 = key_data["key"]
|
||||||
key_bytes = decode_base64(key_base64)
|
key_bytes = decode_base64(key_base64)
|
||||||
|
@ -264,7 +268,7 @@ class Keyring(object):
|
||||||
verify_keys[key_id] = verify_key
|
verify_keys[key_id] = verify_key
|
||||||
|
|
||||||
old_verify_keys = {}
|
old_verify_keys = {}
|
||||||
for key_id, key_data in response["verify_keys"].items():
|
for key_id, key_data in response_json["old_verify_keys"].items():
|
||||||
if is_signing_algorithm_supported(key_id):
|
if is_signing_algorithm_supported(key_id):
|
||||||
key_base64 = key_data["key"]
|
key_base64 = key_data["key"]
|
||||||
key_bytes = decode_base64(key_base64)
|
key_bytes = decode_base64(key_base64)
|
||||||
|
@ -273,21 +277,21 @@ class Keyring(object):
|
||||||
verify_key.time_added = time_now_ms
|
verify_key.time_added = time_now_ms
|
||||||
old_verify_keys[key_id] = verify_key
|
old_verify_keys[key_id] = verify_key
|
||||||
|
|
||||||
for key_id in response["signatures"][server_name]:
|
for key_id in response_json["signatures"][server_name]:
|
||||||
if key_id not in response["verify_keys"]:
|
if key_id not in response_json["verify_keys"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Key response must include verification keys for all"
|
"Key response must include verification keys for all"
|
||||||
" signatures"
|
" signatures"
|
||||||
)
|
)
|
||||||
if key_id in verify_keys:
|
if key_id in verify_keys:
|
||||||
verify_signed_json(
|
verify_signed_json(
|
||||||
response,
|
response_json,
|
||||||
server_name,
|
server_name,
|
||||||
verify_keys[key_id]
|
verify_keys[key_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
signed_key_json = sign_json(
|
signed_key_json = sign_json(
|
||||||
response,
|
response_json,
|
||||||
self.config.server_name,
|
self.config.server_name,
|
||||||
self.config.signing_key[0],
|
self.config.signing_key[0],
|
||||||
)
|
)
|
||||||
|
@ -295,7 +299,9 @@ class Keyring(object):
|
||||||
signed_key_json_bytes = encode_canonical_json(signed_key_json)
|
signed_key_json_bytes = encode_canonical_json(signed_key_json)
|
||||||
ts_valid_until_ms = signed_key_json[u"valid_until"]
|
ts_valid_until_ms = signed_key_json[u"valid_until"]
|
||||||
|
|
||||||
updated_key_ids = set([requested_key_id])
|
updated_key_ids = set()
|
||||||
|
if requested_id is not None:
|
||||||
|
updated_key_ids.add(requested_id)
|
||||||
updated_key_ids.update(verify_keys)
|
updated_key_ids.update(verify_keys)
|
||||||
updated_key_ids.update(old_verify_keys)
|
updated_key_ids.update(old_verify_keys)
|
||||||
|
|
||||||
|
@ -307,8 +313,8 @@ class Keyring(object):
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
key_id=key_id,
|
key_id=key_id,
|
||||||
from_server=server_name,
|
from_server=server_name,
|
||||||
ts_now_ms=ts_now_ms,
|
ts_now_ms=time_now_ms,
|
||||||
ts_valid_until_ms=valid_until,
|
ts_expires_ms=ts_valid_until_ms,
|
||||||
key_json_bytes=signed_key_json_bytes,
|
key_json_bytes=signed_key_json_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -373,7 +379,6 @@ class Keyring(object):
|
||||||
verify_keys[key_id]
|
verify_keys[key_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
yield self.store.store_server_certificate(
|
yield self.store.store_server_certificate(
|
||||||
server_name,
|
server_name,
|
||||||
server_name,
|
server_name,
|
||||||
|
|
|
@ -13,7 +13,13 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.web.resource import Resource
|
||||||
from .local_key_resource import LocalKey
|
from .local_key_resource import LocalKey
|
||||||
|
from .remote_key_resource import RemoteKey
|
||||||
|
|
||||||
class KeyApiV2Resource(LocalKey):
|
|
||||||
pass
|
class KeyApiV2Resource(Resource):
|
||||||
|
def __init__(self, hs):
|
||||||
|
Resource.__init__(self)
|
||||||
|
self.putChild("server", LocalKey(hs))
|
||||||
|
self.putChild("query", RemoteKey(hs))
|
||||||
|
|
|
@ -31,7 +31,7 @@ class LocalKey(Resource):
|
||||||
"""HTTP resource containing encoding the TLS X.509 certificate and NACL
|
"""HTTP resource containing encoding the TLS X.509 certificate and NACL
|
||||||
signature verification keys for this server::
|
signature verification keys for this server::
|
||||||
|
|
||||||
GET /_matrix/key/v2/ HTTP/1.1
|
GET /_matrix/key/v2/server/a.key.id HTTP/1.1
|
||||||
|
|
||||||
HTTP/1.1 200 OK
|
HTTP/1.1 200 OK
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
|
@ -56,6 +56,8 @@ class LocalKey(Resource):
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
|
@ -68,7 +70,6 @@ class LocalKey(Resource):
|
||||||
self.expires = int(time_now_msec + refresh_interval)
|
self.expires = int(time_now_msec + refresh_interval)
|
||||||
self.response_body = encode_canonical_json(self.response_json_object())
|
self.response_body = encode_canonical_json(self.response_json_object())
|
||||||
|
|
||||||
|
|
||||||
def response_json_object(self):
|
def response_json_object(self):
|
||||||
verify_keys = {}
|
verify_keys = {}
|
||||||
for key in self.config.signing_key:
|
for key in self.config.signing_key:
|
||||||
|
@ -120,7 +121,3 @@ class LocalKey(Resource):
|
||||||
request, 200, self.response_body,
|
request, 200, self.response_body,
|
||||||
version_string=self.version_string
|
version_string=self.version_string
|
||||||
)
|
)
|
||||||
|
|
||||||
def getChild(self, name, request):
|
|
||||||
if name == '':
|
|
||||||
return self
|
|
||||||
|
|
|
@ -0,0 +1,174 @@
|
||||||
|
from synapse.http.server import request_handler, respond_with_json_bytes
|
||||||
|
from synapse.api.errors import SynapseError, Codes
|
||||||
|
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
from twisted.web.server import NOT_DONE_YET
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
|
||||||
|
from io import BytesIO
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteKey(Resource):
|
||||||
|
"""HTTP resource for retreiving the TLS certificate and NACL signature
|
||||||
|
verification keys for a collection of servers. Checks that the reported
|
||||||
|
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
|
||||||
|
that the NACL signature for the remote server is valid. Returns a dict of
|
||||||
|
JSON signed by both the remote server and by this server.
|
||||||
|
|
||||||
|
Supports individual GET APIs and a bulk query POST API.
|
||||||
|
|
||||||
|
Requsts:
|
||||||
|
|
||||||
|
GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
|
||||||
|
|
||||||
|
GET /_matrix/key/v2/query/remote.server.example.com/a.key.id HTTP/1.1
|
||||||
|
|
||||||
|
POST /_matrix/v2/query HTTP/1.1
|
||||||
|
Content-Type: application/json
|
||||||
|
{
|
||||||
|
"server_keys": { "remote.server.example.com": ["a.key.id"] }
|
||||||
|
}
|
||||||
|
|
||||||
|
Response:
|
||||||
|
|
||||||
|
HTTP/1.1 200 OK
|
||||||
|
Content-Type: application/json
|
||||||
|
{
|
||||||
|
"server_keys": [
|
||||||
|
{
|
||||||
|
"server_name": "remote.server.example.com"
|
||||||
|
"valid_until": # posix timestamp
|
||||||
|
"verify_keys": {
|
||||||
|
"a.key.id": { # The identifier for a key.
|
||||||
|
key: "" # base64 encoded verification key.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"old_verify_keys": {
|
||||||
|
"an.old.key.id": { # The identifier for an old key.
|
||||||
|
key: "", # base64 encoded key
|
||||||
|
expired: 0, # when th e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"tls_fingerprints": [
|
||||||
|
{ "sha256": # fingerprint }
|
||||||
|
]
|
||||||
|
"signatures": {
|
||||||
|
"remote.server.example.com": {...}
|
||||||
|
"this.server.example.com": {...}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
isLeaf = True
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.keyring = hs.get_keyring()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.version_string = hs.version_string
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
def render_GET(self, request):
|
||||||
|
self.async_render_GET(request)
|
||||||
|
return NOT_DONE_YET
|
||||||
|
|
||||||
|
@request_handler
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def async_render_GET(self, request):
|
||||||
|
if len(request.postpath) == 1:
|
||||||
|
server, = request.postpath
|
||||||
|
query = {server: [None]}
|
||||||
|
elif len(request.postpath) == 2:
|
||||||
|
server, key_id = request.postpath
|
||||||
|
query = {server: [key_id]}
|
||||||
|
else:
|
||||||
|
raise SynapseError(
|
||||||
|
404, "Not found %r" % request.postpath, Codes.NOT_FOUND
|
||||||
|
)
|
||||||
|
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||||
|
|
||||||
|
def render_POST(self, request):
|
||||||
|
self.async_render_POST(request)
|
||||||
|
return NOT_DONE_YET
|
||||||
|
|
||||||
|
@request_handler
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def async_render_POST(self, request):
|
||||||
|
try:
|
||||||
|
content = json.loads(request.content.read())
|
||||||
|
if type(content) != dict:
|
||||||
|
raise ValueError()
|
||||||
|
except ValueError:
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Content must be JSON object.", errcode=Codes.NOT_JSON
|
||||||
|
)
|
||||||
|
|
||||||
|
query = content["server_keys"]
|
||||||
|
|
||||||
|
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def query_keys(self, request, query, query_remote_on_cache_miss=False):
|
||||||
|
store_queries = []
|
||||||
|
for server_name, key_ids in query.items():
|
||||||
|
for key_id in key_ids:
|
||||||
|
store_queries.append((server_name, key_id, None))
|
||||||
|
|
||||||
|
cached = yield self.store.get_server_keys_json(store_queries)
|
||||||
|
|
||||||
|
json_results = []
|
||||||
|
|
||||||
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
|
cache_misses = dict()
|
||||||
|
for (server_name, key_id, from_server), results in cached.items():
|
||||||
|
results = [
|
||||||
|
(result["ts_added_ms"], result) for result in results
|
||||||
|
if result["ts_valid_until_ms"] > time_now_ms
|
||||||
|
]
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
if key_id is not None:
|
||||||
|
cache_misses.setdefault(server_name, set()).add(key_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key_id is not None:
|
||||||
|
most_recent_result = max(results)
|
||||||
|
json_results.append(most_recent_result[-1]["key_json"])
|
||||||
|
else:
|
||||||
|
for result in results:
|
||||||
|
json_results.append(result[-1]["key_json"])
|
||||||
|
|
||||||
|
if cache_misses and query_remote_on_cache_miss:
|
||||||
|
for server_name, key_ids in cache_misses.items():
|
||||||
|
try:
|
||||||
|
yield self.keyring.get_server_verify_key_v2_direct(
|
||||||
|
server_name, key_ids
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
logger.exception("Failed to get key for %r", server_name)
|
||||||
|
pass
|
||||||
|
yield self.query_keys(
|
||||||
|
request, query, query_remote_on_cache_miss=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result_io = BytesIO()
|
||||||
|
result_io.write(b"{\"server_keys\":")
|
||||||
|
sep = b"["
|
||||||
|
for json_bytes in json_results:
|
||||||
|
result_io.write(sep)
|
||||||
|
result_io.write(json_bytes)
|
||||||
|
sep = b","
|
||||||
|
if sep == b"[":
|
||||||
|
result_io.write(sep)
|
||||||
|
result_io.write(b"]}")
|
||||||
|
|
||||||
|
respond_with_json_bytes(
|
||||||
|
request, 200, result_io.getvalue(),
|
||||||
|
version_string=self.version_string
|
||||||
|
)
|
|
@ -140,8 +140,8 @@ class KeyStore(SQLBaseStore):
|
||||||
"key_id": key_id,
|
"key_id": key_id,
|
||||||
"from_server": from_server,
|
"from_server": from_server,
|
||||||
"ts_added_ms": ts_now_ms,
|
"ts_added_ms": ts_now_ms,
|
||||||
"ts_valid_until_ms": ts_valid_until_ms,
|
"ts_valid_until_ms": ts_expires_ms,
|
||||||
"key_json": key_json_bytes,
|
"key_json": buffer(key_json_bytes),
|
||||||
},
|
},
|
||||||
or_replace=True,
|
or_replace=True,
|
||||||
)
|
)
|
||||||
|
@ -149,9 +149,9 @@ class KeyStore(SQLBaseStore):
|
||||||
def get_server_keys_json(self, server_keys):
|
def get_server_keys_json(self, server_keys):
|
||||||
"""Retrive the key json for a list of server_keys and key ids.
|
"""Retrive the key json for a list of server_keys and key ids.
|
||||||
If no keys are found for a given server, key_id and source then
|
If no keys are found for a given server, key_id and source then
|
||||||
that server, key_id, and source triplet will be missing from the
|
that server, key_id, and source triplet entry will be an empty list.
|
||||||
returned dictionary. The JSON is returned as a byte array so that it
|
The JSON is returned as a byte array so that it can be efficiently
|
||||||
can be efficiently used in an HTTP response.
|
used in an HTTP response.
|
||||||
Args:
|
Args:
|
||||||
server_keys (list): List of (server_name, key_id, source) triplets.
|
server_keys (list): List of (server_name, key_id, source) triplets.
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -161,16 +161,25 @@ class KeyStore(SQLBaseStore):
|
||||||
def _get_server_keys_json_txn(txn):
|
def _get_server_keys_json_txn(txn):
|
||||||
results = {}
|
results = {}
|
||||||
for server_name, key_id, from_server in server_keys:
|
for server_name, key_id, from_server in server_keys:
|
||||||
rows = _simple_select_list_txn(
|
keyvalues = {"server_name": server_name}
|
||||||
keyvalues={
|
if key_id is not None:
|
||||||
"server_name": server_name,
|
keyvalues["key_id"] = key_id
|
||||||
"key_id": key_id,
|
if from_server is not None:
|
||||||
"from_server": from_server,
|
keyvalues["from_server"] = from_server
|
||||||
},
|
rows = self._simple_select_list_txn(
|
||||||
retcols=("ts_valid_until_ms", "key_json"),
|
txn,
|
||||||
|
"server_keys_json",
|
||||||
|
keyvalues=keyvalues,
|
||||||
|
retcols=(
|
||||||
|
"key_id",
|
||||||
|
"from_server",
|
||||||
|
"ts_added_ms",
|
||||||
|
"ts_valid_until_ms",
|
||||||
|
"key_json",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
results[(server_name, key_id, from_server)] = rows
|
results[(server_name, key_id, from_server)] = rows
|
||||||
return results
|
return results
|
||||||
return runInteraction(
|
return self.runInteraction(
|
||||||
"get_server_keys_json", _get_server_keys_json_txn
|
"get_server_keys_json", _get_server_keys_json_txn
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue