Preserve some logcontexts

This commit is contained in:
Erik Johnston 2016-08-23 15:23:39 +01:00
parent 928b2187ea
commit 9219139351
18 changed files with 136 additions and 99 deletions

View File

@ -50,7 +50,7 @@ from synapse.api.urls import (
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, logcontext_tracer
from synapse.metrics import register_memory_metrics, get_metrics_for from synapse.metrics import register_memory_metrics, get_metrics_for
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
@ -449,6 +449,7 @@ def run(hs):
# Uncomment to enable tracing of log context changes. # Uncomment to enable tracing of log context changes.
# sys.settrace(logcontext_tracer) # sys.settrace(logcontext_tracer)
with LoggingContext("run"): with LoggingContext("run"):
sys.settrace(logcontext_tracer)
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)
if hs.config.gc_thresholds: if hs.config.gc_thresholds:
gc.set_threshold(*hs.config.gc_thresholds) gc.set_threshold(*hs.config.gc_thresholds)

View File

@ -150,12 +150,12 @@ class _TransactionController(object):
if service_is_up: if service_is_up:
sent = yield txn.send(self.as_api) sent = yield txn.send(self.as_api)
if sent: if sent:
txn.complete(self.store) yield txn.complete(self.store)
else: else:
self._start_recoverer(service) preserve_fn(self._start_recoverer)(service)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
self._start_recoverer(service) preserve_fn(self._start_recoverer)(service)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_recovered(self, recoverer): def on_recovered(self, recoverer):

View File

