Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.13.0

This commit is contained in:
Erik Johnston 2016-02-10 14:12:48 +00:00
commit e66d0bd03a
91 changed files with 1796 additions and 1101 deletions

View File

@ -51,3 +51,6 @@ Steven Hammerton <steven.hammerton at openmarket.com>
Mads Robin Christensen <mads at v42 dot dk>
* CentOS 7 installation instructions.
Florent Violleau <floviolleau at gmail dot com>
* Add Raspberry Pi installation instructions and general troubleshooting items

View File

@ -125,6 +125,15 @@ Installing prerequisites on Mac OS X::
sudo easy_install pip
sudo pip install virtualenv
Installing prerequisites on Raspbian::
sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev
sudo pip install --upgrade pip
sudo pip install --upgrade ndg-httpsclient
sudo pip install --upgrade virtualenv
To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse
@ -310,6 +319,18 @@ may need to manually upgrade it::
sudo pip install --upgrade pip
Installing may fail with ``Could not find any downloads that satisfy the requirement pymacaroons-pynacl (from matrix-synapse==0.12.0)``.
You can fix this by manually upgrading pip and virtualenv::
sudo pip install --upgrade virtualenv
You can next rerun ``virtualenv -p python2.7 synapse`` to update the virtual env.
Installing may fail during installing virtualenv with ``InsecurePlatformWarning: A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.``
You can fix this by manually installing ndg-httpsclient::
pip install --upgrade ndg-httpsclient
Installing may fail with ``mock requires setuptools>=17.1. Aborting installation``.
You can fix this by upgrading setuptools::

24
scripts-dev/dump_macaroon.py Executable file
View File

@ -0,0 +1,24 @@
#!/usr/bin/env python2
import pymacaroons
import sys
if len(sys.argv) == 1:
sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],))
sys.exit(1)
macaroon_string = sys.argv[1]
key = sys.argv[2] if len(sys.argv) > 2 else None
macaroon = pymacaroons.Macaroon.deserialize(macaroon_string)
print macaroon.inspect()
print ""
verifier = pymacaroons.Verifier()
verifier.satisfy_general(lambda c: True)
try:
verifier.verify(macaroon, key)
print "Signature is correct"
except Exception as e:
print e.message

View File

@ -0,0 +1,62 @@
#! /usr/bin/python
import ast
import argparse
import os
import sys
import yaml
PATTERNS_V1 = []
PATTERNS_V2 = []
RESULT = {
"v1": PATTERNS_V1,
"v2": PATTERNS_V2,
}
class CallVisitor(ast.NodeVisitor):
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
name = node.func.id
else:
return
if name == "client_path_patterns":
PATTERNS_V1.append(node.args[0].s)
elif name == "client_v2_patterns":
PATTERNS_V2.append(node.args[0].s)
def find_patterns_in_code(input_code):
input_ast = ast.parse(input_code)
visitor = CallVisitor()
visitor.visit(input_ast)
def find_patterns_in_file(filepath):
with open(filepath) as f:
find_patterns_in_code(f.read())
parser = argparse.ArgumentParser(description='Find url patterns.')
parser.add_argument(
"directories", nargs='+', metavar="DIR",
help="Directories to search for definitions"
)
args = parser.parse_args()
for directory in args.directories:
for root, dirs, files in os.walk(directory):
for filename in files:
if filename.endswith(".py"):
filepath = os.path.join(root, filename)
find_patterns_in_file(filepath)
PATTERNS_V1.sort()
PATTERNS_V2.sort()
yaml.dump(RESULT, sys.stdout, default_flow_style=False)

View File

@ -16,3 +16,4 @@ ignore =
[flake8]
max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.

View File

@ -24,6 +24,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from unpaddedbase64 import decode_base64
import logging
@ -529,7 +530,8 @@ class Auth(object):
default=[""]
)[0]
if user and access_token and ip_addr:
self.store.insert_client_ip(
preserve_context_over_fn(
self.store.insert_client_ip,
user=user,
access_token=access_token,
ip=ip_addr,
@ -696,6 +698,7 @@ class Auth(object):
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
if not ret:
logger.warn("Unrecognised access token - not in store: %s" % (token,))
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
@ -713,6 +716,7 @@ class Auth(object):
token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token)
if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,))
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",

View File

@ -23,5 +23,6 @@ WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/v1"
MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View File

@ -12,3 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.dont_write_bytecode = True
from synapse.python_dependencies import (
check_requirements, MissingRequirementError
) # NOQA
try:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)

View File

