Merge pull request #2459 from matrix-org/rav/keyring_cleanups
Clean up Keyring code
This commit is contained in:
commit
c94ab5976a
|
@ -18,7 +18,7 @@ from synapse.crypto.keyclient import fetch_server_key
|
|||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.util import unwrapFirstError, logcontext
|
||||
from synapse.util.logcontext import (
|
||||
preserve_context_over_fn, PreserveLoggingContext,
|
||||
PreserveLoggingContext,
|
||||
preserve_fn
|
||||
)
|
||||
from synapse.util.metrics import Measure
|
||||
|
@ -57,7 +57,8 @@ Attributes:
|
|||
json_object(dict): The JSON object to verify.
|
||||
deferred(twisted.internet.defer.Deferred):
|
||||
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
||||
a verify key has been fetched
|
||||
a verify key has been fetched. The deferreds' callbacks are run with no
|
||||
logcontext.
|
||||
"""
|
||||
|
||||
|
||||
|
@ -82,9 +83,11 @@ class Keyring(object):
|
|||
self.key_downloads = {}
|
||||
|
||||
def verify_json_for_server(self, server_name, json_object):
|
||||
return self.verify_json_objects_for_server(
|
||||
[(server_name, json_object)]
|
||||
)[0]
|
||||
return logcontext.make_deferred_yieldable(
|
||||
self.verify_json_objects_for_server(
|
||||
[(server_name, json_object)]
|
||||
)[0]
|
||||
)
|
||||
|
||||
def verify_json_objects_for_server(self, server_and_json):
|
||||
"""Bulk verifies signatures of json objects, bulk fetching keys as
|
||||
|
@ -94,8 +97,10 @@ class Keyring(object):
|
|||
server_and_json (list): List of pairs of (server_name, json_object)
|
||||
|
||||
Returns:
|
||||
list of deferreds indicating success or failure to verify each
|
||||
json object's signature for the given server_name.
|
||||
List<Deferred>: for each input pair, a deferred indicating success
|
||||
or failure to verify each json object's signature for the given
|
||||
server_name. The deferreds run their callbacks in the sentinel
|
||||
logcontext.
|
||||
"""
|
||||
verify_requests = []
|
||||
|
||||
|
@ -122,96 +127,72 @@ class Keyring(object):
|
|||
|
||||
verify_requests.append(verify_request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_key_deferred(verify_request):
|
||||
server_name = verify_request.server_name
|
||||
try:
|
||||
_, key_id, verify_key = yield verify_request.deferred
|
||||
except IOError as e:
|
||||
logger.warn(
|
||||
"Got IOError when downloading keys for %s: %s %s",
|
||||
server_name, type(e).__name__, str(e.message),
|
||||
)
|
||||
raise SynapseError(
|
||||
502,
|
||||
"Error downloading keys for %s" % (server_name,),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Got Exception when downloading keys for %s: %s %s",
|
||||
server_name, type(e).__name__, str(e.message),
|
||||
)
|
||||
raise SynapseError(
|
||||
401,
|
||||
"No key for %s with id %s" % (server_name, key_ids),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
json_object = verify_request.json_object
|
||||
|
||||
logger.debug("Got key %s %s:%s for server %s, verifying" % (
|
||||
key_id, verify_key.alg, verify_key.version, server_name,
|
||||
))
|
||||
try:
|
||||
verify_signed_json(json_object, server_name, verify_key)
|
||||
except:
|
||||
raise SynapseError(
|
||||
401,
|
||||
"Invalid signature for server %s with key %s:%s" % (
|
||||
server_name, verify_key.alg, verify_key.version
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
server_to_deferred = {
|
||||
server_name: defer.Deferred()
|
||||
for server_name, _ in server_and_json
|
||||
}
|
||||
|
||||
with PreserveLoggingContext():
|
||||
|
||||
# We want to wait for any previous lookups to complete before
|
||||
# proceeding.
|
||||
wait_on_deferred = self.wait_for_previous_lookups(
|
||||
[server_name for server_name, _ in server_and_json],
|
||||
server_to_deferred,
|
||||
)
|
||||
|
||||
# Actually start fetching keys.
|
||||
wait_on_deferred.addBoth(
|
||||
lambda _: self.get_server_verify_keys(verify_requests)
|
||||
)
|
||||
|
||||
# When we've finished fetching all the keys for a given server_name,
|
||||
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||
# any lookups waiting will proceed.
|
||||
server_to_request_ids = {}
|
||||
|
||||
def remove_deferreds(res, server_name, verify_request):
|
||||
request_id = id(verify_request)
|
||||
server_to_request_ids[server_name].discard(request_id)
|
||||
if not server_to_request_ids[server_name]:
|
||||
d = server_to_deferred.pop(server_name, None)
|
||||
if d:
|
||||
d.callback(None)
|
||||
return res
|
||||
|
||||
for verify_request in verify_requests:
|
||||
server_name = verify_request.server_name
|
||||
request_id = id(verify_request)
|
||||
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||
verify_request.deferred.addBoth(
|
||||
remove_deferreds, server_name, verify_request,
|
||||
)
|
||||
preserve_fn(self._start_key_lookups)(verify_requests)
|
||||
|
||||
# Pass those keys to handle_key_deferred so that the json object
|
||||
# signatures can be verified
|
||||
handle = preserve_fn(_handle_key_deferred)
|
||||
return [
|
||||
preserve_context_over_fn(handle_key_deferred, verify_request)
|
||||
for verify_request in verify_requests
|
||||
handle(rq) for rq in verify_requests
|
||||
]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _start_key_lookups(self, verify_requests):
|
||||
"""Sets off the key fetches for each verify request
|
||||
|
||||
Once each fetch completes, verify_request.deferred will be resolved.
|
||||
|
||||
Args:
|
||||
verify_requests (List[VerifyKeyRequest]):
|
||||
"""
|
||||
|
||||
# create a deferred for each server we're going to look up the keys
|
||||
# for; we'll resolve them once we have completed our lookups.
|
||||
# These will be passed into wait_for_previous_lookups to block
|
||||
# any other lookups until we have finished.
|
||||
# The deferreds are called with no logcontext.
|
||||
server_to_deferred = {
|
||||
rq.server_name: defer.Deferred()
|
||||
for rq in verify_requests
|
||||
}
|
||||
|
||||
# We want to wait for any previous lookups to complete before
|
||||
# proceeding.
|
||||
yield self.wait_for_previous_lookups(
|
||||
[rq.server_name for rq in verify_requests],
|
||||
server_to_deferred,
|
||||
)
|
||||
|
||||
# Actually start fetching keys.
|
||||
self._get_server_verify_keys(verify_requests)
|
||||
|
||||
# When we've finished fetching all the keys for a given server_name,
|
||||
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||
# any lookups waiting will proceed.
|
||||
#
|
||||
# map from server name to a set of request ids
|
||||
server_to_request_ids = {}
|
||||
|
||||
for verify_request in verify_requests:
|
||||
server_name = verify_request.server_name
|
||||
request_id = id(verify_request)
|
||||
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||
|
||||
def remove_deferreds(res, verify_request):
|
||||
server_name = verify_request.server_name
|
||||
request_id = id(verify_request)
|
||||
server_to_request_ids[server_name].discard(request_id)
|
||||
if not server_to_request_ids[server_name]:
|
||||
d = server_to_deferred.pop(server_name, None)
|
||||
if d:
|
||||
d.callback(None)
|
||||
return res
|
||||
|
||||
for verify_request in verify_requests:
|
||||
verify_request.deferred.addBoth(
|
||||
remove_deferreds, verify_request,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_previous_lookups(self, server_names, server_to_deferred):
|
||||
"""Waits for any previous key lookups for the given servers to finish.
|
||||
|
@ -247,7 +228,7 @@ class Keyring(object):
|
|||
self.key_downloads[server_name] = deferred
|
||||
deferred.addBoth(rm, server_name)
|
||||
|
||||
def get_server_verify_keys(self, verify_requests):
|
||||
def _get_server_verify_keys(self, verify_requests):
|
||||
"""Tries to find at least one key for each verify request
|
||||
|
||||
For each verify_request, verify_request.deferred is called back with
|
||||
|
@ -316,21 +297,23 @@ class Keyring(object):
|
|||
if not missing_keys:
|
||||
break
|
||||
|
||||
for verify_request in requests_missing_keys.values():
|
||||
verify_request.deferred.errback(SynapseError(
|
||||
401,
|
||||
"No key for %s with id %s" % (
|
||||
verify_request.server_name, verify_request.key_ids,
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
))
|
||||
with PreserveLoggingContext():
|
||||
for verify_request in requests_missing_keys.values():
|
||||
verify_request.deferred.errback(SynapseError(
|
||||
401,
|
||||
"No key for %s with id %s" % (
|
||||
verify_request.server_name, verify_request.key_ids,
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
))
|
||||
|
||||
def on_err(err):
|
||||
for verify_request in verify_requests:
|
||||
if not verify_request.deferred.called:
|
||||
verify_request.deferred.errback(err)
|
||||
with PreserveLoggingContext():
|
||||
for verify_request in verify_requests:
|
||||
if not verify_request.deferred.called:
|
||||
verify_request.deferred.errback(err)
|
||||
|
||||
do_iterations().addErrback(on_err)
|
||||
preserve_fn(do_iterations)().addErrback(on_err)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys_from_store(self, server_name_and_key_ids):
|
||||
|
@ -740,3 +723,47 @@ class Keyring(object):
|
|||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError))
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_key_deferred(verify_request):
|
||||
server_name = verify_request.server_name
|
||||
try:
|
||||
with PreserveLoggingContext():
|
||||
_, key_id, verify_key = yield verify_request.deferred
|
||||
except IOError as e:
|
||||
logger.warn(
|
||||
"Got IOError when downloading keys for %s: %s %s",
|
||||
server_name, type(e).__name__, str(e.message),
|
||||
)
|
||||
raise SynapseError(
|
||||
502,
|
||||
"Error downloading keys for %s" % (server_name,),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Got Exception when downloading keys for %s: %s %s",
|
||||
server_name, type(e).__name__, str(e.message),
|
||||
)
|
||||
raise SynapseError(
|
||||
401,
|
||||
"No key for %s with id %s" % (server_name, verify_request.key_ids),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
json_object = verify_request.json_object
|
||||
|
||||
logger.debug("Got key %s %s:%s for server %s, verifying" % (
|
||||
key_id, verify_key.alg, verify_key.version, server_name,
|
||||
))
|
||||
try:
|
||||
verify_signed_json(json_object, server_name, verify_key)
|
||||
except:
|
||||
raise SynapseError(
|
||||
401,
|
||||
"Invalid signature for server %s with key %s:%s" % (
|
||||
server_name, verify_key.alg, verify_key.version
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
|
|
@ -18,8 +18,7 @@ from synapse.api.errors import SynapseError
|
|||
from synapse.crypto.event_signing import check_event_content_hash
|
||||
from synapse.events import spamcheck
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.logcontext import preserve_context_over_deferred, preserve_fn
|
||||
from synapse.util import unwrapFirstError, logcontext
|
||||
from twisted.internet import defer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -51,56 +50,52 @@ class FederationBase(object):
|
|||
"""
|
||||
deferreds = self._check_sigs_and_hashes(pdus)
|
||||
|
||||
def callback(pdu):
|
||||
return pdu
|
||||
@defer.inlineCallbacks
|
||||
def handle_check_result(pdu, deferred):
|
||||
try:
|
||||
res = yield logcontext.make_deferred_yieldable(deferred)
|
||||
except SynapseError:
|
||||
res = None
|
||||
|
||||
def errback(failure, pdu):
|
||||
failure.trap(SynapseError)
|
||||
return None
|
||||
|
||||
def try_local_db(res, pdu):
|
||||
if not res:
|
||||
# Check local db.
|
||||
return self.store.get_event(
|
||||
res = yield self.store.get_event(
|
||||
pdu.event_id,
|
||||
allow_rejected=True,
|
||||
allow_none=True,
|
||||
)
|
||||
return res
|
||||
|
||||
def try_remote(res, pdu):
|
||||
if not res and pdu.origin != origin:
|
||||
return self.get_pdu(
|
||||
destinations=[pdu.origin],
|
||||
event_id=pdu.event_id,
|
||||
outlier=outlier,
|
||||
timeout=10000,
|
||||
).addErrback(lambda e: None)
|
||||
return res
|
||||
try:
|
||||
res = yield self.get_pdu(
|
||||
destinations=[pdu.origin],
|
||||
event_id=pdu.event_id,
|
||||
outlier=outlier,
|
||||
timeout=10000,
|
||||
)
|
||||
except SynapseError:
|
||||
pass
|
||||
|
||||
def warn(res, pdu):
|
||||
if not res:
|
||||
logger.warn(
|
||||
"Failed to find copy of %s with valid signature",
|
||||
pdu.event_id,
|
||||
)
|
||||
return res
|
||||
|
||||
for pdu, deferred in zip(pdus, deferreds):
|
||||
deferred.addCallbacks(
|
||||
callback, errback, errbackArgs=[pdu]
|
||||
).addCallback(
|
||||
try_local_db, pdu
|
||||
).addCallback(
|
||||
try_remote, pdu
|
||||
).addCallback(
|
||||
warn, pdu
|
||||
defer.returnValue(res)
|
||||
|
||||
handle = logcontext.preserve_fn(handle_check_result)
|
||||
deferreds2 = [
|
||||
handle(pdu, deferred)
|
||||
for pdu, deferred in zip(pdus, deferreds)
|
||||
]
|
||||
|
||||
valid_pdus = yield logcontext.make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
deferreds2,
|
||||
consumeErrors=True,
|
||||
)
|
||||
|
||||
valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
deferreds,
|
||||
consumeErrors=True
|
||||
)).addErrback(unwrapFirstError)
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
if include_none:
|
||||
defer.returnValue(valid_pdus)
|
||||
|
@ -108,7 +103,9 @@ class FederationBase(object):
|
|||
defer.returnValue([p for p in valid_pdus if p])
|
||||
|
||||
def _check_sigs_and_hash(self, pdu):
|
||||
return self._check_sigs_and_hashes([pdu])[0]
|
||||
return logcontext.make_deferred_yieldable(
|
||||
self._check_sigs_and_hashes([pdu])[0],
|
||||
)
|
||||
|
||||
def _check_sigs_and_hashes(self, pdus):
|
||||
"""Checks that each of the received events is correctly signed by the
|
||||
|
@ -123,6 +120,7 @@ class FederationBase(object):
|
|||
* returns a redacted version of the event (if the signature
|
||||
matched but the hash did not)
|
||||
* throws a SynapseError if the signature check failed.
|
||||
The deferreds run their callbacks in the sentinel logcontext.
|
||||
"""
|
||||
|
||||
redacted_pdus = [
|
||||
|
@ -130,34 +128,38 @@ class FederationBase(object):
|
|||
for pdu in pdus
|
||||
]
|
||||
|
||||
deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
|
||||
deferreds = self.keyring.verify_json_objects_for_server([
|
||||
(p.origin, p.get_pdu_json())
|
||||
for p in redacted_pdus
|
||||
])
|
||||
|
||||
ctx = logcontext.LoggingContext.current_context()
|
||||
|
||||
def callback(_, pdu, redacted):
|
||||
if not check_event_content_hash(pdu):
|
||||
logger.warn(
|
||||
"Event content has been tampered, redacting %s: %s",
|
||||
pdu.event_id, pdu.get_pdu_json()
|
||||
)
|
||||
return redacted
|
||||
with logcontext.PreserveLoggingContext(ctx):
|
||||
if not check_event_content_hash(pdu):
|
||||
logger.warn(
|
||||
"Event content has been tampered, redacting %s: %s",
|
||||
pdu.event_id, pdu.get_pdu_json()
|
||||
)
|
||||
return redacted
|
||||
|
||||
if spamcheck.check_event_for_spam(pdu):
|
||||
logger.warn(
|
||||
"Event contains spam, redacting %s: %s",
|
||||
pdu.event_id, pdu.get_pdu_json()
|
||||
)
|
||||
return redacted
|
||||
if spamcheck.check_event_for_spam(pdu):
|
||||
logger.warn(
|
||||
"Event contains spam, redacting %s: %s",
|
||||
pdu.event_id, pdu.get_pdu_json()
|
||||
)
|
||||
return redacted
|
||||
|
||||
return pdu
|
||||
return pdu
|
||||
|
||||
def errback(failure, pdu):
|
||||
failure.trap(SynapseError)
|
||||
logger.warn(
|
||||
"Signature check failed for %s",
|
||||
pdu.event_id,
|
||||
)
|
||||
with logcontext.PreserveLoggingContext(ctx):
|
||||
logger.warn(
|
||||
"Signature check failed for %s",
|
||||
pdu.event_id,
|
||||
)
|
||||
return failure
|
||||
|
||||
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
|
||||
|
|
|
@ -22,7 +22,7 @@ from synapse.api.constants import Membership
|
|||
from synapse.api.errors import (
|
||||
CodeMessageException, HttpResponseException, SynapseError,
|
||||
)
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util import unwrapFirstError, logcontext
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
|
@ -189,10 +189,10 @@ class FederationClient(FederationBase):
|
|||
]
|
||||
|
||||
# FIXME: We should handle signature failures more gracefully.
|
||||
pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
self._check_sigs_and_hashes(pdus),
|
||||
consumeErrors=True,
|
||||
)).addErrback(unwrapFirstError)
|
||||
).addErrback(unwrapFirstError))
|
||||
|
||||
defer.returnValue(pdus)
|
||||
|
||||
|
@ -252,7 +252,7 @@ class FederationClient(FederationBase):
|
|||
pdu = pdu_list[0]
|
||||
|
||||
# Check signatures are correct.
|
||||
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
|
||||
signed_pdu = yield self._check_sigs_and_hash(pdu)
|
||||
|
||||
break
|
||||
|
||||
|
|
|
@ -113,30 +113,37 @@ class KeyStore(SQLBaseStore):
|
|||
keys[key_id] = key
|
||||
defer.returnValue(keys)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_server_verify_key(self, server_name, from_server, time_now_ms,
|
||||
verify_key):
|
||||
"""Stores a NACL verification key for the given server.
|
||||
Args:
|
||||
server_name (str): The name of the server.
|
||||
key_id (str): The version of the key for the server.
|
||||
from_server (str): Where the verification key was looked up
|
||||
ts_now_ms (int): The time now in milliseconds
|
||||
verification_key (VerifyKey): The NACL verify key.
|
||||
time_now_ms (int): The time now in milliseconds
|
||||
verify_key (nacl.signing.VerifyKey): The NACL verify key.
|
||||
"""
|
||||
yield self._simple_upsert(
|
||||
table="server_signature_keys",
|
||||
keyvalues={
|
||||
"server_name": server_name,
|
||||
"key_id": "%s:%s" % (verify_key.alg, verify_key.version),
|
||||
},
|
||||
values={
|
||||
"from_server": from_server,
|
||||
"ts_added_ms": time_now_ms,
|
||||
"verify_key": buffer(verify_key.encode()),
|
||||
},
|
||||
desc="store_server_verify_key",
|
||||
)
|
||||
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
|
||||
|
||||
def _txn(txn):
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="server_signature_keys",
|
||||
keyvalues={
|
||||
"server_name": server_name,
|
||||
"key_id": key_id,
|
||||
},
|
||||
values={
|
||||
"from_server": from_server,
|
||||
"ts_added_ms": time_now_ms,
|
||||
"verify_key": buffer(verify_key.encode()),
|
||||
},
|
||||
)
|
||||
txn.call_after(
|
||||
self._get_server_verify_key.invalidate,
|
||||
(server_name, key_id)
|
||||
)
|
||||
|
||||
return self.runInteraction("store_server_verify_key", _txn)
|
||||
|
||||
def store_server_keys_json(self, server_name, key_id, from_server,
|
||||
ts_now_ms, ts_expires_ms, key_json_bytes):
|
||||
|
|
|
@ -12,39 +12,72 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import signedjson
|
||||
import time
|
||||
|
||||
import signedjson.key
|
||||
import signedjson.sign
|
||||
from mock import Mock
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.crypto import keyring
|
||||
from synapse.util import async
|
||||
from synapse.util import async, logcontext
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from tests import unittest, utils
|
||||
from twisted.internet import defer
|
||||
|
||||
|
||||
class MockPerspectiveServer(object):
|
||||
def __init__(self):
|
||||
self.server_name = "mock_server"
|
||||
self.key = signedjson.key.generate_signing_key(0)
|
||||
|
||||
def get_verify_keys(self):
|
||||
vk = signedjson.key.get_verify_key(self.key)
|
||||
return {
|
||||
"%s:%s" % (vk.alg, vk.version): vk,
|
||||
}
|
||||
|
||||
def get_signed_key(self, server_name, verify_key):
|
||||
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
|
||||
res = {
|
||||
"server_name": server_name,
|
||||
"old_verify_keys": {},
|
||||
"valid_until_ts": time.time() * 1000 + 3600,
|
||||
"verify_keys": {
|
||||
key_id: {
|
||||
"key": signedjson.key.encode_verify_key_base64(verify_key)
|
||||
}
|
||||
}
|
||||
}
|
||||
signedjson.sign.sign_json(res, self.server_name, self.key)
|
||||
return res
|
||||
|
||||
|
||||
class KeyringTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.mock_perspective_server = MockPerspectiveServer()
|
||||
self.http_client = Mock()
|
||||
self.hs = yield utils.setup_test_homeserver(
|
||||
handlers=None,
|
||||
http_client=self.http_client,
|
||||
)
|
||||
self.hs.config.perspectives = {
|
||||
"persp_server": {"k": "v"}
|
||||
self.mock_perspective_server.server_name:
|
||||
self.mock_perspective_server.get_verify_keys()
|
||||
}
|
||||
|
||||
def check_context(self, _, expected):
|
||||
self.assertEquals(
|
||||
getattr(LoggingContext.current_context(), "test_key", None),
|
||||
expected
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_wait_for_previous_lookups(self):
|
||||
sentinel_context = LoggingContext.current_context()
|
||||
|
||||
kr = keyring.Keyring(self.hs)
|
||||
|
||||
def check_context(_, expected):
|
||||
self.assertEquals(
|
||||
LoggingContext.current_context().test_key, expected
|
||||
)
|
||||
|
||||
lookup_1_deferred = defer.Deferred()
|
||||
lookup_2_deferred = defer.Deferred()
|
||||
|
||||
|
@ -60,7 +93,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||
self.assertTrue(wait_1_deferred.called)
|
||||
# ... so we should have preserved the LoggingContext.
|
||||
self.assertIs(LoggingContext.current_context(), context_one)
|
||||
wait_1_deferred.addBoth(check_context, "one")
|
||||
wait_1_deferred.addBoth(self.check_context, "one")
|
||||
|
||||
with LoggingContext("two") as context_two:
|
||||
context_two.test_key = "two"
|
||||
|
@ -74,7 +107,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||
self.assertFalse(wait_2_deferred.called)
|
||||
# ... so we should have reset the LoggingContext.
|
||||
self.assertIs(LoggingContext.current_context(), sentinel_context)
|
||||
wait_2_deferred.addBoth(check_context, "two")
|
||||
wait_2_deferred.addBoth(self.check_context, "two")
|
||||
|
||||
# let the first lookup complete (in the sentinel context)
|
||||
lookup_1_deferred.callback(None)
|
||||
|
@ -89,38 +122,108 @@ class KeyringTestCase(unittest.TestCase):
|
|||
|
||||
kr = keyring.Keyring(self.hs)
|
||||
json1 = {}
|
||||
signedjson.sign.sign_json(json1, "server1", key1)
|
||||
signedjson.sign.sign_json(json1, "server10", key1)
|
||||
|
||||
self.http_client.post_json.return_value = defer.Deferred()
|
||||
persp_resp = {
|
||||
"server_keys": [
|
||||
self.mock_perspective_server.get_signed_key(
|
||||
"server10",
|
||||
signedjson.key.get_verify_key(key1)
|
||||
),
|
||||
]
|
||||
}
|
||||
persp_deferred = defer.Deferred()
|
||||
|
||||
# start off a first set of lookups
|
||||
res_deferreds = kr.verify_json_objects_for_server(
|
||||
[("server1", json1),
|
||||
("server2", {})
|
||||
]
|
||||
@defer.inlineCallbacks
|
||||
def get_perspectives(**kwargs):
|
||||
self.assertEquals(
|
||||
LoggingContext.current_context().test_key, "11",
|
||||
)
|
||||
with logcontext.PreserveLoggingContext():
|
||||
yield persp_deferred
|
||||
defer.returnValue(persp_resp)
|
||||
self.http_client.post_json.side_effect = get_perspectives
|
||||
|
||||
with LoggingContext("11") as context_11:
|
||||
context_11.test_key = "11"
|
||||
|
||||
# start off a first set of lookups
|
||||
res_deferreds = kr.verify_json_objects_for_server(
|
||||
[("server10", json1),
|
||||
("server11", {})
|
||||
]
|
||||
)
|
||||
|
||||
# the unsigned json should be rejected pretty quickly
|
||||
self.assertTrue(res_deferreds[1].called)
|
||||
try:
|
||||
yield res_deferreds[1]
|
||||
self.assertFalse("unsigned json didn't cause a failure")
|
||||
except SynapseError:
|
||||
pass
|
||||
|
||||
self.assertFalse(res_deferreds[0].called)
|
||||
res_deferreds[0].addBoth(self.check_context, None)
|
||||
|
||||
# wait a tick for it to send the request to the perspectives server
|
||||
# (it first tries the datastore)
|
||||
yield async.sleep(0.005)
|
||||
self.http_client.post_json.assert_called_once()
|
||||
|
||||
self.assertIs(LoggingContext.current_context(), context_11)
|
||||
|
||||
context_12 = LoggingContext("12")
|
||||
context_12.test_key = "12"
|
||||
with logcontext.PreserveLoggingContext(context_12):
|
||||
# a second request for a server with outstanding requests
|
||||
# should block rather than start a second call
|
||||
self.http_client.post_json.reset_mock()
|
||||
self.http_client.post_json.return_value = defer.Deferred()
|
||||
|
||||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||
[("server10", json1)],
|
||||
)
|
||||
yield async.sleep(0.005)
|
||||
self.http_client.post_json.assert_not_called()
|
||||
res_deferreds_2[0].addBoth(self.check_context, None)
|
||||
|
||||
# complete the first request
|
||||
with logcontext.PreserveLoggingContext():
|
||||
persp_deferred.callback(persp_resp)
|
||||
self.assertIs(LoggingContext.current_context(), context_11)
|
||||
|
||||
with logcontext.PreserveLoggingContext():
|
||||
yield res_deferreds[0]
|
||||
yield res_deferreds_2[0]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_verify_json_for_server(self):
|
||||
kr = keyring.Keyring(self.hs)
|
||||
|
||||
key1 = signedjson.key.generate_signing_key(1)
|
||||
yield self.hs.datastore.store_server_verify_key(
|
||||
"server9", "", time.time() * 1000,
|
||||
signedjson.key.get_verify_key(key1),
|
||||
)
|
||||
json1 = {}
|
||||
signedjson.sign.sign_json(json1, "server9", key1)
|
||||
|
||||
# the unsigned json should be rejected pretty quickly
|
||||
try:
|
||||
yield res_deferreds[1]
|
||||
self.assertFalse("unsigned json didn't cause a failure")
|
||||
except SynapseError:
|
||||
pass
|
||||
sentinel_context = LoggingContext.current_context()
|
||||
|
||||
self.assertFalse(res_deferreds[0].called)
|
||||
with LoggingContext("one") as context_one:
|
||||
context_one.test_key = "one"
|
||||
|
||||
# wait a tick for it to send the request to the perspectives server
|
||||
# (it first tries the datastore)
|
||||
yield async.sleep(0.005)
|
||||
self.http_client.post_json.assert_called_once()
|
||||
defer = kr.verify_json_for_server("server9", {})
|
||||
try:
|
||||
yield defer
|
||||
self.fail("should fail on unsigned json")
|
||||
except SynapseError:
|
||||
pass
|
||||
self.assertIs(LoggingContext.current_context(), context_one)
|
||||
|
||||
# a second request for a server with outstanding requests should
|
||||
# block rather than start a second call
|
||||
self.http_client.post_json.reset_mock()
|
||||
self.http_client.post_json.return_value = defer.Deferred()
|
||||
defer = kr.verify_json_for_server("server9", json1)
|
||||
self.assertFalse(defer.called)
|
||||
self.assertIs(LoggingContext.current_context(), sentinel_context)
|
||||
yield defer
|
||||
|
||||
kr.verify_json_objects_for_server(
|
||||
[("server1", json1)],
|
||||
)
|
||||
yield async.sleep(0.005)
|
||||
self.http_client.post_json.assert_not_called()
|
||||
self.assertIs(LoggingContext.current_context(), context_one)
|
||||
|
|
Loading…
Reference in New Issue