Merge remote-tracking branch 'origin/develop' into markjh/end-to-end-key-federation

This commit is contained in:
Mark Haines 2015-08-13 17:27:53 +01:00
commit c5966b2a97
35 changed files with 616 additions and 370 deletions

View File

@ -101,25 +101,26 @@ header files for python C extensions.
Installing prerequisites on Ubuntu or Debian:: Installing prerequisites on Ubuntu or Debian::
$ sudo apt-get install build-essential python2.7-dev libffi-dev \ sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \ python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev libssl-dev python-virtualenv libjpeg-dev
Installing prerequisites on ArchLinux:: Installing prerequisites on ArchLinux::
$ sudo pacman -S base-devel python2 python-pip \ sudo pacman -S base-devel python2 python-pip \
python-setuptools python-virtualenv sqlite3 python-setuptools python-virtualenv sqlite3
Installing prerequisites on Mac OS X:: Installing prerequisites on Mac OS X::
$ xcode-select --install xcode-select --install
$ sudo pip install virtualenv sudo easy_install pip
sudo pip install virtualenv
To install the synapse homeserver run:: To install the synapse homeserver run::
$ virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
$ source ~/.synapse/bin/activate source ~/.synapse/bin/activate
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
This installs synapse, along with the libraries it uses, into a virtual This installs synapse, along with the libraries it uses, into a virtual
environment under ``~/.synapse``. Feel free to pick a different directory environment under ``~/.synapse``. Feel free to pick a different directory
@ -132,8 +133,8 @@ above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
To set up your homeserver, run (in your virtualenv, as before):: To set up your homeserver, run (in your virtualenv, as before)::
$ cd ~/.synapse cd ~/.synapse
$ python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config
@ -192,9 +193,9 @@ Running Synapse
To actually run your new homeserver, pick a working directory for Synapse to run To actually run your new homeserver, pick a working directory for Synapse to run
(e.g. ``~/.synapse``), and:: (e.g. ``~/.synapse``), and::
$ cd ~/.synapse cd ~/.synapse
$ source ./bin/activate source ./bin/activate
$ synctl start synctl start
Platform Specific Instructions Platform Specific Instructions
============================== ==============================
@ -212,12 +213,12 @@ defaults to python 3, but synapse currently assumes python 2.7 by default:
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 ):: pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
$ sudo pip2.7 install --upgrade pip sudo pip2.7 install --upgrade pip
You also may need to explicitly specify python 2.7 again during the install You also may need to explicitly specify python 2.7 again during the install
request:: request::
$ pip2.7 install --process-dependency-links \ pip2.7 install --process-dependency-links \
https://github.com/matrix-org/synapse/tarball/master https://github.com/matrix-org/synapse/tarball/master
If you encounter an error with lib bcrypt causing an Wrong ELF Class: If you encounter an error with lib bcrypt causing an Wrong ELF Class:
@ -225,13 +226,13 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
compile it under the right architecture. (This should not be needed if compile it under the right architecture. (This should not be needed if
installing under virtualenv):: installing under virtualenv)::
$ sudo pip2.7 uninstall py-bcrypt sudo pip2.7 uninstall py-bcrypt
$ sudo pip2.7 install py-bcrypt sudo pip2.7 install py-bcrypt
During setup of Synapse you need to call python2.7 directly again:: During setup of Synapse you need to call python2.7 directly again::
$ cd ~/.synapse cd ~/.synapse
$ python2.7 -m synapse.app.homeserver \ python2.7 -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config
@ -279,22 +280,22 @@ Synapse requires pip 1.7 or later, so if your OS provides too old a version and
you get errors about ``error: no such option: --process-dependency-links`` you you get errors about ``error: no such option: --process-dependency-links`` you
may need to manually upgrade it:: may need to manually upgrade it::
$ sudo pip install --upgrade pip sudo pip install --upgrade pip
If pip crashes mid-installation for reason (e.g. lost terminal), pip may If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it refuse to run until you remove the temporary installation directory it
created. To reset the installation:: created. To reset the installation::
$ rm -rf /tmp/pip_install_matrix rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are happens, you will have to individually install the dependencies which are
failing, e.g.:: failing, e.g.::
$ pip install twisted pip install twisted
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you On OS X, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments. will need to export CFLAGS=-Qunused-arguments.
Troubleshooting Running Troubleshooting Running
@ -310,10 +311,11 @@ correctly, causing all tests to fail with errors about missing "sodium.h". To
fix try re-installing from PyPI or directly from fix try re-installing from PyPI or directly from
(https://github.com/pyca/pynacl):: (https://github.com/pyca/pynacl)::
$ # Install from PyPI # Install from PyPI
$ pip install --user --upgrade --force pynacl pip install --user --upgrade --force pynacl
$ # Install from github
$ pip install --user https://github.com/pyca/pynacl/tarball/master # Install from github
pip install --user https://github.com/pyca/pynacl/tarball/master
ArchLinux ArchLinux
~~~~~~~~~ ~~~~~~~~~
@ -321,7 +323,7 @@ ArchLinux
If running `$ synctl start` fails with 'returned non-zero exit status 1', If running `$ synctl start` fails with 'returned non-zero exit status 1',
you will need to explicitly call Python2.7 - either running as:: you will need to explicitly call Python2.7 - either running as::
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
...or by editing synctl with the correct python executable. ...or by editing synctl with the correct python executable.
@ -331,16 +333,16 @@ Synapse Development
To check out a synapse for development, clone the git repo into a working To check out a synapse for development, clone the git repo into a working
directory of your choice:: directory of your choice::
$ git clone https://github.com/matrix-org/synapse.git git clone https://github.com/matrix-org/synapse.git
$ cd synapse cd synapse
Synapse has a number of external dependencies, that are easiest Synapse has a number of external dependencies, that are easiest
to install using pip and a virtualenv:: to install using pip and a virtualenv::
$ virtualenv env virtualenv env
$ source env/bin/activate source env/bin/activate
$ python synapse/python_dependencies.py | xargs -n1 pip install python synapse/python_dependencies.py | xargs -n1 pip install
$ pip install setuptools_trial mock pip install setuptools_trial mock
This will run a process of downloading and installing all the needed This will run a process of downloading and installing all the needed
dependencies into a virtual env. dependencies into a virtual env.
@ -348,7 +350,7 @@ dependencies into a virtual env.
Once this is done, you may wish to run Synapse's unit tests, to Once this is done, you may wish to run Synapse's unit tests, to
check that everything is installed as it should be:: check that everything is installed as it should be::
$ python setup.py test python setup.py test
This should end with a 'PASSED' result:: This should end with a 'PASSED' result::
@ -389,11 +391,11 @@ IDs:
For the first form, simply pass the required hostname (of the machine) as the For the first form, simply pass the required hostname (of the machine) as the
--server-name parameter:: --server-name parameter::
$ python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config
$ python -m synapse.app.homeserver --config-path homeserver.yaml python -m synapse.app.homeserver --config-path homeserver.yaml
Alternatively, you can run ``synctl start`` to guide you through the process. Alternatively, you can run ``synctl start`` to guide you through the process.
@ -410,11 +412,11 @@ record would then look something like::
At this point, you should then run the homeserver with the hostname of this At this point, you should then run the homeserver with the hostname of this
SRV record, as that is the name other machines will expect it to have:: SRV record, as that is the name other machines will expect it to have::
$ python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--server-name YOURDOMAIN \ --server-name YOURDOMAIN \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config
$ python -m synapse.app.homeserver --config-path homeserver.yaml python -m synapse.app.homeserver --config-path homeserver.yaml
You may additionally want to pass one or more "-v" options, in order to You may additionally want to pass one or more "-v" options, in order to
@ -428,7 +430,7 @@ private federation (``localhost:8080``, ``localhost:8081`` and
``localhost:8082``) which you can then access through the webclient running at ``localhost:8082``) which you can then access through the webclient running at
http://localhost:8080. Simply run:: http://localhost:8080. Simply run::
$ demo/start.sh demo/start.sh
This is mainly useful just for development purposes. This is mainly useful just for development purposes.
@ -502,10 +504,10 @@ Building Internal API Documentation
Before building internal API documentation install sphinx and Before building internal API documentation install sphinx and
sphinxcontrib-napoleon:: sphinxcontrib-napoleon::
$ pip install sphinx pip install sphinx
$ pip install sphinxcontrib-napoleon pip install sphinxcontrib-napoleon
Building internal API documentation:: Building internal API documentation::
$ python setup.py build_sphinx python setup.py build_sphinx

View File

@ -11,7 +11,9 @@ if [ -f $PID_FILE ]; then
exit 1 exit 1
fi fi
find "$DIR" -name "*.log" -delete for port in 8080 8081 8082; do
find "$DIR" -name "*.db" -delete rm -rf $DIR/$port
rm -rf $DIR/media_store.$port
done
rm -rf $DIR/etc rm -rf $DIR/etc

View File

@ -8,14 +8,6 @@ cd "$DIR/.."
mkdir -p demo/etc mkdir -p demo/etc
# Check the --no-rate-limit param
PARAMS=""
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
PARAMS="--rc-messages-per-second 1000 --rc-message-burst-count 1000"
fi
fi
export PYTHONPATH=$(readlink -f $(pwd)) export PYTHONPATH=$(readlink -f $(pwd))
@ -31,10 +23,20 @@ for port in 8080 8081 8082; do
#rm $DIR/etc/$port.config #rm $DIR/etc/$port.config
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--generate-config \ --generate-config \
--enable_registration \
-H "localhost:$https_port" \ -H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \
# Check script parameters
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
# Set high limits in config file to disable rate limiting
perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config
perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config
fi
fi
perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \
-D \ -D \

View File

@ -16,3 +16,6 @@ ignore =
docs/* docs/*
pylint.cfg pylint.cfg
tox.ini tox.ini
[flake8]
max-line-length = 90

View File

@ -48,7 +48,7 @@ setup(
description="Reference Synapse Home Server", description="Reference Synapse Home Server",
install_requires=dependencies['requirements'](include_conditional=True).keys(), install_requires=dependencies['requirements'](include_conditional=True).keys(),
setup_requires=[ setup_requires=[
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0 "Twisted>=15.1.0", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
"setuptools_trial", "setuptools_trial",
"mock" "mock"
], ],

View File

@ -44,6 +44,11 @@ class Auth(object):
def check(self, event, auth_events): def check(self, event, auth_events):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args:
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
Returns: Returns:
True if the auth checks pass. True if the auth checks pass.
""" """
@ -319,7 +324,7 @@ class Auth(object):
Returns: Returns:
tuple : of UserID and device string: tuple : of UserID and device string:
User ID object of the user making the request User ID object of the user making the request
Client ID object of the client instance the user is using ClientInfo object of the client instance the user is using
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
@ -352,7 +357,7 @@ class Auth(object):
) )
return return
except KeyError: except KeyError:
pass # normal users won't have this query parameter set pass # normal users won't have the user_id query parameter set.
user_info = yield self.get_user_by_token(access_token) user_info = yield self.get_user_by_token(access_token)
user = user_info["user"] user = user_info["user"]
@ -521,23 +526,22 @@ class Auth(object):
# Check state_key # Check state_key
if hasattr(event, "state_key"): if hasattr(event, "state_key"):
if not event.state_key.startswith("_"): if event.state_key.startswith("@"):
if event.state_key.startswith("@"): if event.state_key != event.user_id:
if event.state_key != event.user_id: raise AuthError(
403,
"You are not allowed to set others state"
)
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
raise AuthError( raise AuthError(
403, 403,
"You are not allowed to set others state" "You are not allowed to set others state"
) )
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True return True

View File

@ -657,7 +657,8 @@ def run(hs):
if hs.config.daemonize: if hs.config.daemonize:
print hs.config.pid_file if hs.config.print_pidfile:
print hs.config.pid_file
daemon = Daemonize( daemon = Daemonize(
app="synapse-homeserver", app="synapse-homeserver",

View File

@ -138,12 +138,19 @@ class Config(object):
action="store_true", action="store_true",
help="Generate a config file for the server name" help="Generate a config file for the server name"
) )
config_parser.add_argument(
"--generate-keys",
action="store_true",
help="Generate any missing key files then exit"
)
config_parser.add_argument( config_parser.add_argument(
"-H", "--server-name", "-H", "--server-name",
help="The server name to generate a config file for" help="The server name to generate a config file for"
) )
config_args, remaining_args = config_parser.parse_known_args(argv) config_args, remaining_args = config_parser.parse_known_args(argv)
generate_keys = config_args.generate_keys
if config_args.generate_config: if config_args.generate_config:
if not config_args.config_path: if not config_args.config_path:
config_parser.error( config_parser.error(
@ -151,51 +158,40 @@ class Config(object):
" generated using \"--generate-config -H SERVER_NAME" " generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )
config_dir_path = os.path.dirname(config_args.config_path[0])
config_dir_path = os.path.abspath(config_dir_path)
server_name = config_args.server_name
if not server_name:
print "Must specify a server_name to a generate config for."
sys.exit(1)
(config_path,) = config_args.config_path (config_path,) = config_args.config_path
if not os.path.exists(config_dir_path): if not os.path.exists(config_path):
os.makedirs(config_dir_path) config_dir_path = os.path.dirname(config_path)
if os.path.exists(config_path): config_dir_path = os.path.abspath(config_dir_path)
print "Config file %r already exists" % (config_path,)
yaml_config = cls.read_config_file(config_path) server_name = config_args.server_name
yaml_name = yaml_config["server_name"] if not server_name:
if server_name != yaml_name: print "Must specify a server_name to a generate config for."
print (
"Config file %r has a different server_name: "
" %r != %r" % (config_path, server_name, yaml_name)
)
sys.exit(1) sys.exit(1)
config_bytes, config = obj.generate_config( if not os.path.exists(config_dir_path):
config_dir_path, server_name os.makedirs(config_dir_path)
) with open(config_path, "wb") as config_file:
config.update(yaml_config) config_bytes, config = obj.generate_config(
print "Generating any missing keys for %r" % (server_name,) config_dir_path, server_name
obj.invoke_all("generate_files", config) )
sys.exit(0) obj.invoke_all("generate_files", config)
with open(config_path, "wb") as config_file: config_file.write(config_bytes)
config_bytes, config = obj.generate_config(
config_dir_path, server_name
)
obj.invoke_all("generate_files", config)
config_file.write(config_bytes)
print ( print (
"A config file has been generated in %s for server name" "A config file has been generated in %r for server name"
" '%s' with corresponding SSL keys and self-signed" " %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it to" " certificates. Please review this file and customise it"
" your needs." " to your needs."
) % (config_path, server_name) ) % (config_path, server_name)
print ( print (
"If this server name is incorrect, you will need to regenerate" "If this server name is incorrect, you will need to"
" the SSL certificates" " regenerate the SSL certificates"
) )
sys.exit(0) sys.exit(0)
else:
print (
"Config file %r already exists. Generating any missing key"
" files."
) % (config_path,)
generate_keys = True
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
parents=[config_parser], parents=[config_parser],
@ -213,7 +209,7 @@ class Config(object):
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )
config_dir_path = os.path.dirname(config_args.config_path[0]) config_dir_path = os.path.dirname(config_args.config_path[-1])
config_dir_path = os.path.abspath(config_dir_path) config_dir_path = os.path.abspath(config_dir_path)
specified_config = {} specified_config = {}
@ -226,6 +222,10 @@ class Config(object):
config.pop("log_config") config.pop("log_config")
config.update(specified_config) config.update(specified_config)
if generate_keys:
obj.invoke_all("generate_files", config)
sys.exit(0)
obj.invoke_all("read_config", config) obj.invoke_all("read_config", config)
obj.invoke_all("read_arguments", args) obj.invoke_all("read_arguments", args)

View File

@ -24,6 +24,7 @@ class ServerConfig(Config):
self.web_client = config["web_client"] self.web_client = config["web_client"]
self.soft_file_limit = config["soft_file_limit"] self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize") self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile")
self.use_frozen_dicts = config.get("use_frozen_dicts", True) self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.listeners = config.get("listeners", []) self.listeners = config.get("listeners", [])
@ -208,12 +209,18 @@ class ServerConfig(Config):
self.manhole = args.manhole self.manhole = args.manhole
if args.daemonize is not None: if args.daemonize is not None:
self.daemonize = args.daemonize self.daemonize = args.daemonize
if args.print_pidfile is not None:
self.print_pidfile = args.print_pidfile
def add_arguments(self, parser): def add_arguments(self, parser):
server_group = parser.add_argument_group("server") server_group = parser.add_argument_group("server")
server_group.add_argument("-D", "--daemonize", action='store_true', server_group.add_argument("-D", "--daemonize", action='store_true',
default=None, default=None,
help="Daemonize the home server") help="Daemonize the home server")
server_group.add_argument("--print-pidfile", action='store_true',
default=None,
help="Print the path to the pidfile just"
" before daemonizing")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole", server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int, type=int,
help="Turn on the twisted telnet manhole" help="Turn on the twisted telnet manhole"

View File

@ -44,7 +44,7 @@ class IdentityHandler(BaseHandler):
http_client = SimpleHttpClient(self.hs) http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable! # XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090'] # trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org'] trustedIdServers = ['matrix.org', 'vector.im']
if 'id_server' in creds: if 'id_server' in creds:
id_server = creds['id_server'] id_server = creds['id_server']

View File

@ -73,7 +73,8 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
one will be randomly generated. one will be randomly generated.
password (str) : The password to assign to this user so they can password (str) : The password to assign to this user so they can
login again. login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:

View File

@ -16,7 +16,7 @@
from twisted.internet import defer, reactor, protocol from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool from twisted.web.client import readBody, HTTPConnectionPool, Agent
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
@ -55,41 +55,17 @@ incoming_responses_counter = metrics.register_counter(
) )
class MatrixFederationHttpAgent(_AgentBase): class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
self.tls_context_factory = hs.tls_context_factory
def __init__(self, reactor, pool=None): def endpointForURI(self, uri):
_AgentBase.__init__(self, reactor, pool) destination = uri.netloc
def request(self, destination, endpoint, method, path, params, query, return matrix_federation_endpoint(
headers, body_producer): reactor, destination, timeout=10,
ssl_context_factory=self.tls_context_factory
outgoing_requests_counter.inc(method) )
host = b""
port = 0
fragment = b""
parsed_URI = _URI(b"http", destination, host, port, path, params,
query, fragment)
# Set the connection pool key to be the destination.
key = destination
d = self._requestWithEndpoint(key, endpoint, method, parsed_URI,
headers, body_producer,
parsed_URI.originForm)
def _cb(response):
incoming_responses_counter.inc(method, response.code)
return response
def _eb(failure):
incoming_responses_counter.inc(method, "ERR")
return failure
d.addCallbacks(_cb, _eb)
return d
class MatrixFederationHttpClient(object): class MatrixFederationHttpClient(object):
@ -107,12 +83,18 @@ class MatrixFederationHttpClient(object):
self.server_name = hs.hostname self.server_name = hs.hostname
pool = HTTPConnectionPool(reactor) pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10 pool.maxPersistentPerHost = 10
self.agent = MatrixFederationHttpAgent(reactor, pool=pool) self.agent = Agent.usingEndpointFactory(
reactor, MatrixFederationEndpointFactory(hs), pool=pool
)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.version_string = hs.version_string self.version_string = hs.version_string
self._next_id = 1 self._next_id = 1
def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
return urlparse.urlunparse(
("matrix", destination, path_bytes, param_bytes, query_bytes, "")
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes, def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"", body_callback, headers_dict={}, param_bytes=b"",
@ -123,8 +105,8 @@ class MatrixFederationHttpClient(object):
headers_dict[b"User-Agent"] = [self.version_string] headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination] headers_dict[b"Host"] = [destination]
url_bytes = urlparse.urlunparse( url_bytes = self._create_url(
("", "", path_bytes, param_bytes, query_bytes, "",) destination, path_bytes, param_bytes, query_bytes
) )
txn_id = "%s-O-%s" % (method, self._next_id) txn_id = "%s-O-%s" % (method, self._next_id)
@ -139,8 +121,8 @@ class MatrixFederationHttpClient(object):
# (once we have reliable transactions in place) # (once we have reliable transactions in place)
retries_left = 5 retries_left = 5
endpoint = preserve_context_over_fn( http_url_bytes = urlparse.urlunparse(
self._getEndpoint, reactor, destination ("", "", path_bytes, param_bytes, query_bytes, "")
) )
log_result = None log_result = None
@ -148,17 +130,14 @@ class MatrixFederationHttpClient(object):
while True: while True:
producer = None producer = None
if body_callback: if body_callback:
producer = body_callback(method, url_bytes, headers_dict) producer = body_callback(method, http_url_bytes, headers_dict)
try: try:
def send_request(): def send_request():
request_deferred = self.agent.request( request_deferred = preserve_context_over_fn(
destination, self.agent.request,
endpoint,
method, method,
path_bytes, url_bytes,
param_bytes,
query_bytes,
Headers(headers_dict), Headers(headers_dict),
producer producer
) )
@ -452,12 +431,6 @@ class MatrixFederationHttpClient(object):
defer.returnValue((length, headers)) defer.returnValue((length, headers))
def _getEndpoint(self, reactor, destination):
return matrix_federation_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.hs.tls_context_factory
)
class _ReadBodyToFileProtocol(protocol.Protocol): class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size): def __init__(self, stream, deferred, max_size):

View File

@ -18,8 +18,12 @@ from __future__ import absolute_import
import logging import logging
from resource import getrusage, getpagesize, RUSAGE_SELF from resource import getrusage, getpagesize, RUSAGE_SELF
import functools
import os import os
import stat import stat
import time
from twisted.internet import reactor
from .metric import ( from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
@ -144,3 +148,28 @@ def _process_fds():
return counts return counts
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"]) get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
reactor_metrics = get_metrics_for("reactor")
tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
def runUntilCurrentTimer(func):
@functools.wraps(func)
def f(*args, **kwargs):
pending_calls = len(reactor.getDelayedCalls())
start = time.time() * 1000
ret = func(*args, **kwargs)
end = time.time() * 1000
tick_time.inc_by(end - start)
pending_calls_metric.inc_by(pending_calls)
return ret
return f
if hasattr(reactor, "runUntilCurrent"):
# runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling.
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)

View File

@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"syutil>=0.0.7": ["syutil>=0.0.7"], "syutil>=0.0.7": ["syutil>=0.0.7"],
"Twisted==14.0.2": ["twisted==14.0.2"], "Twisted>=15.1.0": ["twisted>=15.1.0"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],

View File

@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_request_allow_empty from ._base import client_v2_pattern, parse_json_dict_from_request
import logging import logging
import hmac import hmac
@ -55,30 +55,55 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
body = parse_json_dict_from_request(request)
body = parse_request_allow_empty(request) # we do basic sanity checks here because the auth layer will store these
# we do basic sanity checks here because the auth # in sessions. Pull out the username/password provided to us.
# layer will store these in sessions desired_password = None
if 'password' in body: if 'password' in body:
if ((not isinstance(body['password'], str) and if (not isinstance(body['password'], basestring) or
not isinstance(body['password'], unicode)) or
len(body['password']) > 512): len(body['password']) > 512):
raise SynapseError(400, "Invalid password") raise SynapseError(400, "Invalid password")
desired_password = body["password"]
desired_username = None
if 'username' in body: if 'username' in body:
if ((not isinstance(body['username'], str) and if (not isinstance(body['username'], basestring) or
not isinstance(body['username'], unicode)) or
len(body['username']) > 512): len(body['username']) > 512):
raise SynapseError(400, "Invalid username") raise SynapseError(400, "Invalid username")
desired_username = body['username'] desired_username = body['username']
yield self.registration_handler.check_username(desired_username)
is_using_shared_secret = False appservice = None
is_application_server = False
service = None
if 'access_token' in request.args: if 'access_token' in request.args:
service = yield self.auth.get_appservice_by_req(request) appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
# have completely different registration flows to normal users
# == Application Service Registration ==
if appservice:
result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0]
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
# == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body:
# FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration(
desired_username, desired_password, body["mac"]
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
# == Normal User Registration == (everyone else)
if self.hs.config.disable_registration:
raise SynapseError(403, "Registration has been disabled")
if desired_username is not None:
yield self.registration_handler.check_username(desired_username)
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
flows = [ flows = [
@ -91,39 +116,20 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
] ]
result = None authed, result, params = yield self.auth_handler.check_auth(
if service: flows, body, self.hs.get_ip_from_request(request)
is_application_server = True
params = body
elif 'mac' in body:
# Check registration-specific shared secret auth
if 'username' not in body:
raise SynapseError(400, "", Codes.MISSING_PARAM)
self._check_shared_secret_auth(
body['username'], body['mac']
)
is_using_shared_secret = True
params = body
else:
authed, result, params = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
if not authed:
defer.returnValue((401, result))
can_register = (
not self.hs.config.disable_registration
or is_application_server
or is_using_shared_secret
) )
if not can_register:
raise SynapseError(403, "Registration has been disabled")
if not authed:
defer.returnValue((401, result))
return
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params: if 'password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM) raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
desired_username = params['username'] if 'username' in params else None
new_password = params['password'] desired_username = params.get("username", None)
new_password = params.get("password", None)
(user_id, token) = yield self.registration_handler.register( (user_id, token) = yield self.registration_handler.register(
localpart=desired_username, localpart=desired_username,
@ -156,18 +162,21 @@ class RegisterRestServlet(RestServlet):
else: else:
logger.info("bind_email not specified: not binding email") logger.info("bind_email not specified: not binding email")
result = { result = self._create_registration_details(user_id, token)
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result)) defer.returnValue((200, result))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
def _check_shared_secret_auth(self, username, mac): @defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token):
(user_id, token) = yield self.registration_handler.appservice_register(
username, as_token
)
defer.returnValue(self._create_registration_details(user_id, token))
@defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, mac):
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled") raise SynapseError(400, "Shared secret registration is not enabled")
@ -183,13 +192,23 @@ class RegisterRestServlet(RestServlet):
digestmod=sha1, digestmod=sha1,
).hexdigest() ).hexdigest()
if compare_digest(want_mac, got_mac): if not compare_digest(want_mac, got_mac):
return True
else:
raise SynapseError( raise SynapseError(
403, "HMAC incorrect", 403, "HMAC incorrect",
) )
(user_id, token) = yield self.registration_handler.register(
localpart=username, password=password
)
defer.returnValue(self._create_registration_details(user_id, token))
def _create_registration_details(self, user_id, token):
return {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View File

@ -244,43 +244,52 @@ class BaseMediaResource(Resource):
) )
return return
scales = set() local_thumbnails = []
crops = set()
for r_width, r_height, r_method, r_type in requirements: def generate_thumbnails():
if r_method == "scale": scales = set()
t_width, t_height = thumbnailer.aspect(r_width, r_height) crops = set()
scales.add(( for r_width, r_height, r_method, r_type in requirements:
min(m_width, t_width), min(m_height, t_height), r_type, if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
)) ))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales: for t_width, t_height, t_type in crops:
t_method = "scale" if (t_width, t_height, t_type) in scales:
t_path = self.filepaths.local_media_thumbnail( # If the aspect ratio of the cropped thumbnail matches a purely
media_id, t_width, t_height, t_type, t_method # scaled one then there is no point in calculating a separate
) # thumbnail.
self._makedirs(t_path) continue
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) t_method = "crop"
yield self.store.store_local_thumbnail( t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len media_id, t_width, t_height, t_type, t_method
) )
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
for t_width, t_height, t_type in crops: yield threads.deferToThread(generate_thumbnails)
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely for l in local_thumbnails:
# scaled one then there is no point in calculating a separate yield self.store.store_local_thumbnail(*l)
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
defer.returnValue({ defer.returnValue({
"width": m_width, "width": m_width,

View File

@ -162,11 +162,12 @@ class ThumbnailResource(BaseMediaResource):
t_method = info["thumbnail_method"] t_method = info["thumbnail_method"]
if t_method == "scale" or t_method == "crop": if t_method == "scale" or t_method == "crop":
aspect_quality = abs(d_w * t_h - d_h * t_w) aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h)) size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"] type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"] length_quality = info["thumbnail_length"]
info_list.append(( info_list.append((
aspect_quality, size_quality, type_quality, aspect_quality, min_quality, size_quality, type_quality,
length_quality, info length_quality, info
)) ))
if info_list: if info_list:

View File

@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
key = (user.to_string(), access_token, device_id, ip) key = (user.to_string(), access_token, device_id, ip)
try: try:
last_seen = self.client_ip_last_seen.get(*key) last_seen = self.client_ip_last_seen.get(key)
except KeyError: except KeyError:
last_seen = None last_seen = None
@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
defer.returnValue(None) defer.returnValue(None)
self.client_ip_last_seen.prefill(*key + (now,)) self.client_ip_last_seen.prefill(key, now)
# It's safe not to lock here: a) no unique constraint, # It's safe not to lock here: a) no unique constraint,
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
@ -354,6 +354,11 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
) )
logger.debug("Running script %s", relative_path) logger.debug("Running script %s", relative_path)
module.run_upgrade(cur, database_engine) module.run_upgrade(cur, database_engine)
elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
# installers. Silently skip it
pass
elif ext == ".sql": elif ext == ".sql":
# A plain old .sql file, just read and execute it # A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path) logger.debug("Applying schema %s", relative_path)

View File

@ -15,6 +15,7 @@
import logging import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.async import ObservableDeferred
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.lrucache import LruCache
@ -27,6 +28,7 @@ from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
import functools import functools
import inspect
import sys import sys
import time import time
import threading import threading
@ -55,9 +57,12 @@ cache_counter = metrics.register_cache(
) )
_CacheSentinel = object()
class Cache(object): class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=False): def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru: if lru:
self.cache = LruCache(max_size=max_entries) self.cache = LruCache(max_size=max_entries)
self.max_entries = None self.max_entries = None
@ -81,45 +86,44 @@ class Cache(object):
"Cache objects can only be accessed from the main thread" "Cache objects can only be accessed from the main thread"
) )
def get(self, *keyargs): def get(self, key, default=_CacheSentinel):
if len(keyargs) != self.keylen: val = self.cache.get(key, _CacheSentinel)
raise ValueError("Expected a key to have %d items", self.keylen) if val is not _CacheSentinel:
if keyargs in self.cache:
cache_counter.inc_hits(self.name) cache_counter.inc_hits(self.name)
return self.cache[keyargs] return val
cache_counter.inc_misses(self.name) cache_counter.inc_misses(self.name)
raise KeyError()
def update(self, sequence, *args): if default is _CacheSentinel:
raise KeyError()
else:
return default
def update(self, sequence, key, value):
self.check_thread() self.check_thread()
if self.sequence == sequence: if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the # Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369) # number that the cache had before the SELECT was started (SYN-369)
self.prefill(*args) self.prefill(key, value)
def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
def prefill(self, key, value):
if self.max_entries is not None: if self.max_entries is not None:
while len(self.cache) >= self.max_entries: while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False) self.cache.popitem(last=False)
self.cache[keyargs] = value self.cache[key] = value
def invalidate(self, *keyargs): def invalidate(self, key):
self.check_thread() self.check_thread()
if len(keyargs) != self.keylen: if not isinstance(key, tuple):
raise ValueError("Expected a key to have %d items", self.keylen) raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
# Increment the sequence number so that any SELECT statements that # Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369) # raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1 self.sequence += 1
self.cache.pop(keyargs, None) self.cache.pop(key, None)
def invalidate_all(self): def invalidate_all(self):
self.check_thread() self.check_thread()
@ -130,6 +134,9 @@ class Cache(object):
class CacheDescriptor(object): class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function. """ A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
fail are removed from the cache.
The function is presumed to take zero or more arguments, which are used in The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache; a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value. misses use the function body to generate the value.
@ -141,58 +148,92 @@ class CacheDescriptor(object):
which can be used to insert values into the cache specifically, without which can be used to insert values into the cache specifically, without
calling the calculation function. calling the calculation function.
""" """
def __init__(self, orig, max_entries=1000, num_args=1, lru=False): def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
inlineCallbacks=False):
self.orig = orig self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.max_entries = max_entries self.max_entries = max_entries
self.num_args = num_args self.num_args = num_args
self.lru = lru self.lru = lru
def __get__(self, obj, objtype=None): self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
cache = Cache(
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
self.cache = Cache(
name=self.orig.__name__, name=self.orig.__name__,
max_entries=self.max_entries, max_entries=self.max_entries,
keylen=self.num_args, keylen=self.num_args,
lru=self.lru, lru=self.lru,
) )
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig) @functools.wraps(self.orig)
@defer.inlineCallbacks def wrapped(*args, **kwargs):
def wrapped(*keyargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try: try:
cached_result = cache.get(*keyargs[:self.num_args]) cached_result_d = self.cache.get(cache_key)
observer = cached_result_d.observe()
if DEBUG_CACHES: if DEBUG_CACHES:
actual_result = yield self.orig(obj, *keyargs) @defer.inlineCallbacks
if actual_result != cached_result: def check_result(cached_result):
logger.error( actual_result = yield self.function_to_call(obj, *args, **kwargs)
"Stale cache entry %s%r: cached: %r, actual %r", if actual_result != cached_result:
self.orig.__name__, keyargs, logger.error(
cached_result, actual_result, "Stale cache entry %s%r: cached: %r, actual %r",
) self.orig.__name__, cache_key,
raise ValueError("Stale cache entry") cached_result, actual_result,
defer.returnValue(cached_result) )
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
observer.addCallback(check_result)
return observer
except KeyError: except KeyError:
# Get the sequence number of the cache before reading from the # Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated # database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369) # while the SELECT is executing (SYN-369)
sequence = cache.sequence sequence = self.cache.sequence
ret = yield self.orig(obj, *keyargs) ret = defer.maybeDeferred(
self.function_to_call,
obj, *args, **kwargs
)
cache.update(sequence, *keyargs[:self.num_args] + (ret,)) def onErr(f):
self.cache.invalidate(cache_key)
return f
defer.returnValue(ret) ret.addErrback(onErr)
wrapped.invalidate = cache.invalidate ret = ObservableDeferred(ret, consumeErrors=True)
wrapped.invalidate_all = cache.invalidate_all self.cache.update(sequence, cache_key, ret)
wrapped.prefill = cache.prefill
return ret.observe()
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped obj.__dict__[self.orig.__name__] = wrapped
return wrapped return wrapped
def cached(max_entries=1000, num_args=1, lru=False): def cached(max_entries=1000, num_args=1, lru=True):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
@ -201,6 +242,16 @@ def cached(max_entries=1000, num_args=1, lru=False):
) )
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
inlineCallbacks=True,
)
class LoggingTransaction(object): class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()

View File

@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore):
}, },
desc="create_room_alias_association", desc="create_room_alias_association",
) )
self.get_aliases_for_room.invalidate(room_id) self.get_aliases_for_room.invalidate((room_id,))
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_room_alias(self, room_alias): def delete_room_alias(self, room_alias):
@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore):
room_alias, room_alias,
) )
self.get_aliases_for_room.invalidate(room_id) self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id) defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias): def _delete_room_alias_txn(self, txn, room_alias):

View File

@ -362,7 +362,7 @@ class EventFederationStore(SQLBaseStore):
for room_id in events_by_room: for room_id in events_by_room:
txn.call_after( txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id self.get_latest_event_ids_in_room.invalidate, (room_id,)
) )
def get_backfill_events(self, room_id, event_list, limit): def get_backfill_events(self, room_id, event_list, limit):
@ -505,4 +505,4 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?" query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,)) txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))

View File

@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore):
if current_state: if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all) txn.call_after(self.get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, event.room_id) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id) txn.call_after(self.get_room_name_and_aliases, event.room_id)
self._simple_delete_txn( self._simple_delete_txn(
@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore):
if not context.rejected: if not context.rejected:
txn.call_after( txn.call_after(
self.get_current_state_for_key.invalidate, self.get_current_state_for_key.invalidate,
event.room_id, event.type, event.state_key (event.room_id, event.type, event.state_key,)
) )
if event.type in [EventTypes.Name, EventTypes.Aliases]: if event.type in [EventTypes.Name, EventTypes.Aliases]:
txn.call_after( txn.call_after(
self.get_room_name_and_aliases.invalidate, self.get_room_name_and_aliases.invalidate,
event.room_id (event.room_id,)
) )
self._simple_upsert_txn( self._simple_upsert_txn(
@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id): def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True): for check_redacted in (False, True):
for get_prev_content in (False, True): for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted, self._get_event_cache.invalidate(
get_prev_content) (event_id, check_redacted, get_prev_content)
)
def _get_event_txn(self, txn, event_id, check_redacted=True, def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False): get_prev_content=False, allow_rejected=False):
@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore):
for event_id in events: for event_id in events:
try: try:
ret = self._get_event_cache.get( ret = self._get_event_cache.get(
event_id, check_redacted, get_prev_content (event_id, check_redacted, get_prev_content,)
) )
if allow_rejected or not ret.rejected_reason: if allow_rejected or not ret.rejected_reason:
@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill( self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev (ev.event_id, check_redacted, get_prev_content), ev
) )
defer.returnValue(ev) defer.returnValue(ev)
@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill( self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev (ev.event_id, check_redacted, get_prev_content), ev
) )
return ev return ev

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from _base import SQLBaseStore, cached from _base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
@ -71,8 +71,7 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate", desc="store_server_certificate",
) )
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_all_server_verify_keys(self, server_name): def get_all_server_verify_keys(self, server_name):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="server_signature_keys", table="server_signature_keys",
@ -132,7 +131,7 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key", desc="store_server_verify_key",
) )
self.get_all_server_verify_keys.invalidate(server_name) self.get_all_server_verify_keys.invalidate((server_name,))
def store_server_keys_json(self, server_name, key_id, from_server, def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes): ts_now_ms, ts_expires_ms, key_json_bytes):

View File

@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"accepted": True}, updatevalues={"accepted": True},
desc="set_presence_list_accepted", desc="set_presence_list_accepted",
) )
self.get_presence_list_accepted.invalidate(observer_localpart) self.get_presence_list_accepted.invalidate((observer_localpart,))
defer.returnValue(result) defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None): def get_presence_list(self, observer_localpart, accepted=None):
@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore):
"observed_user_id": observed_userid}, "observed_user_id": observed_userid},
desc="del_presence_list", desc="del_presence_list",
) )
self.get_presence_list_accepted.invalidate(observer_localpart) self.get_presence_list_accepted.invalidate((observer_localpart,))

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_push_rules_for_user(self, user_name): def get_push_rules_for_user(self, user_name):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table=PushRuleTable.table_name, table=PushRuleTable.table_name,
@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_push_rules_enabled_for_user(self, user_name): def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list( results = yield self._simple_select_list(
table=PushRuleEnableTable.table_name, table=PushRuleEnableTable.table_name,
@ -153,11 +151,11 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -189,10 +187,10 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority'] = new_prio new_rule['priority'] = new_prio
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -218,8 +216,8 @@ class PushRuleStore(SQLBaseStore):
desc="delete_push_rule", desc="delete_push_rule",
) )
self.get_push_rules_for_user.invalidate(user_name) self.get_push_rules_for_user.invalidate((user_name,))
self.get_push_rules_enabled_for_user.invalidate(user_name) self.get_push_rules_enabled_for_user.invalidate((user_name,))
@defer.inlineCallbacks @defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled): def set_push_rule_enabled(self, user_name, rule_id, enabled):
@ -240,10 +238,10 @@ class PushRuleStore(SQLBaseStore):
{'id': new_id}, {'id': new_id},
) )
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
@ -128,8 +128,7 @@ class ReceiptsStore(SQLBaseStore):
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self) return self._receipts_id_gen.get_max_token(self)
@cached @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_graph_receipts_for_room(self, room_id): def get_graph_receipts_for_room(self, room_id):
"""Get receipts for sending to remote servers. """Get receipts for sending to remote servers.
""" """

View File

@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore):
user_id user_id
) )
for r in rows: for r in rows:
self.get_user_by_token.invalidate(r) self.get_user_by_token.invalidate((r,))
@cached() @cached()
def get_user_by_token(self, token): def get_user_by_token(self, token):

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cachedInlineCallbacks
import collections import collections
import logging import logging
@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
} }
) )
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_room_name_and_aliases(self, room_id): def get_room_name_and_aliases(self, room_id):
def f(txn): def f(txn):
sql = ( sql = (

View File

@ -54,9 +54,9 @@ class RoomMemberStore(SQLBaseStore):
) )
for event in events: for event in events:
txn.call_after(self.get_rooms_for_user.invalidate, event.state_key) txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, event.room_id) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.
@ -78,7 +78,7 @@ class RoomMemberStore(SQLBaseStore):
lambda events: events[0] if events else None lambda events: events[0] if events else None
) )
@cached() @cached(max_entries=5000)
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):
@ -154,7 +154,7 @@ class RoomMemberStore(SQLBaseStore):
RoomsForUser(**r) for r in self.cursor_to_dict(txn) RoomsForUser(**r) for r in self.cursor_to_dict(txn)
] ]
@cached() @cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):
return self.runInteraction( return self.runInteraction(
"get_joined_hosts_for_room", "get_joined_hosts_for_room",

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
@ -91,7 +91,6 @@ class StateStore(SQLBaseStore):
defer.returnValue(dict(state_list)) defer.returnValue(dict(state_list))
@cached(num_args=1)
def _fetch_events_for_group(self, key, events): def _fetch_events_for_group(self, key, events):
return self._get_events( return self._get_events(
events, get_prev_content=False events, get_prev_content=False
@ -189,8 +188,7 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False) events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events) defer.returnValue(events)
@cached(num_args=3) @cachedInlineCallbacks(num_args=3)
@defer.inlineCallbacks
def get_current_state_for_key(self, room_id, event_type, state_key): def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn): def f(txn):
sql = ( sql = (

View File

@ -178,7 +178,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
Live tokens start with an "s" followed by the "stream_ordering" id of the Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, follewed by "-", "topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after. followed by the "stream_ordering" id of the event it comes after.
""" """
__slots__ = [] __slots__ = []
@ -211,4 +211,5 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
# token_id is the primary key ID of the access token, not the access token itself.
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))

View File

@ -51,7 +51,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_observers", set()) object.__setattr__(self, "_observers", set())
def callback(r): def callback(r):
self._result = (True, r) object.__setattr__(self, "_result", (True, r))
while self._observers: while self._observers:
try: try:
self._observers.pop().callback(r) self._observers.pop().callback(r)
@ -60,7 +60,7 @@ class ObservableDeferred(object):
return r return r
def errback(f): def errback(f):
self._result = (False, f) object.__setattr__(self, "_result", (False, f))
while self._observers: while self._observers:
try: try:
self._observers.pop().errback(f) self._observers.pop().errback(f)
@ -97,3 +97,8 @@ class ObservableDeferred(object):
def __setattr__(self, name, value): def __setattr__(self, name, value):
setattr(self._deferred, name, value) setattr(self._deferred, name, value)
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred,
)

View File

@ -0,0 +1,134 @@
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.api.errors import SynapseError
from twisted.internet import defer
from mock import Mock, MagicMock
from tests import unittest
import json
class RegisterRestServletTestCase(unittest.TestCase):
def setUp(self):
# do the dance to hook up request data to self.request_data
self.request_data = ""
self.request = Mock(
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
)
self.request.args = {}
self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock(
side_effect=lambda x: defer.succeed(self.appservice))
)
self.auth_result = (False, None, None)
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x,y,z: self.auth_result)
)
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
# do the dance to hook it up to the hs global
self.handlers = Mock(
auth_handler=self.auth_handler,
registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler
)
self.hs = Mock()
self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.config.disable_registration = False
# init the thing we're testing
self.servlet = RegisterRestServlet(self.hs)
@defer.inlineCallbacks
def test_POST_appservice_registration_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
self.request.args = {
"access_token": "i_am_an_app_service"
}
self.request_data = json.dumps({
"username": "kermit"
})
self.appservice = {
"id": "1234"
}
self.registration_handler.appservice_register = Mock(
return_value=(user_id, token)
)
result = yield self.servlet.on_POST(self.request)
self.assertEquals(result, (200, {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname
}))
@defer.inlineCallbacks
def test_POST_appservice_registration_invalid(self):
self.request.args = {
"access_token": "i_am_an_app_service"
}
self.request_data = json.dumps({
"username": "kermit"
})
self.appservice = None # no application service exists
result = yield self.servlet.on_POST(self.request)
self.assertEquals(result, (401, None))
def test_POST_bad_password(self):
self.request_data = json.dumps({
"username": "kermit",
"password": 666
})
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError)
def test_POST_bad_username(self):
self.request_data = json.dumps({
"username": 777,
"password": "monkey"
})
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError)
@defer.inlineCallbacks
def test_POST_user_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
self.request_data = json.dumps({
"username": "kermit",
"password": "monkey"
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
"username": "kermit",
"password": "monkey"
})
self.registration_handler.register = Mock(return_value=(user_id, token))
result = yield self.servlet.on_POST(self.request)
self.assertEquals(result, (200, {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname
}))
def test_POST_disabled_registration(self):
self.hs.config.disable_registration = True
self.request_data = json.dumps({
"username": "kermit",
"password": "monkey"
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
"username": "kermit",
"password": "monkey"
})
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError)

View File

@ -17,6 +17,8 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.util.async import ObservableDeferred
from synapse.storage._base import Cache, cached from synapse.storage._base import Cache, cached
@ -40,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
self.assertEquals(self.cache.get("foo"), 123) self.assertEquals(self.cache.get("foo"), 123)
def test_invalidate(self): def test_invalidate(self):
self.cache.prefill("foo", 123) self.cache.prefill(("foo",), 123)
self.cache.invalidate("foo") self.cache.invalidate(("foo",))
failed = False failed = False
try: try:
self.cache.get("foo") self.cache.get(("foo",))
except KeyError: except KeyError:
failed = True failed = True
@ -139,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 1) self.assertEquals(callcount[0], 1)
a.func.invalidate("foo") a.func.invalidate(("foo",))
yield a.func("foo") yield a.func("foo")
@ -151,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
def func(self, key): def func(self, key):
return key return key
A().func.invalidate("what") A().func.invalidate(("what",))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_max_entries(self): def test_max_entries(self):
@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertTrue(callcount[0] >= 14, self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0])) msg="Expected callcount >= 14, got %d" % (callcount[0]))
@defer.inlineCallbacks
def test_prefill(self): def test_prefill(self):
callcount = [0] callcount = [0]
d = defer.succeed(123)
class A(object): class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return d
a = A() a = A()
a.func.prefill("foo", 123) a.func.prefill(("foo",), ObservableDeferred(d))
self.assertEquals((yield a.func("foo")), 123) self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)

View File

@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase):
yield d yield d
self.assertTrue(d.called) self.assertTrue(d.called)
observers[0].assert_called_once("Go") observers[0].assert_called_once_with("Go")
observers[1].assert_called_once("Go") observers[1].assert_called_once_with("Go")
self.assertEquals(mock_logger.warning.call_count, 1) self.assertEquals(mock_logger.warning.call_count, 1)
self.assertIsInstance(mock_logger.warning.call_args[0][0], self.assertIsInstance(mock_logger.warning.call_args[0][0],