@ -14,27 +14,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from synapse.rest import ClientRestResource
import synapse
import contextlib
import logging
import os
import re
import resource
import subprocess
import sys
import time
from synapse.config._base import ConfigError
sys.dont_write_bytecode = True
from synapse.python_dependencies import (
check_requirements, DEPENDENCY_LINKS, MissingRequirementError
check_requirements, DEPENDENCY_LINKS
)
if __name__ == '__main__':
try:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)
from synapse.rest import ClientRestResource
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException
@ -60,7 +56,7 @@ from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.api.urls import (
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
SERVER_KEY_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
SERVER_KEY_V2_PREFIX,
)
from synapse.config.homeserver import HomeServerConfig
@ -73,17 +69,6 @@ from synapse import events
from daemonize import Daemonize
import synapse
import contextlib
import logging
import os
import re
import resource
import subprocess
import time
logger = logging.getLogger("synapse.app.homeserver")
@ -163,8 +148,10 @@ class SynapseHomeServer(HomeServer):
})
if name in ["media", "federation", "client"]:
media_repo = MediaRepositoryResource(self)
resources.update({
MEDIA_PREFIX: MediaRepositoryResource(self),
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
),
@ -366,11 +353,20 @@ def setup(config_options):
Returns:
HomeServer
"""
try:
config = HomeServerConfig.load_config(
"Synapse Homeserver",
config_options,
generate_section="Homeserver"
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
if not config:
# If a config isn't returned, and an exception isn't raised, we're just
# generating config files and shouldn't try to continue.
sys.exit(0)
config.setup_logging()
@ -690,8 +686,8 @@ def run(hs):
stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users()
all_rooms = yield hs.get_datastore().get_rooms(False)
stats["total_room_count"] = len(all_rooms)
room_count = yield hs.get_datastore().get_room_count()
stats["total_room_count"] = room_count
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
daily_messages = yield hs.get_datastore().count_daily_messages()
@ -713,6 +709,8 @@ def run(hs):
phone_home_task.start(60 * 60 * 24, now=False)
def in_thread():
# Uncomment to enable tracing of log context changes.
# sys.settrace(logcontext_tracer)
with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit)
reactor.run()

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.config._base import ConfigError
if __name__ == "__main__":
import sys
@ -21,7 +22,11 @@ if __name__ == "__main__":
if action == "read":
key = sys.argv[2]
try:
config = HomeServerConfig.load_config("", sys.argv[3:])
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
print getattr(config, key)
sys.exit(0)

View File

@ -17,7 +17,6 @@ import argparse
import errno
import os
import yaml
import sys
from textwrap import dedent
@ -136,13 +135,20 @@ class Config(object):
results.append(getattr(cls, name)(self, *args, **kargs))
return results
def generate_config(self, config_dir_path, server_name, report_stats=None):
def generate_config(
self,
config_dir_path,
server_name,
is_generating_file,
report_stats=None,
):
default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config",
config_dir_path=config_dir_path,
server_name=server_name,
is_generating_file=is_generating_file,
report_stats=report_stats,
))
@ -244,8 +250,10 @@ class Config(object):
server_name = config_args.server_name
if not server_name:
print "Must specify a server_name to a generate config for."
sys.exit(1)
raise ConfigError(
"Must specify a server_name to a generate config for."
" Pass -H server.name."
)
if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file:
@ -253,6 +261,7 @@ class Config(object):
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
is_generating_file=True
)
obj.invoke_all("generate_files", config)
config_file.write(config_bytes)
@ -266,7 +275,7 @@ class Config(object):
"If this server name is incorrect, you will need to"
" regenerate the SSL certificates"
)
sys.exit(0)
return
else:
print (
"Config file %r already exists. Generating any missing key"
@ -302,25 +311,25 @@ class Config(object):
specified_config.update(yaml_config)
if "server_name" not in specified_config:
sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n")
sys.exit(1)
raise ConfigError(MISSING_SERVER_NAME)
server_name = specified_config["server_name"]
_, config = obj.generate_config(
config_dir_path=config_dir_path,
server_name=server_name
server_name=server_name,
is_generating_file=False,
)
config.pop("log_config")
config.update(specified_config)
if "report_stats" not in config:
sys.stderr.write(
"\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
MISSING_REPORT_STATS_SPIEL + "\n")
sys.exit(1)
raise ConfigError(
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
MISSING_REPORT_STATS_SPIEL
)
if generate_keys:
obj.invoke_all("generate_files", config)
sys.exit(0)
return
obj.invoke_all("read_config", config)

View File

@ -22,8 +22,14 @@ from signedjson.key import (
read_signing_keys, write_signing_keys, NACL_ED25519
)
from unpaddedbase64 import decode_base64
from synapse.util.stringutils import random_string_with_symbols
import os
import hashlib
import logging
logger = logging.getLogger(__name__)
class KeyConfig(Config):
@ -40,9 +46,29 @@ class KeyConfig(Config):
config["perspectives"]
)
def default_config(self, config_dir_path, server_name, **kwargs):
self.macaroon_secret_key = config.get(
"macaroon_secret_key", self.registration_shared_secret
)
if not self.macaroon_secret_key:
# Unfortunately, there are people out there that don't have this
# set. Lets just be "nice" and derive one from their secret key.
logger.warn("Config is missing missing macaroon_secret_key")
seed = self.signing_key[0].seed
self.macaroon_secret_key = hashlib.sha256(seed)
def default_config(self, config_dir_path, server_name, is_generating_file=False,
**kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
if is_generating_file:
macaroon_secret_key = random_string_with_symbols(50)
else:
macaroon_secret_key = None
return """\
macaroon_secret_key: "%(macaroon_secret_key)s"
## Signing Keys ##
# Path to the signing key to sign messages with

View File

@ -23,22 +23,23 @@ from distutils.util import strtobool
class RegistrationConfig(Config):
def read_config(self, config):
self.disable_registration = not bool(
self.enable_registration = bool(
strtobool(str(config["enable_registration"]))
)
if "disable_registration" in config:
self.disable_registration = bool(
self.enable_registration = not bool(
strtobool(str(config["disable_registration"]))
)
self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
self.allow_guest_access = config.get("allow_guest_access", False)
def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50)
macaroon_secret_key = random_string_with_symbols(50)
return """\
## Registration ##
@ -49,8 +50,6 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
macaroon_secret_key: "%(macaroon_secret_key)s"
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12.
@ -60,6 +59,12 @@ class RegistrationConfig(Config):
# participate in rooms hosted on this server which have been made
# accessible to anonymous users.
allow_guest_access: False
# The list of identity servers trusted to verify third party
# identifiers by this server.
trusted_third_party_id_servers:
- matrix.org
- vector.im
""" % locals()
def add_arguments(self, parser):
@ -71,6 +76,6 @@ class RegistrationConfig(Config):
def read_arguments(self, args):
if args.enable_registration is not None:
self.disable_registration = not bool(
self.enable_registration = bool(
strtobool(str(args.enable_registration))
)

View File

@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn
)
from twisted.internet import defer
@ -142,6 +146,8 @@ class Keyring(object):
for server_name, _ in server_and_json
}
with PreserveLoggingContext():
# We want to wait for any previous lookups to complete before
# proceeding.
wait_on_deferred = self.wait_for_previous_lookups(
@ -175,7 +181,8 @@ class Keyring(object):
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
handle_key_deferred(
preserve_context_over_fn(
handle_key_deferred,
group_id_to_group[g_id],
deferreds[g_id],
)
@ -198,12 +205,13 @@ class Keyring(object):
if server_name in self.key_downloads
]
if wait_on:
with PreserveLoggingContext():
yield defer.DeferredList(wait_on)
else:
break
for server_name, deferred in server_to_deferred.items():
d = ObservableDeferred(deferred)
d = ObservableDeferred(preserve_context_over_deferred(deferred))
self.key_downloads[server_name] = d
def rm(r, server_name):
@ -244,6 +252,7 @@ class Keyring(object):
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
with PreserveLoggingContext():
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
@ -504,7 +513,7 @@ class Keyring(object):
yield defer.gatherResults(
[
self.store_keys(
preserve_fn(self.store_keys)(
server_name=key_server_name,
from_server=server_name,
verify_keys=verify_keys,
@ -573,7 +582,7 @@ class Keyring(object):
yield defer.gatherResults(
[
self.store.store_server_keys_json(
preserve_fn(self.store.store_server_keys_json)(
server_name=server_name,
key_id=key_id,
from_server=server_name,
@ -675,7 +684,7 @@ class Keyring(object):
# TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults(
[
self.store.store_server_verify_key(
preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()

View File

@ -20,3 +20,4 @@ class EventContext(object):
self.current_state = current_state
self.state_group = None
self.rejected = False
self.push_actions = []

View File

@ -57,7 +57,7 @@ class FederationClient(FederationBase):
cache_name="get_pdu_cache",
clock=self._clock,
max_len=1000,
expiry_ms=120*1000,
expiry_ms=120 * 1000,
reset_expiry_on_get=False,
)

View File

@ -126,10 +126,8 @@ class FederationServer(FederationBase):
results = []
for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu)
try:
yield d
yield self._handle_new_pdu(transaction.origin, pdu)
results.append({})
except FederationError as e:
self.send_failure(e, transaction.origin)

View File

@ -103,7 +103,6 @@ class TransactionQueue(object):
else:
return not destination.startswith("localhost")
@defer.inlineCallbacks
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
@ -141,8 +140,6 @@ class TransactionQueue(object):
deferreds.append(deferred)
yield defer.DeferredList(deferreds, consumeErrors=True)
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination

View File

@ -53,25 +53,10 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
def _filter_events_for_clients(self, user_tuples, events):
def _filter_events_for_clients(self, user_tuples, events, event_id_to_state):
""" Returns dict of user_id -> list of events that user is allowed to
see.
"""
# If there is only one user, just get the state for that one user,
# otherwise just get all the state.
if len(user_tuples) == 1:
types = (
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_tuples[0][0]),
)
else:
types = None
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=types
)
forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room(
room_id,
@ -135,7 +120,17 @@ class BaseHandler(object):
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_peeking=False):
# Assumes that user has at some point joined the room if not is_guest.
res = yield self._filter_events_for_clients([(user_id, is_peeking)], events)
types = (
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=types
)
res = yield self._filter_events_for_clients(
[(user_id, is_peeking)], events, event_id_to_state
)
defer.returnValue(res.get(user_id, []))
def ratelimit(self, user_id):
@ -147,7 +142,7 @@ class BaseHandler(object):
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000*(time_allowed - time_now)),
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
@defer.inlineCallbacks
@ -269,13 +264,13 @@ class BaseHandler(object):
"You don't have permission to redact events"
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, self
event, context, self
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
destinations = set()
@ -293,19 +288,11 @@ class BaseHandler(object):
with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners.
notify_d = self.notifier.on_new_room_event(
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
notify_d.addErrback(log_failure)
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)

View File

@ -175,8 +175,8 @@ class DirectoryHandler(BaseHandler):
# If this server is in the list of servers, return it first.
if self.server_name in servers:
servers = (
[self.server_name]
+ [s for s in servers if s != self.server_name]
[self.server_name] +
[s for s in servers if s != self.server_name]
)
else:
servers = list(servers)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.types import UserID
from synapse.events.utils import serialize_event
from synapse.util.logcontext import preserve_context_over_fn
from ._base import BaseHandler
@ -29,11 +30,17 @@ logger = logging.getLogger(__name__)
def started_user_eventstream(distributor, user):
return distributor.fire("started_user_eventstream", user)
return preserve_context_over_fn(
distributor.fire,
"started_user_eventstream", user
)
def stopped_user_eventstream(distributor, user):
return distributor.fire("stopped_user_eventstream", user)
return preserve_context_over_fn(
distributor.fire,
"stopped_user_eventstream", user
)
class EventStreamHandler(BaseHandler):
@ -130,7 +137,7 @@ class EventStreamHandler(BaseHandler):
# Add some randomness to this value to try and mitigate against
# thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout,

View File

@ -221,19 +221,11 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
prev_state = context.current_state.get((event.type, event.state_key))
@ -244,12 +236,6 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
if not backfilled and not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, self
)
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
event_to_state = yield self.store.get_state_for_events(
@ -643,19 +629,11 @@ class FederationHandler(BaseHandler):
)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[joinee]
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
room_queue = self.room_queues[room_id]
@ -730,18 +708,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key)
@ -811,19 +781,11 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(event.state_key)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[target_user],
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
defer.returnValue(event)
@defer.inlineCallbacks
@ -948,18 +910,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
new_pdu = event
destinations = set()
@ -1113,6 +1067,12 @@ class FederationHandler(BaseHandler):
auth_events=auth_events,
)
if not backfilled and not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, context, self
)
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,

View File

@ -36,14 +36,15 @@ class IdentityHandler(BaseHandler):
self.http_client = hs.get_simple_http_client()
self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
self.trust_any_id_server_just_for_testing_do_not_use = (
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
)
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
yield run_on_reactor()
# XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org', 'vector.im']
if 'id_server' in creds:
id_server = creds['id_server']
elif 'idServer' in creds:
@ -58,7 +59,16 @@ class IdentityHandler(BaseHandler):
else:
raise SynapseError(400, "No client_secret in creds")
if id_server not in trustedIdServers:
if id_server not in self.trusted_id_servers:
if self.trust_any_id_server_just_for_testing_do_not_use:
logger.warn(
"Trusting untrustworthy ID server %r even though it isn't"
" in the trusted id list for testing because"
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
" is set in the config",
id_server,
)
else:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', id_server)
defer.returnValue(None)

View File

@ -34,7 +34,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
# Don't bother bumping "last active" time if it differs by less than 60 seconds
LAST_ACTIVE_GRANULARITY = 60*1000
LAST_ACTIVE_GRANULARITY = 60 * 1000
# Keep no more than this number of offline serial revisions
MAX_OFFLINE_SERIALS = 1000
@ -378,9 +378,9 @@ class PresenceHandler(BaseHandler):
was_polling = target_user in self._user_cachemap
if now_online and not was_polling:
self.start_polling_presence(target_user, state=state)
yield self.start_polling_presence(target_user, state=state)
elif not now_online and was_polling:
self.stop_polling_presence(target_user)
yield self.stop_polling_presence(target_user)
# TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time
@ -394,6 +394,7 @@ class PresenceHandler(BaseHandler):
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
return
with PreserveLoggingContext():
self.changed_presencelike_data(user, {"last_active": now})
def get_joined_rooms_for_user(self, user):
@ -466,6 +467,7 @@ class PresenceHandler(BaseHandler):
local_user, room_ids=[room_id], add_to_cache=False
)
with PreserveLoggingContext():
self.push_update_to_local_and_remote(
observed_user=local_user,
users_to_push=[user],
@ -556,7 +558,7 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string()
)
self.start_polling_presence(
yield self.start_polling_presence(
observer_user, target_user=observed_user
)

View File

@ -21,7 +21,6 @@ from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
)
from ._base import BaseHandler
import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
@ -45,6 +44,8 @@ class RegistrationHandler(BaseHandler):
self.distributor.declare("registered_user")
self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None):
yield run_on_reactor()
@ -91,7 +92,7 @@ class RegistrationHandler(BaseHandler):
Args:
localpart : The local part of the user ID to register. If None,
one will be randomly generated.
one will be generated.
password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
@ -108,6 +109,18 @@ class RegistrationHandler(BaseHandler):
if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token)
was_guest = guest_access_token is not None
if not was_guest:
try:
int(localpart)
raise RegistrationError(
400,
"Numeric user IDs are reserved for guest users."
)
except ValueError:
pass
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@ -118,38 +131,36 @@ class RegistrationHandler(BaseHandler):
user_id=user_id,
token=token,
password_hash=password_hash,
was_guest=guest_access_token is not None,
was_guest=was_guest,
make_guest=make_guest,
)
yield registered_user(self.distributor, user)
else:
# autogen a random user ID
# autogen a sequential user ID
attempts = 0
user_id = None
token = None
while not user_id:
try:
localpart = self._generate_user_id()
user = None
while not user:
localpart = yield self._generate_user_id(attempts > 0)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
if generate_token:
token = self.auth_handler().generate_access_token(user_id)
try:
yield self.store.register(
user_id=user_id,
token=token,
password_hash=password_hash)
yield registered_user(self.distributor, user)
password_hash=password_hash,
make_guest=make_guest
)
except SynapseError:
# if user id is taken, just generate another
user_id = None
token = None
attempts += 1
if attempts > 5:
raise RegistrationError(
500, "Cannot generate user ID.")
yield registered_user(self.distributor, user)
# We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding
@ -175,7 +186,7 @@ class RegistrationHandler(BaseHandler):
token=token,
password_hash=""
)
registered_user(self.distributor, user)
yield registered_user(self.distributor, user)
defer.returnValue((user_id, token))
@defer.inlineCallbacks
@ -281,8 +292,16 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE
)
def _generate_user_id(self):
return "-" + stringutils.random_string(18)
@defer.inlineCallbacks
def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None:
self._next_generated_user_id = (
yield self.store.find_next_generated_user_id_localpart()
)
id = self._next_generated_user_id
self._next_generated_user_id += 1
defer.returnValue(str(id))
@defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response):

View File

