Fix logcontexts in _check_sigs_and_hashes
This commit is contained in:
parent
72472456d8
commit
6de74ea6d7
|
@ -18,8 +18,7 @@ from synapse.api.errors import SynapseError
|
||||||
from synapse.crypto.event_signing import check_event_content_hash
|
from synapse.crypto.event_signing import check_event_content_hash
|
||||||
from synapse.events import spamcheck
|
from synapse.events import spamcheck
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError, logcontext
|
||||||
from synapse.util.logcontext import preserve_context_over_deferred
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -51,56 +50,52 @@ class FederationBase(object):
|
||||||
"""
|
"""
|
||||||
deferreds = self._check_sigs_and_hashes(pdus)
|
deferreds = self._check_sigs_and_hashes(pdus)
|
||||||
|
|
||||||
def callback(pdu):
|
@defer.inlineCallbacks
|
||||||
return pdu
|
def handle_check_result(pdu, deferred):
|
||||||
|
try:
|
||||||
|
res = yield logcontext.make_deferred_yieldable(deferred)
|
||||||
|
except SynapseError:
|
||||||
|
res = None
|
||||||
|
|
||||||
def errback(failure, pdu):
|
|
||||||
failure.trap(SynapseError)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def try_local_db(res, pdu):
|
|
||||||
if not res:
|
if not res:
|
||||||
# Check local db.
|
# Check local db.
|
||||||
return self.store.get_event(
|
res = yield self.store.get_event(
|
||||||
pdu.event_id,
|
pdu.event_id,
|
||||||
allow_rejected=True,
|
allow_rejected=True,
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
return res
|
|
||||||
|
|
||||||
def try_remote(res, pdu):
|
|
||||||
if not res and pdu.origin != origin:
|
if not res and pdu.origin != origin:
|
||||||
return self.get_pdu(
|
try:
|
||||||
destinations=[pdu.origin],
|
res = yield self.get_pdu(
|
||||||
event_id=pdu.event_id,
|
destinations=[pdu.origin],
|
||||||
outlier=outlier,
|
event_id=pdu.event_id,
|
||||||
timeout=10000,
|
outlier=outlier,
|
||||||
).addErrback(lambda e: None)
|
timeout=10000,
|
||||||
return res
|
)
|
||||||
|
except SynapseError:
|
||||||
|
pass
|
||||||
|
|
||||||
def warn(res, pdu):
|
|
||||||
if not res:
|
if not res:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Failed to find copy of %s with valid signature",
|
"Failed to find copy of %s with valid signature",
|
||||||
pdu.event_id,
|
pdu.event_id,
|
||||||
)
|
)
|
||||||
return res
|
|
||||||
|
|
||||||
for pdu, deferred in zip(pdus, deferreds):
|
defer.returnValue(res)
|
||||||
deferred.addCallbacks(
|
|
||||||
callback, errback, errbackArgs=[pdu]
|
handle = logcontext.preserve_fn(handle_check_result)
|
||||||
).addCallback(
|
deferreds2 = [
|
||||||
try_local_db, pdu
|
handle(pdu, deferred)
|
||||||
).addCallback(
|
for pdu, deferred in zip(pdus, deferreds)
|
||||||
try_remote, pdu
|
]
|
||||||
).addCallback(
|
|
||||||
warn, pdu
|
valid_pdus = yield logcontext.make_deferred_yieldable(
|
||||||
|
defer.gatherResults(
|
||||||
|
deferreds2,
|
||||||
|
consumeErrors=True,
|
||||||
)
|
)
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
|
|
||||||
deferreds,
|
|
||||||
consumeErrors=True
|
|
||||||
)).addErrback(unwrapFirstError)
|
|
||||||
|
|
||||||
if include_none:
|
if include_none:
|
||||||
defer.returnValue(valid_pdus)
|
defer.returnValue(valid_pdus)
|
||||||
|
@ -108,7 +103,9 @@ class FederationBase(object):
|
||||||
defer.returnValue([p for p in valid_pdus if p])
|
defer.returnValue([p for p in valid_pdus if p])
|
||||||
|
|
||||||
def _check_sigs_and_hash(self, pdu):
|
def _check_sigs_and_hash(self, pdu):
|
||||||
return self._check_sigs_and_hashes([pdu])[0]
|
return logcontext.make_deferred_yieldable(
|
||||||
|
self._check_sigs_and_hashes([pdu])[0],
|
||||||
|
)
|
||||||
|
|
||||||
def _check_sigs_and_hashes(self, pdus):
|
def _check_sigs_and_hashes(self, pdus):
|
||||||
"""Checks that each of the received events is correctly signed by the
|
"""Checks that each of the received events is correctly signed by the
|
||||||
|
@ -123,6 +120,7 @@ class FederationBase(object):
|
||||||
* returns a redacted version of the event (if the signature
|
* returns a redacted version of the event (if the signature
|
||||||
matched but the hash did not)
|
matched but the hash did not)
|
||||||
* throws a SynapseError if the signature check failed.
|
* throws a SynapseError if the signature check failed.
|
||||||
|
The deferreds run their callbacks in the sentinel logcontext.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
redacted_pdus = [
|
redacted_pdus = [
|
||||||
|
@ -135,29 +133,33 @@ class FederationBase(object):
|
||||||
for p in redacted_pdus
|
for p in redacted_pdus
|
||||||
])
|
])
|
||||||
|
|
||||||
|
ctx = logcontext.LoggingContext.current_context()
|
||||||
|
|
||||||
def callback(_, pdu, redacted):
|
def callback(_, pdu, redacted):
|
||||||
if not check_event_content_hash(pdu):
|
with logcontext.PreserveLoggingContext(ctx):
|
||||||
logger.warn(
|
if not check_event_content_hash(pdu):
|
||||||
"Event content has been tampered, redacting %s: %s",
|
logger.warn(
|
||||||
pdu.event_id, pdu.get_pdu_json()
|
"Event content has been tampered, redacting %s: %s",
|
||||||
)
|
pdu.event_id, pdu.get_pdu_json()
|
||||||
return redacted
|
)
|
||||||
|
return redacted
|
||||||
|
|
||||||
if spamcheck.check_event_for_spam(pdu):
|
if spamcheck.check_event_for_spam(pdu):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Event contains spam, redacting %s: %s",
|
"Event contains spam, redacting %s: %s",
|
||||||
pdu.event_id, pdu.get_pdu_json()
|
pdu.event_id, pdu.get_pdu_json()
|
||||||
)
|
)
|
||||||
return redacted
|
return redacted
|
||||||
|
|
||||||
return pdu
|
return pdu
|
||||||
|
|
||||||
def errback(failure, pdu):
|
def errback(failure, pdu):
|
||||||
failure.trap(SynapseError)
|
failure.trap(SynapseError)
|
||||||
logger.warn(
|
with logcontext.PreserveLoggingContext(ctx):
|
||||||
"Signature check failed for %s",
|
logger.warn(
|
||||||
pdu.event_id,
|
"Signature check failed for %s",
|
||||||
)
|
pdu.event_id,
|
||||||
|
)
|
||||||
return failure
|
return failure
|
||||||
|
|
||||||
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
|
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
|
||||||
|
|
|
@ -22,7 +22,7 @@ from synapse.api.constants import Membership
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException, HttpResponseException, SynapseError,
|
CodeMessageException, HttpResponseException, SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError, logcontext
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
|
@ -189,10 +189,10 @@ class FederationClient(FederationBase):
|
||||||
]
|
]
|
||||||
|
|
||||||
# FIXME: We should handle signature failures more gracefully.
|
# FIXME: We should handle signature failures more gracefully.
|
||||||
pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
|
pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||||
self._check_sigs_and_hashes(pdus),
|
self._check_sigs_and_hashes(pdus),
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
)).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError))
|
||||||
|
|
||||||
defer.returnValue(pdus)
|
defer.returnValue(pdus)
|
||||||
|
|
||||||
|
@ -252,7 +252,7 @@ class FederationClient(FederationBase):
|
||||||
pdu = pdu_list[0]
|
pdu = pdu_list[0]
|
||||||
|
|
||||||
# Check signatures are correct.
|
# Check signatures are correct.
|
||||||
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
|
signed_pdu = yield self._check_sigs_and_hash(pdu)
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue