Merge remote-tracking branch 'origin/develop' into markjh/end-to-end-key-federation
This commit is contained in:
commit
c5966b2a97
92
README.rst
92
README.rst
|
@ -101,25 +101,26 @@ header files for python C extensions.
|
||||||
|
|
||||||
Installing prerequisites on Ubuntu or Debian::
|
Installing prerequisites on Ubuntu or Debian::
|
||||||
|
|
||||||
$ sudo apt-get install build-essential python2.7-dev libffi-dev \
|
sudo apt-get install build-essential python2.7-dev libffi-dev \
|
||||||
python-pip python-setuptools sqlite3 \
|
python-pip python-setuptools sqlite3 \
|
||||||
libssl-dev python-virtualenv libjpeg-dev
|
libssl-dev python-virtualenv libjpeg-dev
|
||||||
|
|
||||||
Installing prerequisites on ArchLinux::
|
Installing prerequisites on ArchLinux::
|
||||||
|
|
||||||
$ sudo pacman -S base-devel python2 python-pip \
|
sudo pacman -S base-devel python2 python-pip \
|
||||||
python-setuptools python-virtualenv sqlite3
|
python-setuptools python-virtualenv sqlite3
|
||||||
|
|
||||||
Installing prerequisites on Mac OS X::
|
Installing prerequisites on Mac OS X::
|
||||||
|
|
||||||
$ xcode-select --install
|
xcode-select --install
|
||||||
$ sudo pip install virtualenv
|
sudo easy_install pip
|
||||||
|
sudo pip install virtualenv
|
||||||
|
|
||||||
To install the synapse homeserver run::
|
To install the synapse homeserver run::
|
||||||
|
|
||||||
$ virtualenv -p python2.7 ~/.synapse
|
virtualenv -p python2.7 ~/.synapse
|
||||||
$ source ~/.synapse/bin/activate
|
source ~/.synapse/bin/activate
|
||||||
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
|
pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
|
||||||
|
|
||||||
This installs synapse, along with the libraries it uses, into a virtual
|
This installs synapse, along with the libraries it uses, into a virtual
|
||||||
environment under ``~/.synapse``. Feel free to pick a different directory
|
environment under ``~/.synapse``. Feel free to pick a different directory
|
||||||
|
@ -132,8 +133,8 @@ above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
|
||||||
|
|
||||||
To set up your homeserver, run (in your virtualenv, as before)::
|
To set up your homeserver, run (in your virtualenv, as before)::
|
||||||
|
|
||||||
$ cd ~/.synapse
|
cd ~/.synapse
|
||||||
$ python -m synapse.app.homeserver \
|
python -m synapse.app.homeserver \
|
||||||
--server-name machine.my.domain.name \
|
--server-name machine.my.domain.name \
|
||||||
--config-path homeserver.yaml \
|
--config-path homeserver.yaml \
|
||||||
--generate-config
|
--generate-config
|
||||||
|
@ -192,9 +193,9 @@ Running Synapse
|
||||||
To actually run your new homeserver, pick a working directory for Synapse to run
|
To actually run your new homeserver, pick a working directory for Synapse to run
|
||||||
(e.g. ``~/.synapse``), and::
|
(e.g. ``~/.synapse``), and::
|
||||||
|
|
||||||
$ cd ~/.synapse
|
cd ~/.synapse
|
||||||
$ source ./bin/activate
|
source ./bin/activate
|
||||||
$ synctl start
|
synctl start
|
||||||
|
|
||||||
Platform Specific Instructions
|
Platform Specific Instructions
|
||||||
==============================
|
==============================
|
||||||
|
@ -212,12 +213,12 @@ defaults to python 3, but synapse currently assumes python 2.7 by default:
|
||||||
|
|
||||||
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
|
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
|
||||||
|
|
||||||
$ sudo pip2.7 install --upgrade pip
|
sudo pip2.7 install --upgrade pip
|
||||||
|
|
||||||
You also may need to explicitly specify python 2.7 again during the install
|
You also may need to explicitly specify python 2.7 again during the install
|
||||||
request::
|
request::
|
||||||
|
|
||||||
$ pip2.7 install --process-dependency-links \
|
pip2.7 install --process-dependency-links \
|
||||||
https://github.com/matrix-org/synapse/tarball/master
|
https://github.com/matrix-org/synapse/tarball/master
|
||||||
|
|
||||||
If you encounter an error with lib bcrypt causing an Wrong ELF Class:
|
If you encounter an error with lib bcrypt causing an Wrong ELF Class:
|
||||||
|
@ -225,13 +226,13 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
|
||||||
compile it under the right architecture. (This should not be needed if
|
compile it under the right architecture. (This should not be needed if
|
||||||
installing under virtualenv)::
|
installing under virtualenv)::
|
||||||
|
|
||||||
$ sudo pip2.7 uninstall py-bcrypt
|
sudo pip2.7 uninstall py-bcrypt
|
||||||
$ sudo pip2.7 install py-bcrypt
|
sudo pip2.7 install py-bcrypt
|
||||||
|
|
||||||
During setup of Synapse you need to call python2.7 directly again::
|
During setup of Synapse you need to call python2.7 directly again::
|
||||||
|
|
||||||
$ cd ~/.synapse
|
cd ~/.synapse
|
||||||
$ python2.7 -m synapse.app.homeserver \
|
python2.7 -m synapse.app.homeserver \
|
||||||
--server-name machine.my.domain.name \
|
--server-name machine.my.domain.name \
|
||||||
--config-path homeserver.yaml \
|
--config-path homeserver.yaml \
|
||||||
--generate-config
|
--generate-config
|
||||||
|
@ -279,22 +280,22 @@ Synapse requires pip 1.7 or later, so if your OS provides too old a version and
|
||||||
you get errors about ``error: no such option: --process-dependency-links`` you
|
you get errors about ``error: no such option: --process-dependency-links`` you
|
||||||
may need to manually upgrade it::
|
may need to manually upgrade it::
|
||||||
|
|
||||||
$ sudo pip install --upgrade pip
|
sudo pip install --upgrade pip
|
||||||
|
|
||||||
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
|
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
|
||||||
refuse to run until you remove the temporary installation directory it
|
refuse to run until you remove the temporary installation directory it
|
||||||
created. To reset the installation::
|
created. To reset the installation::
|
||||||
|
|
||||||
$ rm -rf /tmp/pip_install_matrix
|
rm -rf /tmp/pip_install_matrix
|
||||||
|
|
||||||
pip seems to leak *lots* of memory during installation. For instance, a Linux
|
pip seems to leak *lots* of memory during installation. For instance, a Linux
|
||||||
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
||||||
happens, you will have to individually install the dependencies which are
|
happens, you will have to individually install the dependencies which are
|
||||||
failing, e.g.::
|
failing, e.g.::
|
||||||
|
|
||||||
$ pip install twisted
|
pip install twisted
|
||||||
|
|
||||||
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
|
On OS X, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
|
||||||
will need to export CFLAGS=-Qunused-arguments.
|
will need to export CFLAGS=-Qunused-arguments.
|
||||||
|
|
||||||
Troubleshooting Running
|
Troubleshooting Running
|
||||||
|
@ -310,10 +311,11 @@ correctly, causing all tests to fail with errors about missing "sodium.h". To
|
||||||
fix try re-installing from PyPI or directly from
|
fix try re-installing from PyPI or directly from
|
||||||
(https://github.com/pyca/pynacl)::
|
(https://github.com/pyca/pynacl)::
|
||||||
|
|
||||||
$ # Install from PyPI
|
# Install from PyPI
|
||||||
$ pip install --user --upgrade --force pynacl
|
pip install --user --upgrade --force pynacl
|
||||||
$ # Install from github
|
|
||||||
$ pip install --user https://github.com/pyca/pynacl/tarball/master
|
# Install from github
|
||||||
|
pip install --user https://github.com/pyca/pynacl/tarball/master
|
||||||
|
|
||||||
ArchLinux
|
ArchLinux
|
||||||
~~~~~~~~~
|
~~~~~~~~~
|
||||||
|
@ -321,7 +323,7 @@ ArchLinux
|
||||||
If running `$ synctl start` fails with 'returned non-zero exit status 1',
|
If running `$ synctl start` fails with 'returned non-zero exit status 1',
|
||||||
you will need to explicitly call Python2.7 - either running as::
|
you will need to explicitly call Python2.7 - either running as::
|
||||||
|
|
||||||
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
|
python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
|
||||||
|
|
||||||
...or by editing synctl with the correct python executable.
|
...or by editing synctl with the correct python executable.
|
||||||
|
|
||||||
|
@ -331,16 +333,16 @@ Synapse Development
|
||||||
To check out a synapse for development, clone the git repo into a working
|
To check out a synapse for development, clone the git repo into a working
|
||||||
directory of your choice::
|
directory of your choice::
|
||||||
|
|
||||||
$ git clone https://github.com/matrix-org/synapse.git
|
git clone https://github.com/matrix-org/synapse.git
|
||||||
$ cd synapse
|
cd synapse
|
||||||
|
|
||||||
Synapse has a number of external dependencies, that are easiest
|
Synapse has a number of external dependencies, that are easiest
|
||||||
to install using pip and a virtualenv::
|
to install using pip and a virtualenv::
|
||||||
|
|
||||||
$ virtualenv env
|
virtualenv env
|
||||||
$ source env/bin/activate
|
source env/bin/activate
|
||||||
$ python synapse/python_dependencies.py | xargs -n1 pip install
|
python synapse/python_dependencies.py | xargs -n1 pip install
|
||||||
$ pip install setuptools_trial mock
|
pip install setuptools_trial mock
|
||||||
|
|
||||||
This will run a process of downloading and installing all the needed
|
This will run a process of downloading and installing all the needed
|
||||||
dependencies into a virtual env.
|
dependencies into a virtual env.
|
||||||
|
@ -348,7 +350,7 @@ dependencies into a virtual env.
|
||||||
Once this is done, you may wish to run Synapse's unit tests, to
|
Once this is done, you may wish to run Synapse's unit tests, to
|
||||||
check that everything is installed as it should be::
|
check that everything is installed as it should be::
|
||||||
|
|
||||||
$ python setup.py test
|
python setup.py test
|
||||||
|
|
||||||
This should end with a 'PASSED' result::
|
This should end with a 'PASSED' result::
|
||||||
|
|
||||||
|
@ -389,11 +391,11 @@ IDs:
|
||||||
For the first form, simply pass the required hostname (of the machine) as the
|
For the first form, simply pass the required hostname (of the machine) as the
|
||||||
--server-name parameter::
|
--server-name parameter::
|
||||||
|
|
||||||
$ python -m synapse.app.homeserver \
|
python -m synapse.app.homeserver \
|
||||||
--server-name machine.my.domain.name \
|
--server-name machine.my.domain.name \
|
||||||
--config-path homeserver.yaml \
|
--config-path homeserver.yaml \
|
||||||
--generate-config
|
--generate-config
|
||||||
$ python -m synapse.app.homeserver --config-path homeserver.yaml
|
python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||||
|
|
||||||
Alternatively, you can run ``synctl start`` to guide you through the process.
|
Alternatively, you can run ``synctl start`` to guide you through the process.
|
||||||
|
|
||||||
|
@ -410,11 +412,11 @@ record would then look something like::
|
||||||
At this point, you should then run the homeserver with the hostname of this
|
At this point, you should then run the homeserver with the hostname of this
|
||||||
SRV record, as that is the name other machines will expect it to have::
|
SRV record, as that is the name other machines will expect it to have::
|
||||||
|
|
||||||
$ python -m synapse.app.homeserver \
|
python -m synapse.app.homeserver \
|
||||||
--server-name YOURDOMAIN \
|
--server-name YOURDOMAIN \
|
||||||
--config-path homeserver.yaml \
|
--config-path homeserver.yaml \
|
||||||
--generate-config
|
--generate-config
|
||||||
$ python -m synapse.app.homeserver --config-path homeserver.yaml
|
python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||||
|
|
||||||
|
|
||||||
You may additionally want to pass one or more "-v" options, in order to
|
You may additionally want to pass one or more "-v" options, in order to
|
||||||
|
@ -428,7 +430,7 @@ private federation (``localhost:8080``, ``localhost:8081`` and
|
||||||
``localhost:8082``) which you can then access through the webclient running at
|
``localhost:8082``) which you can then access through the webclient running at
|
||||||
http://localhost:8080. Simply run::
|
http://localhost:8080. Simply run::
|
||||||
|
|
||||||
$ demo/start.sh
|
demo/start.sh
|
||||||
|
|
||||||
This is mainly useful just for development purposes.
|
This is mainly useful just for development purposes.
|
||||||
|
|
||||||
|
@ -502,10 +504,10 @@ Building Internal API Documentation
|
||||||
Before building internal API documentation install sphinx and
|
Before building internal API documentation install sphinx and
|
||||||
sphinxcontrib-napoleon::
|
sphinxcontrib-napoleon::
|
||||||
|
|
||||||
$ pip install sphinx
|
pip install sphinx
|
||||||
$ pip install sphinxcontrib-napoleon
|
pip install sphinxcontrib-napoleon
|
||||||
|
|
||||||
Building internal API documentation::
|
Building internal API documentation::
|
||||||
|
|
||||||
$ python setup.py build_sphinx
|
python setup.py build_sphinx
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,9 @@ if [ -f $PID_FILE ]; then
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
find "$DIR" -name "*.log" -delete
|
for port in 8080 8081 8082; do
|
||||||
find "$DIR" -name "*.db" -delete
|
rm -rf $DIR/$port
|
||||||
|
rm -rf $DIR/media_store.$port
|
||||||
|
done
|
||||||
|
|
||||||
rm -rf $DIR/etc
|
rm -rf $DIR/etc
|
||||||
|
|
|
@ -8,14 +8,6 @@ cd "$DIR/.."
|
||||||
|
|
||||||
mkdir -p demo/etc
|
mkdir -p demo/etc
|
||||||
|
|
||||||
# Check the --no-rate-limit param
|
|
||||||
PARAMS=""
|
|
||||||
if [ $# -eq 1 ]; then
|
|
||||||
if [ $1 = "--no-rate-limit" ]; then
|
|
||||||
PARAMS="--rc-messages-per-second 1000 --rc-message-burst-count 1000"
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
export PYTHONPATH=$(readlink -f $(pwd))
|
export PYTHONPATH=$(readlink -f $(pwd))
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,10 +23,20 @@ for port in 8080 8081 8082; do
|
||||||
#rm $DIR/etc/$port.config
|
#rm $DIR/etc/$port.config
|
||||||
python -m synapse.app.homeserver \
|
python -m synapse.app.homeserver \
|
||||||
--generate-config \
|
--generate-config \
|
||||||
--enable_registration \
|
|
||||||
-H "localhost:$https_port" \
|
-H "localhost:$https_port" \
|
||||||
--config-path "$DIR/etc/$port.config" \
|
--config-path "$DIR/etc/$port.config" \
|
||||||
|
|
||||||
|
# Check script parameters
|
||||||
|
if [ $# -eq 1 ]; then
|
||||||
|
if [ $1 = "--no-rate-limit" ]; then
|
||||||
|
# Set high limits in config file to disable rate limiting
|
||||||
|
perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config
|
||||||
|
perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config
|
||||||
|
|
||||||
python -m synapse.app.homeserver \
|
python -m synapse.app.homeserver \
|
||||||
--config-path "$DIR/etc/$port.config" \
|
--config-path "$DIR/etc/$port.config" \
|
||||||
-D \
|
-D \
|
||||||
|
|
|
@ -16,3 +16,6 @@ ignore =
|
||||||
docs/*
|
docs/*
|
||||||
pylint.cfg
|
pylint.cfg
|
||||||
tox.ini
|
tox.ini
|
||||||
|
|
||||||
|
[flake8]
|
||||||
|
max-line-length = 90
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -48,7 +48,7 @@ setup(
|
||||||
description="Reference Synapse Home Server",
|
description="Reference Synapse Home Server",
|
||||||
install_requires=dependencies['requirements'](include_conditional=True).keys(),
|
install_requires=dependencies['requirements'](include_conditional=True).keys(),
|
||||||
setup_requires=[
|
setup_requires=[
|
||||||
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
|
"Twisted>=15.1.0", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
|
||||||
"setuptools_trial",
|
"setuptools_trial",
|
||||||
"mock"
|
"mock"
|
||||||
],
|
],
|
||||||
|
|
|
@ -44,6 +44,11 @@ class Auth(object):
|
||||||
def check(self, event, auth_events):
|
def check(self, event, auth_events):
|
||||||
""" Checks if this event is correctly authed.
|
""" Checks if this event is correctly authed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: the event being checked.
|
||||||
|
auth_events (dict: event-key -> event): the existing room state.
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the auth checks pass.
|
True if the auth checks pass.
|
||||||
"""
|
"""
|
||||||
|
@ -319,7 +324,7 @@ class Auth(object):
|
||||||
Returns:
|
Returns:
|
||||||
tuple : of UserID and device string:
|
tuple : of UserID and device string:
|
||||||
User ID object of the user making the request
|
User ID object of the user making the request
|
||||||
Client ID object of the client instance the user is using
|
ClientInfo object of the client instance the user is using
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if no user by that token exists or the token is invalid.
|
AuthError if no user by that token exists or the token is invalid.
|
||||||
"""
|
"""
|
||||||
|
@ -352,7 +357,7 @@ class Auth(object):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass # normal users won't have this query parameter set
|
pass # normal users won't have the user_id query parameter set.
|
||||||
|
|
||||||
user_info = yield self.get_user_by_token(access_token)
|
user_info = yield self.get_user_by_token(access_token)
|
||||||
user = user_info["user"]
|
user = user_info["user"]
|
||||||
|
@ -521,23 +526,22 @@ class Auth(object):
|
||||||
|
|
||||||
# Check state_key
|
# Check state_key
|
||||||
if hasattr(event, "state_key"):
|
if hasattr(event, "state_key"):
|
||||||
if not event.state_key.startswith("_"):
|
if event.state_key.startswith("@"):
|
||||||
if event.state_key.startswith("@"):
|
if event.state_key != event.user_id:
|
||||||
if event.state_key != event.user_id:
|
raise AuthError(
|
||||||
|
403,
|
||||||
|
"You are not allowed to set others state"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sender_domain = UserID.from_string(
|
||||||
|
event.user_id
|
||||||
|
).domain
|
||||||
|
|
||||||
|
if sender_domain != event.state_key:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
"You are not allowed to set others state"
|
"You are not allowed to set others state"
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
sender_domain = UserID.from_string(
|
|
||||||
event.user_id
|
|
||||||
).domain
|
|
||||||
|
|
||||||
if sender_domain != event.state_key:
|
|
||||||
raise AuthError(
|
|
||||||
403,
|
|
||||||
"You are not allowed to set others state"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -657,7 +657,8 @@ def run(hs):
|
||||||
|
|
||||||
if hs.config.daemonize:
|
if hs.config.daemonize:
|
||||||
|
|
||||||
print hs.config.pid_file
|
if hs.config.print_pidfile:
|
||||||
|
print hs.config.pid_file
|
||||||
|
|
||||||
daemon = Daemonize(
|
daemon = Daemonize(
|
||||||
app="synapse-homeserver",
|
app="synapse-homeserver",
|
||||||
|
|
|
@ -138,12 +138,19 @@ class Config(object):
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Generate a config file for the server name"
|
help="Generate a config file for the server name"
|
||||||
)
|
)
|
||||||
|
config_parser.add_argument(
|
||||||
|
"--generate-keys",
|
||||||
|
action="store_true",
|
||||||
|
help="Generate any missing key files then exit"
|
||||||
|
)
|
||||||
config_parser.add_argument(
|
config_parser.add_argument(
|
||||||
"-H", "--server-name",
|
"-H", "--server-name",
|
||||||
help="The server name to generate a config file for"
|
help="The server name to generate a config file for"
|
||||||
)
|
)
|
||||||
config_args, remaining_args = config_parser.parse_known_args(argv)
|
config_args, remaining_args = config_parser.parse_known_args(argv)
|
||||||
|
|
||||||
|
generate_keys = config_args.generate_keys
|
||||||
|
|
||||||
if config_args.generate_config:
|
if config_args.generate_config:
|
||||||
if not config_args.config_path:
|
if not config_args.config_path:
|
||||||
config_parser.error(
|
config_parser.error(
|
||||||
|
@ -151,51 +158,40 @@ class Config(object):
|
||||||
" generated using \"--generate-config -H SERVER_NAME"
|
" generated using \"--generate-config -H SERVER_NAME"
|
||||||
" -c CONFIG-FILE\""
|
" -c CONFIG-FILE\""
|
||||||
)
|
)
|
||||||
|
|
||||||
config_dir_path = os.path.dirname(config_args.config_path[0])
|
|
||||||
config_dir_path = os.path.abspath(config_dir_path)
|
|
||||||
|
|
||||||
server_name = config_args.server_name
|
|
||||||
if not server_name:
|
|
||||||
print "Must specify a server_name to a generate config for."
|
|
||||||
sys.exit(1)
|
|
||||||
(config_path,) = config_args.config_path
|
(config_path,) = config_args.config_path
|
||||||
if not os.path.exists(config_dir_path):
|
if not os.path.exists(config_path):
|
||||||
os.makedirs(config_dir_path)
|
config_dir_path = os.path.dirname(config_path)
|
||||||
if os.path.exists(config_path):
|
config_dir_path = os.path.abspath(config_dir_path)
|
||||||
print "Config file %r already exists" % (config_path,)
|
|
||||||
yaml_config = cls.read_config_file(config_path)
|
server_name = config_args.server_name
|
||||||
yaml_name = yaml_config["server_name"]
|
if not server_name:
|
||||||
if server_name != yaml_name:
|
print "Must specify a server_name to a generate config for."
|
||||||
print (
|
|
||||||
"Config file %r has a different server_name: "
|
|
||||||
" %r != %r" % (config_path, server_name, yaml_name)
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
config_bytes, config = obj.generate_config(
|
if not os.path.exists(config_dir_path):
|
||||||
config_dir_path, server_name
|
os.makedirs(config_dir_path)
|
||||||
)
|
with open(config_path, "wb") as config_file:
|
||||||
config.update(yaml_config)
|
config_bytes, config = obj.generate_config(
|
||||||
print "Generating any missing keys for %r" % (server_name,)
|
config_dir_path, server_name
|
||||||
obj.invoke_all("generate_files", config)
|
)
|
||||||
sys.exit(0)
|
obj.invoke_all("generate_files", config)
|
||||||
with open(config_path, "wb") as config_file:
|
config_file.write(config_bytes)
|
||||||
config_bytes, config = obj.generate_config(
|
|
||||||
config_dir_path, server_name
|
|
||||||
)
|
|
||||||
obj.invoke_all("generate_files", config)
|
|
||||||
config_file.write(config_bytes)
|
|
||||||
print (
|
print (
|
||||||
"A config file has been generated in %s for server name"
|
"A config file has been generated in %r for server name"
|
||||||
" '%s' with corresponding SSL keys and self-signed"
|
" %r with corresponding SSL keys and self-signed"
|
||||||
" certificates. Please review this file and customise it to"
|
" certificates. Please review this file and customise it"
|
||||||
" your needs."
|
" to your needs."
|
||||||
) % (config_path, server_name)
|
) % (config_path, server_name)
|
||||||
print (
|
print (
|
||||||
"If this server name is incorrect, you will need to regenerate"
|
"If this server name is incorrect, you will need to"
|
||||||
" the SSL certificates"
|
" regenerate the SSL certificates"
|
||||||
)
|
)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print (
|
||||||
|
"Config file %r already exists. Generating any missing key"
|
||||||
|
" files."
|
||||||
|
) % (config_path,)
|
||||||
|
generate_keys = True
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
parents=[config_parser],
|
parents=[config_parser],
|
||||||
|
@ -213,7 +209,7 @@ class Config(object):
|
||||||
" -c CONFIG-FILE\""
|
" -c CONFIG-FILE\""
|
||||||
)
|
)
|
||||||
|
|
||||||
config_dir_path = os.path.dirname(config_args.config_path[0])
|
config_dir_path = os.path.dirname(config_args.config_path[-1])
|
||||||
config_dir_path = os.path.abspath(config_dir_path)
|
config_dir_path = os.path.abspath(config_dir_path)
|
||||||
|
|
||||||
specified_config = {}
|
specified_config = {}
|
||||||
|
@ -226,6 +222,10 @@ class Config(object):
|
||||||
config.pop("log_config")
|
config.pop("log_config")
|
||||||
config.update(specified_config)
|
config.update(specified_config)
|
||||||
|
|
||||||
|
if generate_keys:
|
||||||
|
obj.invoke_all("generate_files", config)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
obj.invoke_all("read_config", config)
|
obj.invoke_all("read_config", config)
|
||||||
|
|
||||||
obj.invoke_all("read_arguments", args)
|
obj.invoke_all("read_arguments", args)
|
||||||
|
|
|
@ -24,6 +24,7 @@ class ServerConfig(Config):
|
||||||
self.web_client = config["web_client"]
|
self.web_client = config["web_client"]
|
||||||
self.soft_file_limit = config["soft_file_limit"]
|
self.soft_file_limit = config["soft_file_limit"]
|
||||||
self.daemonize = config.get("daemonize")
|
self.daemonize = config.get("daemonize")
|
||||||
|
self.print_pidfile = config.get("print_pidfile")
|
||||||
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
|
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
|
||||||
|
|
||||||
self.listeners = config.get("listeners", [])
|
self.listeners = config.get("listeners", [])
|
||||||
|
@ -208,12 +209,18 @@ class ServerConfig(Config):
|
||||||
self.manhole = args.manhole
|
self.manhole = args.manhole
|
||||||
if args.daemonize is not None:
|
if args.daemonize is not None:
|
||||||
self.daemonize = args.daemonize
|
self.daemonize = args.daemonize
|
||||||
|
if args.print_pidfile is not None:
|
||||||
|
self.print_pidfile = args.print_pidfile
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
def add_arguments(self, parser):
|
||||||
server_group = parser.add_argument_group("server")
|
server_group = parser.add_argument_group("server")
|
||||||
server_group.add_argument("-D", "--daemonize", action='store_true',
|
server_group.add_argument("-D", "--daemonize", action='store_true',
|
||||||
default=None,
|
default=None,
|
||||||
help="Daemonize the home server")
|
help="Daemonize the home server")
|
||||||
|
server_group.add_argument("--print-pidfile", action='store_true',
|
||||||
|
default=None,
|
||||||
|
help="Print the path to the pidfile just"
|
||||||
|
" before daemonizing")
|
||||||
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
||||||
type=int,
|
type=int,
|
||||||
help="Turn on the twisted telnet manhole"
|
help="Turn on the twisted telnet manhole"
|
||||||
|
|
|
@ -44,7 +44,7 @@ class IdentityHandler(BaseHandler):
|
||||||
http_client = SimpleHttpClient(self.hs)
|
http_client = SimpleHttpClient(self.hs)
|
||||||
# XXX: make this configurable!
|
# XXX: make this configurable!
|
||||||
# trustedIdServers = ['matrix.org', 'localhost:8090']
|
# trustedIdServers = ['matrix.org', 'localhost:8090']
|
||||||
trustedIdServers = ['matrix.org']
|
trustedIdServers = ['matrix.org', 'vector.im']
|
||||||
|
|
||||||
if 'id_server' in creds:
|
if 'id_server' in creds:
|
||||||
id_server = creds['id_server']
|
id_server = creds['id_server']
|
||||||
|
|
|
@ -73,7 +73,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
localpart : The local part of the user ID to register. If None,
|
localpart : The local part of the user ID to register. If None,
|
||||||
one will be randomly generated.
|
one will be randomly generated.
|
||||||
password (str) : The password to assign to this user so they can
|
password (str) : The password to assign to this user so they can
|
||||||
login again.
|
login again. This can be None which means they cannot login again
|
||||||
|
via a password (e.g. the user is an application service user).
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (user_id, access_token).
|
A tuple of (user_id, access_token).
|
||||||
Raises:
|
Raises:
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
from twisted.internet import defer, reactor, protocol
|
from twisted.internet import defer, reactor, protocol
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool
|
from twisted.web.client import readBody, HTTPConnectionPool, Agent
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
from twisted.web._newclient import ResponseDone
|
from twisted.web._newclient import ResponseDone
|
||||||
|
|
||||||
|
@ -55,41 +55,17 @@ incoming_responses_counter = metrics.register_counter(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MatrixFederationHttpAgent(_AgentBase):
|
class MatrixFederationEndpointFactory(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.tls_context_factory = hs.tls_context_factory
|
||||||
|
|
||||||
def __init__(self, reactor, pool=None):
|
def endpointForURI(self, uri):
|
||||||
_AgentBase.__init__(self, reactor, pool)
|
destination = uri.netloc
|
||||||
|
|
||||||
def request(self, destination, endpoint, method, path, params, query,
|
return matrix_federation_endpoint(
|
||||||
headers, body_producer):
|
reactor, destination, timeout=10,
|
||||||
|
ssl_context_factory=self.tls_context_factory
|
||||||
outgoing_requests_counter.inc(method)
|
)
|
||||||
|
|
||||||
host = b""
|
|
||||||
port = 0
|
|
||||||
fragment = b""
|
|
||||||
|
|
||||||
parsed_URI = _URI(b"http", destination, host, port, path, params,
|
|
||||||
query, fragment)
|
|
||||||
|
|
||||||
# Set the connection pool key to be the destination.
|
|
||||||
key = destination
|
|
||||||
|
|
||||||
d = self._requestWithEndpoint(key, endpoint, method, parsed_URI,
|
|
||||||
headers, body_producer,
|
|
||||||
parsed_URI.originForm)
|
|
||||||
|
|
||||||
def _cb(response):
|
|
||||||
incoming_responses_counter.inc(method, response.code)
|
|
||||||
return response
|
|
||||||
|
|
||||||
def _eb(failure):
|
|
||||||
incoming_responses_counter.inc(method, "ERR")
|
|
||||||
return failure
|
|
||||||
|
|
||||||
d.addCallbacks(_cb, _eb)
|
|
||||||
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
class MatrixFederationHttpClient(object):
|
class MatrixFederationHttpClient(object):
|
||||||
|
@ -107,12 +83,18 @@ class MatrixFederationHttpClient(object):
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
pool = HTTPConnectionPool(reactor)
|
pool = HTTPConnectionPool(reactor)
|
||||||
pool.maxPersistentPerHost = 10
|
pool.maxPersistentPerHost = 10
|
||||||
self.agent = MatrixFederationHttpAgent(reactor, pool=pool)
|
self.agent = Agent.usingEndpointFactory(
|
||||||
|
reactor, MatrixFederationEndpointFactory(hs), pool=pool
|
||||||
|
)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
|
|
||||||
self._next_id = 1
|
self._next_id = 1
|
||||||
|
|
||||||
|
def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
|
||||||
|
return urlparse.urlunparse(
|
||||||
|
("matrix", destination, path_bytes, param_bytes, query_bytes, "")
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _create_request(self, destination, method, path_bytes,
|
def _create_request(self, destination, method, path_bytes,
|
||||||
body_callback, headers_dict={}, param_bytes=b"",
|
body_callback, headers_dict={}, param_bytes=b"",
|
||||||
|
@ -123,8 +105,8 @@ class MatrixFederationHttpClient(object):
|
||||||
headers_dict[b"User-Agent"] = [self.version_string]
|
headers_dict[b"User-Agent"] = [self.version_string]
|
||||||
headers_dict[b"Host"] = [destination]
|
headers_dict[b"Host"] = [destination]
|
||||||
|
|
||||||
url_bytes = urlparse.urlunparse(
|
url_bytes = self._create_url(
|
||||||
("", "", path_bytes, param_bytes, query_bytes, "",)
|
destination, path_bytes, param_bytes, query_bytes
|
||||||
)
|
)
|
||||||
|
|
||||||
txn_id = "%s-O-%s" % (method, self._next_id)
|
txn_id = "%s-O-%s" % (method, self._next_id)
|
||||||
|
@ -139,8 +121,8 @@ class MatrixFederationHttpClient(object):
|
||||||
# (once we have reliable transactions in place)
|
# (once we have reliable transactions in place)
|
||||||
retries_left = 5
|
retries_left = 5
|
||||||
|
|
||||||
endpoint = preserve_context_over_fn(
|
http_url_bytes = urlparse.urlunparse(
|
||||||
self._getEndpoint, reactor, destination
|
("", "", path_bytes, param_bytes, query_bytes, "")
|
||||||
)
|
)
|
||||||
|
|
||||||
log_result = None
|
log_result = None
|
||||||
|
@ -148,17 +130,14 @@ class MatrixFederationHttpClient(object):
|
||||||
while True:
|
while True:
|
||||||
producer = None
|
producer = None
|
||||||
if body_callback:
|
if body_callback:
|
||||||
producer = body_callback(method, url_bytes, headers_dict)
|
producer = body_callback(method, http_url_bytes, headers_dict)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
def send_request():
|
def send_request():
|
||||||
request_deferred = self.agent.request(
|
request_deferred = preserve_context_over_fn(
|
||||||
destination,
|
self.agent.request,
|
||||||
endpoint,
|
|
||||||
method,
|
method,
|
||||||
path_bytes,
|
url_bytes,
|
||||||
param_bytes,
|
|
||||||
query_bytes,
|
|
||||||
Headers(headers_dict),
|
Headers(headers_dict),
|
||||||
producer
|
producer
|
||||||
)
|
)
|
||||||
|
@ -452,12 +431,6 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
defer.returnValue((length, headers))
|
defer.returnValue((length, headers))
|
||||||
|
|
||||||
def _getEndpoint(self, reactor, destination):
|
|
||||||
return matrix_federation_endpoint(
|
|
||||||
reactor, destination, timeout=10,
|
|
||||||
ssl_context_factory=self.hs.tls_context_factory
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _ReadBodyToFileProtocol(protocol.Protocol):
|
class _ReadBodyToFileProtocol(protocol.Protocol):
|
||||||
def __init__(self, stream, deferred, max_size):
|
def __init__(self, stream, deferred, max_size):
|
||||||
|
|
|
@ -18,8 +18,12 @@ from __future__ import absolute_import
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from resource import getrusage, getpagesize, RUSAGE_SELF
|
from resource import getrusage, getpagesize, RUSAGE_SELF
|
||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
import stat
|
import stat
|
||||||
|
import time
|
||||||
|
|
||||||
|
from twisted.internet import reactor
|
||||||
|
|
||||||
from .metric import (
|
from .metric import (
|
||||||
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
|
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
|
||||||
|
@ -144,3 +148,28 @@ def _process_fds():
|
||||||
return counts
|
return counts
|
||||||
|
|
||||||
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
|
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
|
||||||
|
|
||||||
|
reactor_metrics = get_metrics_for("reactor")
|
||||||
|
tick_time = reactor_metrics.register_distribution("tick_time")
|
||||||
|
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
|
||||||
|
|
||||||
|
|
||||||
|
def runUntilCurrentTimer(func):
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def f(*args, **kwargs):
|
||||||
|
pending_calls = len(reactor.getDelayedCalls())
|
||||||
|
start = time.time() * 1000
|
||||||
|
ret = func(*args, **kwargs)
|
||||||
|
end = time.time() * 1000
|
||||||
|
tick_time.inc_by(end - start)
|
||||||
|
pending_calls_metric.inc_by(pending_calls)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(reactor, "runUntilCurrent"):
|
||||||
|
# runUntilCurrent is called when we have pending calls. It is called once
|
||||||
|
# per iteratation after fd polling.
|
||||||
|
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
|
||||||
|
|
|
@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
REQUIREMENTS = {
|
REQUIREMENTS = {
|
||||||
"syutil>=0.0.7": ["syutil>=0.0.7"],
|
"syutil>=0.0.7": ["syutil>=0.0.7"],
|
||||||
"Twisted==14.0.2": ["twisted==14.0.2"],
|
"Twisted>=15.1.0": ["twisted>=15.1.0"],
|
||||||
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
||||||
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
||||||
"pyyaml": ["yaml"],
|
"pyyaml": ["yaml"],
|
||||||
|
|
|
@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import SynapseError, Codes
|
from synapse.api.errors import SynapseError, Codes
|
||||||
from synapse.http.servlet import RestServlet
|
from synapse.http.servlet import RestServlet
|
||||||
|
|
||||||
from ._base import client_v2_pattern, parse_request_allow_empty
|
from ._base import client_v2_pattern, parse_json_dict_from_request
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import hmac
|
import hmac
|
||||||
|
@ -55,30 +55,55 @@ class RegisterRestServlet(RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
body = parse_json_dict_from_request(request)
|
||||||
|
|
||||||
body = parse_request_allow_empty(request)
|
# we do basic sanity checks here because the auth layer will store these
|
||||||
# we do basic sanity checks here because the auth
|
# in sessions. Pull out the username/password provided to us.
|
||||||
# layer will store these in sessions
|
desired_password = None
|
||||||
if 'password' in body:
|
if 'password' in body:
|
||||||
if ((not isinstance(body['password'], str) and
|
if (not isinstance(body['password'], basestring) or
|
||||||
not isinstance(body['password'], unicode)) or
|
|
||||||
len(body['password']) > 512):
|
len(body['password']) > 512):
|
||||||
raise SynapseError(400, "Invalid password")
|
raise SynapseError(400, "Invalid password")
|
||||||
|
desired_password = body["password"]
|
||||||
|
|
||||||
|
desired_username = None
|
||||||
if 'username' in body:
|
if 'username' in body:
|
||||||
if ((not isinstance(body['username'], str) and
|
if (not isinstance(body['username'], basestring) or
|
||||||
not isinstance(body['username'], unicode)) or
|
|
||||||
len(body['username']) > 512):
|
len(body['username']) > 512):
|
||||||
raise SynapseError(400, "Invalid username")
|
raise SynapseError(400, "Invalid username")
|
||||||
desired_username = body['username']
|
desired_username = body['username']
|
||||||
yield self.registration_handler.check_username(desired_username)
|
|
||||||
|
|
||||||
is_using_shared_secret = False
|
appservice = None
|
||||||
is_application_server = False
|
|
||||||
|
|
||||||
service = None
|
|
||||||
if 'access_token' in request.args:
|
if 'access_token' in request.args:
|
||||||
service = yield self.auth.get_appservice_by_req(request)
|
appservice = yield self.auth.get_appservice_by_req(request)
|
||||||
|
|
||||||
|
# fork off as soon as possible for ASes and shared secret auth which
|
||||||
|
# have completely different registration flows to normal users
|
||||||
|
|
||||||
|
# == Application Service Registration ==
|
||||||
|
if appservice:
|
||||||
|
result = yield self._do_appservice_registration(
|
||||||
|
desired_username, request.args["access_token"][0]
|
||||||
|
)
|
||||||
|
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||||
|
return
|
||||||
|
|
||||||
|
# == Shared Secret Registration == (e.g. create new user scripts)
|
||||||
|
if 'mac' in body:
|
||||||
|
# FIXME: Should we really be determining if this is shared secret
|
||||||
|
# auth based purely on the 'mac' key?
|
||||||
|
result = yield self._do_shared_secret_registration(
|
||||||
|
desired_username, desired_password, body["mac"]
|
||||||
|
)
|
||||||
|
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||||
|
return
|
||||||
|
|
||||||
|
# == Normal User Registration == (everyone else)
|
||||||
|
if self.hs.config.disable_registration:
|
||||||
|
raise SynapseError(403, "Registration has been disabled")
|
||||||
|
|
||||||
|
if desired_username is not None:
|
||||||
|
yield self.registration_handler.check_username(desired_username)
|
||||||
|
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
flows = [
|
flows = [
|
||||||
|
@ -91,39 +116,20 @@ class RegisterRestServlet(RestServlet):
|
||||||
[LoginType.EMAIL_IDENTITY]
|
[LoginType.EMAIL_IDENTITY]
|
||||||
]
|
]
|
||||||
|
|
||||||
result = None
|
authed, result, params = yield self.auth_handler.check_auth(
|
||||||
if service:
|
flows, body, self.hs.get_ip_from_request(request)
|
||||||
is_application_server = True
|
|
||||||
params = body
|
|
||||||
elif 'mac' in body:
|
|
||||||
# Check registration-specific shared secret auth
|
|
||||||
if 'username' not in body:
|
|
||||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
|
||||||
self._check_shared_secret_auth(
|
|
||||||
body['username'], body['mac']
|
|
||||||
)
|
|
||||||
is_using_shared_secret = True
|
|
||||||
params = body
|
|
||||||
else:
|
|
||||||
authed, result, params = yield self.auth_handler.check_auth(
|
|
||||||
flows, body, self.hs.get_ip_from_request(request)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not authed:
|
|
||||||
defer.returnValue((401, result))
|
|
||||||
|
|
||||||
can_register = (
|
|
||||||
not self.hs.config.disable_registration
|
|
||||||
or is_application_server
|
|
||||||
or is_using_shared_secret
|
|
||||||
)
|
)
|
||||||
if not can_register:
|
|
||||||
raise SynapseError(403, "Registration has been disabled")
|
|
||||||
|
|
||||||
|
if not authed:
|
||||||
|
defer.returnValue((401, result))
|
||||||
|
return
|
||||||
|
|
||||||
|
# NB: This may be from the auth handler and NOT from the POST
|
||||||
if 'password' not in params:
|
if 'password' not in params:
|
||||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
|
||||||
desired_username = params['username'] if 'username' in params else None
|
|
||||||
new_password = params['password']
|
desired_username = params.get("username", None)
|
||||||
|
new_password = params.get("password", None)
|
||||||
|
|
||||||
(user_id, token) = yield self.registration_handler.register(
|
(user_id, token) = yield self.registration_handler.register(
|
||||||
localpart=desired_username,
|
localpart=desired_username,
|
||||||
|
@ -156,18 +162,21 @@ class RegisterRestServlet(RestServlet):
|
||||||
else:
|
else:
|
||||||
logger.info("bind_email not specified: not binding email")
|
logger.info("bind_email not specified: not binding email")
|
||||||
|
|
||||||
result = {
|
result = self._create_registration_details(user_id, token)
|
||||||
"user_id": user_id,
|
|
||||||
"access_token": token,
|
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
}
|
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
def on_OPTIONS(self, _):
|
def on_OPTIONS(self, _):
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
def _check_shared_secret_auth(self, username, mac):
|
@defer.inlineCallbacks
|
||||||
|
def _do_appservice_registration(self, username, as_token):
|
||||||
|
(user_id, token) = yield self.registration_handler.appservice_register(
|
||||||
|
username, as_token
|
||||||
|
)
|
||||||
|
defer.returnValue(self._create_registration_details(user_id, token))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _do_shared_secret_registration(self, username, password, mac):
|
||||||
if not self.hs.config.registration_shared_secret:
|
if not self.hs.config.registration_shared_secret:
|
||||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||||
|
|
||||||
|
@ -183,13 +192,23 @@ class RegisterRestServlet(RestServlet):
|
||||||
digestmod=sha1,
|
digestmod=sha1,
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
if compare_digest(want_mac, got_mac):
|
if not compare_digest(want_mac, got_mac):
|
||||||
return True
|
|
||||||
else:
|
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403, "HMAC incorrect",
|
403, "HMAC incorrect",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
(user_id, token) = yield self.registration_handler.register(
|
||||||
|
localpart=username, password=password
|
||||||
|
)
|
||||||
|
defer.returnValue(self._create_registration_details(user_id, token))
|
||||||
|
|
||||||
|
def _create_registration_details(self, user_id, token):
|
||||||
|
return {
|
||||||
|
"user_id": user_id,
|
||||||
|
"access_token": token,
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
RegisterRestServlet(hs).register(http_server)
|
RegisterRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -244,43 +244,52 @@ class BaseMediaResource(Resource):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
scales = set()
|
local_thumbnails = []
|
||||||
crops = set()
|
|
||||||
for r_width, r_height, r_method, r_type in requirements:
|
def generate_thumbnails():
|
||||||
if r_method == "scale":
|
scales = set()
|
||||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
crops = set()
|
||||||
scales.add((
|
for r_width, r_height, r_method, r_type in requirements:
|
||||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
if r_method == "scale":
|
||||||
|
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||||
|
scales.add((
|
||||||
|
min(m_width, t_width), min(m_height, t_height), r_type,
|
||||||
|
))
|
||||||
|
elif r_method == "crop":
|
||||||
|
crops.add((r_width, r_height, r_type))
|
||||||
|
|
||||||
|
for t_width, t_height, t_type in scales:
|
||||||
|
t_method = "scale"
|
||||||
|
t_path = self.filepaths.local_media_thumbnail(
|
||||||
|
media_id, t_width, t_height, t_type, t_method
|
||||||
|
)
|
||||||
|
self._makedirs(t_path)
|
||||||
|
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||||
|
|
||||||
|
local_thumbnails.append((
|
||||||
|
media_id, t_width, t_height, t_type, t_method, t_len
|
||||||
))
|
))
|
||||||
elif r_method == "crop":
|
|
||||||
crops.add((r_width, r_height, r_type))
|
|
||||||
|
|
||||||
for t_width, t_height, t_type in scales:
|
for t_width, t_height, t_type in crops:
|
||||||
t_method = "scale"
|
if (t_width, t_height, t_type) in scales:
|
||||||
t_path = self.filepaths.local_media_thumbnail(
|
# If the aspect ratio of the cropped thumbnail matches a purely
|
||||||
media_id, t_width, t_height, t_type, t_method
|
# scaled one then there is no point in calculating a separate
|
||||||
)
|
# thumbnail.
|
||||||
self._makedirs(t_path)
|
continue
|
||||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
t_method = "crop"
|
||||||
yield self.store.store_local_thumbnail(
|
t_path = self.filepaths.local_media_thumbnail(
|
||||||
media_id, t_width, t_height, t_type, t_method, t_len
|
media_id, t_width, t_height, t_type, t_method
|
||||||
)
|
)
|
||||||
|
self._makedirs(t_path)
|
||||||
|
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||||
|
local_thumbnails.append((
|
||||||
|
media_id, t_width, t_height, t_type, t_method, t_len
|
||||||
|
))
|
||||||
|
|
||||||
for t_width, t_height, t_type in crops:
|
yield threads.deferToThread(generate_thumbnails)
|
||||||
if (t_width, t_height, t_type) in scales:
|
|
||||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
for l in local_thumbnails:
|
||||||
# scaled one then there is no point in calculating a separate
|
yield self.store.store_local_thumbnail(*l)
|
||||||
# thumbnail.
|
|
||||||
continue
|
|
||||||
t_method = "crop"
|
|
||||||
t_path = self.filepaths.local_media_thumbnail(
|
|
||||||
media_id, t_width, t_height, t_type, t_method
|
|
||||||
)
|
|
||||||
self._makedirs(t_path)
|
|
||||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
|
||||||
yield self.store.store_local_thumbnail(
|
|
||||||
media_id, t_width, t_height, t_type, t_method, t_len
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"width": m_width,
|
"width": m_width,
|
||||||
|
|
|
@ -162,11 +162,12 @@ class ThumbnailResource(BaseMediaResource):
|
||||||
t_method = info["thumbnail_method"]
|
t_method = info["thumbnail_method"]
|
||||||
if t_method == "scale" or t_method == "crop":
|
if t_method == "scale" or t_method == "crop":
|
||||||
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
||||||
|
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
|
||||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||||
type_quality = desired_type != info["thumbnail_type"]
|
type_quality = desired_type != info["thumbnail_type"]
|
||||||
length_quality = info["thumbnail_length"]
|
length_quality = info["thumbnail_length"]
|
||||||
info_list.append((
|
info_list.append((
|
||||||
aspect_quality, size_quality, type_quality,
|
aspect_quality, min_quality, size_quality, type_quality,
|
||||||
length_quality, info
|
length_quality, info
|
||||||
))
|
))
|
||||||
if info_list:
|
if info_list:
|
||||||
|
|
|
@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
key = (user.to_string(), access_token, device_id, ip)
|
key = (user.to_string(), access_token, device_id, ip)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
last_seen = self.client_ip_last_seen.get(*key)
|
last_seen = self.client_ip_last_seen.get(key)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
last_seen = None
|
last_seen = None
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
self.client_ip_last_seen.prefill(*key + (now,))
|
self.client_ip_last_seen.prefill(key, now)
|
||||||
|
|
||||||
# It's safe not to lock here: a) no unique constraint,
|
# It's safe not to lock here: a) no unique constraint,
|
||||||
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
|
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
|
||||||
|
@ -354,6 +354,11 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||||
)
|
)
|
||||||
logger.debug("Running script %s", relative_path)
|
logger.debug("Running script %s", relative_path)
|
||||||
module.run_upgrade(cur, database_engine)
|
module.run_upgrade(cur, database_engine)
|
||||||
|
elif ext == ".pyc":
|
||||||
|
# Sometimes .pyc files turn up anyway even though we've
|
||||||
|
# disabled their generation; e.g. from distribution package
|
||||||
|
# installers. Silently skip it
|
||||||
|
pass
|
||||||
elif ext == ".sql":
|
elif ext == ".sql":
|
||||||
# A plain old .sql file, just read and execute it
|
# A plain old .sql file, just read and execute it
|
||||||
logger.debug("Applying schema %s", relative_path)
|
logger.debug("Applying schema %s", relative_path)
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
||||||
from synapse.util.lrucache import LruCache
|
from synapse.util.lrucache import LruCache
|
||||||
|
@ -27,6 +28,7 @@ from twisted.internet import defer
|
||||||
from collections import namedtuple, OrderedDict
|
from collections import namedtuple, OrderedDict
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
|
@ -55,9 +57,12 @@ cache_counter = metrics.register_cache(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_CacheSentinel = object()
|
||||||
|
|
||||||
|
|
||||||
class Cache(object):
|
class Cache(object):
|
||||||
|
|
||||||
def __init__(self, name, max_entries=1000, keylen=1, lru=False):
|
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
|
||||||
if lru:
|
if lru:
|
||||||
self.cache = LruCache(max_size=max_entries)
|
self.cache = LruCache(max_size=max_entries)
|
||||||
self.max_entries = None
|
self.max_entries = None
|
||||||
|
@ -81,45 +86,44 @@ class Cache(object):
|
||||||
"Cache objects can only be accessed from the main thread"
|
"Cache objects can only be accessed from the main thread"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(self, *keyargs):
|
def get(self, key, default=_CacheSentinel):
|
||||||
if len(keyargs) != self.keylen:
|
val = self.cache.get(key, _CacheSentinel)
|
||||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
if val is not _CacheSentinel:
|
||||||
|
|
||||||
if keyargs in self.cache:
|
|
||||||
cache_counter.inc_hits(self.name)
|
cache_counter.inc_hits(self.name)
|
||||||
return self.cache[keyargs]
|
return val
|
||||||
|
|
||||||
cache_counter.inc_misses(self.name)
|
cache_counter.inc_misses(self.name)
|
||||||
raise KeyError()
|
|
||||||
|
|
||||||
def update(self, sequence, *args):
|
if default is _CacheSentinel:
|
||||||
|
raise KeyError()
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
|
||||||
|
def update(self, sequence, key, value):
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
if self.sequence == sequence:
|
if self.sequence == sequence:
|
||||||
# Only update the cache if the caches sequence number matches the
|
# Only update the cache if the caches sequence number matches the
|
||||||
# number that the cache had before the SELECT was started (SYN-369)
|
# number that the cache had before the SELECT was started (SYN-369)
|
||||||
self.prefill(*args)
|
self.prefill(key, value)
|
||||||
|
|
||||||
def prefill(self, *args): # because I can't *keyargs, value
|
|
||||||
keyargs = args[:-1]
|
|
||||||
value = args[-1]
|
|
||||||
|
|
||||||
if len(keyargs) != self.keylen:
|
|
||||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
|
||||||
|
|
||||||
|
def prefill(self, key, value):
|
||||||
if self.max_entries is not None:
|
if self.max_entries is not None:
|
||||||
while len(self.cache) >= self.max_entries:
|
while len(self.cache) >= self.max_entries:
|
||||||
self.cache.popitem(last=False)
|
self.cache.popitem(last=False)
|
||||||
|
|
||||||
self.cache[keyargs] = value
|
self.cache[key] = value
|
||||||
|
|
||||||
def invalidate(self, *keyargs):
|
def invalidate(self, key):
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
if len(keyargs) != self.keylen:
|
if not isinstance(key, tuple):
|
||||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
raise TypeError(
|
||||||
|
"The cache key must be a tuple not %r" % (type(key),)
|
||||||
|
)
|
||||||
|
|
||||||
# Increment the sequence number so that any SELECT statements that
|
# Increment the sequence number so that any SELECT statements that
|
||||||
# raced with the INSERT don't update the cache (SYN-369)
|
# raced with the INSERT don't update the cache (SYN-369)
|
||||||
self.sequence += 1
|
self.sequence += 1
|
||||||
self.cache.pop(keyargs, None)
|
self.cache.pop(key, None)
|
||||||
|
|
||||||
def invalidate_all(self):
|
def invalidate_all(self):
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
|
@ -130,6 +134,9 @@ class Cache(object):
|
||||||
class CacheDescriptor(object):
|
class CacheDescriptor(object):
|
||||||
""" A method decorator that applies a memoizing cache around the function.
|
""" A method decorator that applies a memoizing cache around the function.
|
||||||
|
|
||||||
|
This caches deferreds, rather than the results themselves. Deferreds that
|
||||||
|
fail are removed from the cache.
|
||||||
|
|
||||||
The function is presumed to take zero or more arguments, which are used in
|
The function is presumed to take zero or more arguments, which are used in
|
||||||
a tuple as the key for the cache. Hits are served directly from the cache;
|
a tuple as the key for the cache. Hits are served directly from the cache;
|
||||||
misses use the function body to generate the value.
|
misses use the function body to generate the value.
|
||||||
|
@ -141,58 +148,92 @@ class CacheDescriptor(object):
|
||||||
which can be used to insert values into the cache specifically, without
|
which can be used to insert values into the cache specifically, without
|
||||||
calling the calculation function.
|
calling the calculation function.
|
||||||
"""
|
"""
|
||||||
def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
|
def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
|
||||||
|
inlineCallbacks=False):
|
||||||
self.orig = orig
|
self.orig = orig
|
||||||
|
|
||||||
|
if inlineCallbacks:
|
||||||
|
self.function_to_call = defer.inlineCallbacks(orig)
|
||||||
|
else:
|
||||||
|
self.function_to_call = orig
|
||||||
|
|
||||||
self.max_entries = max_entries
|
self.max_entries = max_entries
|
||||||
self.num_args = num_args
|
self.num_args = num_args
|
||||||
self.lru = lru
|
self.lru = lru
|
||||||
|
|
||||||
def __get__(self, obj, objtype=None):
|
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
|
||||||
cache = Cache(
|
|
||||||
|
if len(self.arg_names) < self.num_args:
|
||||||
|
raise Exception(
|
||||||
|
"Not enough explicit positional arguments to key off of for %r."
|
||||||
|
" (@cached cannot key off of *args or **kwars)"
|
||||||
|
% (orig.__name__,)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cache = Cache(
|
||||||
name=self.orig.__name__,
|
name=self.orig.__name__,
|
||||||
max_entries=self.max_entries,
|
max_entries=self.max_entries,
|
||||||
keylen=self.num_args,
|
keylen=self.num_args,
|
||||||
lru=self.lru,
|
lru=self.lru,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __get__(self, obj, objtype=None):
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
@defer.inlineCallbacks
|
def wrapped(*args, **kwargs):
|
||||||
def wrapped(*keyargs):
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
|
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
||||||
try:
|
try:
|
||||||
cached_result = cache.get(*keyargs[:self.num_args])
|
cached_result_d = self.cache.get(cache_key)
|
||||||
|
|
||||||
|
observer = cached_result_d.observe()
|
||||||
if DEBUG_CACHES:
|
if DEBUG_CACHES:
|
||||||
actual_result = yield self.orig(obj, *keyargs)
|
@defer.inlineCallbacks
|
||||||
if actual_result != cached_result:
|
def check_result(cached_result):
|
||||||
logger.error(
|
actual_result = yield self.function_to_call(obj, *args, **kwargs)
|
||||||
"Stale cache entry %s%r: cached: %r, actual %r",
|
if actual_result != cached_result:
|
||||||
self.orig.__name__, keyargs,
|
logger.error(
|
||||||
cached_result, actual_result,
|
"Stale cache entry %s%r: cached: %r, actual %r",
|
||||||
)
|
self.orig.__name__, cache_key,
|
||||||
raise ValueError("Stale cache entry")
|
cached_result, actual_result,
|
||||||
defer.returnValue(cached_result)
|
)
|
||||||
|
raise ValueError("Stale cache entry")
|
||||||
|
defer.returnValue(cached_result)
|
||||||
|
observer.addCallback(check_result)
|
||||||
|
|
||||||
|
return observer
|
||||||
except KeyError:
|
except KeyError:
|
||||||
# Get the sequence number of the cache before reading from the
|
# Get the sequence number of the cache before reading from the
|
||||||
# database so that we can tell if the cache is invalidated
|
# database so that we can tell if the cache is invalidated
|
||||||
# while the SELECT is executing (SYN-369)
|
# while the SELECT is executing (SYN-369)
|
||||||
sequence = cache.sequence
|
sequence = self.cache.sequence
|
||||||
|
|
||||||
ret = yield self.orig(obj, *keyargs)
|
ret = defer.maybeDeferred(
|
||||||
|
self.function_to_call,
|
||||||
|
obj, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
cache.update(sequence, *keyargs[:self.num_args] + (ret,))
|
def onErr(f):
|
||||||
|
self.cache.invalidate(cache_key)
|
||||||
|
return f
|
||||||
|
|
||||||
defer.returnValue(ret)
|
ret.addErrback(onErr)
|
||||||
|
|
||||||
wrapped.invalidate = cache.invalidate
|
ret = ObservableDeferred(ret, consumeErrors=True)
|
||||||
wrapped.invalidate_all = cache.invalidate_all
|
self.cache.update(sequence, cache_key, ret)
|
||||||
wrapped.prefill = cache.prefill
|
|
||||||
|
return ret.observe()
|
||||||
|
|
||||||
|
wrapped.invalidate = self.cache.invalidate
|
||||||
|
wrapped.invalidate_all = self.cache.invalidate_all
|
||||||
|
wrapped.prefill = self.cache.prefill
|
||||||
|
|
||||||
obj.__dict__[self.orig.__name__] = wrapped
|
obj.__dict__[self.orig.__name__] = wrapped
|
||||||
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
def cached(max_entries=1000, num_args=1, lru=False):
|
def cached(max_entries=1000, num_args=1, lru=True):
|
||||||
return lambda orig: CacheDescriptor(
|
return lambda orig: CacheDescriptor(
|
||||||
orig,
|
orig,
|
||||||
max_entries=max_entries,
|
max_entries=max_entries,
|
||||||
|
@ -201,6 +242,16 @@ def cached(max_entries=1000, num_args=1, lru=False):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
|
||||||
|
return lambda orig: CacheDescriptor(
|
||||||
|
orig,
|
||||||
|
max_entries=max_entries,
|
||||||
|
num_args=num_args,
|
||||||
|
lru=lru,
|
||||||
|
inlineCallbacks=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LoggingTransaction(object):
|
class LoggingTransaction(object):
|
||||||
"""An object that almost-transparently proxies for the 'txn' object
|
"""An object that almost-transparently proxies for the 'txn' object
|
||||||
passed to the constructor. Adds logging and metrics to the .execute()
|
passed to the constructor. Adds logging and metrics to the .execute()
|
||||||
|
|
|
@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore):
|
||||||
},
|
},
|
||||||
desc="create_room_alias_association",
|
desc="create_room_alias_association",
|
||||||
)
|
)
|
||||||
self.get_aliases_for_room.invalidate(room_id)
|
self.get_aliases_for_room.invalidate((room_id,))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_room_alias(self, room_alias):
|
def delete_room_alias(self, room_alias):
|
||||||
|
@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore):
|
||||||
room_alias,
|
room_alias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_aliases_for_room.invalidate(room_id)
|
self.get_aliases_for_room.invalidate((room_id,))
|
||||||
defer.returnValue(room_id)
|
defer.returnValue(room_id)
|
||||||
|
|
||||||
def _delete_room_alias_txn(self, txn, room_alias):
|
def _delete_room_alias_txn(self, txn, room_alias):
|
||||||
|
|
|
@ -362,7 +362,7 @@ class EventFederationStore(SQLBaseStore):
|
||||||
|
|
||||||
for room_id in events_by_room:
|
for room_id in events_by_room:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_latest_event_ids_in_room.invalidate, room_id
|
self.get_latest_event_ids_in_room.invalidate, (room_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_backfill_events(self, room_id, event_list, limit):
|
def get_backfill_events(self, room_id, event_list, limit):
|
||||||
|
@ -505,4 +505,4 @@ class EventFederationStore(SQLBaseStore):
|
||||||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||||
|
|
||||||
txn.execute(query, (room_id,))
|
txn.execute(query, (room_id,))
|
||||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
|
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
|
||||||
|
|
|
@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore):
|
||||||
if current_state:
|
if current_state:
|
||||||
txn.call_after(self.get_current_state_for_key.invalidate_all)
|
txn.call_after(self.get_current_state_for_key.invalidate_all)
|
||||||
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
||||||
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
|
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(self.get_room_name_and_aliases, event.room_id)
|
txn.call_after(self.get_room_name_and_aliases, event.room_id)
|
||||||
|
|
||||||
self._simple_delete_txn(
|
self._simple_delete_txn(
|
||||||
|
@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore):
|
||||||
if not context.rejected:
|
if not context.rejected:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_current_state_for_key.invalidate,
|
self.get_current_state_for_key.invalidate,
|
||||||
event.room_id, event.type, event.state_key
|
(event.room_id, event.type, event.state_key,)
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type in [EventTypes.Name, EventTypes.Aliases]:
|
if event.type in [EventTypes.Name, EventTypes.Aliases]:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_room_name_and_aliases.invalidate,
|
self.get_room_name_and_aliases.invalidate,
|
||||||
event.room_id
|
(event.room_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._simple_upsert_txn(
|
self._simple_upsert_txn(
|
||||||
|
@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore):
|
||||||
def _invalidate_get_event_cache(self, event_id):
|
def _invalidate_get_event_cache(self, event_id):
|
||||||
for check_redacted in (False, True):
|
for check_redacted in (False, True):
|
||||||
for get_prev_content in (False, True):
|
for get_prev_content in (False, True):
|
||||||
self._get_event_cache.invalidate(event_id, check_redacted,
|
self._get_event_cache.invalidate(
|
||||||
get_prev_content)
|
(event_id, check_redacted, get_prev_content)
|
||||||
|
)
|
||||||
|
|
||||||
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
||||||
get_prev_content=False, allow_rejected=False):
|
get_prev_content=False, allow_rejected=False):
|
||||||
|
@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore):
|
||||||
for event_id in events:
|
for event_id in events:
|
||||||
try:
|
try:
|
||||||
ret = self._get_event_cache.get(
|
ret = self._get_event_cache.get(
|
||||||
event_id, check_redacted, get_prev_content
|
(event_id, check_redacted, get_prev_content,)
|
||||||
)
|
)
|
||||||
|
|
||||||
if allow_rejected or not ret.rejected_reason:
|
if allow_rejected or not ret.rejected_reason:
|
||||||
|
@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore):
|
||||||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||||
|
|
||||||
self._get_event_cache.prefill(
|
self._get_event_cache.prefill(
|
||||||
ev.event_id, check_redacted, get_prev_content, ev
|
(ev.event_id, check_redacted, get_prev_content), ev
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(ev)
|
defer.returnValue(ev)
|
||||||
|
@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore):
|
||||||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||||
|
|
||||||
self._get_event_cache.prefill(
|
self._get_event_cache.prefill(
|
||||||
ev.event_id, check_redacted, get_prev_content, ev
|
(ev.event_id, check_redacted, get_prev_content), ev
|
||||||
)
|
)
|
||||||
|
|
||||||
return ev
|
return ev
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from _base import SQLBaseStore, cached
|
from _base import SQLBaseStore, cachedInlineCallbacks
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -71,8 +71,7 @@ class KeyStore(SQLBaseStore):
|
||||||
desc="store_server_certificate",
|
desc="store_server_certificate",
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached()
|
@cachedInlineCallbacks()
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_all_server_verify_keys(self, server_name):
|
def get_all_server_verify_keys(self, server_name):
|
||||||
rows = yield self._simple_select_list(
|
rows = yield self._simple_select_list(
|
||||||
table="server_signature_keys",
|
table="server_signature_keys",
|
||||||
|
@ -132,7 +131,7 @@ class KeyStore(SQLBaseStore):
|
||||||
desc="store_server_verify_key",
|
desc="store_server_verify_key",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_all_server_verify_keys.invalidate(server_name)
|
self.get_all_server_verify_keys.invalidate((server_name,))
|
||||||
|
|
||||||
def store_server_keys_json(self, server_name, key_id, from_server,
|
def store_server_keys_json(self, server_name, key_id, from_server,
|
||||||
ts_now_ms, ts_expires_ms, key_json_bytes):
|
ts_now_ms, ts_expires_ms, key_json_bytes):
|
||||||
|
|
|
@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
updatevalues={"accepted": True},
|
updatevalues={"accepted": True},
|
||||||
desc="set_presence_list_accepted",
|
desc="set_presence_list_accepted",
|
||||||
)
|
)
|
||||||
self.get_presence_list_accepted.invalidate(observer_localpart)
|
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
def get_presence_list(self, observer_localpart, accepted=None):
|
def get_presence_list(self, observer_localpart, accepted=None):
|
||||||
|
@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore):
|
||||||
"observed_user_id": observed_userid},
|
"observed_user_id": observed_userid},
|
||||||
desc="del_presence_list",
|
desc="del_presence_list",
|
||||||
)
|
)
|
||||||
self.get_presence_list_accepted.invalidate(observer_localpart)
|
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PushRuleStore(SQLBaseStore):
|
class PushRuleStore(SQLBaseStore):
|
||||||
@cached()
|
@cachedInlineCallbacks()
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_push_rules_for_user(self, user_name):
|
def get_push_rules_for_user(self, user_name):
|
||||||
rows = yield self._simple_select_list(
|
rows = yield self._simple_select_list(
|
||||||
table=PushRuleTable.table_name,
|
table=PushRuleTable.table_name,
|
||||||
|
@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue(rows)
|
defer.returnValue(rows)
|
||||||
|
|
||||||
@cached()
|
@cachedInlineCallbacks()
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_push_rules_enabled_for_user(self, user_name):
|
def get_push_rules_enabled_for_user(self, user_name):
|
||||||
results = yield self._simple_select_list(
|
results = yield self._simple_select_list(
|
||||||
table=PushRuleEnableTable.table_name,
|
table=PushRuleEnableTable.table_name,
|
||||||
|
@ -153,11 +151,11 @@ class PushRuleStore(SQLBaseStore):
|
||||||
txn.execute(sql, (user_name, priority_class, new_rule_priority))
|
txn.execute(sql, (user_name, priority_class, new_rule_priority))
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_for_user.invalidate, user_name
|
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
|
@ -189,10 +187,10 @@ class PushRuleStore(SQLBaseStore):
|
||||||
new_rule['priority'] = new_prio
|
new_rule['priority'] = new_prio
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_for_user.invalidate, user_name
|
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
|
@ -218,8 +216,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
desc="delete_push_rule",
|
desc="delete_push_rule",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_push_rules_for_user.invalidate(user_name)
|
self.get_push_rules_for_user.invalidate((user_name,))
|
||||||
self.get_push_rules_enabled_for_user.invalidate(user_name)
|
self.get_push_rules_enabled_for_user.invalidate((user_name,))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_push_rule_enabled(self, user_name, rule_id, enabled):
|
def set_push_rule_enabled(self, user_name, rule_id, enabled):
|
||||||
|
@ -240,10 +238,10 @@ class PushRuleStore(SQLBaseStore):
|
||||||
{'id': new_id},
|
{'id': new_id},
|
||||||
)
|
)
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_for_user.invalidate, user_name
|
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -128,8 +128,7 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
def get_max_receipt_stream_id(self):
|
def get_max_receipt_stream_id(self):
|
||||||
return self._receipts_id_gen.get_max_token(self)
|
return self._receipts_id_gen.get_max_token(self)
|
||||||
|
|
||||||
@cached
|
@cachedInlineCallbacks()
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_graph_receipts_for_room(self, room_id):
|
def get_graph_receipts_for_room(self, room_id):
|
||||||
"""Get receipts for sending to remote servers.
|
"""Get receipts for sending to remote servers.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
user_id
|
user_id
|
||||||
)
|
)
|
||||||
for r in rows:
|
for r in rows:
|
||||||
self.get_user_by_token.invalidate(r)
|
self.get_user_by_token.invalidate((r,))
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_user_by_token(self, token):
|
def get_user_by_token(self, token):
|
||||||
|
|
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
|
@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached()
|
@cachedInlineCallbacks()
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_room_name_and_aliases(self, room_id):
|
def get_room_name_and_aliases(self, room_id):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
sql = (
|
sql = (
|
||||||
|
|
|
@ -54,9 +54,9 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
txn.call_after(self.get_rooms_for_user.invalidate, event.state_key)
|
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
|
||||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
|
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||||
|
|
||||||
def get_room_member(self, user_id, room_id):
|
def get_room_member(self, user_id, room_id):
|
||||||
"""Retrieve the current state of a room member.
|
"""Retrieve the current state of a room member.
|
||||||
|
@ -78,7 +78,7 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
lambda events: events[0] if events else None
|
lambda events: events[0] if events else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached()
|
@cached(max_entries=5000)
|
||||||
def get_users_in_room(self, room_id):
|
def get_users_in_room(self, room_id):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
|
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
|
||||||
]
|
]
|
||||||
|
|
||||||
@cached()
|
@cached(max_entries=5000)
|
||||||
def get_joined_hosts_for_room(self, room_id):
|
def get_joined_hosts_for_room(self, room_id):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_joined_hosts_for_room",
|
"get_joined_hosts_for_room",
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -91,7 +91,6 @@ class StateStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue(dict(state_list))
|
defer.returnValue(dict(state_list))
|
||||||
|
|
||||||
@cached(num_args=1)
|
|
||||||
def _fetch_events_for_group(self, key, events):
|
def _fetch_events_for_group(self, key, events):
|
||||||
return self._get_events(
|
return self._get_events(
|
||||||
events, get_prev_content=False
|
events, get_prev_content=False
|
||||||
|
@ -189,8 +188,7 @@ class StateStore(SQLBaseStore):
|
||||||
events = yield self._get_events(event_ids, get_prev_content=False)
|
events = yield self._get_events(event_ids, get_prev_content=False)
|
||||||
defer.returnValue(events)
|
defer.returnValue(events)
|
||||||
|
|
||||||
@cached(num_args=3)
|
@cachedInlineCallbacks(num_args=3)
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_current_state_for_key(self, room_id, event_type, state_key):
|
def get_current_state_for_key(self, room_id, event_type, state_key):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
sql = (
|
sql = (
|
||||||
|
|
|
@ -178,7 +178,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
||||||
|
|
||||||
Live tokens start with an "s" followed by the "stream_ordering" id of the
|
Live tokens start with an "s" followed by the "stream_ordering" id of the
|
||||||
event it comes after. Historic tokens start with a "t" followed by the
|
event it comes after. Historic tokens start with a "t" followed by the
|
||||||
"topological_ordering" id of the event it comes after, follewed by "-",
|
"topological_ordering" id of the event it comes after, followed by "-",
|
||||||
followed by the "stream_ordering" id of the event it comes after.
|
followed by the "stream_ordering" id of the event it comes after.
|
||||||
"""
|
"""
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
|
@ -211,4 +211,5 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
||||||
return "s%d" % (self.stream,)
|
return "s%d" % (self.stream,)
|
||||||
|
|
||||||
|
|
||||||
|
# token_id is the primary key ID of the access token, not the access token itself.
|
||||||
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
|
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
|
||||||
|
|
|
@ -51,7 +51,7 @@ class ObservableDeferred(object):
|
||||||
object.__setattr__(self, "_observers", set())
|
object.__setattr__(self, "_observers", set())
|
||||||
|
|
||||||
def callback(r):
|
def callback(r):
|
||||||
self._result = (True, r)
|
object.__setattr__(self, "_result", (True, r))
|
||||||
while self._observers:
|
while self._observers:
|
||||||
try:
|
try:
|
||||||
self._observers.pop().callback(r)
|
self._observers.pop().callback(r)
|
||||||
|
@ -60,7 +60,7 @@ class ObservableDeferred(object):
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def errback(f):
|
def errback(f):
|
||||||
self._result = (False, f)
|
object.__setattr__(self, "_result", (False, f))
|
||||||
while self._observers:
|
while self._observers:
|
||||||
try:
|
try:
|
||||||
self._observers.pop().errback(f)
|
self._observers.pop().errback(f)
|
||||||
|
@ -97,3 +97,8 @@ class ObservableDeferred(object):
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
setattr(self._deferred, name, value)
|
setattr(self._deferred, name, value)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
|
||||||
|
id(self), self._result, self._deferred,
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,134 @@
|
||||||
|
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from twisted.internet import defer
|
||||||
|
from mock import Mock, MagicMock
|
||||||
|
from tests import unittest
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# do the dance to hook up request data to self.request_data
|
||||||
|
self.request_data = ""
|
||||||
|
self.request = Mock(
|
||||||
|
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
|
||||||
|
)
|
||||||
|
self.request.args = {}
|
||||||
|
|
||||||
|
self.appservice = None
|
||||||
|
self.auth = Mock(get_appservice_by_req=Mock(
|
||||||
|
side_effect=lambda x: defer.succeed(self.appservice))
|
||||||
|
)
|
||||||
|
|
||||||
|
self.auth_result = (False, None, None)
|
||||||
|
self.auth_handler = Mock(
|
||||||
|
check_auth=Mock(side_effect=lambda x,y,z: self.auth_result)
|
||||||
|
)
|
||||||
|
self.registration_handler = Mock()
|
||||||
|
self.identity_handler = Mock()
|
||||||
|
self.login_handler = Mock()
|
||||||
|
|
||||||
|
# do the dance to hook it up to the hs global
|
||||||
|
self.handlers = Mock(
|
||||||
|
auth_handler=self.auth_handler,
|
||||||
|
registration_handler=self.registration_handler,
|
||||||
|
identity_handler=self.identity_handler,
|
||||||
|
login_handler=self.login_handler
|
||||||
|
)
|
||||||
|
self.hs = Mock()
|
||||||
|
self.hs.hostname = "superbig~testing~thing.com"
|
||||||
|
self.hs.get_auth = Mock(return_value=self.auth)
|
||||||
|
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||||
|
self.hs.config.disable_registration = False
|
||||||
|
|
||||||
|
# init the thing we're testing
|
||||||
|
self.servlet = RegisterRestServlet(self.hs)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_POST_appservice_registration_valid(self):
|
||||||
|
user_id = "@kermit:muppet"
|
||||||
|
token = "kermits_access_token"
|
||||||
|
self.request.args = {
|
||||||
|
"access_token": "i_am_an_app_service"
|
||||||
|
}
|
||||||
|
self.request_data = json.dumps({
|
||||||
|
"username": "kermit"
|
||||||
|
})
|
||||||
|
self.appservice = {
|
||||||
|
"id": "1234"
|
||||||
|
}
|
||||||
|
self.registration_handler.appservice_register = Mock(
|
||||||
|
return_value=(user_id, token)
|
||||||
|
)
|
||||||
|
result = yield self.servlet.on_POST(self.request)
|
||||||
|
self.assertEquals(result, (200, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"access_token": token,
|
||||||
|
"home_server": self.hs.hostname
|
||||||
|
}))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_POST_appservice_registration_invalid(self):
|
||||||
|
self.request.args = {
|
||||||
|
"access_token": "i_am_an_app_service"
|
||||||
|
}
|
||||||
|
self.request_data = json.dumps({
|
||||||
|
"username": "kermit"
|
||||||
|
})
|
||||||
|
self.appservice = None # no application service exists
|
||||||
|
result = yield self.servlet.on_POST(self.request)
|
||||||
|
self.assertEquals(result, (401, None))
|
||||||
|
|
||||||
|
def test_POST_bad_password(self):
|
||||||
|
self.request_data = json.dumps({
|
||||||
|
"username": "kermit",
|
||||||
|
"password": 666
|
||||||
|
})
|
||||||
|
d = self.servlet.on_POST(self.request)
|
||||||
|
return self.assertFailure(d, SynapseError)
|
||||||
|
|
||||||
|
def test_POST_bad_username(self):
|
||||||
|
self.request_data = json.dumps({
|
||||||
|
"username": 777,
|
||||||
|
"password": "monkey"
|
||||||
|
})
|
||||||
|
d = self.servlet.on_POST(self.request)
|
||||||
|
return self.assertFailure(d, SynapseError)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_POST_user_valid(self):
|
||||||
|
user_id = "@kermit:muppet"
|
||||||
|
token = "kermits_access_token"
|
||||||
|
self.request_data = json.dumps({
|
||||||
|
"username": "kermit",
|
||||||
|
"password": "monkey"
|
||||||
|
})
|
||||||
|
self.registration_handler.check_username = Mock(return_value=True)
|
||||||
|
self.auth_result = (True, None, {
|
||||||
|
"username": "kermit",
|
||||||
|
"password": "monkey"
|
||||||
|
})
|
||||||
|
self.registration_handler.register = Mock(return_value=(user_id, token))
|
||||||
|
|
||||||
|
result = yield self.servlet.on_POST(self.request)
|
||||||
|
self.assertEquals(result, (200, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"access_token": token,
|
||||||
|
"home_server": self.hs.hostname
|
||||||
|
}))
|
||||||
|
|
||||||
|
def test_POST_disabled_registration(self):
|
||||||
|
self.hs.config.disable_registration = True
|
||||||
|
self.request_data = json.dumps({
|
||||||
|
"username": "kermit",
|
||||||
|
"password": "monkey"
|
||||||
|
})
|
||||||
|
self.registration_handler.check_username = Mock(return_value=True)
|
||||||
|
self.auth_result = (True, None, {
|
||||||
|
"username": "kermit",
|
||||||
|
"password": "monkey"
|
||||||
|
})
|
||||||
|
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
||||||
|
d = self.servlet.on_POST(self.request)
|
||||||
|
return self.assertFailure(d, SynapseError)
|
|
@ -17,6 +17,8 @@
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.util.async import ObservableDeferred
|
||||||
|
|
||||||
from synapse.storage._base import Cache, cached
|
from synapse.storage._base import Cache, cached
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
|
||||||
self.assertEquals(self.cache.get("foo"), 123)
|
self.assertEquals(self.cache.get("foo"), 123)
|
||||||
|
|
||||||
def test_invalidate(self):
|
def test_invalidate(self):
|
||||||
self.cache.prefill("foo", 123)
|
self.cache.prefill(("foo",), 123)
|
||||||
self.cache.invalidate("foo")
|
self.cache.invalidate(("foo",))
|
||||||
|
|
||||||
failed = False
|
failed = False
|
||||||
try:
|
try:
|
||||||
self.cache.get("foo")
|
self.cache.get(("foo",))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
failed = True
|
failed = True
|
||||||
|
|
||||||
|
@ -139,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
self.assertEquals(callcount[0], 1)
|
||||||
|
|
||||||
a.func.invalidate("foo")
|
a.func.invalidate(("foo",))
|
||||||
|
|
||||||
yield a.func("foo")
|
yield a.func("foo")
|
||||||
|
|
||||||
|
@ -151,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
||||||
def func(self, key):
|
def func(self, key):
|
||||||
return key
|
return key
|
||||||
|
|
||||||
A().func.invalidate("what")
|
A().func.invalidate(("what",))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_max_entries(self):
|
def test_max_entries(self):
|
||||||
|
@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
||||||
self.assertTrue(callcount[0] >= 14,
|
self.assertTrue(callcount[0] >= 14,
|
||||||
msg="Expected callcount >= 14, got %d" % (callcount[0]))
|
msg="Expected callcount >= 14, got %d" % (callcount[0]))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_prefill(self):
|
def test_prefill(self):
|
||||||
callcount = [0]
|
callcount = [0]
|
||||||
|
|
||||||
|
d = defer.succeed(123)
|
||||||
|
|
||||||
class A(object):
|
class A(object):
|
||||||
@cached()
|
@cached()
|
||||||
def func(self, key):
|
def func(self, key):
|
||||||
callcount[0] += 1
|
callcount[0] += 1
|
||||||
return key
|
return d
|
||||||
|
|
||||||
a = A()
|
a = A()
|
||||||
|
|
||||||
a.func.prefill("foo", 123)
|
a.func.prefill(("foo",), ObservableDeferred(d))
|
||||||
|
|
||||||
self.assertEquals((yield a.func("foo")), 123)
|
self.assertEquals(a.func("foo").result, d.result)
|
||||||
self.assertEquals(callcount[0], 0)
|
self.assertEquals(callcount[0], 0)
|
||||||
|
|
|
@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase):
|
||||||
yield d
|
yield d
|
||||||
self.assertTrue(d.called)
|
self.assertTrue(d.called)
|
||||||
|
|
||||||
observers[0].assert_called_once("Go")
|
observers[0].assert_called_once_with("Go")
|
||||||
observers[1].assert_called_once("Go")
|
observers[1].assert_called_once_with("Go")
|
||||||
|
|
||||||
self.assertEquals(mock_logger.warning.call_count, 1)
|
self.assertEquals(mock_logger.warning.call_count, 1)
|
||||||
self.assertIsInstance(mock_logger.warning.call_args[0][0],
|
self.assertIsInstance(mock_logger.warning.call_args[0][0],
|
||||||
|
|
Loading…
Reference in New Issue