@ -18,13 +18,14 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset,
)
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
@ -46,11 +47,17 @@ def collect_presencelike_data(distributor, user, content):
def user_left_room(distributor, user, room_id):
return distributor.fire("user_left_room", user=user, room_id=room_id)
return preserve_context_over_fn(
distributor.fire,
"user_left_room", user=user, room_id=room_id
)
def user_joined_room(distributor, user, room_id):
return distributor.fire("user_joined_room", user=user, room_id=room_id)
return preserve_context_over_fn(
distributor.fire,
"user_joined_room", user=user, room_id=room_id
)
class RoomCreationHandler(BaseHandler):
@ -876,39 +883,71 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def get_public_room_list(self):
chunk = yield self.store.get_rooms(is_public=True)
room_members = yield defer.gatherResults(
[
self.store.get_users_in_room(room["room_id"])
for room in chunk
],
consumeErrors=True,
).addErrback(unwrapFirstError)
avatar_urls = yield defer.gatherResults(
[
self.get_room_avatar_url(room["room_id"])
for room in chunk
],
consumeErrors=True,
).addErrback(unwrapFirstError)
for i, room in enumerate(chunk):
room["num_joined_members"] = len(room_members[i])
if avatar_urls[i]:
room["avatar_url"] = avatar_urls[i]
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
room_ids = yield self.store.get_public_room_ids()
@defer.inlineCallbacks
def get_room_avatar_url(self, room_id):
event = yield self.hs.get_state_handler().get_current_state(
room_id, "m.room.avatar"
def handle_room(room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
if not aliases:
defer.returnValue(None)
state = yield self.state_handler.get_current_state(room_id)
result = {"aliases": aliases, "room_id": room_id}
name_event = state.get((EventTypes.Name, ""), None)
if name_event:
name = name_event.content.get("name", None)
if name:
result["name"] = name
topic_event = state.get((EventTypes.Topic, ""), None)
if topic_event:
topic = topic_event.content.get("topic", None)
if topic:
result["topic"] = topic
canonical_event = state.get((EventTypes.CanonicalAlias, ""), None)
if canonical_event:
canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias:
result["canonical_alias"] = canonical_alias
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable"
guest_event = state.get((EventTypes.GuestAccess, ""), None)
guest = None
if guest_event:
guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join"
avatar_event = state.get(("m.room.avatar", ""), None)
if avatar_event:
avatar_url = avatar_event.content.get("url", None)
if avatar_url:
result["avatar_url"] = avatar_url
result["num_joined_members"] = sum(
1 for (event_type, _), ev in state.items()
if event_type == EventTypes.Member and ev.membership == Membership.JOIN
)
if event and "url" in event.content:
defer.returnValue(event.content["url"])
defer.returnValue(result)
result = []
for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)):
chunk_result = yield defer.gatherResults([
handle_room(room_id)
for room_id in chunk
], consumeErrors=True).addErrback(unwrapFirstError)
result.extend(v for v in chunk_result if v)
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": result})
class RoomContextHandler(BaseHandler):
@ -927,7 +966,7 @@ class RoomContextHandler(BaseHandler):
Returns:
dict, or None if the event isn't found
"""
before_limit = math.floor(limit/2.)
before_limit = math.floor(limit / 2.)
after_limit = limit - before_limit
now_token = yield self.hs.get_event_sources().get_current_token()
@ -997,6 +1036,11 @@ class RoomEventSource(object):
to_key = yield self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
logger.warn("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,)
app_service = yield self.store.get_app_service_by_user_id(
user.to_string()
)
@ -1008,15 +1052,30 @@ class RoomEventSource(object):
limit=limit,
)
else:
events, end_key = yield self.store.get_room_events_stream(
user_id=user.to_string(),
room_events = yield self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key
)
room_to_events = yield self.store.get_room_events_stream_for_rooms(
room_ids=room_ids,
from_key=from_key,
to_key=to_key,
limit=limit,
room_ids=room_ids,
is_guest=is_guest,
limit=limit or 10,
)
events = list(room_events)
events.extend(e for evs, _ in room_to_events.values() for e in evs)
events.sort(key=lambda e: e.internal_metadata.order)
if limit:
events[:] = events[:limit]
if events:
end_key = events[-1].internal_metadata.after
else:
end_key = to_key
defer.returnValue((events, end_key))
def get_current_key(self, direction='f'):

View File

@ -18,11 +18,14 @@ from ._base import BaseHandler
from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError
from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.metrics import Measure
from twisted.internet import defer
import collections
import logging
import itertools
logger = logging.getLogger(__name__)
@ -139,6 +142,15 @@ class SyncHandler(BaseHandler):
A Deferred SyncResult.
"""
context = LoggingContext.current_context()
if context:
if since_token is None:
context.tag = "initial_sync"
elif full_state:
context.tag = "full_state_sync"
else:
context.tag = "incremental_sync"
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
@ -167,18 +179,6 @@ class SyncHandler(BaseHandler):
else:
return self.incremental_sync_with_gap(sync_config, since_token)
def last_read_event_id_for_room_and_user(self, room_id, user_id, ephemeral_by_room):
if room_id not in ephemeral_by_room:
return None
for e in ephemeral_by_room[room_id]:
if e['type'] != 'm.receipt':
continue
for receipt_event_id, val in e['content'].items():
if 'm.read' in val:
if user_id in val['m.read']:
return receipt_event_id
return None
@defer.inlineCallbacks
def full_state_sync(self, sync_config, timeline_since_token):
"""Get a sync for a client which is starting without any state.
@ -228,9 +228,14 @@ class SyncHandler(BaseHandler):
invited = []
archived = []
deferreds = []
for event in room_list:
room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)]
for room_list_chunk in room_list_chunks:
for event in room_list_chunk:
if event.membership == Membership.JOIN:
room_sync_deferred = self.full_state_sync_for_joined_room(
room_sync_deferred = preserve_fn(
self.full_state_sync_for_joined_room
)(
room_id=event.room_id,
sync_config=sync_config,
now_token=now_token,
@ -251,7 +256,9 @@ class SyncHandler(BaseHandler):
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
room_sync_deferred = self.full_state_sync_for_archived_room(
room_sync_deferred = preserve_fn(
self.full_state_sync_for_archived_room
)(
sync_config=sync_config,
room_id=event.room_id,
leave_event_id=event.event_id,
@ -305,7 +312,6 @@ class SyncHandler(BaseHandler):
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
all_ephemeral_by_room=ephemeral_by_room,
batch=batch,
full_state=True,
)
@ -355,6 +361,7 @@ class SyncHandler(BaseHandler):
typing events for that room.
"""
with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0"
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
@ -438,13 +445,6 @@ class SyncHandler(BaseHandler):
)
now_token = now_token.copy_and_replace("presence_key", presence_key)
# We now fetch all ephemeral events for this room in order to get
# this users current read receipt. This could almost certainly be
# optimised.
_, all_ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token
)
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token, since_token
)
@ -478,7 +478,7 @@ class SyncHandler(BaseHandler):
)
# Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_room_changes_for_user(
rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
@ -576,7 +576,6 @@ class SyncHandler(BaseHandler):
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
all_ephemeral_by_room=all_ephemeral_by_room,
batch=batch,
full_state=full_state,
)
@ -606,6 +605,7 @@ class SyncHandler(BaseHandler):
"""
:returns a Deferred TimelineBatch
"""
with Measure(self.clock, "load_filtered_recents"):
filtering_factor = 2
timeline_limit = sync_config.filter_collection.timeline_limit()
load_limit = max(timeline_limit * filtering_factor, 10)
@ -613,7 +613,10 @@ class SyncHandler(BaseHandler):
room_key = now_token.room_key
end_key = room_key
limited = recents is None or newly_joined_room or timeline_limit < len(recents)
if recents is None or newly_joined_room or timeline_limit < len(recents):
limited = True
else:
limited = False
if recents is not None:
recents = sync_config.filter_collection.filter_room_timeline(recents)
@ -636,7 +639,9 @@ class SyncHandler(BaseHandler):
from_key=since_key,
to_key=end_key,
)
loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
)
loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(),
loaded_recents,
@ -670,37 +675,11 @@ class SyncHandler(BaseHandler):
since_token, now_token,
ephemeral_by_room, tags_by_room,
account_data_by_room,
all_ephemeral_by_room,
batch, full_state=False):
if full_state:
state = yield self.get_state_at(room_id, now_token)
elif batch.limited:
current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=current_state,
room_id, batch, sync_config, since_token, now_token,
full_state=full_state
)
else:
state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
state = yield self.get_state_at(room_id, now_token)
state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
@ -726,14 +705,12 @@ class SyncHandler(BaseHandler):
if room_sync:
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, all_ephemeral_by_room
room_id, sync_config
)
if notifs is not None:
unread_notifications["notification_count"] = len(notifs)
unread_notifications["highlight_count"] = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
unread_notifications["notification_count"] = notifs["notify_count"]
unread_notifications["highlight_count"] = notifs["highlight_count"]
logger.debug("Room sync: %r", room_sync)
@ -766,29 +743,10 @@ class SyncHandler(BaseHandler):
logger.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event(
leave_event_id
)
if not full_state:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state_events_delta = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=state_events_at_leave,
room_id, batch, sync_config, since_token, leave_token,
full_state=full_state
)
else:
state_events_delta = state_events_at_leave
state_events_delta = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
state_events_delta.values()
)
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
@ -843,15 +801,19 @@ class SyncHandler(BaseHandler):
state = {}
defer.returnValue(state)
def compute_state_delta(self, since_token, previous_state, current_state):
""" Works out the differnce in state between the current state and the
state the client got when it last performed a sync.
@defer.inlineCallbacks
def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
full_state):
""" Works out the differnce in state between the start of the timeline
and the previous sync.
:param str since_token: the point we are comparing against
:param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
state to compare to
:param dict[(str,str), synapse.events.FrozenEvent] current_state: the
new state
:param str room_id
:param TimelineBatch batch: The timeline batch for the room that will
be sent to the user.
:param sync_config
:param str since_token: Token of the end of the previous batch. May be None.
:param str now_token: Token of the end of the current batch.
:param bool full_state: Whether to force returning the full state.
:returns A new event dictionary
"""
@ -860,12 +822,53 @@ class SyncHandler(BaseHandler):
# updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events.
state_delta = {}
for key, event in current_state.iteritems():
if (key not in previous_state or
previous_state[key].event_id != event.event_id):
state_delta[key] = event
return state_delta
with Measure(self.clock, "compute_state_delta"):
if full_state:
if batch:
state = yield self.store.get_state_for_event(
batch.events[0].event_id
)
else:
state = yield self.get_state_at(
room_id, stream_position=now_token
)
timeline_state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state,
previous={},
)
elif batch.limited:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state_at_timeline_start = yield self.store.get_state_for_event(
batch.events[0].event_id
)
timeline_state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_at_timeline_start,
previous=state_at_previous_sync,
)
else:
state = {}
defer.returnValue({
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
})
def check_joined_room(self, sync_config, state_delta):
"""
@ -886,9 +889,12 @@ class SyncHandler(BaseHandler):
return False
@defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room):
last_unread_event_id = self.last_read_event_id_for_room_and_user(
room_id, sync_config.user.to_string(), ephemeral_by_room
def unread_notifs_for_room_id(self, room_id, sync_config):
with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
room_id=room_id,
receipt_type="m.read"
)
notifs = []
@ -912,3 +918,37 @@ def _action_has_highlight(actions):
pass
return False
def _calculate_state(timeline_contains, timeline_start, previous):
"""Works out what state to include in a sync response.
Args:
timeline_contains (dict): state in the timeline
timeline_start (dict): state at the start of the timeline
previous (dict): state at the end of the previous sync (or empty dict
if this is an initial sync)
Returns:
dict
"""
event_id_to_state = {
e.event_id: e
for e in itertools.chain(
timeline_contains.values(),
previous.values(),
timeline_start.values(),
)
}
tc_ids = set(e.event_id for e in timeline_contains.values())
p_ids = set(e.event_id for e in previous.values())
ts_ids = set(e.event_id for e in timeline_start.values())
state_ids = (ts_ids - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids)
return {
(e.type, e.state_key): e
for e in evs
}

View File

@ -19,6 +19,7 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.metrics import Measure
from synapse.types import UserID
import logging
@ -222,6 +223,7 @@ class TypingNotificationHandler(BaseHandler):
class TypingNotificationEventSource(object):
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self._handler = None
self._room_member_handler = None
@ -247,6 +249,7 @@ class TypingNotificationEventSource(object):
}
def get_new_events(self, from_key, room_ids, **kwargs):
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
handler = self.handler()

View File

