Pass around the reactor explicitly (#3385)
This commit is contained in:
parent
c2eff937ac
commit
77ac14b960
|
@ -33,6 +33,7 @@ import logging
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
import simplejson
|
import simplejson
|
||||||
|
import attr
|
||||||
|
|
||||||
import synapse.util.stringutils as stringutils
|
import synapse.util.stringutils as stringutils
|
||||||
|
|
||||||
|
@ -854,7 +855,11 @@ class AuthHandler(BaseHandler):
|
||||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
||||||
bcrypt.gensalt(self.bcrypt_rounds))
|
bcrypt.gensalt(self.bcrypt_rounds))
|
||||||
|
|
||||||
return make_deferred_yieldable(threads.deferToThread(_do_hash))
|
return make_deferred_yieldable(
|
||||||
|
threads.deferToThreadPool(
|
||||||
|
self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def validate_hash(self, password, stored_hash):
|
def validate_hash(self, password, stored_hash):
|
||||||
"""Validates that self.hash(password) == stored_hash.
|
"""Validates that self.hash(password) == stored_hash.
|
||||||
|
@ -874,16 +879,21 @@ class AuthHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
if stored_hash:
|
if stored_hash:
|
||||||
return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
|
return make_deferred_yieldable(
|
||||||
|
threads.deferToThreadPool(
|
||||||
|
self.hs.get_reactor(),
|
||||||
|
self.hs.get_reactor().getThreadPool(),
|
||||||
|
_do_validate_hash,
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return defer.succeed(False)
|
return defer.succeed(False)
|
||||||
|
|
||||||
|
|
||||||
class MacaroonGeneartor(object):
|
@attr.s
|
||||||
def __init__(self, hs):
|
class MacaroonGenerator(object):
|
||||||
self.clock = hs.get_clock()
|
|
||||||
self.server_name = hs.config.server_name
|
hs = attr.ib()
|
||||||
self.macaroon_secret_key = hs.config.macaroon_secret_key
|
|
||||||
|
|
||||||
def generate_access_token(self, user_id, extra_caveats=None):
|
def generate_access_token(self, user_id, extra_caveats=None):
|
||||||
extra_caveats = extra_caveats or []
|
extra_caveats = extra_caveats or []
|
||||||
|
@ -901,7 +911,7 @@ class MacaroonGeneartor(object):
|
||||||
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = login")
|
macaroon.add_first_party_caveat("type = login")
|
||||||
now = self.clock.time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
expiry = now + duration_in_ms
|
expiry = now + duration_in_ms
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
@ -913,9 +923,9 @@ class MacaroonGeneartor(object):
|
||||||
|
|
||||||
def _generate_base_macaroon(self, user_id):
|
def _generate_base_macaroon(self, user_id):
|
||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
location=self.server_name,
|
location=self.hs.config.server_name,
|
||||||
identifier="key",
|
identifier="key",
|
||||||
key=self.macaroon_secret_key)
|
key=self.hs.config.macaroon_secret_key)
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
return macaroon
|
return macaroon
|
||||||
|
|
|
@ -806,6 +806,7 @@ class EventCreationHandler(object):
|
||||||
# If we're a worker we need to hit out to the master.
|
# If we're a worker we need to hit out to the master.
|
||||||
if self.config.worker_app:
|
if self.config.worker_app:
|
||||||
yield send_event_to_master(
|
yield send_event_to_master(
|
||||||
|
self.hs.get_clock(),
|
||||||
self.http_client,
|
self.http_client,
|
||||||
host=self.config.worker_replication_host,
|
host=self.config.worker_replication_host,
|
||||||
port=self.config.worker_replication_http_port,
|
port=self.config.worker_replication_http_port,
|
||||||
|
|
|
@ -19,7 +19,6 @@ from twisted.internet import defer
|
||||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||||
from synapse.storage.roommember import ProfileInfo
|
from synapse.storage.roommember import ProfileInfo
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.util.async import sleep
|
|
||||||
from synapse.types import get_localpart_from_id
|
from synapse.types import get_localpart_from_id
|
||||||
|
|
||||||
from six import iteritems
|
from six import iteritems
|
||||||
|
@ -174,7 +173,7 @@ class UserDirectoryHandler(object):
|
||||||
logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
|
logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
|
||||||
yield self._handle_initial_room(room_id)
|
yield self._handle_initial_room(room_id)
|
||||||
num_processed_rooms += 1
|
num_processed_rooms += 1
|
||||||
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||||
|
|
||||||
logger.info("Processed all rooms.")
|
logger.info("Processed all rooms.")
|
||||||
|
|
||||||
|
@ -188,7 +187,7 @@ class UserDirectoryHandler(object):
|
||||||
logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
|
logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
|
||||||
yield self._handle_local_user(user_id)
|
yield self._handle_local_user(user_id)
|
||||||
num_processed_users += 1
|
num_processed_users += 1
|
||||||
yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
|
yield self.clock.sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
|
||||||
|
|
||||||
logger.info("Processed all users")
|
logger.info("Processed all users")
|
||||||
|
|
||||||
|
@ -236,7 +235,7 @@ class UserDirectoryHandler(object):
|
||||||
count = 0
|
count = 0
|
||||||
for user_id in user_ids:
|
for user_id in user_ids:
|
||||||
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
|
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
|
||||||
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||||
|
|
||||||
if not self.is_mine_id(user_id):
|
if not self.is_mine_id(user_id):
|
||||||
count += 1
|
count += 1
|
||||||
|
@ -251,7 +250,7 @@ class UserDirectoryHandler(object):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
|
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
|
||||||
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
user_set = (user_id, other_user_id)
|
user_set = (user_id, other_user_id)
|
||||||
|
|
|
@ -98,8 +98,8 @@ class SimpleHttpClient(object):
|
||||||
method, uri, *args, **kwargs
|
method, uri, *args, **kwargs
|
||||||
)
|
)
|
||||||
add_timeout_to_deferred(
|
add_timeout_to_deferred(
|
||||||
request_deferred,
|
request_deferred, 60, self.hs.get_reactor(),
|
||||||
60, cancelled_to_request_timed_out_error,
|
cancelled_to_request_timed_out_error,
|
||||||
)
|
)
|
||||||
response = yield make_deferred_yieldable(request_deferred)
|
response = yield make_deferred_yieldable(request_deferred)
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ class SimpleHttpClient(object):
|
||||||
"Error sending request to %s %s: %s %s",
|
"Error sending request to %s %s: %s %s",
|
||||||
method, redact_uri(uri), type(e).__name__, e.message
|
method, redact_uri(uri), type(e).__name__, e.message
|
||||||
)
|
)
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def post_urlencoded_get_json(self, uri, args={}, headers=None):
|
def post_urlencoded_get_json(self, uri, args={}, headers=None):
|
||||||
|
|
|
@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone
|
||||||
from synapse.http import cancelled_to_request_timed_out_error
|
from synapse.http import cancelled_to_request_timed_out_error
|
||||||
from synapse.http.endpoint import matrix_federation_endpoint
|
from synapse.http.endpoint import matrix_federation_endpoint
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
from synapse.util.async import sleep, add_timeout_to_deferred
|
from synapse.util.async import add_timeout_to_deferred
|
||||||
from synapse.util import logcontext
|
from synapse.util import logcontext
|
||||||
from synapse.util.logcontext import make_deferred_yieldable
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
import synapse.util.retryutils
|
import synapse.util.retryutils
|
||||||
|
@ -193,6 +193,7 @@ class MatrixFederationHttpClient(object):
|
||||||
add_timeout_to_deferred(
|
add_timeout_to_deferred(
|
||||||
request_deferred,
|
request_deferred,
|
||||||
timeout / 1000. if timeout else 60,
|
timeout / 1000. if timeout else 60,
|
||||||
|
self.hs.get_reactor(),
|
||||||
cancelled_to_request_timed_out_error,
|
cancelled_to_request_timed_out_error,
|
||||||
)
|
)
|
||||||
response = yield make_deferred_yieldable(
|
response = yield make_deferred_yieldable(
|
||||||
|
@ -234,7 +235,7 @@ class MatrixFederationHttpClient(object):
|
||||||
delay = min(delay, 2)
|
delay = min(delay, 2)
|
||||||
delay *= random.uniform(0.8, 1.4)
|
delay *= random.uniform(0.8, 1.4)
|
||||||
|
|
||||||
yield sleep(delay)
|
yield self.clock.sleep(delay)
|
||||||
retries_left -= 1
|
retries_left -= 1
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
|
@ -161,6 +161,7 @@ class Notifier(object):
|
||||||
self.user_to_user_stream = {}
|
self.user_to_user_stream = {}
|
||||||
self.room_to_user_streams = {}
|
self.room_to_user_streams = {}
|
||||||
|
|
||||||
|
self.hs = hs
|
||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.pending_new_room_events = []
|
self.pending_new_room_events = []
|
||||||
|
@ -340,6 +341,7 @@ class Notifier(object):
|
||||||
add_timeout_to_deferred(
|
add_timeout_to_deferred(
|
||||||
listener.deferred,
|
listener.deferred,
|
||||||
(end_time - now) / 1000.,
|
(end_time - now) / 1000.,
|
||||||
|
self.hs.get_reactor(),
|
||||||
)
|
)
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
yield listener.deferred
|
yield listener.deferred
|
||||||
|
@ -561,6 +563,7 @@ class Notifier(object):
|
||||||
add_timeout_to_deferred(
|
add_timeout_to_deferred(
|
||||||
listener.deferred.addTimeout,
|
listener.deferred.addTimeout,
|
||||||
(end_time - now) / 1000.,
|
(end_time - now) / 1000.,
|
||||||
|
self.hs.get_reactor(),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
|
|
|
@ -21,7 +21,6 @@ from synapse.api.errors import (
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
from synapse.util.async import sleep
|
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.types import Requester, UserID
|
from synapse.types import Requester, UserID
|
||||||
|
@ -33,11 +32,12 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_event_to_master(client, host, port, requester, event, context,
|
def send_event_to_master(clock, client, host, port, requester, event, context,
|
||||||
ratelimit, extra_users):
|
ratelimit, extra_users):
|
||||||
"""Send event to be handled on the master
|
"""Send event to be handled on the master
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
clock (synapse.util.Clock)
|
||||||
client (SimpleHttpClient)
|
client (SimpleHttpClient)
|
||||||
host (str): host of master
|
host (str): host of master
|
||||||
port (int): port on master listening for HTTP replication
|
port (int): port on master listening for HTTP replication
|
||||||
|
@ -77,7 +77,7 @@ def send_event_to_master(client, host, port, requester, event, context,
|
||||||
|
|
||||||
# If we timed out we probably don't need to worry about backing
|
# If we timed out we probably don't need to worry about backing
|
||||||
# off too much, but lets just wait a little anyway.
|
# off too much, but lets just wait a little anyway.
|
||||||
yield sleep(1)
|
yield clock.sleep(1)
|
||||||
except MatrixCodeMessageException as e:
|
except MatrixCodeMessageException as e:
|
||||||
# We convert to SynapseError as we know that it was a SynapseError
|
# We convert to SynapseError as we know that it was a SynapseError
|
||||||
# on the master process that we should send to the client. (And
|
# on the master process that we should send to the client. (And
|
||||||
|
|
|
@ -58,6 +58,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
|
||||||
|
|
||||||
class MediaRepository(object):
|
class MediaRepository(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.client = MatrixFederationHttpClient(hs)
|
self.client = MatrixFederationHttpClient(hs)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
@ -94,7 +95,7 @@ class MediaRepository(object):
|
||||||
storage_providers.append(provider)
|
storage_providers.append(provider)
|
||||||
|
|
||||||
self.media_storage = MediaStorage(
|
self.media_storage = MediaStorage(
|
||||||
self.primary_base_path, self.filepaths, storage_providers,
|
self.hs, self.primary_base_path, self.filepaths, storage_providers,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.clock.looping_call(
|
self.clock.looping_call(
|
||||||
|
|
|
@ -37,13 +37,15 @@ class MediaStorage(object):
|
||||||
"""Responsible for storing/fetching files from local sources.
|
"""Responsible for storing/fetching files from local sources.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
hs (synapse.server.Homeserver)
|
||||||
local_media_directory (str): Base path where we store media on disk
|
local_media_directory (str): Base path where we store media on disk
|
||||||
filepaths (MediaFilePaths)
|
filepaths (MediaFilePaths)
|
||||||
storage_providers ([StorageProvider]): List of StorageProvider that are
|
storage_providers ([StorageProvider]): List of StorageProvider that are
|
||||||
used to fetch and store files.
|
used to fetch and store files.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, local_media_directory, filepaths, storage_providers):
|
def __init__(self, hs, local_media_directory, filepaths, storage_providers):
|
||||||
|
self.hs = hs
|
||||||
self.local_media_directory = local_media_directory
|
self.local_media_directory = local_media_directory
|
||||||
self.filepaths = filepaths
|
self.filepaths = filepaths
|
||||||
self.storage_providers = storage_providers
|
self.storage_providers = storage_providers
|
||||||
|
@ -175,7 +177,8 @@ class MediaStorage(object):
|
||||||
res = yield provider.fetch(path, file_info)
|
res = yield provider.fetch(path, file_info)
|
||||||
if res:
|
if res:
|
||||||
with res:
|
with res:
|
||||||
consumer = BackgroundFileConsumer(open(local_path, "w"))
|
consumer = BackgroundFileConsumer(
|
||||||
|
open(local_path, "w"), self.hs.get_reactor())
|
||||||
yield res.write_to_consumer(consumer)
|
yield res.write_to_consumer(consumer)
|
||||||
yield consumer.wait()
|
yield consumer.wait()
|
||||||
defer.returnValue(local_path)
|
defer.returnValue(local_path)
|
||||||
|
|
|
@ -40,7 +40,7 @@ from synapse.federation.transport.client import TransportLayerClient
|
||||||
from synapse.federation.transaction_queue import TransactionQueue
|
from synapse.federation.transaction_queue import TransactionQueue
|
||||||
from synapse.handlers import Handlers
|
from synapse.handlers import Handlers
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
|
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
|
||||||
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||||
from synapse.handlers.devicemessage import DeviceMessageHandler
|
from synapse.handlers.devicemessage import DeviceMessageHandler
|
||||||
from synapse.handlers.device import DeviceHandler
|
from synapse.handlers.device import DeviceHandler
|
||||||
|
@ -165,15 +165,19 @@ class HomeServer(object):
|
||||||
'server_notices_sender',
|
'server_notices_sender',
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, hostname, **kwargs):
|
def __init__(self, hostname, reactor=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hostname : The hostname for the server.
|
hostname : The hostname for the server.
|
||||||
"""
|
"""
|
||||||
|
if not reactor:
|
||||||
|
from twisted.internet import reactor
|
||||||
|
|
||||||
|
self._reactor = reactor
|
||||||
self.hostname = hostname
|
self.hostname = hostname
|
||||||
self._building = {}
|
self._building = {}
|
||||||
|
|
||||||
self.clock = Clock()
|
self.clock = Clock(reactor)
|
||||||
self.distributor = Distributor()
|
self.distributor = Distributor()
|
||||||
self.ratelimiter = Ratelimiter()
|
self.ratelimiter = Ratelimiter()
|
||||||
|
|
||||||
|
@ -186,6 +190,12 @@ class HomeServer(object):
|
||||||
self.datastore = DataStore(self.get_db_conn(), self)
|
self.datastore = DataStore(self.get_db_conn(), self)
|
||||||
logger.info("Finished setting up.")
|
logger.info("Finished setting up.")
|
||||||
|
|
||||||
|
def get_reactor(self):
|
||||||
|
"""
|
||||||
|
Fetch the Twisted reactor in use by this HomeServer.
|
||||||
|
"""
|
||||||
|
return self._reactor
|
||||||
|
|
||||||
def get_ip_from_request(self, request):
|
def get_ip_from_request(self, request):
|
||||||
# X-Forwarded-For is handled by our custom request type.
|
# X-Forwarded-For is handled by our custom request type.
|
||||||
return request.getClientIP()
|
return request.getClientIP()
|
||||||
|
@ -261,7 +271,7 @@ class HomeServer(object):
|
||||||
return AuthHandler(self)
|
return AuthHandler(self)
|
||||||
|
|
||||||
def build_macaroon_generator(self):
|
def build_macaroon_generator(self):
|
||||||
return MacaroonGeneartor(self)
|
return MacaroonGenerator(self)
|
||||||
|
|
||||||
def build_device_handler(self):
|
def build_device_handler(self):
|
||||||
return DeviceHandler(self)
|
return DeviceHandler(self)
|
||||||
|
@ -328,6 +338,7 @@ class HomeServer(object):
|
||||||
|
|
||||||
return adbapi.ConnectionPool(
|
return adbapi.ConnectionPool(
|
||||||
name,
|
name,
|
||||||
|
cp_reactor=self.get_reactor(),
|
||||||
**self.db_config.get("args", {})
|
**self.db_config.get("args", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import synapse.util.async
|
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from . import engines
|
from . import engines
|
||||||
|
@ -92,7 +91,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
logger.info("Starting background schema updates")
|
logger.info("Starting background schema updates")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
yield synapse.util.async.sleep(
|
yield self.hs.get_clock().sleep(
|
||||||
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
|
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import Cache
|
from ._base import Cache
|
||||||
from . import background_updates
|
from . import background_updates
|
||||||
|
@ -70,7 +70,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
self._client_ip_looper = self._clock.looping_call(
|
self._client_ip_looper = self._clock.looping_call(
|
||||||
self._update_client_ips_batch, 5 * 1000
|
self._update_client_ips_batch, 5 * 1000
|
||||||
)
|
)
|
||||||
reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
|
self.hs.get_reactor().addSystemEventTrigger(
|
||||||
|
"before", "shutdown", self._update_client_ips_batch
|
||||||
|
)
|
||||||
|
|
||||||
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
|
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
|
||||||
now=None):
|
now=None):
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore, LoggingTransaction
|
from synapse.storage._base import SQLBaseStore, LoggingTransaction
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from synapse.util.async import sleep
|
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -800,7 +799,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||||
)
|
)
|
||||||
if caught_up:
|
if caught_up:
|
||||||
break
|
break
|
||||||
yield sleep(5)
|
yield self.hs.get_clock().sleep(5)
|
||||||
finally:
|
finally:
|
||||||
self._doing_notif_rotation = False
|
self._doing_notif_rotation = False
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
|
@ -265,7 +265,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to callback")
|
logger.exception("Failed to callback")
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
reactor.callFromThread(fire, event_list, row_dict)
|
self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("do_fetch")
|
logger.exception("do_fetch")
|
||||||
|
|
||||||
|
@ -278,7 +278,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
if event_list:
|
if event_list:
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
reactor.callFromThread(fire, event_list)
|
self.hs.get_reactor().callFromThread(fire, event_list)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
|
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
|
||||||
|
|
|
@ -13,15 +13,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
|
||||||
|
|
||||||
from twisted.internet import defer, reactor, task
|
|
||||||
|
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from twisted.internet import defer, task
|
||||||
|
|
||||||
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,16 +30,24 @@ def unwrapFirstError(failure):
|
||||||
return failure.value.subFailure
|
return failure.value.subFailure
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
class Clock(object):
|
class Clock(object):
|
||||||
"""A small utility that obtains current time-of-day so that time may be
|
|
||||||
mocked during unit-tests.
|
|
||||||
|
|
||||||
TODO(paul): Also move the sleep() functionality into it
|
|
||||||
"""
|
"""
|
||||||
|
A Clock wraps a Twisted reactor and provides utilities on top of it.
|
||||||
|
"""
|
||||||
|
_reactor = attr.ib()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def sleep(self, seconds):
|
||||||
|
d = defer.Deferred()
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
self._reactor.callLater(seconds, d.callback, seconds)
|
||||||
|
res = yield d
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
def time(self):
|
def time(self):
|
||||||
"""Returns the current system time in seconds since epoch."""
|
"""Returns the current system time in seconds since epoch."""
|
||||||
return time.time()
|
return self._reactor.seconds()
|
||||||
|
|
||||||
def time_msec(self):
|
def time_msec(self):
|
||||||
"""Returns the current system time in miliseconds since epoch."""
|
"""Returns the current system time in miliseconds since epoch."""
|
||||||
|
@ -56,6 +63,7 @@ class Clock(object):
|
||||||
msec(float): How long to wait between calls in milliseconds.
|
msec(float): How long to wait between calls in milliseconds.
|
||||||
"""
|
"""
|
||||||
call = task.LoopingCall(f)
|
call = task.LoopingCall(f)
|
||||||
|
call.clock = self._reactor
|
||||||
call.start(msec / 1000.0, now=False)
|
call.start(msec / 1000.0, now=False)
|
||||||
return call
|
return call
|
||||||
|
|
||||||
|
@ -73,7 +81,7 @@ class Clock(object):
|
||||||
callback(*args, **kwargs)
|
callback(*args, **kwargs)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
|
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
|
||||||
|
|
||||||
def cancel_call_later(self, timer, ignore_errs=False):
|
def cancel_call_later(self, timer, ignore_errs=False):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -13,14 +13,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer
|
||||||
from twisted.internet.defer import CancelledError
|
from twisted.internet.defer import CancelledError
|
||||||
from twisted.python import failure
|
from twisted.python import failure
|
||||||
|
|
||||||
from .logcontext import (
|
from .logcontext import (
|
||||||
PreserveLoggingContext, make_deferred_yieldable, run_in_background
|
PreserveLoggingContext, make_deferred_yieldable, run_in_background
|
||||||
)
|
)
|
||||||
from synapse.util import logcontext, unwrapFirstError
|
from synapse.util import logcontext, unwrapFirstError, Clock
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
@ -31,15 +31,6 @@ from six.moves import range
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def sleep(seconds):
|
|
||||||
d = defer.Deferred()
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
reactor.callLater(seconds, d.callback, seconds)
|
|
||||||
res = yield d
|
|
||||||
defer.returnValue(res)
|
|
||||||
|
|
||||||
|
|
||||||
class ObservableDeferred(object):
|
class ObservableDeferred(object):
|
||||||
"""Wraps a deferred object so that we can add observer deferreds. These
|
"""Wraps a deferred object so that we can add observer deferreds. These
|
||||||
observer deferreds do not affect the callback chain of the original
|
observer deferreds do not affect the callback chain of the original
|
||||||
|
@ -172,13 +163,18 @@ class Linearizer(object):
|
||||||
# do some work.
|
# do some work.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, name=None):
|
def __init__(self, name=None, clock=None):
|
||||||
if name is None:
|
if name is None:
|
||||||
self.name = id(self)
|
self.name = id(self)
|
||||||
else:
|
else:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.key_to_defer = {}
|
self.key_to_defer = {}
|
||||||
|
|
||||||
|
if not clock:
|
||||||
|
from twisted.internet import reactor
|
||||||
|
clock = Clock(reactor)
|
||||||
|
self._clock = clock
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def queue(self, key):
|
def queue(self, key):
|
||||||
# If there is already a deferred in the queue, we pull it out so that
|
# If there is already a deferred in the queue, we pull it out so that
|
||||||
|
@ -219,7 +215,7 @@ class Linearizer(object):
|
||||||
# the context manager, but it needs to happen while we hold the
|
# the context manager, but it needs to happen while we hold the
|
||||||
# lock, and the context manager's exit code must be synchronous,
|
# lock, and the context manager's exit code must be synchronous,
|
||||||
# so actually this is the only sensible place.
|
# so actually this is the only sensible place.
|
||||||
yield sleep(0)
|
yield self._clock.sleep(0)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.info("Acquired uncontended linearizer lock %r for key %r",
|
logger.info("Acquired uncontended linearizer lock %r for key %r",
|
||||||
|
@ -396,7 +392,7 @@ class DeferredTimeoutError(Exception):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
|
def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
|
||||||
"""
|
"""
|
||||||
Add a timeout to a deferred by scheduling it to be cancelled after
|
Add a timeout to a deferred by scheduling it to be cancelled after
|
||||||
timeout seconds.
|
timeout seconds.
|
||||||
|
@ -411,6 +407,7 @@ def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
|
||||||
Args:
|
Args:
|
||||||
deferred (defer.Deferred): deferred to be timed out
|
deferred (defer.Deferred): deferred to be timed out
|
||||||
timeout (Number): seconds to time out after
|
timeout (Number): seconds to time out after
|
||||||
|
reactor (twisted.internet.reactor): the Twisted reactor to use
|
||||||
|
|
||||||
on_timeout_cancel (callable): A callable which is called immediately
|
on_timeout_cancel (callable): A callable which is called immediately
|
||||||
after the deferred times out, and not if this deferred is
|
after the deferred times out, and not if this deferred is
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import threads, reactor
|
from twisted.internet import threads
|
||||||
|
|
||||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ class BackgroundFileConsumer(object):
|
||||||
Args:
|
Args:
|
||||||
file_obj (file): The file like object to write to. Closed when
|
file_obj (file): The file like object to write to. Closed when
|
||||||
finished.
|
finished.
|
||||||
|
reactor (twisted.internet.reactor): the Twisted reactor to use
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# For PushProducers pause if we have this many unwritten slices
|
# For PushProducers pause if we have this many unwritten slices
|
||||||
|
@ -34,9 +35,11 @@ class BackgroundFileConsumer(object):
|
||||||
# And resume once the size of the queue is less than this
|
# And resume once the size of the queue is less than this
|
||||||
_RESUME_ON_QUEUE_SIZE = 2
|
_RESUME_ON_QUEUE_SIZE = 2
|
||||||
|
|
||||||
def __init__(self, file_obj):
|
def __init__(self, file_obj, reactor):
|
||||||
self._file_obj = file_obj
|
self._file_obj = file_obj
|
||||||
|
|
||||||
|
self._reactor = reactor
|
||||||
|
|
||||||
# Producer we're registered with
|
# Producer we're registered with
|
||||||
self._producer = None
|
self._producer = None
|
||||||
|
|
||||||
|
@ -71,7 +74,10 @@ class BackgroundFileConsumer(object):
|
||||||
self._producer = producer
|
self._producer = producer
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
self._finished_deferred = run_in_background(
|
self._finished_deferred = run_in_background(
|
||||||
threads.deferToThread, self._writer
|
threads.deferToThreadPool,
|
||||||
|
self._reactor,
|
||||||
|
self._reactor.getThreadPool(),
|
||||||
|
self._writer,
|
||||||
)
|
)
|
||||||
if not streaming:
|
if not streaming:
|
||||||
self._producer.resumeProducing()
|
self._producer.resumeProducing()
|
||||||
|
@ -109,7 +115,7 @@ class BackgroundFileConsumer(object):
|
||||||
# producer.
|
# producer.
|
||||||
if self._producer and self._paused_producer:
|
if self._producer and self._paused_producer:
|
||||||
if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
|
if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
|
||||||
reactor.callFromThread(self._resume_paused_producer)
|
self._reactor.callFromThread(self._resume_paused_producer)
|
||||||
|
|
||||||
bytes = self._bytes_queue.get()
|
bytes = self._bytes_queue.get()
|
||||||
|
|
||||||
|
@ -121,7 +127,7 @@ class BackgroundFileConsumer(object):
|
||||||
# If its a pull producer then we need to explicitly ask for
|
# If its a pull producer then we need to explicitly ask for
|
||||||
# more stuff.
|
# more stuff.
|
||||||
if not self.streaming and self._producer:
|
if not self.streaming and self._producer:
|
||||||
reactor.callFromThread(self._producer.resumeProducing)
|
self._reactor.callFromThread(self._producer.resumeProducing)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._write_exception = e
|
self._write_exception = e
|
||||||
raise
|
raise
|
||||||
|
|
|
@ -17,7 +17,6 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import LimitExceededError
|
from synapse.api.errors import LimitExceededError
|
||||||
|
|
||||||
from synapse.util.async import sleep
|
|
||||||
from synapse.util.logcontext import (
|
from synapse.util.logcontext import (
|
||||||
run_in_background, make_deferred_yieldable,
|
run_in_background, make_deferred_yieldable,
|
||||||
PreserveLoggingContext,
|
PreserveLoggingContext,
|
||||||
|
@ -153,7 +152,7 @@ class _PerHostRatelimiter(object):
|
||||||
"Ratelimit [%s]: sleeping req",
|
"Ratelimit [%s]: sleeping req",
|
||||||
id(request_id),
|
id(request_id),
|
||||||
)
|
)
|
||||||
ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0)
|
ret_defer = run_in_background(self.clock.sleep, self.sleep_msec / 1000.0)
|
||||||
|
|
||||||
self.sleeping_requests.add(request_id)
|
self.sleeping_requests.add(request_id)
|
||||||
|
|
||||||
|
|
|
@ -19,10 +19,10 @@ import signedjson.sign
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.crypto import keyring
|
from synapse.crypto import keyring
|
||||||
from synapse.util import async, logcontext
|
from synapse.util import logcontext, Clock
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
from tests import unittest, utils
|
from tests import unittest, utils
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
|
|
||||||
class MockPerspectiveServer(object):
|
class MockPerspectiveServer(object):
|
||||||
|
@ -118,6 +118,7 @@ class KeyringTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_verify_json_objects_for_server_awaits_previous_requests(self):
|
def test_verify_json_objects_for_server_awaits_previous_requests(self):
|
||||||
|
clock = Clock(reactor)
|
||||||
key1 = signedjson.key.generate_signing_key(1)
|
key1 = signedjson.key.generate_signing_key(1)
|
||||||
|
|
||||||
kr = keyring.Keyring(self.hs)
|
kr = keyring.Keyring(self.hs)
|
||||||
|
@ -167,7 +168,7 @@ class KeyringTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# wait a tick for it to send the request to the perspectives server
|
# wait a tick for it to send the request to the perspectives server
|
||||||
# (it first tries the datastore)
|
# (it first tries the datastore)
|
||||||
yield async.sleep(1) # XXX find out why this takes so long!
|
yield clock.sleep(1) # XXX find out why this takes so long!
|
||||||
self.http_client.post_json.assert_called_once()
|
self.http_client.post_json.assert_called_once()
|
||||||
|
|
||||||
self.assertIs(LoggingContext.current_context(), context_11)
|
self.assertIs(LoggingContext.current_context(), context_11)
|
||||||
|
@ -183,7 +184,7 @@ class KeyringTestCase(unittest.TestCase):
|
||||||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1)],
|
[("server10", json1)],
|
||||||
)
|
)
|
||||||
yield async.sleep(1)
|
yield clock.sleep(1)
|
||||||
self.http_client.post_json.assert_not_called()
|
self.http_client.post_json.assert_not_called()
|
||||||
res_deferreds_2[0].addBoth(self.check_context, None)
|
res_deferreds_2[0].addBoth(self.check_context, None)
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from synapse.rest.client.transactions import HttpTransactionCache
|
from synapse.rest.client.transactions import HttpTransactionCache
|
||||||
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
|
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer, reactor
|
||||||
from mock import Mock, call
|
from mock import Mock, call
|
||||||
|
|
||||||
from synapse.util import async
|
from synapse.util import Clock
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.utils import MockClock
|
from tests.utils import MockClock
|
||||||
|
@ -46,7 +46,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
||||||
def test_logcontexts_with_async_result(self):
|
def test_logcontexts_with_async_result(self):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def cb():
|
def cb():
|
||||||
yield async.sleep(0)
|
yield Clock(reactor).sleep(0)
|
||||||
defer.returnValue("yay")
|
defer.returnValue("yay")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
from synapse.rest.media.v1._base import FileInfo
|
from synapse.rest.media.v1._base import FileInfo
|
||||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||||
|
@ -38,6 +38,7 @@ class MediaStorageTests(unittest.TestCase):
|
||||||
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
|
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
|
||||||
|
|
||||||
hs = Mock()
|
hs = Mock()
|
||||||
|
hs.get_reactor = Mock(return_value=reactor)
|
||||||
hs.config.media_store_path = self.primary_base_path
|
hs.config.media_store_path = self.primary_base_path
|
||||||
|
|
||||||
storage_providers = [FileStorageProviderBackend(
|
storage_providers = [FileStorageProviderBackend(
|
||||||
|
@ -46,7 +47,7 @@ class MediaStorageTests(unittest.TestCase):
|
||||||
|
|
||||||
self.filepaths = MediaFilePaths(self.primary_base_path)
|
self.filepaths = MediaFilePaths(self.primary_base_path)
|
||||||
self.media_storage = MediaStorage(
|
self.media_storage = MediaStorage(
|
||||||
self.primary_base_path, self.filepaths, storage_providers,
|
hs, self.primary_base_path, self.filepaths, storage_providers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
|
|
@ -30,7 +30,7 @@ class FileConsumerTests(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_pull_consumer(self):
|
def test_pull_consumer(self):
|
||||||
string_file = StringIO()
|
string_file = StringIO()
|
||||||
consumer = BackgroundFileConsumer(string_file)
|
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
producer = DummyPullProducer()
|
producer = DummyPullProducer()
|
||||||
|
@ -54,7 +54,7 @@ class FileConsumerTests(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_push_consumer(self):
|
def test_push_consumer(self):
|
||||||
string_file = BlockingStringWrite()
|
string_file = BlockingStringWrite()
|
||||||
consumer = BackgroundFileConsumer(string_file)
|
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
producer = NonCallableMock(spec_set=[])
|
producer = NonCallableMock(spec_set=[])
|
||||||
|
@ -80,7 +80,7 @@ class FileConsumerTests(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_push_producer_feedback(self):
|
def test_push_producer_feedback(self):
|
||||||
string_file = BlockingStringWrite()
|
string_file = BlockingStringWrite()
|
||||||
consumer = BackgroundFileConsumer(string_file)
|
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
|
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
|
||||||
|
|
|
@ -12,10 +12,11 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from synapse.util import async, logcontext
|
|
||||||
|
from synapse.util import logcontext, Clock
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from six.moves import range
|
from six.moves import range
|
||||||
|
@ -53,7 +54,7 @@ class LinearizerTestCase(unittest.TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
logcontext.LoggingContext.current_context(), lc)
|
logcontext.LoggingContext.current_context(), lc)
|
||||||
if sleep:
|
if sleep:
|
||||||
yield async.sleep(0)
|
yield Clock(reactor).sleep(0)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
logcontext.LoggingContext.current_context(), lc)
|
logcontext.LoggingContext.current_context(), lc)
|
||||||
|
|
|
@ -3,8 +3,7 @@ from twisted.internet import defer
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
from .. import unittest
|
from .. import unittest
|
||||||
|
|
||||||
from synapse.util.async import sleep
|
from synapse.util import logcontext, Clock
|
||||||
from synapse.util import logcontext
|
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,18 +21,20 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_sleep(self):
|
def test_sleep(self):
|
||||||
|
clock = Clock(reactor)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def competing_callback():
|
def competing_callback():
|
||||||
with LoggingContext() as competing_context:
|
with LoggingContext() as competing_context:
|
||||||
competing_context.request = "competing"
|
competing_context.request = "competing"
|
||||||
yield sleep(0)
|
yield clock.sleep(0)
|
||||||
self._check_test_key("competing")
|
self._check_test_key("competing")
|
||||||
|
|
||||||
reactor.callLater(0, competing_callback)
|
reactor.callLater(0, competing_callback)
|
||||||
|
|
||||||
with LoggingContext() as context_one:
|
with LoggingContext() as context_one:
|
||||||
context_one.request = "one"
|
context_one.request = "one"
|
||||||
yield sleep(0)
|
yield clock.sleep(0)
|
||||||
self._check_test_key("one")
|
self._check_test_key("one")
|
||||||
|
|
||||||
def _test_run_in_background(self, function):
|
def _test_run_in_background(self, function):
|
||||||
|
@ -87,7 +88,7 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||||
def test_run_in_background_with_blocking_fn(self):
|
def test_run_in_background_with_blocking_fn(self):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def blocking_function():
|
def blocking_function():
|
||||||
yield sleep(0)
|
yield Clock(reactor).sleep(0)
|
||||||
|
|
||||||
return self._test_run_in_background(blocking_function)
|
return self._test_run_in_background(blocking_function)
|
||||||
|
|
||||||
|
|
|
@ -37,11 +37,15 @@ USE_POSTGRES_FOR_TESTS = False
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None,
|
||||||
|
**kargs):
|
||||||
"""Setup a homeserver suitable for running tests against. Keyword arguments
|
"""Setup a homeserver suitable for running tests against. Keyword arguments
|
||||||
are passed to the Homeserver constructor. If no datastore is supplied a
|
are passed to the Homeserver constructor. If no datastore is supplied a
|
||||||
datastore backed by an in-memory sqlite db will be given to the HS.
|
datastore backed by an in-memory sqlite db will be given to the HS.
|
||||||
"""
|
"""
|
||||||
|
if reactor is None:
|
||||||
|
from twisted.internet import reactor
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = Mock()
|
config = Mock()
|
||||||
config.signing_key = [MockKey()]
|
config.signing_key = [MockKey()]
|
||||||
|
@ -110,6 +114,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||||
database_engine=db_engine,
|
database_engine=db_engine,
|
||||||
room_list_handler=object(),
|
room_list_handler=object(),
|
||||||
tls_server_context_factory=Mock(),
|
tls_server_context_factory=Mock(),
|
||||||
|
reactor=reactor,
|
||||||
**kargs
|
**kargs
|
||||||
)
|
)
|
||||||
db_conn = hs.get_db_conn()
|
db_conn = hs.get_db_conn()
|
||||||
|
|
Loading…
Reference in New Issue