@ -308,15 +308,15 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults( res = yield preserve_context_over_deferred(defer.gatherResults(
[ [
self.store.get_server_verify_keys( preserve_fn(self.store.get_server_verify_keys)(
server_name, key_ids server_name, key_ids
).addCallback(lambda ks, server: (server, ks), server_name) ).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
defer.returnValue(dict(res)) defer.returnValue(dict(res))
@ -337,13 +337,13 @@ class Keyring(object):
) )
defer.returnValue({}) defer.returnValue({})
results = yield defer.gatherResults( results = yield preserve_context_over_deferred(defer.gatherResults(
[ [
get_key(p_name, p_keys) preserve_fn(get_key)(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items() for p_name, p_keys in self.perspective_servers.items()
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
union_of_keys = {} union_of_keys = {}
for result in results: for result in results:
@ -383,13 +383,13 @@ class Keyring(object):
defer.returnValue(keys) defer.returnValue(keys)
results = yield defer.gatherResults( results = yield preserve_context_over_deferred(defer.gatherResults(
[ [
get_key(server_name, key_ids) preserve_fn(get_key)(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
merged = {} merged = {}
for result in results: for result in results:
@ -466,9 +466,9 @@ class Keyring(object):
for server_name, response_keys in processed_response.items(): for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys) keys.setdefault(server_name, {}).update(response_keys)
yield defer.gatherResults( yield preserve_context_over_deferred(defer.gatherResults(
[ [
self.store_keys( preserve_fn(self.store_keys)(
server_name=server_name, server_name=server_name,
from_server=perspective_name, from_server=perspective_name,
verify_keys=response_keys, verify_keys=response_keys,
@ -476,7 +476,7 @@ class Keyring(object):
for server_name, response_keys in keys.items() for server_name, response_keys in keys.items()
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@ -524,7 +524,7 @@ class Keyring(object):
keys.update(response_keys) keys.update(response_keys)
yield defer.gatherResults( yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self.store_keys)( preserve_fn(self.store_keys)(
server_name=key_server_name, server_name=key_server_name,
@ -534,7 +534,7 @@ class Keyring(object):
for key_server_name, verify_keys in keys.items() for key_server_name, verify_keys in keys.items()
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@ -600,7 +600,7 @@ class Keyring(object):
response_keys.update(verify_keys) response_keys.update(verify_keys)
response_keys.update(old_verify_keys) response_keys.update(old_verify_keys)
yield defer.gatherResults( yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self.store.store_server_keys_json)( preserve_fn(self.store.store_server_keys_json)(
server_name=server_name, server_name=server_name,
@ -613,7 +613,7 @@ class Keyring(object):
for key_id in updated_key_ids for key_id in updated_key_ids
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
results[server_name] = response_keys results[server_name] = response_keys
@ -702,7 +702,7 @@ class Keyring(object):
A deferred that completes when the keys are stored. A deferred that completes when the keys are stored.
""" """
# TODO(markjh): Store whether the keys have expired. # TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults( yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self.store.store_server_verify_key)( preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key server_name, server_name, key.time_added, key
@ -710,4 +710,4 @@ class Keyring(object):
for key_id, key in verify_keys.items() for key_id, key in verify_keys.items()
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)

View File

@ -23,6 +23,7 @@ from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging import logging
@ -102,10 +103,10 @@ class FederationBase(object):
warn, pdu warn, pdu
) )
valid_pdus = yield defer.gatherResults( valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
deferreds, deferreds,
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
if include_none: if include_none:
defer.returnValue(valid_pdus) defer.returnValue(valid_pdus)
@ -129,7 +130,7 @@ class FederationBase(object):
for pdu in pdus for pdu in pdus
] ]
deferreds = self.keyring.verify_json_objects_for_server([ deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
(p.origin, p.get_pdu_json()) (p.origin, p.get_pdu_json())
for p in redacted_pdus for p in redacted_pdus
]) ])

View File

@ -27,6 +27,7 @@ from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
@ -225,10 +226,10 @@ class FederationClient(FederationBase):
] ]
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults( pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
self._check_sigs_and_hashes(pdus), self._check_sigs_and_hashes(pdus),
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
defer.returnValue(pdus) defer.returnValue(pdus)
@ -457,14 +458,16 @@ class FederationClient(FederationBase):
batch = set(missing_events[i:i + batch_size]) batch = set(missing_events[i:i + batch_size])
deferreds = [ deferreds = [
self.get_pdu( preserve_fn(self.get_pdu)(
destinations=random_server_list(), destinations=random_server_list(),
event_id=e_id, event_id=e_id,
) )
for e_id in batch for e_id in batch
] ]
res = yield defer.DeferredList(deferreds, consumeErrors=True) res = yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res: for success, result in res:
if success: if success:
signed_events.append(result) signed_events.append(result)
@ -853,14 +856,16 @@ class FederationClient(FederationBase):
return srvs return srvs
deferreds = [ deferreds = [
self.get_pdu( preserve_fn(self.get_pdu)(
destinations=random_server_list(), destinations=random_server_list(),
event_id=e_id, event_id=e_id,
) )
for e_id, depth in ordered_missing[:limit - len(signed_events)] for e_id, depth in ordered_missing[:limit - len(signed_events)]
] ]
res = yield defer.DeferredList(deferreds, consumeErrors=True) res = yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
for (result, val), (e_id, _) in zip(res, ordered_missing): for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val: if result and val:
signed_events.append(val) signed_events.append(val)

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging import logging
@ -163,10 +163,10 @@ class ApplicationServicesHandler(object):
def query_3pe(self, kind, protocol, fields): def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol) services = yield self._get_services_for_3pn(protocol)
results = yield defer.DeferredList([ results = yield preserve_context_over_deferred(defer.DeferredList([
self.appservice_api.query_3pe(service, kind, protocol, fields) preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
for service in services for service in services
], consumeErrors=True) ], consumeErrors=True))
ret = [] ret = []
for (success, result) in results: for (success, result) in results:

View File

@ -26,7 +26,9 @@ from synapse.api.errors import (
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
)
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
@ -361,9 +363,9 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch missing_auth - failed_to_fetch
) )
results = yield defer.gatherResults( results = yield preserve_context_over_deferred(defer.gatherResults(
[ [
self.replication_layer.get_pdu( preserve_fn(self.replication_layer.get_pdu)(
[dest], [dest],
event_id, event_id,
outlier=True, outlier=True,
@ -372,7 +374,7 @@ class FederationHandler(BaseHandler):
for event_id in missing_auth - failed_to_fetch for event_id in missing_auth - failed_to_fetch
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results}) auth_events.update({a.event_id: a for a in results})
required_auth.update( required_auth.update(
a_id for event in results for a_id, _ in event.auth_events a_id for event in results for a_id, _ in event.auth_events
@ -552,10 +554,10 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
states = yield defer.gatherResults([ states = yield preserve_context_over_deferred(defer.gatherResults([
self.state_handler.resolve_state_groups(room_id, [e]) preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
for e in event_ids for e in event_ids
]) ]))
states = dict(zip(event_ids, [s[1] for s in states])) states = dict(zip(event_ids, [s[1] for s in states]))
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
@ -1166,9 +1168,9 @@ class FederationHandler(BaseHandler):
a bunch of outliers, but not a chunk of individual events that depend a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations. on each other for state calculations.
""" """
contexts = yield defer.gatherResults( contexts = yield preserve_context_over_deferred(defer.gatherResults(
[ [
self._prep_event( preserve_fn(self._prep_event)(
origin, origin,
ev_info["event"], ev_info["event"],
state=ev_info.get("state"), state=ev_info.get("state"),
@ -1176,7 +1178,7 @@ class FederationHandler(BaseHandler):
) )
for ev_info in event_infos for ev_info in event_infos
] ]
) ))
yield self.store.persist_events( yield self.store.persist_events(
[ [
@ -1460,9 +1462,9 @@ class FederationHandler(BaseHandler):
# Do auth conflict res. # Do auth conflict res.
logger.info("Different auth: %s", different_auth) logger.info("Different auth: %s", different_auth)
different_events = yield defer.gatherResults( different_events = yield preserve_context_over_deferred(defer.gatherResults(
[ [
self.store.get_event( preserve_fn(self.store.get_event)(
d, d,
allow_none=True, allow_none=True,
allow_rejected=False, allow_rejected=False,
@ -1471,7 +1473,7 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d] if d in have_events and not have_events[d]
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
if different_events: if different_events:
local_view = dict(auth_events) local_view = dict(auth_events)

View File

@ -28,7 +28,8 @@ from synapse.types import (
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -502,15 +503,17 @@ class MessageHandler(BaseHandler):
lambda states: states[event.event_id] lambda states: states[event.event_id]
) )
(messages, token), current_state = yield defer.gatherResults( (messages, token), current_state = yield preserve_context_over_deferred(
defer.gatherResults(
[ [
self.store.get_recent_events_for_room( preserve_fn(self.store.get_recent_events_for_room)(
event.room_id, event.room_id,
limit=limit, limit=limit,
end_token=room_end_token, end_token=room_end_token,
), ),
deferred_room_state, deferred_room_state,
] ]
)
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield filter_events_for_client( messages = yield filter_events_for_client(
@ -719,9 +722,9 @@ class MessageHandler(BaseHandler):
presence, receipts, (messages, token) = yield defer.gatherResults( presence, receipts, (messages, token) = yield defer.gatherResults(
[ [
get_presence(), preserve_fn(get_presence)(),
get_receipts(), preserve_fn(get_receipts)(),
self.store.get_recent_events_for_room( preserve_fn(self.store.get_recent_events_for_room)(
room_id, room_id,
limit=limit, limit=limit,
end_token=now_token.room_key, end_token=now_token.room_key,
@ -755,6 +758,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret) defer.returnValue(ret)
@measure_func("_create_new_client_event")
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None): def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids: if prev_event_ids:
@ -806,6 +810,7 @@ class MessageHandler(BaseHandler):
(event, context,) (event, context,)
) )
@measure_func("handle_new_client_event")
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event( def handle_new_client_event(
self, self,
@ -934,7 +939,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _notify(): def _notify():
yield run_on_reactor() yield run_on_reactor()
self.notifier.on_new_room_event( yield self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
@ -944,6 +949,6 @@ class MessageHandler(BaseHandler):
# If invite, remove room_state from unsigned before sending. # If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None) event.unsigned.pop("invite_room_state", None)
federation_handler.handle_new_event( preserve_fn(federation_handler.handle_new_event)(
event, destinations=destinations, event, destinations=destinations,
) )

View File

@ -16,7 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
)
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.types import UserID from synapse.types import UserID
@ -169,13 +171,13 @@ class TypingHandler(object):
deferreds = [] deferreds = []
for domain in domains: for domain in domains:
if domain == self.server_name: if domain == self.server_name:
self._push_update_local( preserve_fn(self._push_update_local)(
room_id=room_id, room_id=room_id,
user_id=user_id, user_id=user_id,
typing=typing typing=typing
) )
else: else:
deferreds.append(self.federation.send_edu( deferreds.append(preserve_fn(self.federation.send_edu)(
destination=domain, destination=domain,
edu_type="m.typing", edu_type="m.typing",
content={ content={
@ -185,7 +187,9 @@ class TypingHandler(object):
}, },
)) ))
yield defer.DeferredList(deferreds, consumeErrors=True) yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _recv_edu(self, origin, content): def _recv_edu(self, origin, content):

View File

@ -19,7 +19,7 @@ from synapse.api.errors import AuthError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -174,6 +174,7 @@ class Notifier(object):
lambda: len(self.user_to_user_stream), lambda: len(self.user_to_user_stream),
) )
@preserve_fn
def on_new_room_event(self, event, room_stream_id, max_room_stream_id, def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]): extra_users=[]):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
@ -195,6 +196,7 @@ class Notifier(object):
self.notify_replication() self.notify_replication()
@preserve_fn
def _notify_pending_new_room_events(self, max_room_stream_id): def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous """Notify for the room events that were queued waiting for a previous
event to be persisted. event to be persisted.
@ -212,6 +214,7 @@ class Notifier(object):
else: else:
self._on_new_room_event(event, room_stream_id, extra_users) self._on_new_room_event(event, room_stream_id, extra_users)
@preserve_fn
def _on_new_room_event(self, event, room_stream_id, extra_users=[]): def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
@ -226,6 +229,7 @@ class Notifier(object):
rooms=[event.room_id], rooms=[event.room_id],
) )
@preserve_fn
def on_new_event(self, stream_key, new_token, users=[], rooms=[]): def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise. """ Used to inform listeners that something has happend event wise.
@ -252,6 +256,7 @@ class Notifier(object):
self.notify_replication() self.notify_replication()
@preserve_fn
def on_new_replication_data(self): def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend """Used to inform replication listeners that something has happend
without waking up any of the normal user event streams""" without waking up any of the normal user event streams"""

View File

@ -17,14 +17,15 @@ from twisted.internet import defer
from synapse.util.presentable_names import ( from synapse.util.presentable_names import (
calculate_room_name, name_from_member_event calculate_room_name, name_from_member_event
) )
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def get_badge_count(store, user_id): def get_badge_count(store, user_id):
invites, joins = yield defer.gatherResults([ invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
store.get_invited_rooms_for_user(user_id), preserve_fn(store.get_invited_rooms_for_user)(user_id),
store.get_rooms_for_user(user_id), preserve_fn(store.get_rooms_for_user)(user_id),
], consumeErrors=True) ], consumeErrors=True))
my_receipts_by_room = yield store.get_receipts_for_user( my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read", user_id, "m.read",

View File

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
import pusher import pusher
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
import logging import logging
@ -130,10 +130,12 @@ class PusherPool:
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
deferreds.append( deferreds.append(
p.on_new_notifications(min_stream_id, max_stream_id) preserve_fn(p.on_new_notifications)(
min_stream_id, max_stream_id
)
) )
yield defer.gatherResults(deferreds) yield preserve_context_over_deferred(defer.gatherResults(deferreds))
except: except:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
@ -155,10 +157,10 @@ class PusherPool:
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
deferreds.append( deferreds.append(
p.on_new_receipts(min_stream_id, max_stream_id) preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
) )
yield defer.gatherResults(deferreds) yield preserve_context_over_deferred(defer.gatherResults(deferreds))
except: except:
logger.exception("Exception in pusher on_new_receipts") logger.exception("Exception in pusher on_new_receipts")

View File

@ -403,10 +403,9 @@ class RegisterRestServlet(RestServlet):
# register the user's device # register the user's device
device_id = params.get("device_id") device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
device_id = self.device_handler.check_device_registered( return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )
return device_id
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_guest_registration(self): def _do_guest_registration(self):

View File

@ -20,7 +20,9 @@ from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext from synapse.util.logcontext import (
preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
)
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -202,7 +204,7 @@ class EventsStore(SQLBaseStore):
deferreds = [] deferreds = []
for room_id, evs_ctxs in partitioned.items(): for room_id, evs_ctxs in partitioned.items():
d = self._event_persist_queue.add_to_queue( d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs, room_id, evs_ctxs,
backfilled=backfilled, backfilled=backfilled,
current_state=None, current_state=None,
@ -212,7 +214,9 @@ class EventsStore(SQLBaseStore):
for room_id in partitioned.keys(): for room_id in partitioned.keys():
self._maybe_start_persisting(room_id) self._maybe_start_persisting(room_id)
return defer.gatherResults(deferreds, consumeErrors=True) return preserve_context_over_deferred(
defer.gatherResults(deferreds, consumeErrors=True)
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -225,7 +229,7 @@ class EventsStore(SQLBaseStore):
self._maybe_start_persisting(event.room_id) self._maybe_start_persisting(event.room_id)
yield deferred yield preserve_context_over_deferred(deferred)
max_persisted_id = yield self._stream_id_gen.get_current_token() max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
@ -1088,7 +1092,7 @@ class EventsStore(SQLBaseStore):
if not allow_rejected: if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]] rows[:] = [r for r in rows if not r["rejects"]]
res = yield defer.gatherResults( res = yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self._get_event_from_row)( preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"], row["internal_metadata"], row["json"], row["redacts"],
@ -1097,7 +1101,7 @@ class EventsStore(SQLBaseStore):
for row in rows for row in rows
], ],
consumeErrors=True consumeErrors=True
) ))
defer.returnValue({ defer.returnValue({
e.event.event_id: e e.event.event_id: e

View File

@ -39,7 +39,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging import logging
@ -234,12 +234,12 @@ class StreamStore(SQLBaseStore):
results = {} results = {}
room_ids = list(room_ids) room_ids = list(room_ids)
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
res = yield defer.gatherResults([ res = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(self.get_room_events_stream_for_room)( preserve_fn(self.get_room_events_stream_for_room)(
room_id, from_key, to_key, limit, order=order, room_id, from_key, to_key, limit, order=order,
) )
for room_id in rm_ids for room_id in rm_ids
]) ]))
results.update(dict(zip(rm_ids, res))) results.update(dict(zip(rm_ids, res)))
defer.returnValue(results) defer.returnValue(results)

View File

@ -146,10 +146,10 @@ def concurrently_execute(func, args, limit):
except StopIteration: except StopIteration:
pass pass
return defer.gatherResults([ return preserve_context_over_deferred(defer.gatherResults([
preserve_fn(_concurrently_execute_inner)() preserve_fn(_concurrently_execute_inner)()
for _ in xrange(limit) for _ in xrange(limit)
], consumeErrors=True).addErrback(unwrapFirstError) ], consumeErrors=True)).addErrback(unwrapFirstError)
class Linearizer(object): class Linearizer(object):
@ -181,7 +181,8 @@ class Linearizer(object):
self.key_to_defer[key] = new_defer self.key_to_defer[key] = new_defer
if current_defer: if current_defer:
yield preserve_context_over_deferred(current_defer) with PreserveLoggingContext():
yield current_defer
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager():
@ -264,7 +265,7 @@ class ReadWriteLock(object):
curr_readers.clear() curr_readers.clear()
self.key_to_current_writer[key] = new_defer self.key_to_current_writer[key] = new_defer
yield defer.gatherResults(to_wait_on) yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager():

View File

@ -297,12 +297,13 @@ def preserve_context_over_fn(fn, *args, **kwargs):
return res return res
def preserve_context_over_deferred(deferred): def preserve_context_over_deferred(deferred, context=None):
"""Given a deferred wrap it such that any callbacks added later to it will """Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context. be invoked with the current context.
""" """
current_context = LoggingContext.current_context() if context is None:
d = _PreservingContextDeferred(current_context) context = LoggingContext.current_context()
d = _PreservingContextDeferred(context)
deferred.chainDeferred(d) deferred.chainDeferred(d)
return d return d
@ -316,7 +317,13 @@ def preserve_fn(f):
def g(*args, **kwargs): def g(*args, **kwargs):
with PreserveLoggingContext(current): with PreserveLoggingContext(current):
return f(*args, **kwargs) res = f(*args, **kwargs)
if isinstance(res, defer.Deferred):
return preserve_context_over_deferred(
res, context=LoggingContext.sentinel
)
else:
return res
return g return g

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging import logging
@ -55,12 +55,12 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
given events given events
events ([synapse.events.EventBase]): list of events to filter events ([synapse.events.EventBase]): list of events to filter
""" """
forgotten = yield defer.gatherResults([ forgotten = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(store.who_forgot_in_room)( preserve_fn(store.who_forgot_in_room)(
room_id, room_id,
) )
for room_id in frozenset(e.room_id for e in events) for room_id in frozenset(e.room_id for e in events)
], consumeErrors=True) ], consumeErrors=True))
# Set of membership event_ids that have been forgotten # Set of membership event_ids that have been forgotten
event_id_forgotten = frozenset( event_id_forgotten = frozenset(