@ -152,7 +152,7 @@ class MatrixFederationHttpClient(object):
return self.clock.time_bound_deferred(
request_deferred,
time_out=timeout/1000. if timeout else 60,
time_out=timeout / 1000. if timeout else 60,
)
response = yield preserve_context_over_fn(

View File

@ -41,7 +41,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
incoming_requests_counter = metrics.register_counter(
"requests",
labels=["method", "servlet"],
labels=["method", "servlet", "tag"],
)
outgoing_responses_counter = metrics.register_counter(
"responses",
@ -50,23 +50,23 @@ outgoing_responses_counter = metrics.register_counter(
response_timer = metrics.register_distribution(
"response_time",
labels=["method", "servlet"]
labels=["method", "servlet", "tag"]
)
response_ru_utime = metrics.register_distribution(
"response_ru_utime", labels=["method", "servlet"]
"response_ru_utime", labels=["method", "servlet", "tag"]
)
response_ru_stime = metrics.register_distribution(
"response_ru_stime", labels=["method", "servlet"]
"response_ru_stime", labels=["method", "servlet", "tag"]
)
response_db_txn_count = metrics.register_distribution(
"response_db_txn_count", labels=["method", "servlet"]
"response_db_txn_count", labels=["method", "servlet", "tag"]
)
response_db_txn_duration = metrics.register_distribution(
"response_db_txn_duration", labels=["method", "servlet"]
"response_db_txn_duration", labels=["method", "servlet", "tag"]
)
@ -99,9 +99,8 @@ def request_handler(request_handler):
request_context.request = request_id
with request.processing():
try:
d = request_handler(self, request)
with PreserveLoggingContext():
yield d
with PreserveLoggingContext(request_context):
yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
@ -208,6 +207,9 @@ class JsonResource(HttpServer, resource.Resource):
if request.method == "OPTIONS":
self._send_response(request, 200, {})
return
start_context = LoggingContext.current_context()
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
@ -226,7 +228,6 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
incoming_requests_counter.inc(request.method, servlet_classname)
args = [
urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
@ -237,21 +238,40 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return
self._send_response(request, code, response)
response_timer.inc_by(
self.clock.time_msec() - start, request.method, servlet_classname
)
try:
context = LoggingContext.current_context()
tag = ""
if context:
tag = context.tag
if context != start_context:
logger.warn(
"Context have unexpectedly changed %r, %r",
context, self.start_context
)
return
incoming_requests_counter.inc(request.method, servlet_classname, tag)
response_timer.inc_by(
self.clock.time_msec() - start, request.method,
servlet_classname, tag
)
ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(ru_utime, request.method, servlet_classname)
response_ru_stime.inc_by(ru_stime, request.method, servlet_classname)
response_ru_utime.inc_by(
ru_utime, request.method, servlet_classname, tag
)
response_ru_stime.inc_by(
ru_stime, request.method, servlet_classname, tag
)
response_db_txn_count.inc_by(
context.db_txn_count, request.method, servlet_classname
context.db_txn_count, request.method, servlet_classname, tag
)
response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, servlet_classname
context.db_txn_duration, request.method, servlet_classname, tag
)
except:
pass

View File

@ -18,10 +18,13 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor, ObservableDeferred
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import StreamToken
import synapse.metrics
from collections import namedtuple
import logging
@ -71,6 +74,7 @@ class _NotifierUserStream(object):
self.current_token = current_token
self.last_notified_ms = time_now_ms
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms):
@ -86,6 +90,8 @@ class _NotifierUserStream(object):
)
self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token)
@ -118,6 +124,11 @@ class _NotifierUserStream(object):
return _NotificationListener(self.notify_deferred.observe())
class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
def __nonzero__(self):
return bool(self.events)
class Notifier(object):
""" This class is responsible for notifying any listeners when there are
new events available for it.
@ -177,8 +188,6 @@ class Notifier(object):
lambda: count(bool, self.appservice_to_user_streams.values()),
)
@log_function
@defer.inlineCallbacks
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]):
""" Used by handlers to inform the notifier something has happened
@ -192,8 +201,7 @@ class Notifier(object):
until all previous events have been persisted before notifying
the client streams.
"""
yield run_on_reactor()
with PreserveLoggingContext():
self.pending_new_room_events.append((
room_stream_id, event, extra_users
))
@ -244,15 +252,13 @@ class Notifier(object):
extra_streams=app_streams,
)
@defer.inlineCallbacks
@log_function
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
extra_streams=set()):
""" Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms.
"""
yield run_on_reactor()
with PreserveLoggingContext():
user_streams = set()
for user in users:
@ -301,7 +307,7 @@ class Notifier(object):
def timed_out():
if listener:
listener.deferred.cancel()
timer = self.clock.call_later(timeout/1000., timed_out)
timer = self.clock.call_later(timeout / 1000., timed_out)
prev_token = from_token
while not result:
@ -318,6 +324,7 @@ class Notifier(object):
# that we don't miss any current_token updates.
prev_token = current_token
listener = user_stream.new_listener(prev_token)
with PreserveLoggingContext():
yield listener.deferred
except defer.CancelledError:
break
@ -356,7 +363,7 @@ class Notifier(object):
@defer.inlineCallbacks
def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
defer.returnValue(None)
defer.returnValue(EventStreamResult([], (from_token, from_token)))
events = []
end_token = from_token
@ -369,6 +376,7 @@ class Notifier(object):
continue
if only_keys and name not in only_keys:
continue
new_events, new_key = yield source.get_new_events(
user=user,
from_key=getattr(from_token, keyname),
@ -388,10 +396,7 @@ class Notifier(object):
events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
if events:
defer.returnValue((events, (from_token, end_token)))
else:
defer.returnValue(None)
defer.returnValue(EventStreamResult(events, (from_token, end_token)))
user_id_for_stream = user.to_string()
if is_peeking:
@ -415,9 +420,6 @@ class Notifier(object):
from_token=from_token,
)
if result is None:
result = ([], (from_token, from_token))
defer.returnValue(result)
@defer.inlineCallbacks

View File

@ -17,6 +17,8 @@ from twisted.internet import defer
from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
import synapse.util.async
import push_rule_evaluator as push_rule_evaluator
@ -27,6 +29,16 @@ import random
logger = logging.getLogger(__name__)
_NEXT_ID = 1
def _get_next_id():
global _NEXT_ID
_id = _NEXT_ID
_NEXT_ID += 1
return _id
# Pushers could now be moved to pull out of the event_push_actions table instead
# of listening on the event stream: this would avoid them having to run the
# rules again.
@ -57,6 +69,8 @@ class Pusher(object):
self.alive = True
self.badge = None
self.name = "Pusher-%d" % (_get_next_id(),)
# The last value of last_active_time that we saw
self.last_last_active_time = 0
self.has_unread = True
@ -86,6 +100,7 @@ class Pusher(object):
@defer.inlineCallbacks
def start(self):
with LoggingContext(self.name):
if not self.last_token:
# First-time setup: get a token to start from (we can't
# just start from no token, ie. 'now'
@ -96,17 +111,24 @@ class Pusher(object):
self.user_id, config, timeout=0, affect_presence=False
)
self.last_token = chunk['end']
self.store.update_pusher_last_token(
yield self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.user_id, self.last_token
)
logger.info("Pusher %s for user %s starting from token %s",
logger.info("New pusher %s for user %s starting from token %s",
self.pushkey, self.user_id, self.last_token)
else:
logger.info(
"Old pusher %s for user %s starting",
self.pushkey, self.user_id,
)
wait = 0
while self.alive:
try:
if wait > 0:
yield synapse.util.async.sleep(wait)
with Measure(self.clock, "push"):
yield self.get_and_dispatch()
wait = 0
except:
@ -316,7 +338,7 @@ class Pusher(object):
r.room_id, self.user_id, last_unread_event_id
)
)
badge += len(notifs)
badge += notifs["notify_count"]
defer.returnValue(badge)

View File

@ -19,8 +19,6 @@ import bulk_push_rule_evaluator
import logging
from synapse.api.constants import EventTypes
logger = logging.getLogger(__name__)
@ -36,21 +34,15 @@ class ActionGenerator:
# tag (ie. we just need all the users).
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, handler):
if event.type == EventTypes.Redaction and event.redacts is not None:
yield self.store.remove_push_actions_for_event_id(
event.room_id, event.redacts
)
def handle_push_actions_for_event(self, event, context, handler):
bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
event.room_id, self.hs, self.store
)
actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler)
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, handler, context.current_state
)
yield self.store.set_push_actions_for_event_and_users(
event,
[
context.push_actions = [
(uid, None, actions) for uid, actions in actions_by_user.items()
]
)

View File

@ -98,25 +98,21 @@ class BulkPushRuleEvaluator:
self.store = store
@defer.inlineCallbacks
def action_for_event_by_user(self, event, handler):
def action_for_event_by_user(self, event, handler, current_state):
actions_by_user = {}
users_dict = yield self.store.are_guests(self.rules_by_user.keys())
filtered_by_user = yield handler._filter_events_for_clients(
users_dict.items(), [event]
users_dict.items(), [event], {event.event_id: current_state}
)
evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
condition_cache = {}
member_state = yield self.store.get_state_for_event(
event.event_id,
)
display_names = {}
for ev in member_state.values():
for ev in current_state.values():
nm = ev.content.get("displayname", None)
if nm and ev.type == EventTypes.Member:
display_names[ev.state_key] = nm

View File

@ -304,7 +304,7 @@ def _flatten_dict(d, prefix=[], result={}):
if isinstance(value, basestring):
result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix+[key]), result=result)
_flatten_dict(value, prefix=(prefix + [key]), result=result)
return result

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from httppusher import HttpPusher
from synapse.push import PusherConfigException
from synapse.util.logcontext import preserve_fn
import logging
@ -76,7 +77,7 @@ class PusherPool:
"Removing pusher for app id %s, pushkey %s, user %s",
app_id, pushkey, p['user_name']
)
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def remove_pushers_by_user(self, user_id):
@ -91,7 +92,7 @@ class PusherPool:
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']
)
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
@ -110,7 +111,7 @@ class PusherPool:
lang=lang,
data=data,
)
self._refresh_pusher(app_id, pushkey, user_id)
yield self._refresh_pusher(app_id, pushkey, user_id)
def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http':
@ -166,7 +167,7 @@ class PusherPool:
if fullid in self.pushers:
self.pushers[fullid].stop()
self.pushers[fullid] = p
p.start()
preserve_fn(p.start)()
logger.info("Started pushers")

View File

@ -89,7 +89,7 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
relay_state = "&RelayState="+urllib.quote(
relay_state = "&RelayState=" + urllib.quote(
login_submission["relay_state"])
result = {
"uri": "%s%s" % (self.idp_redirect_url, relay_state)

View File

@ -33,7 +33,11 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
user,
)
defer.returnValue((200, {"displayname": displayname}))
ret = {}
if displayname is not None:
ret["displayname"] = displayname
defer.returnValue((200, ret))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
@ -66,7 +70,11 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
user,
)
defer.returnValue((200, {"avatar_url": avatar_url}))
ret = {}
if avatar_url is not None:
ret["avatar_url"] = avatar_url
defer.returnValue((200, ret))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
@ -102,10 +110,13 @@ class ProfileRestServlet(ClientV1RestServlet):
user,
)
defer.returnValue((200, {
"displayname": displayname,
"avatar_url": avatar_url
}))
ret = {}
if displayname is not None:
ret["displayname"] = displayname
if avatar_url is not None:
ret["avatar_url"] = avatar_url
defer.returnValue((200, ret))
def register_servlets(hs, http_server):

View File

@ -52,7 +52,7 @@ class PusherRestServlet(ClientV1RestServlet):
if i not in content:
missing.append(i)
if len(missing):
raise SynapseError(400, "Missing parameters: "+','.join(missing),
raise SynapseError(400, "Missing parameters: " + ','.join(missing),
errcode=Codes.MISSING_PARAM)
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
@ -83,7 +83,7 @@ class PusherRestServlet(ClientV1RestServlet):
data=content['data']
)
except PusherConfigException as pce:
raise SynapseError(400, "Config Error: "+pce.message,
raise SynapseError(400, "Config Error: " + pce.message,
errcode=Codes.MISSING_PARAM)
defer.returnValue((200, {}))

View File

@ -38,7 +38,8 @@ logger = logging.getLogger(__name__)
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
compare_digest = lambda a, b: a == b
def compare_digest(a, b):
return a == b
class RegisterRestServlet(ClientV1RestServlet):
@ -58,7 +59,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# }
# TODO: persistent storage
self.sessions = {}
self.disable_registration = hs.config.disable_registration
self.enable_registration = hs.config.enable_registration
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
@ -112,7 +113,7 @@ class RegisterRestServlet(ClientV1RestServlet):
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
can_register = (
not self.disable_registration
self.enable_registration
or is_application_server
or is_using_shared_secret
)

View File

@ -429,8 +429,6 @@ class RoomEventContext(ClientV1RestServlet):
serialize_event(event, time_now) for event in results["state"]
]
logger.info("Responding with %r", results)
defer.returnValue((200, results))

View File

@ -116,9 +116,10 @@ class ThreepidRestServlet(RestServlet):
body = parse_json_dict_from_request(request)
if 'threePidCreds' not in body:
threePidCreds = body.get('threePidCreds')
threePidCreds = body.get('three_pid_creds', threePidCreds)
if threePidCreds is None:
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
threePidCreds = body['threePidCreds']
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()

View File

@ -57,7 +57,7 @@ class AccountDataServlet(RestServlet):
user_id, account_data_type, body
)
yield self.notifier.on_new_event(
self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
@ -99,7 +99,7 @@ class RoomAccountDataServlet(RestServlet):
user_id, room_id, account_data_type, body
)
yield self.notifier.on_new_event(
self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)

View File

@ -34,7 +34,8 @@ from synapse.util.async import run_on_reactor
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
compare_digest = lambda a, b: a == b
def compare_digest(a, b):
return a == b
logger = logging.getLogger(__name__)
@ -116,7 +117,7 @@ class RegisterRestServlet(RestServlet):
return
# == Normal User Registration == (everyone else)
if self.hs.config.disable_registration:
if not self.hs.config.enable_registration:
raise SynapseError(403, "Registration has been disabled")
guest_access_token = body.get("guest_access_token", None)
@ -152,6 +153,7 @@ class RegisterRestServlet(RestServlet):
desired_username = params.get("username", None)
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
(user_id, token) = yield self.registration_handler.register(
localpart=desired_username,

View File

@ -20,7 +20,6 @@ from synapse.http.servlet import (
)
from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken
from synapse.events import FrozenEvent
from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_room_id,
)
@ -287,9 +286,6 @@ class SyncRestServlet(RestServlet):
state_dict = room.state
timeline_events = room.timeline.events
state_dict = SyncRestServlet._rollback_state_for_timeline(
state_dict, timeline_events)
state_events = state_dict.values()
serialized_state = [serialize(e) for e in state_events]
@ -314,77 +310,6 @@ class SyncRestServlet(RestServlet):
return result
@staticmethod
def _rollback_state_for_timeline(state, timeline):
"""
Wind the state dictionary backwards, so that it represents the
state at the start of the timeline, rather than at the end.
:param dict[(str, str), synapse.events.EventBase] state: the
state dictionary. Will be updated to the state before the timeline.
:param list[synapse.events.EventBase] timeline: the event timeline
:return: updated state dictionary
"""
result = state.copy()
for timeline_event in reversed(timeline):
if not timeline_event.is_state():
continue
event_key = (timeline_event.type, timeline_event.state_key)
logger.debug("Considering %s for removal", event_key)
state_event = result.get(event_key)
if (state_event is None or
state_event.event_id != timeline_event.event_id):
# the event in the timeline isn't present in the state
# dictionary.
#
# the most likely cause for this is that there was a fork in
# the event graph, and the state is no longer valid. Really,
# the event shouldn't be in the timeline. We're going to ignore
# it for now, however.
logger.debug("Found state event %r in timeline which doesn't "
"match state dictionary", timeline_event)
continue
prev_event_id = timeline_event.unsigned.get("replaces_state", None)
prev_content = timeline_event.unsigned.get('prev_content')
prev_sender = timeline_event.unsigned.get('prev_sender')
# Empircally it seems possible for the event to have a
# "replaces_state" key but not a prev_content or prev_sender
# markjh conjectures that it could be due to the server not
# having a copy of that event.
# If this is the case the we ignore the previous event. This will
# cause the displayname calculations on the client to be incorrect
if prev_event_id is None or not prev_content or not prev_sender:
logger.debug(
"Removing %r from the state dict, as it is missing"
" prev_content (prev_event_id=%r)",
timeline_event.event_id, prev_event_id
)
del result[event_key]
else:
logger.debug(
"Replacing %r with %r in state dict",
timeline_event.event_id, prev_event_id
)
result[event_key] = FrozenEvent({
"type": timeline_event.type,
"state_key": timeline_event.state_key,
"content": prev_content,
"sender": prev_sender,
"event_id": prev_event_id,
"room_id": timeline_event.room_id,
})
logger.debug("New value: %r", result.get(event_key))
return result
def register_servlets(hs, http_server):
SyncRestServlet(hs).register(http_server)

View File

@ -80,7 +80,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
yield self.notifier.on_new_event(
self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
@ -94,7 +94,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
yield self.notifier.on_new_event(
self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)

View File

@ -26,9 +26,7 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request):
return (200, {
"versions": [
"r0.0.1",
]
"versions": ["r0.0.1"]
})

View File

@ -28,6 +28,7 @@ from twisted.protocols.basic import FileSender
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import preserve_context_over_fn
import os
@ -276,7 +277,8 @@ class BaseMediaResource(Resource):
)
self._makedirs(t_path)
t_len = yield threads.deferToThread(
t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
@ -298,7 +300,8 @@ class BaseMediaResource(Resource):
)
self._makedirs(t_path)
t_len = yield threads.deferToThread(
t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
@ -372,7 +375,7 @@ class BaseMediaResource(Resource):
media_id, t_width, t_height, t_type, t_method, t_len
))
yield threads.deferToThread(generate_thumbnails)
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for l in local_thumbnails:
yield self.store.store_local_thumbnail(*l)
@ -445,7 +448,7 @@ class BaseMediaResource(Resource):
t_width, t_height, t_type, t_method, t_len
])
yield threads.deferToThread(generate_thumbnails)
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for r in remote_thumbnails:
yield self.store.store_remote_media_thumbnail(*r)

View File

@ -63,7 +63,7 @@ class StateHandler(object):
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS*1000,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
reset_expiry_on_get=True,
)

View File

@ -45,9 +45,10 @@ from .search import SearchStore
from .tags import TagsStore
from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator
from synapse.util.caches.stream_change_cache import StreamChangeCache
import logging
@ -58,7 +59,7 @@ logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
# times give more inserts into the database even for readonly API hits
# 120 seconds == 2 minutes
LAST_SEEN_GRANULARITY = 120*1000
LAST_SEEN_GRANULARITY = 120 * 1000
class DataStore(RoomMemberStore, RoomStore,
@ -84,6 +85,7 @@ class DataStore(RoomMemberStore, RoomStore,
def __init__(self, db_conn, hs):
self.hs = hs
self.database_engine = hs.database_engine
cur = db_conn.cursor()
try:
@ -117,8 +119,61 @@ class DataStore(RoomMemberStore, RoomStore,
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
events_max = self._stream_id_gen.get_max_token(None)
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max,
)
account_max = self._account_data_id_gen.get_max_token(None)
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max,
)
super(DataStore, self).__init__(hs)
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
sql = (
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
" WHERE %(stream)s > ? - 100000"
" GROUP BY %(entity)s"
) % {
"table": table,
"entity": entity_column,
"stream": stream_column,
}
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
cache = {
row[0]: int(row[1])
for row in rows
}
if cache:
min_val = min(cache.values())
else:
min_val = max_value
return cache, min_val
@defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec())

View File

@ -15,7 +15,7 @@
import logging
from synapse.api.errors import StoreError
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
import synapse.metrics
@ -185,7 +185,7 @@ class SQLBaseStore(object):
time_then = self._previous_loop_ts
self._previous_loop_ts = time_now
ratio = (curr - prev)/(time_now - time_then)
ratio = (curr - prev) / (time_now - time_then)
top_three_counters = self._txn_perf_counters.interval(
time_now - time_then, limit=3
@ -298,8 +298,8 @@ class SQLBaseStore(object):
func, *args, **kwargs
)
result = yield preserve_context_over_fn(
self._db_pool.runWithConnection,
with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
@ -326,8 +326,8 @@ class SQLBaseStore(object):
return func(conn, *args, **kwargs)
result = yield preserve_context_over_fn(
self._db_pool.runWithConnection,
with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
@ -643,7 +643,10 @@ class SQLBaseStore(object):
if not iterable:
defer.returnValue(results)
chunks = [iterable[i:i+batch_size] for i in xrange(0, len(iterable), batch_size)]
chunks = [
iterable[i:i + batch_size]
for i in xrange(0, len(iterable), batch_size)
]
for chunk in chunks:
rows = yield self.runInteraction(
desc,

View File

@ -14,7 +14,6 @@
# limitations under the License.
from ._base import SQLBaseStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
import ujson as json
@ -24,14 +23,6 @@ logger = logging.getLogger(__name__)
class AccountDataStore(SQLBaseStore):
def __init__(self, hs):
super(AccountDataStore, self).__init__(hs)
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_max_token(None),
max_size=10000,
)
def get_account_data_for_user(self, user_id):
"""Get all the client account_data for a user.
@ -166,6 +157,10 @@ class AccountDataStore(SQLBaseStore):
"content": content_json,
}
)
txn.call_after(
self._account_data_stream_cache.entity_has_changed,
user_id, next_id,
)
self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id:

View File

@ -276,7 +276,8 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
"application_services_state",
dict(as_id=service.id),
["state"],
allow_none=True
allow_none=True,
desc="get_appservice_state",
)
if result:
defer.returnValue(result.get("state"))

View File

@ -54,7 +54,7 @@ class Sqlite3Engine(object):
def _parse_match_info(buf):
bufsize = len(buf)
return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]
return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info):

View File

@ -58,7 +58,7 @@ class EventFederationStore(SQLBaseStore):
new_front = set()
front_list = list(front)
chunks = [
front_list[x:x+100]
front_list[x:x + 100]
for x in xrange(0, len(front), 100)
]
for chunk in chunks:

View File

@ -24,8 +24,7 @@ logger = logging.getLogger(__name__)
class EventPushActionsStore(SQLBaseStore):
@defer.inlineCallbacks
def set_push_actions_for_event_and_users(self, event, tuples):
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
"""
:param event: the event set actions for
:param tuples: list of tuples of (user_id, profile_tag, actions)
@ -37,21 +36,19 @@ class EventPushActionsStore(SQLBaseStore):
'event_id': event.event_id,
'user_id': uid,
'profile_tag': profile_tag,
'actions': json.dumps(actions)
'actions': json.dumps(actions),
'stream_ordering': event.internal_metadata.stream_ordering,
'topological_ordering': event.depth,
'notif': 1,
'highlight': 1 if _action_has_highlight(actions) else 0,
})
def f(txn):
for uid, _, __ in tuples:
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(event.room_id, uid)
)
return self._simple_insert_many_txn(txn, "event_push_actions", values)
yield self.runInteraction(
"set_actions_for_event_and_users",
f,
)
self._simple_insert_many_txn(txn, "event_push_actions", values)
@cachedInlineCallbacks(num_args=3, lru=True, tree=True)
def get_unread_event_push_actions_by_room_for_user(
@ -68,32 +65,34 @@ class EventPushActionsStore(SQLBaseStore):
)
results = txn.fetchall()
if len(results) == 0:
return []
return {"notify_count": 0, "highlight_count": 0}
stream_ordering = results[0][0]
topological_ordering = results[0][1]
sql = (
"SELECT ea.event_id, ea.actions"
" FROM event_push_actions ea, events e"
" WHERE ea.room_id = e.room_id"
" AND ea.event_id = e.event_id"
" AND ea.user_id = ?"
" AND ea.room_id = ?"
"SELECT sum(notif), sum(highlight)"
" FROM event_push_actions ea"
" WHERE"
" user_id = ?"
" AND room_id = ?"
" AND ("
" e.topological_ordering > ?"
" OR (e.topological_ordering = ? AND e.stream_ordering > ?)"
" topological_ordering > ?"
" OR (topological_ordering = ? AND stream_ordering > ?)"
")"
)
txn.execute(sql, (
user_id, room_id,
topological_ordering, topological_ordering, stream_ordering
)
)
return [
{"event_id": row[0], "actions": json.loads(row[1])}
for row in txn.fetchall()
]
))
row = txn.fetchone()
if row:
return {
"notify_count": row[0] or 0,
"highlight_count": row[1] or 0,
}
else:
return {"notify_count": 0, "highlight_count": 0}
ret = yield self.runInteraction(
"get_unread_event_push_actions_by_room",
@ -101,9 +100,7 @@ class EventPushActionsStore(SQLBaseStore):
)
defer.returnValue(ret)
@defer.inlineCallbacks
def remove_push_actions_for_event_id(self, room_id, event_id):
def f(txn):
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
# Sad that we have to blow away the cache for the whole room here
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
@ -113,7 +110,14 @@ class EventPushActionsStore(SQLBaseStore):
"DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
(room_id, event_id)
)
yield self.runInteraction(
"remove_push_actions_for_event_id",
f
)
def _action_has_highlight(actions):
for action in actions:
try:
if action.get("set_tweak", None) == "highlight":
return action.get("value", True)
except AttributeError:
pass
return False

View File

@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
@ -84,7 +84,7 @@ class EventsStore(SQLBaseStore):
event.internal_metadata.stream_ordering = stream
chunks = [
events_and_contexts[x:x+100]
events_and_contexts[x:x + 100]
for x in xrange(0, len(events_and_contexts), 100)
]
@ -205,25 +205,31 @@ class EventsStore(SQLBaseStore):
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
is_new_state=True):
depth_updates = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
for event, _ in events_and_contexts:
txn.call_after(self._invalidate_get_event_cache, event.event_id)
if not backfilled:
txn.call_after(
self._events_stream_cache.entity_has_changed,
event.room_id, event.internal_metadata.stream_ordering,
)
depth_updates = {}
for event, _ in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
if not event.internal_metadata.is_outlier():
depth_updates[event.room_id] = max(
event.depth, depth_updates.get(event.room_id, event.depth)
)
if context.push_actions:
self._set_push_actions_for_event_and_users_txn(
txn, event, context.push_actions
)
if event.type == EventTypes.Redaction and event.redacts is not None:
self._remove_push_actions_for_event_id_txn(
txn, event.room_id, event.redacts
)
for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
@ -664,6 +670,7 @@ class EventsStore(SQLBaseStore):
for ids, d in lst:
if not d.called:
try:
with PreserveLoggingContext():
d.callback([
res[i]
for i in ids
@ -671,6 +678,7 @@ class EventsStore(SQLBaseStore):
])
except:
logger.exception("Failed to callback")
with PreserveLoggingContext():
reactor.callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
@ -679,9 +687,11 @@ class EventsStore(SQLBaseStore):
def fire(evs):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
d.errback(e)
if event_list:
with PreserveLoggingContext():
reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks
@ -709,18 +719,20 @@ class EventsStore(SQLBaseStore):
should_start = False
if should_start:
with PreserveLoggingContext():
self.runWithConnection(
self._do_fetch
)
rows = yield preserve_context_over_deferred(events_d)
with PreserveLoggingContext():
rows = yield events_d
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = yield defer.gatherResults(
[
self._get_event_from_row(
preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
@ -740,7 +752,7 @@ class EventsStore(SQLBaseStore):
rows = []
N = 200
for i in range(1 + len(events) / N):
evs = events[i*N:(i + 1)*N]
evs = events[i * N:(i + 1) * N]
if not evs:
break
@ -755,7 +767,7 @@ class EventsStore(SQLBaseStore):
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"]*len(evs)),)
) % (",".join(["?"] * len(evs)),)
txn.execute(sql, evs)
rows.extend(self.cursor_to_dict(txn))

View File

@ -39,6 +39,7 @@ class KeyStore(SQLBaseStore):
table="server_tls_certificates",
keyvalues={"server_name": server_name},
retcols=("tls_certificate",),
desc="get_server_certificate",
)
tls_certificate = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes,

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 28
SCHEMA_VERSION = 29
dir_path = os.path.abspath(os.path.dirname(__file__))
@ -211,7 +211,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
logger.debug("applied_delta_files: %s", applied_delta_files)
for v in range(start_ver, SCHEMA_VERSION + 1):
logger.debug("Upgrading schema to v%d", v)
logger.info("Upgrading schema to v%d", v)
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))

View File

@ -68,8 +68,9 @@ class PresenceStore(SQLBaseStore):
for row in rows
})
@defer.inlineCallbacks
def set_presence_state(self, user_localpart, new_state):
res = self._simple_update_one(
res = yield self._simple_update_one(
table="presence",
keyvalues={"user_id": user_localpart},
updatevalues={"state": new_state["state"],
@ -79,7 +80,7 @@ class PresenceStore(SQLBaseStore):
)
self.get_presence_state.invalidate((user_localpart,))
return res
defer.returnValue(res)
def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert(

View File

@ -46,6 +46,20 @@ class ReceiptsStore(SQLBaseStore):
desc="get_receipts_for_room",
)
@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
return self._simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id
},
retcol="event_id",
desc="get_own_receipt_for_user",
allow_none=True,
)
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
def f(txn):
@ -226,6 +240,11 @@ class ReceiptsStore(SQLBaseStore):
room_id, stream_id
)
txn.call_after(
self.get_last_receipt_event_id_for_user.invalidate,
(user_id, room_id, receipt_type)
)
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
sql = (

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
@ -134,6 +136,7 @@ class RegistrationStore(SQLBaseStore):
},
retcols=["name", "password_hash", "is_guest"],
allow_none=True,
desc="get_user_by_id",
)
def get_users_by_id_case_insensitive(self, user_id):
@ -350,3 +353,37 @@ class RegistrationStore(SQLBaseStore):
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
"""
Gets the localpart of the next generated user ID.
Generated user IDs are integers, and we aim for them to be as small as
we can. Unfortunately, it's possible some of them are already taken by
existing users, and there may be gaps in the already taken range. This
function returns the start of the first allocatable gap. This is to
avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs.
"""
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
rows = self.cursor_to_dict(txn)
regex = re.compile("^@(\d+):")
found = set()
for r in rows:
user_id = r["name"]
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
for i in xrange(len(found) + 1):
if i not in found:
return i
defer.returnValue((yield self.runInteraction(
"find_next_generated_user_id",
_find_next_generated_user_id
)))

View File

@ -87,90 +87,20 @@ class RoomStore(SQLBaseStore):
desc="get_public_room_ids",
)
@defer.inlineCallbacks
def get_rooms(self, is_public):
"""Retrieve a list of all public rooms.
Args:
is_public (bool): True if the rooms returned should be public.
Returns:
A list of room dicts containing at least a "room_id" key, a
"topic" key if one is set, and a "name" key if one is set
def get_room_count(self):
"""Retrieve a list of all rooms
"""
def f(txn):
def subquery(table_name, column_name=None):
column_name = column_name or table_name
return (
"SELECT %(table_name)s.event_id as event_id, "
"%(table_name)s.room_id as room_id, %(column_name)s "
"FROM %(table_name)s "
"INNER JOIN current_state_events as c "
"ON c.event_id = %(table_name)s.event_id " % {
"column_name": column_name,
"table_name": table_name,
}
)
sql = "SELECT count(*) FROM rooms"
txn.execute(sql)
row = txn.fetchone()
return row[0] or 0
sql = (
"SELECT"
" r.room_id,"
" max(n.name),"
" max(t.topic),"
" max(v.history_visibility),"
" max(g.guest_access)"
" FROM rooms AS r"
" LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id"
" LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id"
" LEFT JOIN (%(history_visibility)s) AS v ON v.room_id = r.room_id"
" LEFT JOIN (%(guest_access)s) AS g ON g.room_id = r.room_id"
" WHERE r.is_public = ?"
" GROUP BY r.room_id" % {
"topic": subquery("topics", "topic"),
"name": subquery("room_names", "name"),
"history_visibility": subquery("history_visibility"),
"guest_access": subquery("guest_access"),
}
)
txn.execute(sql, (is_public,))
rows = txn.fetchall()
for i, row in enumerate(rows):
room_id = row[0]
aliases = self._simple_select_onecol_txn(
txn,
table="room_aliases",
keyvalues={
"room_id": room_id
},
retcol="room_alias",
)
rows[i] = list(row) + [aliases]
return rows
rows = yield self.runInteraction(
return self.runInteraction(
"get_rooms", f
)
ret = [
{
"room_id": r[0],
"name": r[1],
"topic": r[2],
"world_readable": r[3] == "world_readable",
"guest_can_join": r[4] == "can_join",
"aliases": r[5],
}
for r in rows
if r[5] # We only return rooms that have at least one alias.
]
defer.returnValue(ret)
def _store_room_topic_txn(self, txn, event):
if hasattr(event, "content") and "topic" in event.content:
self._simple_insert_txn(

View File

@ -58,6 +58,10 @@ class RoomMemberStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
)
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.

View File

@ -0,0 +1,16 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX public_room_index on rooms(is_public);

View File

@ -0,0 +1,31 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE event_push_actions ADD COLUMN topological_ordering BIGINT;
ALTER TABLE event_push_actions ADD COLUMN stream_ordering BIGINT;
ALTER TABLE event_push_actions ADD COLUMN notif SMALLINT;
ALTER TABLE event_push_actions ADD COLUMN highlight SMALLINT;
UPDATE event_push_actions SET stream_ordering = (
SELECT stream_ordering FROM events WHERE event_id = event_push_actions.event_id
), topological_ordering = (
SELECT topological_ordering FROM events WHERE event_id = event_push_actions.event_id
);
UPDATE event_push_actions SET notif = 1, highlight = 0;
CREATE INDEX event_push_actions_rm_tokens on event_push_actions(
user_id, room_id, topological_ordering, stream_ordering
);

View File

@ -171,15 +171,10 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
def _get_state_groups_from_groups(self, groups_and_types):
def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> state event ids
Args:
groups_and_types (list): list of 2-tuple (`group`, `types`)
"""
def f(txn):
results = {}
for group, types in groups_and_types:
def f(txn, groups):
if types is not None:
where_clause = "AND (%s)" % (
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
@ -188,23 +183,30 @@ class StateStore(SQLBaseStore):
where_clause = ""
sql = (
"SELECT event_id FROM state_groups_state WHERE"
" state_group = ? %s"
) % (where_clause,)
"SELECT state_group, event_id FROM state_groups_state WHERE"
" state_group IN (%s) %s" % (
",".join("?" for _ in groups),
where_clause,
)
)
args = [group]
args = list(groups)
if types is not None:
args.extend([i for typ in types for i in typ])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
results[group] = [r[0] for r in txn.fetchall()]
results = {}
for row in rows:
results.setdefault(row["state_group"], []).append(row["event_id"])
return results
chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
for chunk in chunks:
return self.runInteraction(
"_get_state_groups_from_groups",
f,
f, chunk
)
@defer.inlineCallbacks
@ -264,26 +266,20 @@ class StateStore(SQLBaseStore):
)
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
num_args=1)
num_args=1, inlineCallbacks=True)
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
def f(txn):
results = {}
for event_id in event_ids:
results[event_id] = self._simple_select_one_onecol_txn(
txn,
rows = yield self._simple_select_many_batch(
table="event_to_state_groups",
keyvalues={
"event_id": event_id,
},
retcol="state_group",
allow_none=True,
column="event_id",
iterable=event_ids,
keyvalues={},
retcols=("event_id", "state_group",),
desc="_get_state_group_for_events",
)
return results
return self.runInteraction("_get_state_group_for_events", f)
defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
def _get_some_state_from_cache(self, group, types):
"""Checks if group is in cache. See `_get_state_for_groups`
@ -355,7 +351,7 @@ class StateStore(SQLBaseStore):
all events are returned.
"""
results = {}
missing_groups_and_types = []
missing_groups = []
if types is not None:
for group in set(groups):
state_dict, missing_types, got_all = self._get_some_state_from_cache(
@ -364,7 +360,7 @@ class StateStore(SQLBaseStore):
results[group] = state_dict
if not got_all:
missing_groups_and_types.append((group, missing_types))
missing_groups.append(group)
else:
for group in set(groups):
state_dict, got_all = self._get_all_state_from_cache(
@ -373,9 +369,9 @@ class StateStore(SQLBaseStore):
results[group] = state_dict
if not got_all:
missing_groups_and_types.append((group, None))
missing_groups.append(group)
if not missing_groups_and_types:
if not missing_groups:
defer.returnValue({
group: {
type_tuple: event
@ -389,7 +385,7 @@ class StateStore(SQLBaseStore):
cache_seq_num = self._state_group_cache.sequence
group_state_dict = yield self._get_state_groups_from_groups(
missing_groups_and_types
missing_groups, types
)
state_events = yield self._get_events(

View File

@ -37,10 +37,9 @@ from twisted.internet import defer
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn
import logging
@ -78,13 +77,6 @@ def upper_bound(token):
class StreamStore(SQLBaseStore):
def __init__(self, hs):
super(StreamStore, self).__init__(hs)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None)
)
@defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
# NB this lives here instead of appservice.py so we can reuse the
@ -177,14 +169,14 @@ class StreamStore(SQLBaseStore):
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)):
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
res = yield defer.gatherResults([
self.get_room_events_stream_for_room(
room_id, from_key, to_key, limit
).addCallback(lambda r, rm: (rm, r), room_id)
preserve_fn(self.get_room_events_stream_for_room)(
room_id, from_key, to_key, limit,
)
for room_id in room_ids
])
results.update(dict(res))
results.update(dict(zip(rm_ids, res)))
defer.returnValue(results)
@ -229,8 +221,11 @@ class StreamStore(SQLBaseStore):
rows = self.cursor_to_dict(txn)
ret = self._get_events_txn(
txn,
return rows
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
ret = yield self._get_events(
[r["event_id"] for r in rows],
get_prev_content=True
)
@ -246,11 +241,10 @@ class StreamStore(SQLBaseStore):
# get.
key = from_key
return ret, key
res = yield self.runInteraction("get_room_events_stream_for_room", f)
defer.returnValue(res)
defer.returnValue((ret, key))
def get_room_changes_for_user(self, user_id, from_key, to_key):
@defer.inlineCallbacks
def get_membership_changes_for_user(self, user_id, from_key, to_key):
if from_key is not None:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
else:
@ -258,7 +252,14 @@ class StreamStore(SQLBaseStore):
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
return defer.succeed([])
defer.returnValue([])
if from_id:
has_changed = self._membership_stream_cache.has_entity_changed(
user_id, int(from_id)
)
if not has_changed:
defer.returnValue([])
def f(txn):
if from_id is not None:
@ -283,17 +284,19 @@ class StreamStore(SQLBaseStore):
txn.execute(sql, (user_id, to_id,))
rows = self.cursor_to_dict(txn)
ret = self._get_events_txn(
txn,
return rows
rows = yield self.runInteraction("get_membership_changes_for_user", f)
ret = yield self._get_events(
[r["event_id"] for r in rows],
get_prev_content=True
)
return ret
self._set_before_and_after(ret, rows, topo_order=False)
return self.runInteraction("get_room_changes_for_user", f)
defer.returnValue(ret)
@log_function
def get_room_events_stream(
self,
user_id,
@ -324,11 +327,6 @@ class StreamStore(SQLBaseStore):
" WHERE m.user_id = ? AND m.membership = 'join'"
)
current_room_membership_args = [user_id]
if room_ids:
current_room_membership_sql += " AND m.room_id in (%s)" % (
",".join(map(lambda _: "?", room_ids))
)
current_room_membership_args = [user_id] + room_ids
# We also want to get any membership events about that user, e.g.
# invites or leave notifications.
@ -567,6 +565,7 @@ class StreamStore(SQLBaseStore):
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
).addCallback(lambda row: "t%d-%d" % (
row["topological_ordering"], row["stream_ordering"],)
)
@ -604,6 +603,10 @@ class StreamStore(SQLBaseStore):
internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream))
internal.order = (
int(topo) if topo else 0,
int(stream),
)
@defer.inlineCallbacks
def get_events_around(self, room_id, event_id, before_limit, after_limit):

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task
@ -46,7 +46,7 @@ class Clock(object):
def looping_call(self, f, msec):
l = task.LoopingCall(f)
l.start(msec/1000.0, now=False)
l.start(msec / 1000.0, now=False)
return l
def stop_looping_call(self, loop):
@ -61,10 +61,8 @@ class Clock(object):
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
current_context = LoggingContext.current_context()
def wrapped_callback(*args, **kwargs):
with PreserveLoggingContext(current_context):
with PreserveLoggingContext():
callback(*args, **kwargs)
with PreserveLoggingContext():

View File

@ -16,13 +16,16 @@
from twisted.internet import defer, reactor
from .logcontext import preserve_context_over_deferred
from .logcontext import PreserveLoggingContext
@defer.inlineCallbacks
def sleep(seconds):
d = defer.Deferred()
with PreserveLoggingContext():
reactor.callLater(seconds, d.callback, seconds)
return preserve_context_over_deferred(d)
res = yield d
defer.returnValue(res)
def run_on_reactor():
@ -54,6 +57,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (True, r))
while self._observers:
try:
# TODO: Handle errors here.
self._observers.pop().callback(r)
except:
pass
@ -63,6 +67,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (False, f))
while self._observers:
try:
# TODO: Handle errors here.
self._observers.pop().errback(f)
except:
pass

View File

@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
from . import caches_by_name, DEBUG_CACHES, cache_counter
@ -149,7 +152,7 @@ class CacheDescriptor(object):
self.lru = lru
self.tree = tree
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
if len(self.arg_names) < self.num_args:
raise Exception(
@ -190,7 +193,7 @@ class CacheDescriptor(object):
defer.returnValue(cached_result)
observer.addCallback(check_result)
return observer
return preserve_context_over_deferred(observer)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
@ -198,6 +201,7 @@ class CacheDescriptor(object):
sequence = self.cache.sequence
ret = defer.maybeDeferred(
preserve_context_over_fn,
self.function_to_call,
obj, *args, **kwargs
)
@ -211,7 +215,7 @@ class CacheDescriptor(object):
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
return ret.observe()
return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
@ -250,7 +254,7 @@ class CacheListDescriptor(object):
self.num_args = num_args
self.list_name = list_name
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache
@ -299,6 +303,7 @@ class CacheListDescriptor(object):
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
preserve_context_over_fn,
self.function_to_call,
**args_to_call
)
@ -308,6 +313,7 @@ class CacheListDescriptor(object):
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
with PreserveLoggingContext():
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
@ -327,10 +333,10 @@ class CacheListDescriptor(object):
cached[arg] = res
return defer.gatherResults(
return preserve_context_over_deferred(defer.gatherResults(
cached.values(),
consumeErrors=True,
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
obj.__dict__[self.orig.__name__] = wrapped

View File

@ -55,7 +55,7 @@ class ExpiringCache(object):
def f():
self._prune_cache()
self._clock.looping_call(f, self._expiry_ms/2)
self._clock.looping_call(f, self._expiry_ms / 2)
def __setitem__(self, key, value):
now = self._clock.time_msec()

View File

@ -87,7 +87,8 @@ class SnapshotCache(object):
# expire from the rotation of that cache.
self.next_result_cache[key] = result
self.pending_result_cache.pop(key, None)
return r
result.observe().addBoth(shuffle_along)
result.addBoth(shuffle_along)
return result.observe()

View File

@ -32,7 +32,7 @@ class StreamChangeCache(object):
entities that may have changed since that position. If position key is too
old then the cache will simply return all given entities.
"""
def __init__(self, name, current_stream_pos, max_size=10000):
def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}):
self._max_size = max_size
self._entity_to_key = {}
self._cache = sorteddict()
@ -40,6 +40,9 @@ class StreamChangeCache(object):
self.name = name
caches_by_name[self.name] = self._cache
for entity, stream_pos in prefilled_cache.items():
self.entity_has_changed(entity, stream_pos)
def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos
"""
@ -49,15 +52,10 @@ class StreamChangeCache(object):
cache_counter.inc_misses(self.name)
return True
if stream_pos == self._earliest_known_stream_pos:
# If the same as the earliest key, assume nothing has changed.
cache_counter.inc_hits(self.name)
return False
latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None:
cache_counter.inc_misses(self.name)
return True
cache_counter.inc_hits(self.name)
return False
if stream_pos < latest_entity_change_pos:
cache_counter.inc_misses(self.name)
@ -95,7 +93,7 @@ class StreamChangeCache(object):
if stream_pos > self._earliest_known_stream_pos:
old_pos = self._entity_to_key.get(entity, None)
if old_pos:
if old_pos is not None:
stream_pos = max(stream_pos, old_pos)
self._cache.pop(old_pos, None)
self._cache[stream_pos] = entity

View File

@ -58,7 +58,7 @@ class TreeCache(object):
if n:
break
node_and_keys[i+1][0].pop(k)
node_and_keys[i + 1][0].pop(k)
popped, cnt = _strip_and_count_entires(popped)
self.size -= cnt

View File

@ -15,9 +15,7 @@
from twisted.internet import defer
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred,
)
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util import unwrapFirstError
@ -97,6 +95,7 @@ class Signal(object):
Each observer callable may return a Deferred."""
self.observers.append(observer)
@defer.inlineCallbacks
def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
@ -116,6 +115,7 @@ class Signal(object):
failure.getTracebackObject()))
if not self.suppress_failures:
return failure
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
with PreserveLoggingContext():
@ -124,8 +124,11 @@ class Signal(object):
for observer in self.observers
]
d = defer.gatherResults(deferreds, consumeErrors=True)
res = yield defer.gatherResults(
deferreds, consumeErrors=True
).addErrback(unwrapFirstError)
d.addErrback(unwrapFirstError)
defer.returnValue(res)
return preserve_context_over_deferred(d)
def __repr__(self):
return "<Signal name=%r>" % (self.name,)

View File

@ -41,13 +41,14 @@ except:
class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
"with" block. Contexts inherit the state of their parent contexts.
"with" block.
Args:
name (str): Name for the context for debugging.
"""
__slots__ = [
"parent_context", "name", "usage_start", "usage_end", "main_thread", "__dict__"
"previous_context", "name", "usage_start", "usage_end", "main_thread",
"__dict__", "tag", "alive",
]
thread_local = threading.local()
@ -72,10 +73,13 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
pass
def __nonzero__(self):
return False
sentinel = Sentinel()
def __init__(self, name=None):
self.parent_context = None
self.previous_context = LoggingContext.current_context()
self.name = name
self.ru_stime = 0.
self.ru_utime = 0.
@ -83,6 +87,8 @@ class LoggingContext(object):
self.db_txn_duration = 0.
self.usage_start = None
self.main_thread = threading.current_thread()
self.tag = ""
self.alive = True
def __str__(self):
return "%s@%x" % (self.name, id(self))
@ -101,6 +107,7 @@ class LoggingContext(object):
The context that was previously active
"""
current = cls.current_context()
if current is not context:
current.stop()
cls.thread_local.current_context = context
@ -109,9 +116,13 @@ class LoggingContext(object):
def __enter__(self):
"""Enters this logging context into thread local storage"""
if self.parent_context is not None:
raise Exception("Attempt to enter logging context multiple times")
self.parent_context = self.set_current_context(self)
old_context = self.set_current_context(self)
if self.previous_context != old_context:
logger.warn(
"Expected previous context %r, found %r",
self.previous_context, old_context
)
self.alive = True
return self
def __exit__(self, type, value, traceback):
@ -120,7 +131,7 @@ class LoggingContext(object):
Returns:
None to avoid suppressing any exeptions that were thrown.
"""
current = self.set_current_context(self.parent_context)
current = self.set_current_context(self.previous_context)
if current is not self:
if current is self.sentinel:
logger.debug("Expected logging context %s has been lost", self)
@ -130,16 +141,11 @@ class LoggingContext(object):
current,
self
)
self.parent_context = None
def __getattr__(self, name):
"""Delegate member lookup to parent context"""
return getattr(self.parent_context, name)
self.previous_context = None
self.alive = False
def copy_to(self, record):
"""Copy fields from this context and its parents to the record"""
if self.parent_context is not None:
self.parent_context.copy_to(record)
"""Copy fields from this context to the record"""
for key, value in self.__dict__.items():
setattr(record, key, value)
@ -208,7 +214,7 @@ class PreserveLoggingContext(object):
exited. Used to restore the context after a function using
@defer.inlineCallbacks is resumed by a callback from the reactor."""
__slots__ = ["current_context", "new_context"]
__slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context=LoggingContext.sentinel):
self.new_context = new_context
@ -219,12 +225,27 @@ class PreserveLoggingContext(object):
self.new_context
)
if self.current_context:
self.has_parent = self.current_context.previous_context is not None
if not self.current_context.alive:
logger.debug(
"Entering dead context: %s",
self.current_context,
)
def __exit__(self, type, value, traceback):
"""Restores the current logging context"""
LoggingContext.set_current_context(self.current_context)
context = LoggingContext.set_current_context(self.current_context)
if context != self.new_context:
logger.debug(
"Unexpected logging context: %s is not %s",
context, self.new_context,
)
if self.current_context is not LoggingContext.sentinel:
if self.current_context.parent_context is None:
logger.warn(
if not self.current_context.alive:
logger.debug(
"Restoring dead context: %s",
self.current_context,
)
@ -284,3 +305,74 @@ def preserve_context_over_deferred(deferred):
d = _PreservingContextDeferred(current_context)
deferred.chainDeferred(d)
return d
def preserve_fn(f):
"""Ensures that function is called with correct context and that context is
restored after return. Useful for wrapping functions that return a deferred
which you don't yield on.
"""
current = LoggingContext.current_context()
def g(*args, **kwargs):
with PreserveLoggingContext(current):
return f(*args, **kwargs)
return g
# modules to ignore in `logcontext_tracer`
_to_ignore = [
"synapse.util.logcontext",
"synapse.http.server",
"synapse.storage._base",
"synapse.util.async",
]
def logcontext_tracer(frame, event, arg):
"""A tracer that logs whenever a logcontext "unexpectedly" changes within
a function. Probably inaccurate.
Use by calling `sys.settrace(logcontext_tracer)` in the main thread.
"""
if event == 'call':
name = frame.f_globals["__name__"]
if name.startswith("synapse"):
if name == "synapse.util.logcontext":
if frame.f_code.co_name in ["__enter__", "__exit__"]:
tracer = frame.f_back.f_trace
if tracer:
tracer.just_changed = True
tracer = frame.f_trace
if tracer:
return tracer
if not any(name.startswith(ig) for ig in _to_ignore):
return LineTracer()
class LineTracer(object):
__slots__ = ["context", "just_changed"]
def __init__(self):
self.context = LoggingContext.current_context()
self.just_changed = False
def __call__(self, frame, event, arg):
if event in 'line':
if self.just_changed:
self.context = LoggingContext.current_context()
self.just_changed = False
else:
c = LoggingContext.current_context()
if c != self.context:
logger.info(
"Context changed! %s -> %s, %s, %s",
self.context, c,
frame.f_code.co_filename, frame.f_lineno
)
self.context = c
return self

View File

@ -111,7 +111,7 @@ def time_function(f):
_log_debug_as_f(
f,
"[FUNC END] {%s-%d} %f",
(func_name, id, end-start,),
(func_name, id, end - start,),
)
return r
@ -168,3 +168,38 @@ def trace_function(f):
wrapped.__name__ = func_name
return wrapped
def get_previous_frames():
s = inspect.currentframe().f_back.f_back
to_return = []
while s:
if s.f_globals["__name__"].startswith("synapse"):
filename, lineno, function, _, _ = inspect.getframeinfo(s)
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
to_return.append("{{ %s:%d %s - Args: %s }}" % (
filename, lineno, function, args_string
))
s = s.f_back
return ", ". join(to_return)
def get_previous_frame(ignore=[]):
s = inspect.currentframe().f_back.f_back
while s:
if s.f_globals["__name__"].startswith("synapse"):
if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore):
filename, lineno, function, _, _ = inspect.getframeinfo(s)
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
return "{{ %s:%d %s - Args: %s }}" % (
filename, lineno, function, args_string
)
s = s.f_back
return None

97
synapse/util/metrics.py Normal file
View File

@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.logcontext import LoggingContext
import synapse.metrics
import logging
logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
block_timer = metrics.register_distribution(
"block_timer",
labels=["block_name"]
)
block_ru_utime = metrics.register_distribution(
"block_ru_utime", labels=["block_name"]
)
block_ru_stime = metrics.register_distribution(
"block_ru_stime", labels=["block_name"]
)
block_db_txn_count = metrics.register_distribution(
"block_db_txn_count", labels=["block_name"]
)
block_db_txn_duration = metrics.register_distribution(
"block_db_txn_duration", labels=["block_name"]
)
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
"ru_stime", "db_txn_count", "db_txn_duration"
]
def __init__(self, clock, name):
self.clock = clock
self.name = name
self.start_context = None
self.start = None
def __enter__(self):
self.start = self.clock.time_msec()
self.start_context = LoggingContext.current_context()
if self.start_context:
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
self.db_txn_count = self.start_context.db_txn_count
self.db_txn_duration = self.start_context.db_txn_duration
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None or not self.start_context:
return
duration = self.clock.time_msec() - self.start
block_timer.inc_by(duration, self.name)
context = LoggingContext.current_context()
if context != self.start_context:
logger.warn(
"Context have unexpectedly changed from '%s' to '%s'. (%r)",
context, self.start_context, self.name
)
return
if not context:
logger.warn("Expected context. (%r)", self.name)
return
ru_utime, ru_stime = context.get_resource_usage()
block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name)
block_db_txn_duration.inc_by(
context.db_txn_duration - self.db_txn_duration, self.name
)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
from synapse.util.logcontext import preserve_fn
import collections
import contextlib
@ -163,7 +164,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req",
id(request_id),
)
ret_defer = sleep(self.sleep_msec/1000.0)
ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id)

14
tests/config/__init__.py Normal file
View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import shutil
import tempfile
from synapse.config.homeserver import HomeServerConfig
from tests import unittest
class ConfigGenerationTestCase(unittest.TestCase):
def setUp(self):
self.dir = tempfile.mkdtemp()
print self.dir
self.file = os.path.join(self.dir, "homeserver.yaml")
def tearDown(self):
shutil.rmtree(self.dir)
def test_generate_config_generates_files(self):
HomeServerConfig.load_config("", [
"--generate-config",
"-c", self.file,
"--report-stats=yes",
"-H", "lemurs.win"
])
self.assertSetEqual(
set([
"homeserver.yaml",
"lemurs.win.log.config",
"lemurs.win.signing.key",
"lemurs.win.tls.crt",
"lemurs.win.tls.dh",
"lemurs.win.tls.key",
]),
set(os.listdir(self.dir))
)

78
tests/config/test_load.py Normal file
View File

@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import shutil
import tempfile
import yaml
from synapse.config.homeserver import HomeServerConfig
from tests import unittest
class ConfigLoadingTestCase(unittest.TestCase):
def setUp(self):
self.dir = tempfile.mkdtemp()
print self.dir
self.file = os.path.join(self.dir, "homeserver.yaml")
def tearDown(self):
shutil.rmtree(self.dir)
def test_load_fails_if_server_name_missing(self):
self.generate_config_and_remove_lines_containing("server_name")
with self.assertRaises(Exception):
HomeServerConfig.load_config("", ["-c", self.file])
def test_generates_and_loads_macaroon_secret_key(self):
self.generate_config()
with open(self.file,
"r") as f:
raw = yaml.load(f)
self.assertIn("macaroon_secret_key", raw)
config = HomeServerConfig.load_config("", ["-c", self.file])
self.assertTrue(
hasattr(config, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key"
)
if len(config.macaroon_secret_key) < 5:
self.fail(
"Want macaroon secret key to be string of at least length 5,"
"was: %r" % (config.macaroon_secret_key,)
)
def test_load_succeeds_if_macaroon_secret_key_missing(self):
self.generate_config_and_remove_lines_containing("macaroon")
config1 = HomeServerConfig.load_config("", ["-c", self.file])
config2 = HomeServerConfig.load_config("", ["-c", self.file])
self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key)
def generate_config(self):
HomeServerConfig.load_config("", [
"--generate-config",
"-c", self.file,
"--report-stats=yes",
"-H", "lemurs.win"
])
def generate_config_and_remove_lines_containing(self, needle):
self.generate_config()
with open(self.file, "r") as f:
contents = f.readlines()
contents = [l for l in contents if needle not in l]
with open(self.file, "w") as f:
f.write("".join(contents))

View File

@ -122,7 +122,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.config.enable_registration_captcha = False
hs.config.disable_registration = False
hs.config.enable_registration = True
hs.get_handlers().federation_handler = Mock()

View File

@ -41,7 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.config.disable_registration = False
self.hs.config.enable_registration = True
# init the thing we're testing
self.servlet = RegisterRestServlet(self.hs)
@ -120,7 +120,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
}))
def test_POST_disabled_registration(self):
self.hs.config.disable_registration = True
self.hs.config.enable_registration = False
self.request_data = json.dumps({
"username": "kermit",
"password": "monkey"

View File

@ -51,32 +51,6 @@ class RoomStoreTestCase(unittest.TestCase):
(yield self.store.get_room(self.room.to_string()))
)
@defer.inlineCallbacks
def test_get_rooms(self):
# get_rooms does an INNER JOIN on the room_aliases table :(
rooms = yield self.store.get_rooms(is_public=True)
# Should be empty before we add the alias
self.assertEquals([], rooms)
yield self.store.create_room_alias_association(
room_alias=self.alias,
room_id=self.room.to_string(),
servers=["test"]
)
rooms = yield self.store.get_rooms(is_public=True)
self.assertEquals(1, len(rooms))
self.assertEquals({
"name": None,
"room_id": self.room.to_string(),
"topic": None,
"aliases": [self.alias.to_string()],
"world_readable": False,
"guest_can_join": False,
}, rooms[0])
class RoomEventsStoreTestCase(unittest.TestCase):

View File

@ -5,6 +5,7 @@ from .. import unittest
from synapse.util.async import sleep
from synapse.util.logcontext import LoggingContext
class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
@ -17,15 +18,6 @@ class LoggingContextTestCase(unittest.TestCase):
context_one.test_key = "test"
self._check_test_key("test")
def test_chaining(self):
with LoggingContext() as context_one:
context_one.test_key = "one"
with LoggingContext() as context_two:
self._check_test_key("one")
context_two.test_key = "two"
self._check_test_key("two")
self._check_test_key("one")
@defer.inlineCallbacks
def test_sleep(self):
@defer.inlineCallbacks

View File

@ -46,9 +46,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config = Mock()
config.signing_key = [MockKey()]
config.event_cache_size = 1
config.disable_registration = False
config.enable_registration = True
config.macaroon_secret_key = "not even a little secret"
config.server_name = "server.under.test"
config.trusted_third_party_id_servers = []
if "clock" not in kargs:
kargs["clock"] = MockClock()

View File

@ -11,7 +11,7 @@ deps =
setenv =
PYTHONDONTWRITEBYTECODE = no_byte_code
commands =
/bin/bash -c "coverage run {env:COVERAGE_OPTS:} --source={toxinidir}/synapse \
/bin/bash -c "find {toxinidir} -name '*.pyc' -delete ; coverage run {env:COVERAGE_OPTS:} --source={toxinidir}/synapse \
{envbindir}/trial {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}"
{env:DUMP_COVERAGE_COMMAND:coverage report -m}