Merge branch 'develop' into markjh/twisted-15
Conflicts: synapse/http/matrixfederationclient.py
This commit is contained in:
commit
998a72d4d9
|
@ -38,3 +38,10 @@ Brabo <brabo at riseup.net>
|
|||
|
||||
Ivan Shapovalov <intelfx100 at gmail.com>
|
||||
* contrib/systemd: a sample systemd unit file and a logger configuration
|
||||
|
||||
Eric Myhre <hash at exultant.us>
|
||||
* Fix bug where ``media_store_path`` config option was ignored by v0 content
|
||||
repository API.
|
||||
|
||||
Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com>
|
||||
* Add SAML2 support for registration and logins.
|
||||
|
|
51
CHANGES.rst
51
CHANGES.rst
|
@ -1,3 +1,54 @@
|
|||
Changes in synapse v0.9.3 (2015-07-01)
|
||||
======================================
|
||||
|
||||
No changes from v0.9.3 Release Candidate 1.
|
||||
|
||||
Changes in synapse v0.9.3-rc1 (2015-06-23)
|
||||
==========================================
|
||||
|
||||
General:
|
||||
|
||||
* Fix a memory leak in the notifier. (SYN-412)
|
||||
* Improve performance of room initial sync. (SYN-418)
|
||||
* General improvements to logging.
|
||||
* Remove ``access_token`` query params from ``INFO`` level logging.
|
||||
|
||||
Configuration:
|
||||
|
||||
* Add support for specifying and configuring multiple listeners. (SYN-389)
|
||||
|
||||
Application services:
|
||||
|
||||
* Fix bug where synapse failed to send user queries to application services.
|
||||
|
||||
Changes in synapse v0.9.2-r2 (2015-06-15)
|
||||
=========================================
|
||||
|
||||
Fix packaging so that schema delta python files get included in the package.
|
||||
|
||||
Changes in synapse v0.9.2 (2015-06-12)
|
||||
======================================
|
||||
|
||||
General:
|
||||
|
||||
* Use ultrajson for json (de)serialisation when a canonical encoding is not
|
||||
required. Ultrajson is significantly faster than simplejson in certain
|
||||
circumstances.
|
||||
* Use connection pools for outgoing HTTP connections.
|
||||
* Process thumbnails on separate threads.
|
||||
|
||||
Configuration:
|
||||
|
||||
* Add option, ``gzip_responses``, to disable HTTP response compression.
|
||||
|
||||
Federation:
|
||||
|
||||
* Improve resilience of backfill by ensuring we fetch any missing auth events.
|
||||
* Improve performance of backfill and joining remote rooms by removing
|
||||
unnecessary computations. This included handling events we'd previously
|
||||
handled as well as attempting to compute the current state for outliers.
|
||||
|
||||
|
||||
Changes in synapse v0.9.1 (2015-05-26)
|
||||
======================================
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ include *.rst
|
|||
include demo/README
|
||||
|
||||
recursive-include synapse/storage/schema *.sql
|
||||
recursive-include synapse/storage/schema *.py
|
||||
|
||||
recursive-include demo *.dh
|
||||
recursive-include demo *.py
|
||||
|
|
91
README.rst
91
README.rst
|
@ -101,36 +101,40 @@ header files for python C extensions.
|
|||
|
||||
Installing prerequisites on Ubuntu or Debian::
|
||||
|
||||
$ sudo apt-get install build-essential python2.7-dev libffi-dev \
|
||||
sudo apt-get install build-essential python2.7-dev libffi-dev \
|
||||
python-pip python-setuptools sqlite3 \
|
||||
libssl-dev python-virtualenv libjpeg-dev
|
||||
|
||||
Installing prerequisites on ArchLinux::
|
||||
|
||||
$ sudo pacman -S base-devel python2 python-pip \
|
||||
sudo pacman -S base-devel python2 python-pip \
|
||||
python-setuptools python-virtualenv sqlite3
|
||||
|
||||
Installing prerequisites on Mac OS X::
|
||||
|
||||
$ xcode-select --install
|
||||
$ sudo pip install virtualenv
|
||||
xcode-select --install
|
||||
sudo easy_install pip
|
||||
sudo pip install virtualenv
|
||||
|
||||
To install the synapse homeserver run::
|
||||
|
||||
$ virtualenv -p python2.7 ~/.synapse
|
||||
$ source ~/.synapse/bin/activate
|
||||
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
|
||||
virtualenv -p python2.7 ~/.synapse
|
||||
source ~/.synapse/bin/activate
|
||||
pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
|
||||
|
||||
This installs synapse, along with the libraries it uses, into a virtual
|
||||
environment under ``~/.synapse``.
|
||||
environment under ``~/.synapse``. Feel free to pick a different directory
|
||||
if you prefer.
|
||||
|
||||
In case of problems, please see the _Troubleshooting section below.
|
||||
|
||||
Alternatively, Silvio Fricke has contributed a Dockerfile to automate the
|
||||
above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
|
||||
|
||||
To set up your homeserver, run (in your virtualenv, as before)::
|
||||
|
||||
$ cd ~/.synapse
|
||||
$ python -m synapse.app.homeserver \
|
||||
cd ~/.synapse
|
||||
python -m synapse.app.homeserver \
|
||||
--server-name machine.my.domain.name \
|
||||
--config-path homeserver.yaml \
|
||||
--generate-config
|
||||
|
@ -189,9 +193,9 @@ Running Synapse
|
|||
To actually run your new homeserver, pick a working directory for Synapse to run
|
||||
(e.g. ``~/.synapse``), and::
|
||||
|
||||
$ cd ~/.synapse
|
||||
$ source ./bin/activate
|
||||
$ synctl start
|
||||
cd ~/.synapse
|
||||
source ./bin/activate
|
||||
synctl start
|
||||
|
||||
Platform Specific Instructions
|
||||
==============================
|
||||
|
@ -209,12 +213,12 @@ defaults to python 3, but synapse currently assumes python 2.7 by default:
|
|||
|
||||
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
|
||||
|
||||
$ sudo pip2.7 install --upgrade pip
|
||||
sudo pip2.7 install --upgrade pip
|
||||
|
||||
You also may need to explicitly specify python 2.7 again during the install
|
||||
request::
|
||||
|
||||
$ pip2.7 install --process-dependency-links \
|
||||
pip2.7 install --process-dependency-links \
|
||||
https://github.com/matrix-org/synapse/tarball/master
|
||||
|
||||
If you encounter an error with lib bcrypt causing an Wrong ELF Class:
|
||||
|
@ -222,13 +226,13 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
|
|||
compile it under the right architecture. (This should not be needed if
|
||||
installing under virtualenv)::
|
||||
|
||||
$ sudo pip2.7 uninstall py-bcrypt
|
||||
$ sudo pip2.7 install py-bcrypt
|
||||
sudo pip2.7 uninstall py-bcrypt
|
||||
sudo pip2.7 install py-bcrypt
|
||||
|
||||
During setup of Synapse you need to call python2.7 directly again::
|
||||
|
||||
$ cd ~/.synapse
|
||||
$ python2.7 -m synapse.app.homeserver \
|
||||
cd ~/.synapse
|
||||
python2.7 -m synapse.app.homeserver \
|
||||
--server-name machine.my.domain.name \
|
||||
--config-path homeserver.yaml \
|
||||
--generate-config
|
||||
|
@ -276,22 +280,22 @@ Synapse requires pip 1.7 or later, so if your OS provides too old a version and
|
|||
you get errors about ``error: no such option: --process-dependency-links`` you
|
||||
may need to manually upgrade it::
|
||||
|
||||
$ sudo pip install --upgrade pip
|
||||
sudo pip install --upgrade pip
|
||||
|
||||
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
|
||||
refuse to run until you remove the temporary installation directory it
|
||||
created. To reset the installation::
|
||||
|
||||
$ rm -rf /tmp/pip_install_matrix
|
||||
rm -rf /tmp/pip_install_matrix
|
||||
|
||||
pip seems to leak *lots* of memory during installation. For instance, a Linux
|
||||
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
||||
happens, you will have to individually install the dependencies which are
|
||||
failing, e.g.::
|
||||
|
||||
$ pip install twisted
|
||||
pip install twisted
|
||||
|
||||
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
|
||||
On OS X, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
|
||||
will need to export CFLAGS=-Qunused-arguments.
|
||||
|
||||
Troubleshooting Running
|
||||
|
@ -307,10 +311,11 @@ correctly, causing all tests to fail with errors about missing "sodium.h". To
|
|||
fix try re-installing from PyPI or directly from
|
||||
(https://github.com/pyca/pynacl)::
|
||||
|
||||
$ # Install from PyPI
|
||||
$ pip install --user --upgrade --force pynacl
|
||||
$ # Install from github
|
||||
$ pip install --user https://github.com/pyca/pynacl/tarball/master
|
||||
# Install from PyPI
|
||||
pip install --user --upgrade --force pynacl
|
||||
|
||||
# Install from github
|
||||
pip install --user https://github.com/pyca/pynacl/tarball/master
|
||||
|
||||
ArchLinux
|
||||
~~~~~~~~~
|
||||
|
@ -318,7 +323,7 @@ ArchLinux
|
|||
If running `$ synctl start` fails with 'returned non-zero exit status 1',
|
||||
you will need to explicitly call Python2.7 - either running as::
|
||||
|
||||
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
|
||||
python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
|
||||
|
||||
...or by editing synctl with the correct python executable.
|
||||
|
||||
|
@ -328,16 +333,16 @@ Synapse Development
|
|||
To check out a synapse for development, clone the git repo into a working
|
||||
directory of your choice::
|
||||
|
||||
$ git clone https://github.com/matrix-org/synapse.git
|
||||
$ cd synapse
|
||||
git clone https://github.com/matrix-org/synapse.git
|
||||
cd synapse
|
||||
|
||||
Synapse has a number of external dependencies, that are easiest
|
||||
to install using pip and a virtualenv::
|
||||
|
||||
$ virtualenv env
|
||||
$ source env/bin/activate
|
||||
$ python synapse/python_dependencies.py | xargs -n1 pip install
|
||||
$ pip install setuptools_trial mock
|
||||
virtualenv env
|
||||
source env/bin/activate
|
||||
python synapse/python_dependencies.py | xargs -n1 pip install
|
||||
pip install setuptools_trial mock
|
||||
|
||||
This will run a process of downloading and installing all the needed
|
||||
dependencies into a virtual env.
|
||||
|
@ -345,7 +350,7 @@ dependencies into a virtual env.
|
|||
Once this is done, you may wish to run Synapse's unit tests, to
|
||||
check that everything is installed as it should be::
|
||||
|
||||
$ python setup.py test
|
||||
python setup.py test
|
||||
|
||||
This should end with a 'PASSED' result::
|
||||
|
||||
|
@ -386,11 +391,11 @@ IDs:
|
|||
For the first form, simply pass the required hostname (of the machine) as the
|
||||
--server-name parameter::
|
||||
|
||||
$ python -m synapse.app.homeserver \
|
||||
python -m synapse.app.homeserver \
|
||||
--server-name machine.my.domain.name \
|
||||
--config-path homeserver.yaml \
|
||||
--generate-config
|
||||
$ python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||
python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||
|
||||
Alternatively, you can run ``synctl start`` to guide you through the process.
|
||||
|
||||
|
@ -407,11 +412,11 @@ record would then look something like::
|
|||
At this point, you should then run the homeserver with the hostname of this
|
||||
SRV record, as that is the name other machines will expect it to have::
|
||||
|
||||
$ python -m synapse.app.homeserver \
|
||||
python -m synapse.app.homeserver \
|
||||
--server-name YOURDOMAIN \
|
||||
--config-path homeserver.yaml \
|
||||
--generate-config
|
||||
$ python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||
python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||
|
||||
|
||||
You may additionally want to pass one or more "-v" options, in order to
|
||||
|
@ -425,7 +430,7 @@ private federation (``localhost:8080``, ``localhost:8081`` and
|
|||
``localhost:8082``) which you can then access through the webclient running at
|
||||
http://localhost:8080. Simply run::
|
||||
|
||||
$ demo/start.sh
|
||||
demo/start.sh
|
||||
|
||||
This is mainly useful just for development purposes.
|
||||
|
||||
|
@ -499,10 +504,10 @@ Building Internal API Documentation
|
|||
Before building internal API documentation install sphinx and
|
||||
sphinxcontrib-napoleon::
|
||||
|
||||
$ pip install sphinx
|
||||
$ pip install sphinxcontrib-napoleon
|
||||
pip install sphinx
|
||||
pip install sphinxcontrib-napoleon
|
||||
|
||||
Building internal API documentation::
|
||||
|
||||
$ python setup.py build_sphinx
|
||||
python setup.py build_sphinx
|
||||
|
||||
|
|
|
@ -11,7 +11,9 @@ if [ -f $PID_FILE ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
find "$DIR" -name "*.log" -delete
|
||||
find "$DIR" -name "*.db" -delete
|
||||
for port in 8080 8081 8082; do
|
||||
rm -rf $DIR/$port
|
||||
rm -rf $DIR/media_store.$port
|
||||
done
|
||||
|
||||
rm -rf $DIR/etc
|
||||
|
|
|
@ -8,14 +8,6 @@ cd "$DIR/.."
|
|||
|
||||
mkdir -p demo/etc
|
||||
|
||||
# Check the --no-rate-limit param
|
||||
PARAMS=""
|
||||
if [ $# -eq 1 ]; then
|
||||
if [ $1 = "--no-rate-limit" ]; then
|
||||
PARAMS="--rc-messages-per-second 1000 --rc-message-burst-count 1000"
|
||||
fi
|
||||
fi
|
||||
|
||||
export PYTHONPATH=$(readlink -f $(pwd))
|
||||
|
||||
|
||||
|
@ -35,6 +27,15 @@ for port in 8080 8081 8082; do
|
|||
-H "localhost:$https_port" \
|
||||
--config-path "$DIR/etc/$port.config" \
|
||||
|
||||
# Check script parameters
|
||||
if [ $# -eq 1 ]; then
|
||||
if [ $1 = "--no-rate-limit" ]; then
|
||||
# Set high limits in config file to disable rate limiting
|
||||
perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config
|
||||
perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config
|
||||
fi
|
||||
fi
|
||||
|
||||
python -m synapse.app.homeserver \
|
||||
--config-path "$DIR/etc/$port.config" \
|
||||
-D \
|
||||
|
|
|
@ -16,4 +16,4 @@
|
|||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.9.1"
|
||||
__version__ = "0.9.3"
|
||||
|
|
|
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
AuthEventTypes = (
|
||||
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
|
||||
EventTypes.JoinRules,
|
||||
EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
|
||||
)
|
||||
|
||||
|
||||
|
@ -44,6 +44,11 @@ class Auth(object):
|
|||
def check(self, event, auth_events):
|
||||
""" Checks if this event is correctly authed.
|
||||
|
||||
Args:
|
||||
event: the event being checked.
|
||||
auth_events (dict: event-key -> event): the existing room state.
|
||||
|
||||
|
||||
Returns:
|
||||
True if the auth checks pass.
|
||||
"""
|
||||
|
@ -187,6 +192,9 @@ class Auth(object):
|
|||
join_rule = JoinRules.INVITE
|
||||
|
||||
user_level = self._get_user_power_level(event.user_id, auth_events)
|
||||
target_level = self._get_user_power_level(
|
||||
target_user_id, auth_events
|
||||
)
|
||||
|
||||
# FIXME (erikj): What should we do here as the default?
|
||||
ban_level = self._get_named_level(auth_events, "ban", 50)
|
||||
|
@ -258,12 +266,12 @@ class Auth(object):
|
|||
elif target_user_id != event.user_id:
|
||||
kick_level = self._get_named_level(auth_events, "kick", 50)
|
||||
|
||||
if user_level < kick_level:
|
||||
if user_level < kick_level or user_level <= target_level:
|
||||
raise AuthError(
|
||||
403, "You cannot kick user %s." % target_user_id
|
||||
)
|
||||
elif Membership.BAN == membership:
|
||||
if user_level < ban_level:
|
||||
if user_level < ban_level or user_level <= target_level:
|
||||
raise AuthError(403, "You don't have permission to ban")
|
||||
else:
|
||||
raise AuthError(500, "Unknown membership %s" % membership)
|
||||
|
@ -316,7 +324,7 @@ class Auth(object):
|
|||
Returns:
|
||||
tuple : of UserID and device string:
|
||||
User ID object of the user making the request
|
||||
Client ID object of the client instance the user is using
|
||||
ClientInfo object of the client instance the user is using
|
||||
Raises:
|
||||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
|
@ -349,7 +357,7 @@ class Auth(object):
|
|||
)
|
||||
return
|
||||
except KeyError:
|
||||
pass # normal users won't have this query parameter set
|
||||
pass # normal users won't have the user_id query parameter set.
|
||||
|
||||
user_info = yield self.get_user_by_token(access_token)
|
||||
user = user_info["user"]
|
||||
|
@ -370,6 +378,8 @@ class Auth(object):
|
|||
user_agent=user_agent
|
||||
)
|
||||
|
||||
request.authenticated_entity = user.to_string()
|
||||
|
||||
defer.returnValue((user, ClientInfo(device_id, token_id)))
|
||||
except KeyError:
|
||||
raise AuthError(
|
||||
|
@ -516,7 +526,6 @@ class Auth(object):
|
|||
|
||||
# Check state_key
|
||||
if hasattr(event, "state_key"):
|
||||
if not event.state_key.startswith("_"):
|
||||
if event.state_key.startswith("@"):
|
||||
if event.state_key != event.user_id:
|
||||
raise AuthError(
|
||||
|
@ -571,25 +580,26 @@ class Auth(object):
|
|||
|
||||
# Check other levels:
|
||||
levels_to_check = [
|
||||
("users_default", []),
|
||||
("events_default", []),
|
||||
("ban", []),
|
||||
("redact", []),
|
||||
("kick", []),
|
||||
("invite", []),
|
||||
("users_default", None),
|
||||
("events_default", None),
|
||||
("state_default", None),
|
||||
("ban", None),
|
||||
("redact", None),
|
||||
("kick", None),
|
||||
("invite", None),
|
||||
]
|
||||
|
||||
old_list = current_state.content.get("users")
|
||||
for user in set(old_list.keys() + user_list.keys()):
|
||||
levels_to_check.append(
|
||||
(user, ["users"])
|
||||
(user, "users")
|
||||
)
|
||||
|
||||
old_list = current_state.content.get("events")
|
||||
new_list = event.content.get("events")
|
||||
for ev_id in set(old_list.keys() + new_list.keys()):
|
||||
levels_to_check.append(
|
||||
(ev_id, ["events"])
|
||||
(ev_id, "events")
|
||||
)
|
||||
|
||||
old_state = current_state.content
|
||||
|
@ -597,12 +607,10 @@ class Auth(object):
|
|||
|
||||
for level_to_check, dir in levels_to_check:
|
||||
old_loc = old_state
|
||||
for d in dir:
|
||||
old_loc = old_loc.get(d, {})
|
||||
|
||||
new_loc = new_state
|
||||
for d in dir:
|
||||
new_loc = new_loc.get(d, {})
|
||||
if dir:
|
||||
old_loc = old_loc.get(dir, {})
|
||||
new_loc = new_loc.get(dir, {})
|
||||
|
||||
if level_to_check in old_loc:
|
||||
old_level = int(old_loc[level_to_check])
|
||||
|
@ -618,6 +626,14 @@ class Auth(object):
|
|||
if new_level == old_level:
|
||||
continue
|
||||
|
||||
if dir == "users" and level_to_check != event.user_id:
|
||||
if old_level == user_level:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to remove ops level equal "
|
||||
"to your own"
|
||||
)
|
||||
|
||||
if old_level > user_level or new_level > user_level:
|
||||
raise AuthError(
|
||||
403,
|
||||
|
|
|
@ -75,6 +75,8 @@ class EventTypes(object):
|
|||
Redaction = "m.room.redaction"
|
||||
Feedback = "m.room.message.feedback"
|
||||
|
||||
RoomHistoryVisibility = "m.room.history_visibility"
|
||||
|
||||
# These are used for validation
|
||||
Message = "m.room.message"
|
||||
Topic = "m.room.topic"
|
||||
|
@ -85,3 +87,8 @@ class RejectedReason(object):
|
|||
AUTH_ERROR = "auth_error"
|
||||
REPLACED = "replaced"
|
||||
NOT_ANCESTOR = "not_ancestor"
|
||||
|
||||
|
||||
class RoomCreationPreset(object):
|
||||
PRIVATE_CHAT = "private_chat"
|
||||
PUBLIC_CHAT = "public_chat"
|
||||
|
|
|
@ -34,8 +34,7 @@ from twisted.application import service
|
|||
from twisted.enterprise import adbapi
|
||||
from twisted.web.resource import Resource, EncodingResourceWrapper
|
||||
from twisted.web.static import File
|
||||
from twisted.web.server import Site, GzipEncoderFactory
|
||||
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
|
||||
from twisted.web.server import Site, GzipEncoderFactory, Request
|
||||
from synapse.http.server import JsonResource, RootRedirect
|
||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||
|
@ -61,11 +60,13 @@ import twisted.manhole.telnet
|
|||
|
||||
import synapse
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import resource
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
|
||||
logger = logging.getLogger("synapse.app.homeserver")
|
||||
|
@ -87,10 +88,10 @@ class SynapseHomeServer(HomeServer):
|
|||
return MatrixFederationHttpClient(self)
|
||||
|
||||
def build_resource_for_client(self):
|
||||
return gz_wrap(ClientV1RestResource(self))
|
||||
return ClientV1RestResource(self)
|
||||
|
||||
def build_resource_for_client_v2_alpha(self):
|
||||
return gz_wrap(ClientV2AlphaRestResource(self))
|
||||
return ClientV2AlphaRestResource(self)
|
||||
|
||||
def build_resource_for_federation(self):
|
||||
return JsonResource(self)
|
||||
|
@ -113,7 +114,7 @@ class SynapseHomeServer(HomeServer):
|
|||
|
||||
def build_resource_for_content_repo(self):
|
||||
return ContentRepoResource(
|
||||
self, self.upload_dir, self.auth, self.content_addr
|
||||
self, self.config.uploads_path, self.auth, self.content_addr
|
||||
)
|
||||
|
||||
def build_resource_for_media_repository(self):
|
||||
|
@ -139,152 +140,105 @@ class SynapseHomeServer(HomeServer):
|
|||
**self.db_config.get("args", {})
|
||||
)
|
||||
|
||||
def create_resource_tree(self, redirect_root_to_web_client):
|
||||
"""Create the resource tree for this Home Server.
|
||||
def _listener_http(self, config, listener_config):
|
||||
port = listener_config["port"]
|
||||
bind_address = listener_config.get("bind_address", "")
|
||||
tls = listener_config.get("tls", False)
|
||||
site_tag = listener_config.get("tag", port)
|
||||
|
||||
This in unduly complicated because Twisted does not support putting
|
||||
child resources more than 1 level deep at a time.
|
||||
|
||||
Args:
|
||||
web_client (bool): True to enable the web client.
|
||||
redirect_root_to_web_client (bool): True to redirect '/' to the
|
||||
location of the web client. This does nothing if web_client is not
|
||||
True.
|
||||
"""
|
||||
config = self.get_config()
|
||||
web_client = config.web_client
|
||||
|
||||
# list containing (path_str, Resource) e.g:
|
||||
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
|
||||
desired_tree = [
|
||||
(CLIENT_PREFIX, self.get_resource_for_client()),
|
||||
(CLIENT_V2_ALPHA_PREFIX, self.get_resource_for_client_v2_alpha()),
|
||||
(FEDERATION_PREFIX, self.get_resource_for_federation()),
|
||||
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
|
||||
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
|
||||
(SERVER_KEY_V2_PREFIX, self.get_resource_for_server_key_v2()),
|
||||
(MEDIA_PREFIX, self.get_resource_for_media_repository()),
|
||||
(STATIC_PREFIX, self.get_resource_for_static_content()),
|
||||
]
|
||||
|
||||
if web_client:
|
||||
logger.info("Adding the web client.")
|
||||
desired_tree.append((WEB_CLIENT_PREFIX,
|
||||
self.get_resource_for_web_client()))
|
||||
|
||||
if web_client and redirect_root_to_web_client:
|
||||
self.root_resource = RootRedirect(WEB_CLIENT_PREFIX)
|
||||
else:
|
||||
self.root_resource = Resource()
|
||||
if tls and config.no_tls:
|
||||
return
|
||||
|
||||
metrics_resource = self.get_resource_for_metrics()
|
||||
if config.metrics_port is None and metrics_resource is not None:
|
||||
desired_tree.append((METRICS_PREFIX, metrics_resource))
|
||||
|
||||
# ideally we'd just use getChild and putChild but getChild doesn't work
|
||||
# unless you give it a Request object IN ADDITION to the name :/ So
|
||||
# instead, we'll store a copy of this mapping so we can actually add
|
||||
# extra resources to existing nodes. See self._resource_id for the key.
|
||||
resource_mappings = {}
|
||||
for full_path, res in desired_tree:
|
||||
logger.info("Attaching %s to path %s", res, full_path)
|
||||
last_resource = self.root_resource
|
||||
for path_seg in full_path.split('/')[1:-1]:
|
||||
if path_seg not in last_resource.listNames():
|
||||
# resource doesn't exist, so make a "dummy resource"
|
||||
child_resource = Resource()
|
||||
last_resource.putChild(path_seg, child_resource)
|
||||
res_id = self._resource_id(last_resource, path_seg)
|
||||
resource_mappings[res_id] = child_resource
|
||||
last_resource = child_resource
|
||||
resources = {}
|
||||
for res in listener_config["resources"]:
|
||||
for name in res["names"]:
|
||||
if name == "client":
|
||||
if res["compress"]:
|
||||
client_v1 = gz_wrap(self.get_resource_for_client())
|
||||
client_v2 = gz_wrap(self.get_resource_for_client_v2_alpha())
|
||||
else:
|
||||
# we have an existing Resource, use that instead.
|
||||
res_id = self._resource_id(last_resource, path_seg)
|
||||
last_resource = resource_mappings[res_id]
|
||||
client_v1 = self.get_resource_for_client()
|
||||
client_v2 = self.get_resource_for_client_v2_alpha()
|
||||
|
||||
# ===========================
|
||||
# now attach the actual desired resource
|
||||
last_path_seg = full_path.split('/')[-1]
|
||||
resources.update({
|
||||
CLIENT_PREFIX: client_v1,
|
||||
CLIENT_V2_ALPHA_PREFIX: client_v2,
|
||||
})
|
||||
|
||||
# if there is already a resource here, thieve its children and
|
||||
# replace it
|
||||
res_id = self._resource_id(last_resource, last_path_seg)
|
||||
if res_id in resource_mappings:
|
||||
# there is a dummy resource at this path already, which needs
|
||||
# to be replaced with the desired resource.
|
||||
existing_dummy_resource = resource_mappings[res_id]
|
||||
for child_name in existing_dummy_resource.listNames():
|
||||
child_res_id = self._resource_id(existing_dummy_resource,
|
||||
child_name)
|
||||
child_resource = resource_mappings[child_res_id]
|
||||
# steal the children
|
||||
res.putChild(child_name, child_resource)
|
||||
if name == "federation":
|
||||
resources.update({
|
||||
FEDERATION_PREFIX: self.get_resource_for_federation(),
|
||||
})
|
||||
|
||||
# finally, insert the desired resource in the right place
|
||||
last_resource.putChild(last_path_seg, res)
|
||||
res_id = self._resource_id(last_resource, last_path_seg)
|
||||
resource_mappings[res_id] = res
|
||||
if name in ["static", "client"]:
|
||||
resources.update({
|
||||
STATIC_PREFIX: self.get_resource_for_static_content(),
|
||||
})
|
||||
|
||||
return self.root_resource
|
||||
if name in ["media", "federation", "client"]:
|
||||
resources.update({
|
||||
MEDIA_PREFIX: self.get_resource_for_media_repository(),
|
||||
CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(),
|
||||
})
|
||||
|
||||
def _resource_id(self, resource, path_seg):
|
||||
"""Construct an arbitrary resource ID so you can retrieve the mapping
|
||||
later.
|
||||
if name in ["keys", "federation"]:
|
||||
resources.update({
|
||||
SERVER_KEY_PREFIX: self.get_resource_for_server_key(),
|
||||
SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(),
|
||||
})
|
||||
|
||||
If you want to represent resource A putChild resource B with path C,
|
||||
the mapping should looks like _resource_id(A,C) = B.
|
||||
if name == "webclient":
|
||||
resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client()
|
||||
|
||||
Args:
|
||||
resource (Resource): The *parent* Resource
|
||||
path_seg (str): The name of the child Resource to be attached.
|
||||
Returns:
|
||||
str: A unique string which can be a key to the child Resource.
|
||||
"""
|
||||
return "%s-%s" % (resource, path_seg)
|
||||
if name == "metrics" and metrics_resource:
|
||||
resources[METRICS_PREFIX] = metrics_resource
|
||||
|
||||
root_resource = create_resource_tree(resources)
|
||||
if tls:
|
||||
reactor.listenSSL(
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.https.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
),
|
||||
self.tls_context_factory,
|
||||
interface=bind_address
|
||||
)
|
||||
else:
|
||||
reactor.listenTCP(
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.http.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
),
|
||||
interface=bind_address
|
||||
)
|
||||
logger.info("Synapse now listening on port %d", port)
|
||||
|
||||
def start_listening(self):
|
||||
config = self.get_config()
|
||||
|
||||
if not config.no_tls and config.bind_port is not None:
|
||||
reactor.listenSSL(
|
||||
config.bind_port,
|
||||
SynapseSite(
|
||||
"synapse.access.https",
|
||||
config,
|
||||
self.root_resource,
|
||||
),
|
||||
self.tls_context_factory,
|
||||
interface=config.bind_host
|
||||
)
|
||||
logger.info("Synapse now listening on port %d", config.bind_port)
|
||||
|
||||
if config.unsecure_port is not None:
|
||||
for listener in config.listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listener_http(config, listener)
|
||||
elif listener["type"] == "manhole":
|
||||
f = twisted.manhole.telnet.ShellFactory()
|
||||
f.username = "matrix"
|
||||
f.password = "rabbithole"
|
||||
f.namespace['hs'] = self
|
||||
reactor.listenTCP(
|
||||
config.unsecure_port,
|
||||
SynapseSite(
|
||||
"synapse.access.http",
|
||||
config,
|
||||
self.root_resource,
|
||||
),
|
||||
interface=config.bind_host
|
||||
)
|
||||
logger.info("Synapse now listening on port %d", config.unsecure_port)
|
||||
|
||||
metrics_resource = self.get_resource_for_metrics()
|
||||
if metrics_resource and config.metrics_port is not None:
|
||||
reactor.listenTCP(
|
||||
config.metrics_port,
|
||||
SynapseSite(
|
||||
"synapse.access.metrics",
|
||||
config,
|
||||
metrics_resource,
|
||||
),
|
||||
interface=config.metrics_bind_host,
|
||||
)
|
||||
logger.info(
|
||||
"Metrics now running on %s port %d",
|
||||
config.metrics_bind_host, config.metrics_port,
|
||||
listener["port"],
|
||||
f,
|
||||
interface=listener.get("bind_address", '127.0.0.1')
|
||||
)
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
def run_startup_checks(self, db_conn, database_engine):
|
||||
all_users_native = are_all_users_on_domain(
|
||||
|
@ -419,11 +373,6 @@ def setup(config_options):
|
|||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
if re.search(":[0-9]+$", config.server_name):
|
||||
domain_with_port = config.server_name
|
||||
else:
|
||||
domain_with_port = "%s:%s" % (config.server_name, config.bind_port)
|
||||
|
||||
tls_context_factory = context_factory.ServerContextFactory(config)
|
||||
|
||||
database_engine = create_engine(config.database_config["name"])
|
||||
|
@ -431,8 +380,6 @@ def setup(config_options):
|
|||
|
||||
hs = SynapseHomeServer(
|
||||
config.server_name,
|
||||
domain_with_port=domain_with_port,
|
||||
upload_dir=os.path.abspath("uploads"),
|
||||
db_config=config.database_config,
|
||||
tls_context_factory=tls_context_factory,
|
||||
config=config,
|
||||
|
@ -441,10 +388,6 @@ def setup(config_options):
|
|||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
hs.create_resource_tree(
|
||||
redirect_root_to_web_client=True,
|
||||
)
|
||||
|
||||
logger.info("Preparing database: %r...", config.database_config)
|
||||
|
||||
try:
|
||||
|
@ -469,13 +412,6 @@ def setup(config_options):
|
|||
|
||||
logger.info("Database prepared in %r.", config.database_config)
|
||||
|
||||
if config.manhole:
|
||||
f = twisted.manhole.telnet.ShellFactory()
|
||||
f.username = "matrix"
|
||||
f.password = "rabbithole"
|
||||
f.namespace['hs'] = hs
|
||||
reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
|
||||
|
||||
hs.start_listening()
|
||||
|
||||
hs.get_pusherpool().start()
|
||||
|
@ -501,22 +437,194 @@ class SynapseService(service.Service):
|
|||
return self._port.stopListening()
|
||||
|
||||
|
||||
class SynapseRequest(Request):
|
||||
def __init__(self, site, *args, **kw):
|
||||
Request.__init__(self, *args, **kw)
|
||||
self.site = site
|
||||
self.authenticated_entity = None
|
||||
self.start_time = 0
|
||||
|
||||
def __repr__(self):
|
||||
# We overwrite this so that we don't log ``access_token``
|
||||
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
|
||||
self.__class__.__name__,
|
||||
id(self),
|
||||
self.method,
|
||||
self.get_redacted_uri(),
|
||||
self.clientproto,
|
||||
self.site.site_tag,
|
||||
)
|
||||
|
||||
def get_redacted_uri(self):
|
||||
return re.sub(
|
||||
r'(\?.*access_token=)[^&]*(.*)$',
|
||||
r'\1<redacted>\2',
|
||||
self.uri
|
||||
)
|
||||
|
||||
def get_user_agent(self):
|
||||
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
|
||||
|
||||
def started_processing(self):
|
||||
self.site.access_logger.info(
|
||||
"%s - %s - Received request: %s %s",
|
||||
self.getClientIP(),
|
||||
self.site.site_tag,
|
||||
self.method,
|
||||
self.get_redacted_uri()
|
||||
)
|
||||
self.start_time = int(time.time() * 1000)
|
||||
|
||||
def finished_processing(self):
|
||||
self.site.access_logger.info(
|
||||
"%s - %s - {%s}"
|
||||
" Processed request: %dms %sB %s \"%s %s %s\" \"%s\"",
|
||||
self.getClientIP(),
|
||||
self.site.site_tag,
|
||||
self.authenticated_entity,
|
||||
int(time.time() * 1000) - self.start_time,
|
||||
self.sentLength,
|
||||
self.code,
|
||||
self.method,
|
||||
self.get_redacted_uri(),
|
||||
self.clientproto,
|
||||
self.get_user_agent(),
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def processing(self):
|
||||
self.started_processing()
|
||||
yield
|
||||
self.finished_processing()
|
||||
|
||||
|
||||
class XForwardedForRequest(SynapseRequest):
|
||||
def __init__(self, *args, **kw):
|
||||
SynapseRequest.__init__(self, *args, **kw)
|
||||
|
||||
"""
|
||||
Add a layer on top of another request that only uses the value of an
|
||||
X-Forwarded-For header as the result of C{getClientIP}.
|
||||
"""
|
||||
def getClientIP(self):
|
||||
"""
|
||||
@return: The client address (the first address) in the value of the
|
||||
I{X-Forwarded-For header}. If the header is not present, return
|
||||
C{b"-"}.
|
||||
"""
|
||||
return self.requestHeaders.getRawHeaders(
|
||||
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
|
||||
|
||||
|
||||
class SynapseRequestFactory(object):
|
||||
def __init__(self, site, x_forwarded_for):
|
||||
self.site = site
|
||||
self.x_forwarded_for = x_forwarded_for
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.x_forwarded_for:
|
||||
return XForwardedForRequest(self.site, *args, **kwargs)
|
||||
else:
|
||||
return SynapseRequest(self.site, *args, **kwargs)
|
||||
|
||||
|
||||
class SynapseSite(Site):
|
||||
"""
|
||||
Subclass of a twisted http Site that does access logging with python's
|
||||
standard logging
|
||||
"""
|
||||
def __init__(self, logger_name, config, resource, *args, **kwargs):
|
||||
def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
|
||||
Site.__init__(self, resource, *args, **kwargs)
|
||||
if config.captcha_ip_origin_is_x_forwarded:
|
||||
self._log_formatter = proxiedLogFormatter
|
||||
else:
|
||||
self._log_formatter = combinedLogFormatter
|
||||
|
||||
self.site_tag = site_tag
|
||||
|
||||
proxied = config.get("x_forwarded", False)
|
||||
self.requestFactory = SynapseRequestFactory(self, proxied)
|
||||
self.access_logger = logging.getLogger(logger_name)
|
||||
|
||||
def log(self, request):
|
||||
line = self._log_formatter(self._logDateTime, request)
|
||||
self.access_logger.info(line)
|
||||
pass
|
||||
|
||||
|
||||
def create_resource_tree(desired_tree, redirect_root_to_web_client=True):
|
||||
"""Create the resource tree for this Home Server.
|
||||
|
||||
This in unduly complicated because Twisted does not support putting
|
||||
child resources more than 1 level deep at a time.
|
||||
|
||||
Args:
|
||||
web_client (bool): True to enable the web client.
|
||||
redirect_root_to_web_client (bool): True to redirect '/' to the
|
||||
location of the web client. This does nothing if web_client is not
|
||||
True.
|
||||
"""
|
||||
if redirect_root_to_web_client and WEB_CLIENT_PREFIX in desired_tree:
|
||||
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
|
||||
else:
|
||||
root_resource = Resource()
|
||||
|
||||
# ideally we'd just use getChild and putChild but getChild doesn't work
|
||||
# unless you give it a Request object IN ADDITION to the name :/ So
|
||||
# instead, we'll store a copy of this mapping so we can actually add
|
||||
# extra resources to existing nodes. See self._resource_id for the key.
|
||||
resource_mappings = {}
|
||||
for full_path, res in desired_tree.items():
|
||||
logger.info("Attaching %s to path %s", res, full_path)
|
||||
last_resource = root_resource
|
||||
for path_seg in full_path.split('/')[1:-1]:
|
||||
if path_seg not in last_resource.listNames():
|
||||
# resource doesn't exist, so make a "dummy resource"
|
||||
child_resource = Resource()
|
||||
last_resource.putChild(path_seg, child_resource)
|
||||
res_id = _resource_id(last_resource, path_seg)
|
||||
resource_mappings[res_id] = child_resource
|
||||
last_resource = child_resource
|
||||
else:
|
||||
# we have an existing Resource, use that instead.
|
||||
res_id = _resource_id(last_resource, path_seg)
|
||||
last_resource = resource_mappings[res_id]
|
||||
|
||||
# ===========================
|
||||
# now attach the actual desired resource
|
||||
last_path_seg = full_path.split('/')[-1]
|
||||
|
||||
# if there is already a resource here, thieve its children and
|
||||
# replace it
|
||||
res_id = _resource_id(last_resource, last_path_seg)
|
||||
if res_id in resource_mappings:
|
||||
# there is a dummy resource at this path already, which needs
|
||||
# to be replaced with the desired resource.
|
||||
existing_dummy_resource = resource_mappings[res_id]
|
||||
for child_name in existing_dummy_resource.listNames():
|
||||
child_res_id = _resource_id(
|
||||
existing_dummy_resource, child_name
|
||||
)
|
||||
child_resource = resource_mappings[child_res_id]
|
||||
# steal the children
|
||||
res.putChild(child_name, child_resource)
|
||||
|
||||
# finally, insert the desired resource in the right place
|
||||
last_resource.putChild(last_path_seg, res)
|
||||
res_id = _resource_id(last_resource, last_path_seg)
|
||||
resource_mappings[res_id] = res
|
||||
|
||||
return root_resource
|
||||
|
||||
|
||||
def _resource_id(resource, path_seg):
|
||||
"""Construct an arbitrary resource ID so you can retrieve the mapping
|
||||
later.
|
||||
|
||||
If you want to represent resource A putChild resource B with path C,
|
||||
the mapping should looks like _resource_id(A,C) = B.
|
||||
|
||||
Args:
|
||||
resource (Resource): The *parent* Resource
|
||||
path_seg (str): The name of the child Resource to be attached.
|
||||
Returns:
|
||||
str: A unique string which can be a key to the child Resource.
|
||||
"""
|
||||
return "%s-%s" % (resource, path_seg)
|
||||
|
||||
|
||||
def run(hs):
|
||||
|
@ -549,6 +657,7 @@ def run(hs):
|
|||
|
||||
if hs.config.daemonize:
|
||||
|
||||
if hs.config.print_pidfile:
|
||||
print hs.config.pid_file
|
||||
|
||||
daemon = Daemonize(
|
||||
|
|
|
@ -138,47 +138,37 @@ class Config(object):
|
|||
action="store_true",
|
||||
help="Generate a config file for the server name"
|
||||
)
|
||||
config_parser.add_argument(
|
||||
"--generate-keys",
|
||||
action="store_true",
|
||||
help="Generate any missing key files then exit"
|
||||
)
|
||||
config_parser.add_argument(
|
||||
"-H", "--server-name",
|
||||
help="The server name to generate a config file for"
|
||||
)
|
||||
config_args, remaining_args = config_parser.parse_known_args(argv)
|
||||
|
||||
generate_keys = config_args.generate_keys
|
||||
|
||||
if config_args.generate_config:
|
||||
if not config_args.config_path:
|
||||
config_parser.error(
|
||||
"Must supply a config file.\nA config file can be automatically"
|
||||
" generated using \"--generate-config -h SERVER_NAME"
|
||||
" generated using \"--generate-config -H SERVER_NAME"
|
||||
" -c CONFIG-FILE\""
|
||||
)
|
||||
|
||||
config_dir_path = os.path.dirname(config_args.config_path[0])
|
||||
(config_path,) = config_args.config_path
|
||||
if not os.path.exists(config_path):
|
||||
config_dir_path = os.path.dirname(config_path)
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
|
||||
server_name = config_args.server_name
|
||||
if not server_name:
|
||||
print "Must specify a server_name to a generate config for."
|
||||
sys.exit(1)
|
||||
(config_path,) = config_args.config_path
|
||||
if not os.path.exists(config_dir_path):
|
||||
os.makedirs(config_dir_path)
|
||||
if os.path.exists(config_path):
|
||||
print "Config file %r already exists" % (config_path,)
|
||||
yaml_config = cls.read_config_file(config_path)
|
||||
yaml_name = yaml_config["server_name"]
|
||||
if server_name != yaml_name:
|
||||
print (
|
||||
"Config file %r has a different server_name: "
|
||||
" %r != %r" % (config_path, server_name, yaml_name)
|
||||
)
|
||||
sys.exit(1)
|
||||
config_bytes, config = obj.generate_config(
|
||||
config_dir_path, server_name
|
||||
)
|
||||
config.update(yaml_config)
|
||||
print "Generating any missing keys for %r" % (server_name,)
|
||||
obj.invoke_all("generate_files", config)
|
||||
sys.exit(0)
|
||||
with open(config_path, "wb") as config_file:
|
||||
config_bytes, config = obj.generate_config(
|
||||
config_dir_path, server_name
|
||||
|
@ -186,16 +176,22 @@ class Config(object):
|
|||
obj.invoke_all("generate_files", config)
|
||||
config_file.write(config_bytes)
|
||||
print (
|
||||
"A config file has been generated in %s for server name"
|
||||
" '%s' with corresponding SSL keys and self-signed"
|
||||
" certificates. Please review this file and customise it to"
|
||||
" your needs."
|
||||
"A config file has been generated in %r for server name"
|
||||
" %r with corresponding SSL keys and self-signed"
|
||||
" certificates. Please review this file and customise it"
|
||||
" to your needs."
|
||||
) % (config_path, server_name)
|
||||
print (
|
||||
"If this server name is incorrect, you will need to regenerate"
|
||||
" the SSL certificates"
|
||||
"If this server name is incorrect, you will need to"
|
||||
" regenerate the SSL certificates"
|
||||
)
|
||||
sys.exit(0)
|
||||
else:
|
||||
print (
|
||||
"Config file %r already exists. Generating any missing key"
|
||||
" files."
|
||||
) % (config_path,)
|
||||
generate_keys = True
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[config_parser],
|
||||
|
@ -209,11 +205,11 @@ class Config(object):
|
|||
if not config_args.config_path:
|
||||
config_parser.error(
|
||||
"Must supply a config file.\nA config file can be automatically"
|
||||
" generated using \"--generate-config -h SERVER_NAME"
|
||||
" generated using \"--generate-config -H SERVER_NAME"
|
||||
" -c CONFIG-FILE\""
|
||||
)
|
||||
|
||||
config_dir_path = os.path.dirname(config_args.config_path[0])
|
||||
config_dir_path = os.path.dirname(config_args.config_path[-1])
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
|
||||
specified_config = {}
|
||||
|
@ -226,6 +222,10 @@ class Config(object):
|
|||
config.pop("log_config")
|
||||
config.update(specified_config)
|
||||
|
||||
if generate_keys:
|
||||
obj.invoke_all("generate_files", config)
|
||||
sys.exit(0)
|
||||
|
||||
obj.invoke_all("read_config", config)
|
||||
|
||||
obj.invoke_all("read_arguments", args)
|
||||
|
|
|
@ -21,10 +21,6 @@ class CaptchaConfig(Config):
|
|||
self.recaptcha_private_key = config["recaptcha_private_key"]
|
||||
self.recaptcha_public_key = config["recaptcha_public_key"]
|
||||
self.enable_registration_captcha = config["enable_registration_captcha"]
|
||||
# XXX: This is used for more than just captcha
|
||||
self.captcha_ip_origin_is_x_forwarded = (
|
||||
config["captcha_ip_origin_is_x_forwarded"]
|
||||
)
|
||||
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
|
||||
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
|
||||
|
||||
|
@ -33,20 +29,16 @@ class CaptchaConfig(Config):
|
|||
## Captcha ##
|
||||
|
||||
# This Home Server's ReCAPTCHA public key.
|
||||
recaptcha_private_key: "YOUR_PUBLIC_KEY"
|
||||
recaptcha_private_key: "YOUR_PRIVATE_KEY"
|
||||
|
||||
# This Home Server's ReCAPTCHA private key.
|
||||
recaptcha_public_key: "YOUR_PRIVATE_KEY"
|
||||
recaptcha_public_key: "YOUR_PUBLIC_KEY"
|
||||
|
||||
# Enables ReCaptcha checks when registering, preventing signup
|
||||
# unless a captcha is answered. Requires a valid ReCaptcha
|
||||
# public/private key.
|
||||
enable_registration_captcha: False
|
||||
|
||||
# When checking captchas, use the X-Forwarded-For (XFF) header
|
||||
# as the client IP and not the actual client IP.
|
||||
captcha_ip_origin_is_x_forwarded: False
|
||||
|
||||
# A secret key used to bypass the captcha test entirely.
|
||||
#captcha_bypass_secret: "YOUR_SECRET_HERE"
|
||||
|
||||
|
|
|
@ -25,12 +25,13 @@ from .registration import RegistrationConfig
|
|||
from .metrics import MetricsConfig
|
||||
from .appservice import AppServiceConfig
|
||||
from .key import KeyConfig
|
||||
from .saml2 import SAML2Config
|
||||
|
||||
|
||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
|
||||
VoipConfig, RegistrationConfig,
|
||||
MetricsConfig, AppServiceConfig, KeyConfig,):
|
||||
VoipConfig, RegistrationConfig, MetricsConfig,
|
||||
AppServiceConfig, KeyConfig, SAML2Config, ):
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -28,10 +28,4 @@ class MetricsConfig(Config):
|
|||
|
||||
# Enable collection and rendering of performance metrics
|
||||
enable_metrics: False
|
||||
|
||||
# Separate port to accept metrics requests on
|
||||
# metrics_port: 8081
|
||||
|
||||
# Which host to bind the metric listener to
|
||||
# metrics_bind_host: 127.0.0.1
|
||||
"""
|
||||
|
|
|
@ -21,13 +21,18 @@ class ContentRepositoryConfig(Config):
|
|||
self.max_upload_size = self.parse_size(config["max_upload_size"])
|
||||
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
|
||||
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
||||
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
||||
|
||||
def default_config(self, config_dir_path, server_name):
|
||||
media_store = self.default_path("media_store")
|
||||
uploads_path = self.default_path("uploads")
|
||||
return """
|
||||
# Directory where uploaded images and attachments are stored.
|
||||
media_store_path: "%(media_store)s"
|
||||
|
||||
# Directory where in-progress uploads are stored.
|
||||
uploads_path: "%(uploads_path)s"
|
||||
|
||||
# The largest allowed upload size in bytes
|
||||
max_upload_size: "10M"
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 Ericsson
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class SAML2Config(Config):
|
||||
"""SAML2 Configuration
|
||||
Synapse uses pysaml2 libraries for providing SAML2 support
|
||||
|
||||
config_path: Path to the sp_conf.py configuration file
|
||||
idp_redirect_url: Identity provider URL which will redirect
|
||||
the user back to /login/saml2 with proper info.
|
||||
|
||||
sp_conf.py file is something like:
|
||||
https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example
|
||||
|
||||
More information: https://pythonhosted.org/pysaml2/howto/config.html
|
||||
"""
|
||||
|
||||
def read_config(self, config):
|
||||
saml2_config = config.get("saml2_config", None)
|
||||
if saml2_config:
|
||||
self.saml2_enabled = True
|
||||
self.saml2_config_path = saml2_config["config_path"]
|
||||
self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"]
|
||||
else:
|
||||
self.saml2_enabled = False
|
||||
self.saml2_config_path = None
|
||||
self.saml2_idp_redirect_url = None
|
||||
|
||||
def default_config(self, config_dir_path, server_name):
|
||||
return """
|
||||
# Enable SAML2 for registration and login. Uses pysaml2
|
||||
# config_path: Path to the sp_conf.py configuration file
|
||||
# idp_redirect_url: Identity provider URL which will redirect
|
||||
# the user back to /login/saml2 with proper info.
|
||||
# See pysaml2 docs for format of config.
|
||||
#saml2_config:
|
||||
# config_path: "%s/sp_conf.py"
|
||||
# idp_redirect_url: "http://%s/idp"
|
||||
""" % (config_dir_path, server_name)
|
|
@ -20,25 +20,98 @@ class ServerConfig(Config):
|
|||
|
||||
def read_config(self, config):
|
||||
self.server_name = config["server_name"]
|
||||
self.bind_port = config["bind_port"]
|
||||
self.bind_host = config["bind_host"]
|
||||
self.unsecure_port = config["unsecure_port"]
|
||||
self.manhole = config.get("manhole")
|
||||
self.pid_file = self.abspath(config.get("pid_file"))
|
||||
self.web_client = config["web_client"]
|
||||
self.soft_file_limit = config["soft_file_limit"]
|
||||
self.daemonize = config.get("daemonize")
|
||||
self.print_pidfile = config.get("print_pidfile")
|
||||
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
|
||||
|
||||
self.listeners = config.get("listeners", [])
|
||||
|
||||
bind_port = config.get("bind_port")
|
||||
if bind_port:
|
||||
self.listeners = []
|
||||
bind_host = config.get("bind_host", "")
|
||||
gzip_responses = config.get("gzip_responses", True)
|
||||
|
||||
names = ["client", "webclient"] if self.web_client else ["client"]
|
||||
|
||||
self.listeners.append({
|
||||
"port": bind_port,
|
||||
"bind_address": bind_host,
|
||||
"tls": True,
|
||||
"type": "http",
|
||||
"resources": [
|
||||
{
|
||||
"names": names,
|
||||
"compress": gzip_responses,
|
||||
},
|
||||
{
|
||||
"names": ["federation"],
|
||||
"compress": False,
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
unsecure_port = config.get("unsecure_port", bind_port - 400)
|
||||
if unsecure_port:
|
||||
self.listeners.append({
|
||||
"port": unsecure_port,
|
||||
"bind_address": bind_host,
|
||||
"tls": False,
|
||||
"type": "http",
|
||||
"resources": [
|
||||
{
|
||||
"names": names,
|
||||
"compress": gzip_responses,
|
||||
},
|
||||
{
|
||||
"names": ["federation"],
|
||||
"compress": False,
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
manhole = config.get("manhole")
|
||||
if manhole:
|
||||
self.listeners.append({
|
||||
"port": manhole,
|
||||
"bind_address": "127.0.0.1",
|
||||
"type": "manhole",
|
||||
})
|
||||
|
||||
metrics_port = config.get("metrics_port")
|
||||
if metrics_port:
|
||||
self.listeners.append({
|
||||
"port": metrics_port,
|
||||
"bind_address": config.get("metrics_bind_host", "127.0.0.1"),
|
||||
"tls": False,
|
||||
"type": "http",
|
||||
"resources": [
|
||||
{
|
||||
"names": ["metrics"],
|
||||
"compress": False,
|
||||
},
|
||||
]
|
||||
})
|
||||
|
||||
# Attempt to guess the content_addr for the v0 content repostitory
|
||||
content_addr = config.get("content_addr")
|
||||
if not content_addr:
|
||||
for listener in self.listeners:
|
||||
if listener["type"] == "http" and not listener.get("tls", False):
|
||||
unsecure_port = listener["port"]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Could not determine 'content_addr'")
|
||||
|
||||
host = self.server_name
|
||||
if ':' not in host:
|
||||
host = "%s:%d" % (host, self.unsecure_port)
|
||||
host = "%s:%d" % (host, unsecure_port)
|
||||
else:
|
||||
host = host.split(':')[0]
|
||||
host = "%s:%d" % (host, self.unsecure_port)
|
||||
host = "%s:%d" % (host, unsecure_port)
|
||||
content_addr = "http://%s" % (host,)
|
||||
|
||||
self.content_addr = content_addr
|
||||
|
@ -60,18 +133,6 @@ class ServerConfig(Config):
|
|||
# e.g. matrix.org, localhost:8080, etc.
|
||||
server_name: "%(server_name)s"
|
||||
|
||||
# The port to listen for HTTPS requests on.
|
||||
# For when matrix traffic is sent directly to synapse.
|
||||
bind_port: %(bind_port)s
|
||||
|
||||
# The port to listen for HTTP requests on.
|
||||
# For when matrix traffic passes through loadbalancer that unwraps TLS.
|
||||
unsecure_port: %(unsecure_port)s
|
||||
|
||||
# Local interface to listen on.
|
||||
# The empty string will cause synapse to listen on all interfaces.
|
||||
bind_host: ""
|
||||
|
||||
# When running as a daemon, the file to store the pid in
|
||||
pid_file: %(pid_file)s
|
||||
|
||||
|
@ -83,9 +144,64 @@ class ServerConfig(Config):
|
|||
# hard limit.
|
||||
soft_file_limit: 0
|
||||
|
||||
# List of ports that Synapse should listen on, their purpose and their
|
||||
# configuration.
|
||||
listeners:
|
||||
# Main HTTPS listener
|
||||
# For when matrix traffic is sent directly to synapse.
|
||||
-
|
||||
# The port to listen for HTTPS requests on.
|
||||
port: %(bind_port)s
|
||||
|
||||
# Local interface to listen on.
|
||||
# The empty string will cause synapse to listen on all interfaces.
|
||||
bind_address: ''
|
||||
|
||||
# This is a 'http' listener, allows us to specify 'resources'.
|
||||
type: http
|
||||
|
||||
tls: true
|
||||
|
||||
# Use the X-Forwarded-For (XFF) header as the client IP and not the
|
||||
# actual client IP.
|
||||
x_forwarded: false
|
||||
|
||||
# List of HTTP resources to serve on this listener.
|
||||
resources:
|
||||
-
|
||||
# List of resources to host on this listener.
|
||||
names:
|
||||
- client # The client-server APIs, both v1 and v2
|
||||
- webclient # The bundled webclient.
|
||||
|
||||
# Should synapse compress HTTP responses to clients that support it?
|
||||
# This should be disabled if running synapse behind a load balancer
|
||||
# that can do automatic compression.
|
||||
compress: true
|
||||
|
||||
- names: [federation] # Federation APIs
|
||||
compress: false
|
||||
|
||||
# Unsecure HTTP listener,
|
||||
# For when matrix traffic passes through loadbalancer that unwraps TLS.
|
||||
- port: %(unsecure_port)s
|
||||
tls: false
|
||||
bind_address: ''
|
||||
type: http
|
||||
|
||||
x_forwarded: false
|
||||
|
||||
resources:
|
||||
- names: [client, webclient]
|
||||
compress: true
|
||||
- names: [federation]
|
||||
compress: false
|
||||
|
||||
# Turn on the twisted telnet manhole service on localhost on the given
|
||||
# port.
|
||||
#manhole: 9000
|
||||
# - port: 9000
|
||||
# bind_address: 127.0.0.1
|
||||
# type: manhole
|
||||
""" % locals()
|
||||
|
||||
def read_arguments(self, args):
|
||||
|
@ -93,12 +209,18 @@ class ServerConfig(Config):
|
|||
self.manhole = args.manhole
|
||||
if args.daemonize is not None:
|
||||
self.daemonize = args.daemonize
|
||||
if args.print_pidfile is not None:
|
||||
self.print_pidfile = args.print_pidfile
|
||||
|
||||
def add_arguments(self, parser):
|
||||
server_group = parser.add_argument_group("server")
|
||||
server_group.add_argument("-D", "--daemonize", action='store_true',
|
||||
default=None,
|
||||
help="Daemonize the home server")
|
||||
server_group.add_argument("--print-pidfile", action='store_true',
|
||||
default=None,
|
||||
help="Print the path to the pidfile just"
|
||||
" before daemonizing")
|
||||
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
||||
type=int,
|
||||
help="Turn on the twisted telnet manhole"
|
||||
|
|
|
@ -27,6 +27,7 @@ class TlsConfig(Config):
|
|||
self.tls_certificate = self.read_tls_certificate(
|
||||
config.get("tls_certificate_path")
|
||||
)
|
||||
self.tls_certificate_file = config.get("tls_certificate_path")
|
||||
|
||||
self.no_tls = config.get("no_tls", False)
|
||||
|
||||
|
@ -49,7 +50,11 @@ class TlsConfig(Config):
|
|||
tls_dh_params_path = base_key_name + ".tls.dh"
|
||||
|
||||
return """\
|
||||
# PEM encoded X509 certificate for TLS
|
||||
# PEM encoded X509 certificate for TLS.
|
||||
# You can replace the self-signed certificate that synapse
|
||||
# autogenerates on launch with your own SSL certificate + key pair
|
||||
# if you like. Any required intermediary certificates can be
|
||||
# appended after the primary certificate in hierarchical order.
|
||||
tls_certificate_path: "%(tls_certificate_path)s"
|
||||
|
||||
# PEM encoded private key for TLS
|
||||
|
@ -91,7 +96,7 @@ class TlsConfig(Config):
|
|||
)
|
||||
|
||||
if not os.path.exists(tls_certificate_path):
|
||||
with open(tls_certificate_path, "w") as certifcate_file:
|
||||
with open(tls_certificate_path, "w") as certificate_file:
|
||||
cert = crypto.X509()
|
||||
subject = cert.get_subject()
|
||||
subject.CN = config["server_name"]
|
||||
|
@ -106,7 +111,7 @@ class TlsConfig(Config):
|
|||
|
||||
cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
|
||||
|
||||
certifcate_file.write(cert_pem)
|
||||
certificate_file.write(cert_pem)
|
||||
|
||||
if not os.path.exists(tls_dh_params_path):
|
||||
if GENERATE_DH_PARAMS:
|
||||
|
|
|
@ -35,9 +35,9 @@ class ServerContextFactory(ssl.ContextFactory):
|
|||
_ecCurve = _OpenSSLECCurve(_defaultCurveName)
|
||||
_ecCurve.addECKeyToContext(context)
|
||||
except:
|
||||
logger.exception("Failed to enable eliptic curve for TLS")
|
||||
logger.exception("Failed to enable elliptic curve for TLS")
|
||||
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
|
||||
context.use_certificate(config.tls_certificate)
|
||||
context.use_certificate_chain_file(config.tls_certificate_file)
|
||||
|
||||
if not config.no_tls:
|
||||
context.use_privatekey(config.tls_private_key)
|
||||
|
|
|
@ -25,11 +25,13 @@ from syutil.base64util import decode_base64, encode_base64
|
|||
from synapse.api.errors import SynapseError, Codes
|
||||
|
||||
from synapse.util.retryutils import get_retry_limiter
|
||||
from synapse.util import unwrapFirstError
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
|
||||
from OpenSSL import crypto
|
||||
|
||||
from collections import namedtuple
|
||||
import urllib
|
||||
import hashlib
|
||||
import logging
|
||||
|
@ -38,6 +40,9 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
|
||||
|
||||
|
||||
class Keyring(object):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -49,18 +54,55 @@ class Keyring(object):
|
|||
|
||||
self.key_downloads = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def verify_json_for_server(self, server_name, json_object):
|
||||
return self.verify_json_objects_for_server(
|
||||
[(server_name, json_object)]
|
||||
)[0]
|
||||
|
||||
def verify_json_objects_for_server(self, server_and_json):
|
||||
"""Bulk verfies signatures of json objects, bulk fetching keys as
|
||||
necessary.
|
||||
|
||||
Args:
|
||||
server_and_json (list): List of pairs of (server_name, json_object)
|
||||
|
||||
Returns:
|
||||
list of deferreds indicating success or failure to verify each
|
||||
json object's signature for the given server_name.
|
||||
"""
|
||||
group_id_to_json = {}
|
||||
group_id_to_group = {}
|
||||
group_ids = []
|
||||
|
||||
next_group_id = 0
|
||||
deferreds = {}
|
||||
|
||||
for server_name, json_object in server_and_json:
|
||||
logger.debug("Verifying for %s", server_name)
|
||||
group_id = next_group_id
|
||||
next_group_id += 1
|
||||
group_ids.append(group_id)
|
||||
|
||||
key_ids = signature_ids(json_object, server_name)
|
||||
if not key_ids:
|
||||
raise SynapseError(
|
||||
deferreds[group_id] = defer.fail(SynapseError(
|
||||
400,
|
||||
"Not signed with a supported algorithm",
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
))
|
||||
else:
|
||||
deferreds[group_id] = defer.Deferred()
|
||||
|
||||
group = KeyGroup(server_name, group_id, key_ids)
|
||||
|
||||
group_id_to_group[group_id] = group
|
||||
group_id_to_json[group_id] = json_object
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_key_deferred(group, deferred):
|
||||
server_name = group.server_name
|
||||
try:
|
||||
verify_key = yield self.get_server_verify_key(server_name, key_ids)
|
||||
_, _, key_id, verify_key = yield deferred
|
||||
except IOError as e:
|
||||
logger.warn(
|
||||
"Got IOError when downloading keys for %s: %s %s",
|
||||
|
@ -72,7 +114,7 @@ class Keyring(object):
|
|||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
logger.exception(
|
||||
"Got Exception when downloading keys for %s: %s %s",
|
||||
server_name, type(e).__name__, str(e.message),
|
||||
)
|
||||
|
@ -82,6 +124,8 @@ class Keyring(object):
|
|||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
json_object = group_id_to_json[group.group_id]
|
||||
|
||||
try:
|
||||
verify_signed_json(json_object, server_name, verify_key)
|
||||
except:
|
||||
|
@ -93,79 +137,208 @@ class Keyring(object):
|
|||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_key(self, server_name, key_ids):
|
||||
"""Finds a verification key for the server with one of the key ids.
|
||||
Trys to fetch the key from a trusted perspective server first.
|
||||
Args:
|
||||
server_name(str): The name of the server to fetch a key for.
|
||||
keys_ids (list of str): The key_ids to check for.
|
||||
"""
|
||||
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
|
||||
server_to_deferred = {
|
||||
server_name: defer.Deferred()
|
||||
for server_name, _ in server_and_json
|
||||
}
|
||||
|
||||
if cached:
|
||||
defer.returnValue(cached[0])
|
||||
return
|
||||
|
||||
download = self.key_downloads.get(server_name)
|
||||
|
||||
if download is None:
|
||||
download = self._get_server_verify_key_impl(server_name, key_ids)
|
||||
download = ObservableDeferred(
|
||||
download,
|
||||
consumeErrors=True
|
||||
# We want to wait for any previous lookups to complete before
|
||||
# proceeding.
|
||||
wait_on_deferred = self.wait_for_previous_lookups(
|
||||
[server_name for server_name, _ in server_and_json],
|
||||
server_to_deferred,
|
||||
)
|
||||
self.key_downloads[server_name] = download
|
||||
|
||||
@download.addBoth
|
||||
def callback(ret):
|
||||
del self.key_downloads[server_name]
|
||||
return ret
|
||||
# Actually start fetching keys.
|
||||
wait_on_deferred.addBoth(
|
||||
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
|
||||
)
|
||||
|
||||
r = yield download.observe()
|
||||
defer.returnValue(r)
|
||||
# When we've finished fetching all the keys for a given server_name,
|
||||
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||
# any lookups waiting will proceed.
|
||||
server_to_gids = {}
|
||||
|
||||
def remove_deferreds(res, server_name, group_id):
|
||||
server_to_gids[server_name].discard(group_id)
|
||||
if not server_to_gids[server_name]:
|
||||
server_to_deferred.pop(server_name).callback(None)
|
||||
return res
|
||||
|
||||
for g_id, deferred in deferreds.items():
|
||||
server_name = group_id_to_group[g_id].server_name
|
||||
server_to_gids.setdefault(server_name, set()).add(g_id)
|
||||
deferred.addBoth(remove_deferreds, server_name, g_id)
|
||||
|
||||
# Pass those keys to handle_key_deferred so that the json object
|
||||
# signatures can be verified
|
||||
return [
|
||||
handle_key_deferred(
|
||||
group_id_to_group[g_id],
|
||||
deferreds[g_id],
|
||||
)
|
||||
for g_id in group_ids
|
||||
]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_server_verify_key_impl(self, server_name, key_ids):
|
||||
keys = None
|
||||
def wait_for_previous_lookups(self, server_names, server_to_deferred):
|
||||
"""Waits for any previous key lookups for the given servers to finish.
|
||||
|
||||
Args:
|
||||
server_names (list): list of server_names we want to lookup
|
||||
server_to_deferred (dict): server_name to deferred which gets
|
||||
resolved once we've finished looking up keys for that server
|
||||
"""
|
||||
while True:
|
||||
wait_on = [
|
||||
self.key_downloads[server_name]
|
||||
for server_name in server_names
|
||||
if server_name in self.key_downloads
|
||||
]
|
||||
if wait_on:
|
||||
yield defer.DeferredList(wait_on)
|
||||
else:
|
||||
break
|
||||
|
||||
for server_name, deferred in server_to_deferred:
|
||||
self.key_downloads[server_name] = ObservableDeferred(deferred)
|
||||
|
||||
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
|
||||
"""Takes a dict of KeyGroups and tries to find at least one key for
|
||||
each group.
|
||||
"""
|
||||
|
||||
# These are functions that produce keys given a list of key ids
|
||||
key_fetch_fns = (
|
||||
self.get_keys_from_store, # First try the local store
|
||||
self.get_keys_from_perspectives, # Then try via perspectives
|
||||
self.get_keys_from_server, # Then try directly
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_iterations():
|
||||
merged_results = {}
|
||||
|
||||
missing_keys = {
|
||||
group.server_name: key_id
|
||||
for group in group_id_to_group.values()
|
||||
for key_id in group.key_ids
|
||||
}
|
||||
|
||||
for fn in key_fetch_fns:
|
||||
results = yield fn(missing_keys.items())
|
||||
merged_results.update(results)
|
||||
|
||||
# We now need to figure out which groups we have keys for
|
||||
# and which we don't
|
||||
missing_groups = {}
|
||||
for group in group_id_to_group.values():
|
||||
for key_id in group.key_ids:
|
||||
if key_id in merged_results[group.server_name]:
|
||||
group_id_to_deferred[group.group_id].callback((
|
||||
group.group_id,
|
||||
group.server_name,
|
||||
key_id,
|
||||
merged_results[group.server_name][key_id],
|
||||
))
|
||||
break
|
||||
else:
|
||||
missing_groups.setdefault(
|
||||
group.server_name, []
|
||||
).append(group)
|
||||
|
||||
if not missing_groups:
|
||||
break
|
||||
|
||||
missing_keys = {
|
||||
server_name: set(
|
||||
key_id for group in groups for key_id in group.key_ids
|
||||
)
|
||||
for server_name, groups in missing_groups.items()
|
||||
}
|
||||
|
||||
for group in missing_groups.values():
|
||||
group_id_to_deferred[group.group_id].errback(SynapseError(
|
||||
401,
|
||||
"No key for %s with id %s" % (
|
||||
group.server_name, group.key_ids,
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
))
|
||||
|
||||
def on_err(err):
|
||||
for deferred in group_id_to_deferred.values():
|
||||
if not deferred.called:
|
||||
deferred.errback(err)
|
||||
|
||||
do_iterations().addErrback(on_err)
|
||||
|
||||
return group_id_to_deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys_from_store(self, server_name_and_key_ids):
|
||||
res = yield defer.gatherResults(
|
||||
[
|
||||
self.store.get_server_verify_keys(server_name, key_ids)
|
||||
for server_name, key_ids in server_name_and_key_ids
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
defer.returnValue(dict(zip(
|
||||
[server_name for server_name, _ in server_name_and_key_ids],
|
||||
res
|
||||
)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys_from_perspectives(self, server_name_and_key_ids):
|
||||
@defer.inlineCallbacks
|
||||
def get_key(perspective_name, perspective_keys):
|
||||
try:
|
||||
result = yield self.get_server_verify_key_v2_indirect(
|
||||
server_name, key_ids, perspective_name, perspective_keys
|
||||
server_name_and_key_ids, perspective_name, perspective_keys
|
||||
)
|
||||
defer.returnValue(result)
|
||||
except Exception as e:
|
||||
logging.info(
|
||||
"Unable to getting key %r for %r from %r: %s %s",
|
||||
key_ids, server_name, perspective_name,
|
||||
logger.exception(
|
||||
"Unable to get key from %r: %s %s",
|
||||
perspective_name,
|
||||
type(e).__name__, str(e.message),
|
||||
)
|
||||
defer.returnValue({})
|
||||
|
||||
perspective_results = yield defer.gatherResults([
|
||||
results = yield defer.gatherResults(
|
||||
[
|
||||
get_key(p_name, p_keys)
|
||||
for p_name, p_keys in self.perspective_servers.items()
|
||||
])
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
for results in perspective_results:
|
||||
if results is not None:
|
||||
keys = results
|
||||
union_of_keys = {}
|
||||
for result in results:
|
||||
for server_name, keys in result.items():
|
||||
union_of_keys.setdefault(server_name, {}).update(keys)
|
||||
|
||||
defer.returnValue(union_of_keys)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys_from_server(self, server_name_and_key_ids):
|
||||
@defer.inlineCallbacks
|
||||
def get_key(server_name, key_ids):
|
||||
limiter = yield get_retry_limiter(
|
||||
server_name,
|
||||
self.clock,
|
||||
self.store,
|
||||
)
|
||||
|
||||
with limiter:
|
||||
if not keys:
|
||||
keys = None
|
||||
try:
|
||||
keys = yield self.get_server_verify_key_v2_direct(
|
||||
server_name, key_ids
|
||||
)
|
||||
except Exception as e:
|
||||
logging.info(
|
||||
logger.info(
|
||||
"Unable to getting key %r for %r directly: %s %s",
|
||||
key_ids, server_name,
|
||||
type(e).__name__, str(e.message),
|
||||
|
@ -176,14 +349,30 @@ class Keyring(object):
|
|||
server_name, key_ids
|
||||
)
|
||||
|
||||
for key_id in key_ids:
|
||||
if key_id in keys:
|
||||
defer.returnValue(keys[key_id])
|
||||
return
|
||||
raise ValueError("No verification key found for given key ids")
|
||||
keys = {server_name: keys}
|
||||
|
||||
defer.returnValue(keys)
|
||||
|
||||
results = yield defer.gatherResults(
|
||||
[
|
||||
get_key(server_name, key_ids)
|
||||
for server_name, key_ids in server_name_and_key_ids
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
merged = {}
|
||||
for result in results:
|
||||
merged.update(result)
|
||||
|
||||
defer.returnValue({
|
||||
server_name: keys
|
||||
for server_name, keys in merged.items()
|
||||
if keys
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_key_v2_indirect(self, server_name, key_ids,
|
||||
def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
|
||||
perspective_name,
|
||||
perspective_keys):
|
||||
limiter = yield get_retry_limiter(
|
||||
|
@ -204,6 +393,7 @@ class Keyring(object):
|
|||
u"minimum_valid_until_ts": 0
|
||||
} for key_id in key_ids
|
||||
}
|
||||
for server_name, key_ids in server_names_and_key_ids
|
||||
}
|
||||
},
|
||||
)
|
||||
|
@ -243,23 +433,29 @@ class Keyring(object):
|
|||
" server %r" % (perspective_name,)
|
||||
)
|
||||
|
||||
response_keys = yield self.process_v2_response(
|
||||
server_name, perspective_name, response
|
||||
processed_response = yield self.process_v2_response(
|
||||
perspective_name, response
|
||||
)
|
||||
|
||||
keys.update(response_keys)
|
||||
for server_name, response_keys in processed_response.items():
|
||||
keys.setdefault(server_name, {}).update(response_keys)
|
||||
|
||||
yield self.store_keys(
|
||||
yield defer.gatherResults(
|
||||
[
|
||||
self.store_keys(
|
||||
server_name=server_name,
|
||||
from_server=perspective_name,
|
||||
verify_keys=keys,
|
||||
verify_keys=response_keys,
|
||||
)
|
||||
for server_name, response_keys in keys.items()
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
defer.returnValue(keys)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
||||
|
||||
keys = {}
|
||||
|
||||
for requested_key_id in key_ids:
|
||||
|
@ -295,25 +491,30 @@ class Keyring(object):
|
|||
raise ValueError("TLS certificate not allowed by fingerprints")
|
||||
|
||||
response_keys = yield self.process_v2_response(
|
||||
server_name=server_name,
|
||||
from_server=server_name,
|
||||
requested_id=requested_key_id,
|
||||
requested_ids=[requested_key_id],
|
||||
response_json=response,
|
||||
)
|
||||
|
||||
keys.update(response_keys)
|
||||
|
||||
yield self.store_keys(
|
||||
server_name=server_name,
|
||||
yield defer.gatherResults(
|
||||
[
|
||||
self.store_keys(
|
||||
server_name=key_server_name,
|
||||
from_server=server_name,
|
||||
verify_keys=keys,
|
||||
verify_keys=verify_keys,
|
||||
)
|
||||
for key_server_name, verify_keys in keys.items()
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
defer.returnValue(keys)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def process_v2_response(self, server_name, from_server, response_json,
|
||||
requested_id=None):
|
||||
def process_v2_response(self, from_server, response_json,
|
||||
requested_ids=[]):
|
||||
time_now_ms = self.clock.time_msec()
|
||||
response_keys = {}
|
||||
verify_keys = {}
|
||||
|
@ -335,6 +536,8 @@ class Keyring(object):
|
|||
verify_key.time_added = time_now_ms
|
||||
old_verify_keys[key_id] = verify_key
|
||||
|
||||
results = {}
|
||||
server_name = response_json["server_name"]
|
||||
for key_id in response_json["signatures"].get(server_name, {}):
|
||||
if key_id not in response_json["verify_keys"]:
|
||||
raise ValueError(
|
||||
|
@ -357,17 +560,16 @@ class Keyring(object):
|
|||
signed_key_json_bytes = encode_canonical_json(signed_key_json)
|
||||
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
|
||||
|
||||
updated_key_ids = set()
|
||||
if requested_id is not None:
|
||||
updated_key_ids.add(requested_id)
|
||||
updated_key_ids = set(requested_ids)
|
||||
updated_key_ids.update(verify_keys)
|
||||
updated_key_ids.update(old_verify_keys)
|
||||
|
||||
response_keys.update(verify_keys)
|
||||
response_keys.update(old_verify_keys)
|
||||
|
||||
for key_id in updated_key_ids:
|
||||
yield self.store.store_server_keys_json(
|
||||
yield defer.gatherResults(
|
||||
[
|
||||
self.store.store_server_keys_json(
|
||||
server_name=server_name,
|
||||
key_id=key_id,
|
||||
from_server=server_name,
|
||||
|
@ -375,10 +577,14 @@ class Keyring(object):
|
|||
ts_expires_ms=ts_valid_until_ms,
|
||||
key_json_bytes=signed_key_json_bytes,
|
||||
)
|
||||
for key_id in updated_key_ids
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
defer.returnValue(response_keys)
|
||||
results[server_name] = response_keys
|
||||
|
||||
raise ValueError("No verification key found for given key ids")
|
||||
defer.returnValue(results)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_key_v1_direct(self, server_name, key_ids):
|
||||
|
@ -462,8 +668,13 @@ class Keyring(object):
|
|||
Returns:
|
||||
A deferred that completes when the keys are stored.
|
||||
"""
|
||||
for key_id, key in verify_keys.items():
|
||||
# TODO(markjh): Store whether the keys have expired.
|
||||
yield self.store.store_server_verify_key(
|
||||
yield defer.gatherResults(
|
||||
[
|
||||
self.store.store_server_verify_key(
|
||||
server_name, server_name, key.time_added, key
|
||||
)
|
||||
for key_id, key in verify_keys.items()
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
|
|
@ -74,6 +74,8 @@ def prune_event(event):
|
|||
)
|
||||
elif event_type == EventTypes.Aliases:
|
||||
add_fields("aliases")
|
||||
elif event_type == EventTypes.RoomHistoryVisibility:
|
||||
add_fields("history_visibility")
|
||||
|
||||
allowed_fields = {
|
||||
k: v
|
||||
|
|
|
@ -18,8 +18,6 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.events.utils import prune_event
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
from synapse.crypto.event_signing import check_event_content_hash
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
|
@ -34,7 +32,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class FederationBase(object):
|
||||
@defer.inlineCallbacks
|
||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
|
||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
||||
include_none=False):
|
||||
"""Takes a list of PDUs and checks the signatures and hashs of each
|
||||
one. If a PDU fails its signature check then we check if we have it in
|
||||
the database and if not then request if from the originating server of
|
||||
|
@ -52,85 +51,108 @@ class FederationBase(object):
|
|||
Returns:
|
||||
Deferred : A list of PDUs that have valid signatures and hashes.
|
||||
"""
|
||||
deferreds = self._check_sigs_and_hashes(pdus)
|
||||
|
||||
signed_pdus = []
|
||||
def callback(pdu):
|
||||
return pdu
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do(pdu):
|
||||
try:
|
||||
new_pdu = yield self._check_sigs_and_hash(pdu)
|
||||
signed_pdus.append(new_pdu)
|
||||
except SynapseError:
|
||||
# FIXME: We should handle signature failures more gracefully.
|
||||
def errback(failure, pdu):
|
||||
failure.trap(SynapseError)
|
||||
return None
|
||||
|
||||
def try_local_db(res, pdu):
|
||||
if not res:
|
||||
# Check local db.
|
||||
new_pdu = yield self.store.get_event(
|
||||
return self.store.get_event(
|
||||
pdu.event_id,
|
||||
allow_rejected=True,
|
||||
allow_none=True,
|
||||
)
|
||||
if new_pdu:
|
||||
signed_pdus.append(new_pdu)
|
||||
return
|
||||
return res
|
||||
|
||||
# Check pdu.origin
|
||||
if pdu.origin != origin:
|
||||
try:
|
||||
new_pdu = yield self.get_pdu(
|
||||
def try_remote(res, pdu):
|
||||
if not res and pdu.origin != origin:
|
||||
return self.get_pdu(
|
||||
destinations=[pdu.origin],
|
||||
event_id=pdu.event_id,
|
||||
outlier=outlier,
|
||||
timeout=10000,
|
||||
)
|
||||
|
||||
if new_pdu:
|
||||
signed_pdus.append(new_pdu)
|
||||
return
|
||||
except:
|
||||
pass
|
||||
).addErrback(lambda e: None)
|
||||
return res
|
||||
|
||||
def warn(res, pdu):
|
||||
if not res:
|
||||
logger.warn(
|
||||
"Failed to find copy of %s with valid signature",
|
||||
pdu.event_id,
|
||||
)
|
||||
return res
|
||||
|
||||
yield defer.gatherResults(
|
||||
[do(pdu) for pdu in pdus],
|
||||
for pdu, deferred in zip(pdus, deferreds):
|
||||
deferred.addCallbacks(
|
||||
callback, errback, errbackArgs=[pdu]
|
||||
).addCallback(
|
||||
try_local_db, pdu
|
||||
).addCallback(
|
||||
try_remote, pdu
|
||||
).addCallback(
|
||||
warn, pdu
|
||||
)
|
||||
|
||||
valid_pdus = yield defer.gatherResults(
|
||||
deferreds,
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
defer.returnValue(signed_pdus)
|
||||
if include_none:
|
||||
defer.returnValue(valid_pdus)
|
||||
else:
|
||||
defer.returnValue([p for p in valid_pdus if p])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_sigs_and_hash(self, pdu):
|
||||
"""Throws a SynapseError if the PDU does not have the correct
|
||||
return self._check_sigs_and_hashes([pdu])[0]
|
||||
|
||||
def _check_sigs_and_hashes(self, pdus):
|
||||
"""Throws a SynapseError if a PDU does not have the correct
|
||||
signatures.
|
||||
|
||||
Returns:
|
||||
FrozenEvent: Either the given event or it redacted if it failed the
|
||||
content hash check.
|
||||
"""
|
||||
# Check signatures are correct.
|
||||
redacted_event = prune_event(pdu)
|
||||
redacted_pdu_json = redacted_event.get_pdu_json()
|
||||
|
||||
try:
|
||||
yield self.keyring.verify_json_for_server(
|
||||
pdu.origin, redacted_pdu_json
|
||||
)
|
||||
except SynapseError:
|
||||
logger.warn(
|
||||
"Signature check failed for %s redacted to %s",
|
||||
encode_canonical_json(pdu.get_pdu_json()),
|
||||
encode_canonical_json(redacted_pdu_json),
|
||||
)
|
||||
raise
|
||||
redacted_pdus = [
|
||||
prune_event(pdu)
|
||||
for pdu in pdus
|
||||
]
|
||||
|
||||
deferreds = self.keyring.verify_json_objects_for_server([
|
||||
(p.origin, p.get_pdu_json())
|
||||
for p in redacted_pdus
|
||||
])
|
||||
|
||||
def callback(_, pdu, redacted):
|
||||
if not check_event_content_hash(pdu):
|
||||
logger.warn(
|
||||
"Event content has been tampered, redacting %s, %s",
|
||||
pdu.event_id, encode_canonical_json(pdu.get_dict())
|
||||
"Event content has been tampered, redacting %s: %s",
|
||||
pdu.event_id, pdu.get_pdu_json()
|
||||
)
|
||||
defer.returnValue(redacted_event)
|
||||
return redacted
|
||||
return pdu
|
||||
|
||||
defer.returnValue(pdu)
|
||||
def errback(failure, pdu):
|
||||
failure.trap(SynapseError)
|
||||
logger.warn(
|
||||
"Signature check failed for %s",
|
||||
pdu.event_id,
|
||||
)
|
||||
return failure
|
||||
|
||||
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
|
||||
deferred.addCallbacks(
|
||||
callback, errback,
|
||||
callbackArgs=[pdu, redacted],
|
||||
errbackArgs=[pdu],
|
||||
)
|
||||
|
||||
return deferreds
|
||||
|
|
|
@ -30,6 +30,7 @@ import synapse.metrics
|
|||
|
||||
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
import random
|
||||
|
@ -167,7 +168,7 @@ class FederationClient(FederationBase):
|
|||
|
||||
# FIXME: We should handle signature failures more gracefully.
|
||||
pdus[:] = yield defer.gatherResults(
|
||||
[self._check_sigs_and_hash(pdu) for pdu in pdus],
|
||||
self._check_sigs_and_hashes(pdus),
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
|
@ -230,7 +231,7 @@ class FederationClient(FederationBase):
|
|||
pdu = pdu_list[0]
|
||||
|
||||
# Check signatures are correct.
|
||||
pdu = yield self._check_sigs_and_hash(pdu)
|
||||
pdu = yield self._check_sigs_and_hashes([pdu])[0]
|
||||
|
||||
break
|
||||
|
||||
|
@ -327,6 +328,9 @@ class FederationClient(FederationBase):
|
|||
@defer.inlineCallbacks
|
||||
def make_join(self, destinations, room_id, user_id):
|
||||
for destination in destinations:
|
||||
if destination == self.server_name:
|
||||
continue
|
||||
|
||||
try:
|
||||
ret = yield self.transport_layer.make_join(
|
||||
destination, room_id, user_id
|
||||
|
@ -353,6 +357,9 @@ class FederationClient(FederationBase):
|
|||
@defer.inlineCallbacks
|
||||
def send_join(self, destinations, pdu):
|
||||
for destination in destinations:
|
||||
if destination == self.server_name:
|
||||
continue
|
||||
|
||||
try:
|
||||
time_now = self._clock.time_msec()
|
||||
_, content = yield self.transport_layer.send_join(
|
||||
|
@ -374,17 +381,39 @@ class FederationClient(FederationBase):
|
|||
for p in content.get("auth_chain", [])
|
||||
]
|
||||
|
||||
signed_state, signed_auth = yield defer.gatherResults(
|
||||
[
|
||||
self._check_sigs_and_hash_and_fetch(
|
||||
destination, state, outlier=True
|
||||
),
|
||||
self._check_sigs_and_hash_and_fetch(
|
||||
destination, auth_chain, outlier=True
|
||||
pdus = {
|
||||
p.event_id: p
|
||||
for p in itertools.chain(state, auth_chain)
|
||||
}
|
||||
|
||||
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, pdus.values(),
|
||||
outlier=True,
|
||||
)
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
valid_pdus_map = {
|
||||
p.event_id: p
|
||||
for p in valid_pdus
|
||||
}
|
||||
|
||||
# NB: We *need* to copy to ensure that we don't have multiple
|
||||
# references being passed on, as that causes... issues.
|
||||
signed_state = [
|
||||
copy.copy(valid_pdus_map[p.event_id])
|
||||
for p in state
|
||||
if p.event_id in valid_pdus_map
|
||||
]
|
||||
|
||||
signed_auth = [
|
||||
valid_pdus_map[p.event_id]
|
||||
for p in auth_chain
|
||||
if p.event_id in valid_pdus_map
|
||||
]
|
||||
|
||||
# NB: We *need* to copy to ensure that we don't have multiple
|
||||
# references being passed on, as that causes... issues.
|
||||
for s in signed_state:
|
||||
s.internal_metadata = copy.deepcopy(s.internal_metadata)
|
||||
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
|
@ -396,7 +425,7 @@ class FederationClient(FederationBase):
|
|||
except CodeMessageException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
logger.exception(
|
||||
"Failed to send_join via %s: %s",
|
||||
destination, e.message
|
||||
)
|
||||
|
|
|
@ -93,6 +93,9 @@ class TransportLayerServer(object):
|
|||
|
||||
yield self.keyring.verify_json_for_server(origin, json_request)
|
||||
|
||||
logger.info("Request from %s", origin)
|
||||
request.authenticated_entity = origin
|
||||
|
||||
defer.returnValue((origin, content))
|
||||
|
||||
@log_function
|
||||
|
|
|
@ -32,6 +32,7 @@ from .appservice import ApplicationServicesHandler
|
|||
from .sync import SyncHandler
|
||||
from .auth import AuthHandler
|
||||
from .identity import IdentityHandler
|
||||
from .receipts import ReceiptsHandler
|
||||
|
||||
|
||||
class Handlers(object):
|
||||
|
@ -57,6 +58,7 @@ class Handlers(object):
|
|||
self.directory_handler = DirectoryHandler(hs)
|
||||
self.typing_notification_handler = TypingNotificationHandler(hs)
|
||||
self.admin_handler = AdminHandler(hs)
|
||||
self.receipts_handler = ReceiptsHandler(hs)
|
||||
asapi = ApplicationServiceApi(hs)
|
||||
self.appservice_handler = ApplicationServicesHandler(
|
||||
hs, asapi, AppServiceScheduler(
|
||||
|
|
|
@ -78,7 +78,9 @@ class BaseHandler(object):
|
|||
context = yield state_handler.compute_event_context(builder)
|
||||
|
||||
if builder.is_state():
|
||||
builder.prev_state = context.prev_state_events
|
||||
builder.prev_state = yield self.store.add_event_hashes(
|
||||
context.prev_state_events
|
||||
)
|
||||
|
||||
yield self.auth.add_auth_events(builder, context)
|
||||
|
||||
|
|
|
@ -177,7 +177,7 @@ class ApplicationServicesHandler(object):
|
|||
return
|
||||
|
||||
user_info = yield self.store.get_user_by_id(user_id)
|
||||
if not user_info:
|
||||
if user_info:
|
||||
defer.returnValue(False)
|
||||
return
|
||||
|
||||
|
|
|
@ -85,8 +85,10 @@ class AuthHandler(BaseHandler):
|
|||
# email auth link on there). It's probably too open to abuse
|
||||
# because it lets unauthenticated clients store arbitrary objects
|
||||
# on a home server.
|
||||
# sess['clientdict'] = clientdict
|
||||
# self._save_session(sess)
|
||||
# Revisit: Assumimg the REST APIs do sensible validation, the data
|
||||
# isn't arbintrary.
|
||||
sess['clientdict'] = clientdict
|
||||
self._save_session(sess)
|
||||
pass
|
||||
elif 'clientdict' in sess:
|
||||
clientdict = sess['clientdict']
|
||||
|
|
|
@ -31,6 +31,8 @@ from synapse.crypto.event_signing import (
|
|||
)
|
||||
from synapse.types import UserID
|
||||
|
||||
from synapse.events.utils import prune_event
|
||||
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -138,25 +140,28 @@ class FederationHandler(BaseHandler):
|
|||
if state and auth_chain is not None:
|
||||
# If we have any state or auth_chain given to us by the replication
|
||||
# layer, then we should handle them (if we haven't before.)
|
||||
|
||||
event_infos = []
|
||||
|
||||
for e in itertools.chain(auth_chain, state):
|
||||
if e.event_id in seen_ids:
|
||||
continue
|
||||
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
auth_ids = [e_id for e_id, _ in e.auth_events]
|
||||
auth = {
|
||||
(e.type, e.state_key): e for e in auth_chain
|
||||
if e.event_id in auth_ids
|
||||
}
|
||||
yield self._handle_new_event(
|
||||
origin, e, auth_events=auth
|
||||
)
|
||||
event_infos.append({
|
||||
"event": e,
|
||||
"auth_events": auth,
|
||||
})
|
||||
seen_ids.add(e.event_id)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to handle state event %s",
|
||||
e.event_id,
|
||||
|
||||
yield self._handle_new_events(
|
||||
origin,
|
||||
event_infos,
|
||||
outliers=True
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -222,6 +227,56 @@ class FederationHandler(BaseHandler):
|
|||
"user_joined_room", user=user, room_id=event.room_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _filter_events_for_server(self, server_name, room_id, events):
|
||||
states = yield self.store.get_state_for_events(
|
||||
room_id, [e.event_id for e in events],
|
||||
)
|
||||
|
||||
events_and_states = zip(events, states)
|
||||
|
||||
def redact_disallowed(event_and_state):
|
||||
event, state = event_and_state
|
||||
|
||||
if not state:
|
||||
return event
|
||||
|
||||
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||
if history:
|
||||
visibility = history.content.get("history_visibility", "shared")
|
||||
if visibility in ["invited", "joined"]:
|
||||
# We now loop through all state events looking for
|
||||
# membership states for the requesting server to determine
|
||||
# if the server is either in the room or has been invited
|
||||
# into the room.
|
||||
for ev in state.values():
|
||||
if ev.type != EventTypes.Member:
|
||||
continue
|
||||
try:
|
||||
domain = UserID.from_string(ev.state_key).domain
|
||||
except:
|
||||
continue
|
||||
|
||||
if domain != server_name:
|
||||
continue
|
||||
|
||||
memtype = ev.membership
|
||||
if memtype == Membership.JOIN:
|
||||
return event
|
||||
elif memtype == Membership.INVITE:
|
||||
if visibility == "invited":
|
||||
return event
|
||||
else:
|
||||
return prune_event(event)
|
||||
|
||||
return event
|
||||
|
||||
res = map(redact_disallowed, events_and_states)
|
||||
|
||||
logger.info("_filter_events_for_server %r", res)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def backfill(self, dest, room_id, limit, extremities=[]):
|
||||
|
@ -247,9 +302,15 @@ class FederationHandler(BaseHandler):
|
|||
if set(e_id for e_id, _ in ev.prev_events) - event_ids
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"backfill: Got %d events with %d edges",
|
||||
len(events), len(edges),
|
||||
)
|
||||
|
||||
# For each edge get the current state.
|
||||
|
||||
auth_events = {}
|
||||
state_events = {}
|
||||
events_to_state = {}
|
||||
for e_id in edges:
|
||||
state, auth = yield self.replication_layer.get_state_for_room(
|
||||
|
@ -258,27 +319,57 @@ class FederationHandler(BaseHandler):
|
|||
event_id=e_id
|
||||
)
|
||||
auth_events.update({a.event_id: a for a in auth})
|
||||
auth_events.update({s.event_id: s for s in state})
|
||||
state_events.update({s.event_id: s for s in state})
|
||||
events_to_state[e_id] = state
|
||||
|
||||
yield defer.gatherResults(
|
||||
[
|
||||
self._handle_new_event(dest, a)
|
||||
for a in auth_events.values()
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
yield defer.gatherResults(
|
||||
[
|
||||
self._handle_new_event(
|
||||
dest, event_map[e_id],
|
||||
state=events_to_state[e_id],
|
||||
backfilled=True,
|
||||
seen_events = yield self.store.have_events(
|
||||
set(auth_events.keys()) | set(state_events.keys())
|
||||
)
|
||||
for e_id in events_to_state
|
||||
|
||||
all_events = events + state_events.values() + auth_events.values()
|
||||
required_auth = set(
|
||||
a_id for event in all_events for a_id, _ in event.auth_events
|
||||
)
|
||||
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
results = yield defer.gatherResults(
|
||||
[
|
||||
self.replication_layer.get_pdu(
|
||||
[dest],
|
||||
event_id,
|
||||
outlier=True,
|
||||
timeout=10000,
|
||||
)
|
||||
for event_id in missing_auth
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
auth_events.update({a.event_id: a for a in results})
|
||||
|
||||
ev_infos = []
|
||||
for a in auth_events.values():
|
||||
if a.event_id in seen_events:
|
||||
continue
|
||||
ev_infos.append({
|
||||
"event": a,
|
||||
"auth_events": {
|
||||
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||
auth_events[a_id]
|
||||
for a_id, _ in a.auth_events
|
||||
}
|
||||
})
|
||||
|
||||
for e_id in events_to_state:
|
||||
ev_infos.append({
|
||||
"event": event_map[e_id],
|
||||
"state": events_to_state[e_id],
|
||||
"auth_events": {
|
||||
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||
auth_events[a_id]
|
||||
for a_id, _ in event_map[e_id].auth_events
|
||||
}
|
||||
})
|
||||
|
||||
events.sort(key=lambda e: e.depth)
|
||||
|
||||
|
@ -286,8 +377,12 @@ class FederationHandler(BaseHandler):
|
|||
if event in events_to_state:
|
||||
continue
|
||||
|
||||
yield self._handle_new_event(
|
||||
dest, event,
|
||||
ev_infos.append({
|
||||
"event": event,
|
||||
})
|
||||
|
||||
yield self._handle_new_events(
|
||||
dest, ev_infos,
|
||||
backfilled=True,
|
||||
)
|
||||
|
||||
|
@ -555,32 +650,22 @@ class FederationHandler(BaseHandler):
|
|||
# FIXME
|
||||
pass
|
||||
|
||||
yield self._handle_auth_events(
|
||||
origin, [e for e in auth_chain if e.event_id != event.event_id]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_state(e):
|
||||
ev_infos = []
|
||||
for e in itertools.chain(state, auth_chain):
|
||||
if e.event_id == event.event_id:
|
||||
return
|
||||
continue
|
||||
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
auth_ids = [e_id for e_id, _ in e.auth_events]
|
||||
auth = {
|
||||
ev_infos.append({
|
||||
"event": e,
|
||||
"auth_events": {
|
||||
(e.type, e.state_key): e for e in auth_chain
|
||||
if e.event_id in auth_ids
|
||||
}
|
||||
yield self._handle_new_event(
|
||||
origin, e, auth_events=auth
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to handle state event %s",
|
||||
e.event_id,
|
||||
)
|
||||
})
|
||||
|
||||
yield defer.DeferredList([handle_state(e) for e in state])
|
||||
yield self._handle_new_events(origin, ev_infos, outliers=True)
|
||||
|
||||
auth_ids = [e_id for e_id, _ in event.auth_events]
|
||||
auth_events = {
|
||||
|
@ -837,6 +922,8 @@ class FederationHandler(BaseHandler):
|
|||
limit
|
||||
)
|
||||
|
||||
events = yield self._filter_events_for_server(origin, room_id, events)
|
||||
|
||||
defer.returnValue(events)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -895,25 +982,63 @@ class FederationHandler(BaseHandler):
|
|||
def _handle_new_event(self, origin, event, state=None, backfilled=False,
|
||||
current_state=None, auth_events=None):
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
outlier = event.internal_metadata.is_outlier()
|
||||
|
||||
context = yield self._prep_event(
|
||||
origin, event,
|
||||
state=state,
|
||||
backfilled=backfilled,
|
||||
current_state=current_state,
|
||||
auth_events=auth_events,
|
||||
)
|
||||
|
||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||
event,
|
||||
context=context,
|
||||
backfilled=backfilled,
|
||||
is_new_state=(not outlier and not backfilled),
|
||||
current_state=current_state,
|
||||
)
|
||||
|
||||
defer.returnValue((context, event_stream_id, max_stream_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_new_events(self, origin, event_infos, backfilled=False,
|
||||
outliers=False):
|
||||
contexts = yield defer.gatherResults(
|
||||
[
|
||||
self._prep_event(
|
||||
origin,
|
||||
ev_info["event"],
|
||||
state=ev_info.get("state"),
|
||||
backfilled=backfilled,
|
||||
auth_events=ev_info.get("auth_events"),
|
||||
)
|
||||
for ev_info in event_infos
|
||||
]
|
||||
)
|
||||
|
||||
yield self.store.persist_events(
|
||||
[
|
||||
(ev_info["event"], context)
|
||||
for ev_info, context in itertools.izip(event_infos, contexts)
|
||||
],
|
||||
backfilled=backfilled,
|
||||
is_new_state=(not outliers and not backfilled),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _prep_event(self, origin, event, state=None, backfilled=False,
|
||||
current_state=None, auth_events=None):
|
||||
outlier = event.internal_metadata.is_outlier()
|
||||
|
||||
context = yield self.state_handler.compute_event_context(
|
||||
event, old_state=state
|
||||
event, old_state=state, outlier=outlier,
|
||||
)
|
||||
|
||||
if not auth_events:
|
||||
auth_events = context.current_state
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: %s, auth_events: %s",
|
||||
event.event_id, auth_events,
|
||||
)
|
||||
|
||||
is_new_state = not event.internal_metadata.is_outlier()
|
||||
|
||||
# This is a hack to fix some old rooms where the initial join event
|
||||
# didn't reference the create event in its auth events.
|
||||
if event.type == EventTypes.Member and not event.auth_events:
|
||||
|
@ -937,26 +1062,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
context.rejected = RejectedReason.AUTH_ERROR
|
||||
|
||||
# FIXME: Don't store as rejected with AUTH_ERROR if we haven't
|
||||
# seen all the auth events.
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
context=context,
|
||||
backfilled=backfilled,
|
||||
is_new_state=False,
|
||||
current_state=current_state,
|
||||
)
|
||||
raise
|
||||
|
||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||
event,
|
||||
context=context,
|
||||
backfilled=backfilled,
|
||||
is_new_state=(is_new_state and not backfilled),
|
||||
current_state=current_state,
|
||||
)
|
||||
|
||||
defer.returnValue((context, event_stream_id, max_stream_id))
|
||||
defer.returnValue(context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
|
||||
|
@ -1019,14 +1125,24 @@ class FederationHandler(BaseHandler):
|
|||
@log_function
|
||||
def do_auth(self, origin, event, context, auth_events):
|
||||
# Check if we have all the auth events.
|
||||
have_events = yield self.store.have_events(
|
||||
[e_id for e_id, _ in event.auth_events]
|
||||
)
|
||||
|
||||
current_state = set(e.event_id for e in auth_events.values())
|
||||
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
||||
|
||||
if event_auth_events - current_state:
|
||||
have_events = yield self.store.have_events(
|
||||
event_auth_events - current_state
|
||||
)
|
||||
else:
|
||||
have_events = {}
|
||||
|
||||
have_events.update({
|
||||
e.event_id: ""
|
||||
for e in auth_events.values()
|
||||
})
|
||||
|
||||
seen_events = set(have_events.keys())
|
||||
|
||||
missing_auth = event_auth_events - seen_events
|
||||
missing_auth = event_auth_events - seen_events - current_state
|
||||
|
||||
if missing_auth:
|
||||
logger.info("Missing auth: %s", missing_auth)
|
||||
|
|
|
@ -44,7 +44,7 @@ class IdentityHandler(BaseHandler):
|
|||
http_client = SimpleHttpClient(self.hs)
|
||||
# XXX: make this configurable!
|
||||
# trustedIdServers = ['matrix.org', 'localhost:8090']
|
||||
trustedIdServers = ['matrix.org']
|
||||
trustedIdServers = ['matrix.org', 'vector.im']
|
||||
|
||||
if 'id_server' in creds:
|
||||
id_server = creds['id_server']
|
||||
|
|
|
@ -113,11 +113,21 @@ class MessageHandler(BaseHandler):
|
|||
"room_key", next_key
|
||||
)
|
||||
|
||||
if not events:
|
||||
defer.returnValue({
|
||||
"chunk": [],
|
||||
"start": pagin_config.from_token.to_string(),
|
||||
"end": next_token.to_string(),
|
||||
})
|
||||
|
||||
events = yield self._filter_events_for_client(user_id, room_id, events)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
chunk = {
|
||||
"chunk": [
|
||||
serialize_event(e, time_now, as_client_event) for e in events
|
||||
serialize_event(e, time_now, as_client_event)
|
||||
for e in events
|
||||
],
|
||||
"start": pagin_config.from_token.to_string(),
|
||||
"end": next_token.to_string(),
|
||||
|
@ -125,6 +135,52 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
defer.returnValue(chunk)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _filter_events_for_client(self, user_id, room_id, events):
|
||||
states = yield self.store.get_state_for_events(
|
||||
room_id, [e.event_id for e in events],
|
||||
)
|
||||
|
||||
events_and_states = zip(events, states)
|
||||
|
||||
def allowed(event_and_state):
|
||||
event, state = event_and_state
|
||||
|
||||
if event.type == EventTypes.RoomHistoryVisibility:
|
||||
return True
|
||||
|
||||
membership_ev = state.get((EventTypes.Member, user_id), None)
|
||||
if membership_ev:
|
||||
membership = membership_ev.membership
|
||||
else:
|
||||
membership = Membership.LEAVE
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
return True
|
||||
|
||||
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||
if history:
|
||||
visibility = history.content.get("history_visibility", "shared")
|
||||
else:
|
||||
visibility = "shared"
|
||||
|
||||
if visibility == "public":
|
||||
return True
|
||||
elif visibility == "shared":
|
||||
return True
|
||||
elif visibility == "joined":
|
||||
return membership == Membership.JOIN
|
||||
elif visibility == "invited":
|
||||
return membership == Membership.INVITE
|
||||
|
||||
return True
|
||||
|
||||
events_and_states = filter(allowed, events_and_states)
|
||||
defer.returnValue([
|
||||
ev
|
||||
for ev, _ in events_and_states
|
||||
])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_and_send_event(self, event_dict, ratelimit=True,
|
||||
client=None, txn_id=None):
|
||||
|
@ -278,6 +334,11 @@ class MessageHandler(BaseHandler):
|
|||
user, pagination_config.get_source_config("presence"), None
|
||||
)
|
||||
|
||||
receipt_stream = self.hs.get_event_sources().sources["receipt"]
|
||||
receipt, _ = yield receipt_stream.get_pagination_rows(
|
||||
user, pagination_config.get_source_config("receipt"), None
|
||||
)
|
||||
|
||||
public_room_ids = yield self.store.get_public_room_ids()
|
||||
|
||||
limit = pagin_config.limit
|
||||
|
@ -316,6 +377,10 @@ class MessageHandler(BaseHandler):
|
|||
]
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
messages = yield self._filter_events_for_client(
|
||||
user_id, event.room_id, messages
|
||||
)
|
||||
|
||||
start_token = now_token.copy_and_replace("room_key", token[0])
|
||||
end_token = now_token.copy_and_replace("room_key", token[1])
|
||||
time_now = self.clock.time_msec()
|
||||
|
@ -344,7 +409,8 @@ class MessageHandler(BaseHandler):
|
|||
ret = {
|
||||
"rooms": rooms_ret,
|
||||
"presence": presence,
|
||||
"end": now_token.to_string()
|
||||
"receipts": receipt,
|
||||
"end": now_token.to_string(),
|
||||
}
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
@ -380,15 +446,6 @@ class MessageHandler(BaseHandler):
|
|||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
messages, token = yield self.store.get_recent_events_for_room(
|
||||
room_id,
|
||||
limit=limit,
|
||||
end_token=now_token.room_key,
|
||||
)
|
||||
|
||||
start_token = now_token.copy_and_replace("room_key", token[0])
|
||||
end_token = now_token.copy_and_replace("room_key", token[1])
|
||||
|
||||
room_members = [
|
||||
m for m in current_state.values()
|
||||
if m.type == EventTypes.Member
|
||||
|
@ -396,20 +453,46 @@ class MessageHandler(BaseHandler):
|
|||
]
|
||||
|
||||
presence_handler = self.hs.get_handlers().presence_handler
|
||||
presence = []
|
||||
for m in room_members:
|
||||
try:
|
||||
member_presence = yield presence_handler.get_state(
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_presence():
|
||||
presence_defs = yield defer.DeferredList(
|
||||
[
|
||||
presence_handler.get_state(
|
||||
target_user=UserID.from_string(m.user_id),
|
||||
auth_user=auth_user,
|
||||
as_event=True,
|
||||
check_auth=False,
|
||||
)
|
||||
presence.append(member_presence)
|
||||
except SynapseError:
|
||||
logger.exception(
|
||||
"Failed to get member presence of %r", m.user_id
|
||||
for m in room_members
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
|
||||
defer.returnValue([p for success, p in presence_defs if success])
|
||||
|
||||
receipts_handler = self.hs.get_handlers().receipts_handler
|
||||
|
||||
presence, receipts, (messages, token) = yield defer.gatherResults(
|
||||
[
|
||||
get_presence(),
|
||||
receipts_handler.get_receipts_for_room(room_id, now_token.receipt_key),
|
||||
self.store.get_recent_events_for_room(
|
||||
room_id,
|
||||
limit=limit,
|
||||
end_token=now_token.room_key,
|
||||
)
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
messages = yield self._filter_events_for_client(
|
||||
user_id, room_id, messages
|
||||
)
|
||||
|
||||
start_token = now_token.copy_and_replace("room_key", token[0])
|
||||
end_token = now_token.copy_and_replace("room_key", token[1])
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
defer.returnValue({
|
||||
|
@ -421,5 +504,6 @@ class MessageHandler(BaseHandler):
|
|||
"end": end_token.to_string(),
|
||||
},
|
||||
"state": state,
|
||||
"presence": presence
|
||||
"presence": presence,
|
||||
"receipts": receipts,
|
||||
})
|
||||
|
|
|
@ -191,8 +191,9 @@ class PresenceHandler(BaseHandler):
|
|||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state(self, target_user, auth_user, as_event=False):
|
||||
def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
|
||||
if self.hs.is_mine(target_user):
|
||||
if check_auth:
|
||||
visible = yield self.is_presence_visible(
|
||||
observer_user=auth_user,
|
||||
observed_user=target_user
|
||||
|
@ -200,15 +201,14 @@ class PresenceHandler(BaseHandler):
|
|||
|
||||
if not visible:
|
||||
raise SynapseError(404, "Presence information not visible")
|
||||
|
||||
if target_user in self._user_cachemap:
|
||||
state = self._user_cachemap[target_user].get_state()
|
||||
else:
|
||||
state = yield self.store.get_presence_state(target_user.localpart)
|
||||
if "mtime" in state:
|
||||
del state["mtime"]
|
||||
state["presence"] = state.pop("state")
|
||||
|
||||
if target_user in self._user_cachemap:
|
||||
cached_state = self._user_cachemap[target_user].get_state()
|
||||
if "last_active" in cached_state:
|
||||
state["last_active"] = cached_state["last_active"]
|
||||
else:
|
||||
# TODO(paul): Have remote server send us permissions set
|
||||
state = self._get_or_offline_usercache(target_user).get_state()
|
||||
|
@ -992,7 +992,7 @@ class PresenceHandler(BaseHandler):
|
|||
room_ids([str]): List of room_ids to notify.
|
||||
"""
|
||||
with PreserveLoggingContext():
|
||||
self.notifier.on_new_user_event(
|
||||
self.notifier.on_new_event(
|
||||
"presence_key",
|
||||
self._user_cachemap_latest_serial,
|
||||
users_to_push,
|
||||
|
|
|
@ -0,0 +1,212 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReceiptsHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(ReceiptsHandler, self).__init__(hs)
|
||||
|
||||
self.hs = hs
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.federation.register_edu_handler(
|
||||
"m.receipt", self._received_remote_receipt
|
||||
)
|
||||
self.clock = self.hs.get_clock()
|
||||
|
||||
self._receipt_cache = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def received_client_receipt(self, room_id, receipt_type, user_id,
|
||||
event_id):
|
||||
"""Called when a client tells us a local user has read up to the given
|
||||
event_id in the room.
|
||||
"""
|
||||
receipt = {
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
"event_ids": [event_id],
|
||||
"data": {
|
||||
"ts": int(self.clock.time_msec()),
|
||||
}
|
||||
}
|
||||
|
||||
is_new = yield self._handle_new_receipts([receipt])
|
||||
|
||||
if is_new:
|
||||
self._push_remotes([receipt])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _received_remote_receipt(self, origin, content):
|
||||
"""Called when we receive an EDU of type m.receipt from a remote HS.
|
||||
"""
|
||||
receipts = [
|
||||
{
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
"event_ids": user_values["event_ids"],
|
||||
"data": user_values.get("data", {}),
|
||||
}
|
||||
for room_id, room_values in content.items()
|
||||
for receipt_type, users in room_values.items()
|
||||
for user_id, user_values in users.items()
|
||||
]
|
||||
|
||||
yield self._handle_new_receipts(receipts)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_new_receipts(self, receipts):
|
||||
"""Takes a list of receipts, stores them and informs the notifier.
|
||||
"""
|
||||
for receipt in receipts:
|
||||
room_id = receipt["room_id"]
|
||||
receipt_type = receipt["receipt_type"]
|
||||
user_id = receipt["user_id"]
|
||||
event_ids = receipt["event_ids"]
|
||||
data = receipt["data"]
|
||||
|
||||
res = yield self.store.insert_receipt(
|
||||
room_id, receipt_type, user_id, event_ids, data
|
||||
)
|
||||
|
||||
if not res:
|
||||
# res will be None if this read receipt is 'old'
|
||||
defer.returnValue(False)
|
||||
|
||||
stream_id, max_persisted_id = res
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self.notifier.on_new_event(
|
||||
"receipt_key", max_persisted_id, rooms=[room_id]
|
||||
)
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _push_remotes(self, receipts):
|
||||
"""Given a list of receipts, works out which remote servers should be
|
||||
poked and pokes them.
|
||||
"""
|
||||
# TODO: Some of this stuff should be coallesced.
|
||||
for receipt in receipts:
|
||||
room_id = receipt["room_id"]
|
||||
receipt_type = receipt["receipt_type"]
|
||||
user_id = receipt["user_id"]
|
||||
event_ids = receipt["event_ids"]
|
||||
data = receipt["data"]
|
||||
|
||||
remotedomains = set()
|
||||
|
||||
rm_handler = self.hs.get_handlers().room_member_handler
|
||||
yield rm_handler.fetch_room_distributions_into(
|
||||
room_id, localusers=None, remotedomains=remotedomains
|
||||
)
|
||||
|
||||
logger.debug("Sending receipt to: %r", remotedomains)
|
||||
|
||||
for domain in remotedomains:
|
||||
self.federation.send_edu(
|
||||
destination=domain,
|
||||
edu_type="m.receipt",
|
||||
content={
|
||||
room_id: {
|
||||
receipt_type: {
|
||||
user_id: {
|
||||
"event_ids": event_ids,
|
||||
"data": data,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_receipts_for_room(self, room_id, to_key):
|
||||
"""Gets all receipts for a room, upto the given key.
|
||||
"""
|
||||
result = yield self.store.get_linearized_receipts_for_room(
|
||||
room_id,
|
||||
to_key=to_key,
|
||||
)
|
||||
|
||||
if not result:
|
||||
defer.returnValue([])
|
||||
|
||||
event = {
|
||||
"type": "m.receipt",
|
||||
"room_id": room_id,
|
||||
"content": result,
|
||||
}
|
||||
|
||||
defer.returnValue([event])
|
||||
|
||||
|
||||
class ReceiptEventSource(object):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_new_events_for_user(self, user, from_key, limit):
|
||||
defer.returnValue(([], from_key))
|
||||
from_key = int(from_key)
|
||||
to_key = yield self.get_current_key()
|
||||
|
||||
if from_key == to_key:
|
||||
defer.returnValue(([], to_key))
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(user.to_string())
|
||||
rooms = [room.room_id for room in rooms]
|
||||
events = yield self.store.get_linearized_receipts_for_rooms(
|
||||
rooms,
|
||||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
)
|
||||
|
||||
defer.returnValue((events, to_key))
|
||||
|
||||
def get_current_key(self, direction='f'):
|
||||
return self.store.get_max_receipt_stream_id()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_pagination_rows(self, user, config, key):
|
||||
to_key = int(config.from_key)
|
||||
defer.returnValue(([], to_key))
|
||||
|
||||
if config.to_key:
|
||||
from_key = int(config.to_key)
|
||||
else:
|
||||
from_key = None
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(user.to_string())
|
||||
rooms = [room.room_id for room in rooms]
|
||||
events = yield self.store.get_linearized_receipts_for_rooms(
|
||||
rooms,
|
||||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
)
|
||||
|
||||
defer.returnValue((events, to_key))
|
|
@ -73,7 +73,8 @@ class RegistrationHandler(BaseHandler):
|
|||
localpart : The local part of the user ID to register. If None,
|
||||
one will be randomly generated.
|
||||
password (str) : The password to assign to this user so they can
|
||||
login again.
|
||||
login again. This can be None which means they cannot login again
|
||||
via a password (e.g. the user is an application service user).
|
||||
Returns:
|
||||
A tuple of (user_id, access_token).
|
||||
Raises:
|
||||
|
@ -192,6 +193,35 @@ class RegistrationHandler(BaseHandler):
|
|||
else:
|
||||
logger.info("Valid captcha entered from %s", ip)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register_saml2(self, localpart):
|
||||
"""
|
||||
Registers email_id as SAML2 Based Auth.
|
||||
"""
|
||||
if urllib.quote(localpart) != localpart:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"User ID must only contain characters which do not"
|
||||
" require URL encoding."
|
||||
)
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
yield self.check_user_id_is_valid(user_id)
|
||||
token = self._generate_token(user_id)
|
||||
try:
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=None
|
||||
)
|
||||
yield self.distributor.fire("registered_user", user)
|
||||
except Exception, e:
|
||||
yield self.store.add_access_token_to_user(user_id, token)
|
||||
# Ignore Registration errors
|
||||
logger.exception(e)
|
||||
defer.returnValue((user_id, token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register_email(self, threepidCreds):
|
||||
"""
|
||||
|
|
|
@ -19,12 +19,15 @@ from twisted.internet import defer
|
|||
from ._base import BaseHandler
|
||||
|
||||
from synapse.types import UserID, RoomAlias, RoomID
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.constants import (
|
||||
EventTypes, Membership, JoinRules, RoomCreationPreset,
|
||||
)
|
||||
from synapse.api.errors import StoreError, SynapseError
|
||||
from synapse.util import stringutils, unwrapFirstError
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.events.utils import serialize_event
|
||||
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
import string
|
||||
|
||||
|
@ -33,6 +36,19 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class RoomCreationHandler(BaseHandler):
|
||||
|
||||
PRESETS_DICT = {
|
||||
RoomCreationPreset.PRIVATE_CHAT: {
|
||||
"join_rules": JoinRules.INVITE,
|
||||
"history_visibility": "invited",
|
||||
"original_invitees_have_ops": False,
|
||||
},
|
||||
RoomCreationPreset.PUBLIC_CHAT: {
|
||||
"join_rules": JoinRules.PUBLIC,
|
||||
"history_visibility": "shared",
|
||||
"original_invitees_have_ops": False,
|
||||
},
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_room(self, user_id, room_id, config):
|
||||
""" Creates a new room.
|
||||
|
@ -121,9 +137,25 @@ class RoomCreationHandler(BaseHandler):
|
|||
servers=[self.hs.hostname],
|
||||
)
|
||||
|
||||
preset_config = config.get(
|
||||
"preset",
|
||||
RoomCreationPreset.PUBLIC_CHAT
|
||||
if is_public
|
||||
else RoomCreationPreset.PRIVATE_CHAT
|
||||
)
|
||||
|
||||
raw_initial_state = config.get("initial_state", [])
|
||||
|
||||
initial_state = OrderedDict()
|
||||
for val in raw_initial_state:
|
||||
initial_state[(val["type"], val.get("state_key", ""))] = val["content"]
|
||||
|
||||
user = UserID.from_string(user_id)
|
||||
creation_events = self._create_events_for_new_room(
|
||||
user, room_id, is_public=is_public
|
||||
user, room_id,
|
||||
preset_config=preset_config,
|
||||
invite_list=invite_list,
|
||||
initial_state=initial_state,
|
||||
)
|
||||
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
|
@ -170,7 +202,10 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
defer.returnValue(result)
|
||||
|
||||
def _create_events_for_new_room(self, creator, room_id, is_public=False):
|
||||
def _create_events_for_new_room(self, creator, room_id, preset_config,
|
||||
invite_list, initial_state):
|
||||
config = RoomCreationHandler.PRESETS_DICT[preset_config]
|
||||
|
||||
creator_id = creator.to_string()
|
||||
|
||||
event_keys = {
|
||||
|
@ -203,9 +238,10 @@ class RoomCreationHandler(BaseHandler):
|
|||
},
|
||||
)
|
||||
|
||||
power_levels_event = create(
|
||||
etype=EventTypes.PowerLevels,
|
||||
content={
|
||||
returned_events = [creation_event, join_event]
|
||||
|
||||
if (EventTypes.PowerLevels, '') not in initial_state:
|
||||
power_level_content = {
|
||||
"users": {
|
||||
creator.to_string(): 100,
|
||||
},
|
||||
|
@ -213,6 +249,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
"events": {
|
||||
EventTypes.Name: 100,
|
||||
EventTypes.PowerLevels: 100,
|
||||
EventTypes.RoomHistoryVisibility: 100,
|
||||
},
|
||||
"events_default": 0,
|
||||
"state_default": 50,
|
||||
|
@ -220,21 +257,43 @@ class RoomCreationHandler(BaseHandler):
|
|||
"kick": 50,
|
||||
"redact": 50,
|
||||
"invite": 0,
|
||||
},
|
||||
}
|
||||
|
||||
if config["original_invitees_have_ops"]:
|
||||
for invitee in invite_list:
|
||||
power_level_content["users"][invitee] = 100
|
||||
|
||||
power_levels_event = create(
|
||||
etype=EventTypes.PowerLevels,
|
||||
content=power_level_content,
|
||||
)
|
||||
|
||||
join_rule = JoinRules.PUBLIC if is_public else JoinRules.INVITE
|
||||
returned_events.append(power_levels_event)
|
||||
|
||||
if (EventTypes.JoinRules, '') not in initial_state:
|
||||
join_rules_event = create(
|
||||
etype=EventTypes.JoinRules,
|
||||
content={"join_rule": join_rule},
|
||||
content={"join_rule": config["join_rules"]},
|
||||
)
|
||||
|
||||
return [
|
||||
creation_event,
|
||||
join_event,
|
||||
power_levels_event,
|
||||
join_rules_event,
|
||||
]
|
||||
returned_events.append(join_rules_event)
|
||||
|
||||
if (EventTypes.RoomHistoryVisibility, '') not in initial_state:
|
||||
history_event = create(
|
||||
etype=EventTypes.RoomHistoryVisibility,
|
||||
content={"history_visibility": config["history_visibility"]}
|
||||
)
|
||||
|
||||
returned_events.append(history_event)
|
||||
|
||||
for (etype, state_key), content in initial_state.items():
|
||||
returned_events.append(create(
|
||||
etype=etype,
|
||||
state_key=state_key,
|
||||
content=content,
|
||||
))
|
||||
|
||||
return returned_events
|
||||
|
||||
|
||||
class RoomMemberHandler(BaseHandler):
|
||||
|
|
|
@ -292,6 +292,51 @@ class SyncHandler(BaseHandler):
|
|||
next_batch=now_token,
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _filter_events_for_client(self, user_id, room_id, events):
|
||||
states = yield self.store.get_state_for_events(
|
||||
room_id, [e.event_id for e in events],
|
||||
)
|
||||
|
||||
events_and_states = zip(events, states)
|
||||
|
||||
def allowed(event_and_state):
|
||||
event, state = event_and_state
|
||||
|
||||
if event.type == EventTypes.RoomHistoryVisibility:
|
||||
return True
|
||||
|
||||
membership_ev = state.get((EventTypes.Member, user_id), None)
|
||||
if membership_ev:
|
||||
membership = membership_ev.membership
|
||||
else:
|
||||
membership = Membership.LEAVE
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
return True
|
||||
|
||||
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||
if history:
|
||||
visibility = history.content.get("history_visibility", "shared")
|
||||
else:
|
||||
visibility = "shared"
|
||||
|
||||
if visibility == "public":
|
||||
return True
|
||||
elif visibility == "shared":
|
||||
return True
|
||||
elif visibility == "joined":
|
||||
return membership == Membership.JOIN
|
||||
elif visibility == "invited":
|
||||
return membership == Membership.INVITE
|
||||
|
||||
return True
|
||||
events_and_states = filter(allowed, events_and_states)
|
||||
defer.returnValue([
|
||||
ev
|
||||
for ev, _ in events_and_states
|
||||
])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def load_filtered_recents(self, room_id, sync_config, now_token,
|
||||
since_token=None):
|
||||
|
@ -313,6 +358,9 @@ class SyncHandler(BaseHandler):
|
|||
(room_key, _) = keys
|
||||
end_key = "s" + room_key.split('-')[-1]
|
||||
loaded_recents = sync_config.filter.filter_room_events(events)
|
||||
loaded_recents = yield self._filter_events_for_client(
|
||||
sync_config.user.to_string(), room_id, loaded_recents,
|
||||
)
|
||||
loaded_recents.extend(recents)
|
||||
recents = loaded_recents
|
||||
if len(events) <= load_limit:
|
||||
|
|
|
@ -218,7 +218,7 @@ class TypingNotificationHandler(BaseHandler):
|
|||
self._room_serials[room_id] = self._latest_room_serial
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self.notifier.on_new_user_event(
|
||||
self.notifier.on_new_event(
|
||||
"typing_key", self._latest_room_serial, rooms=[room_id]
|
||||
)
|
||||
|
||||
|
|
|
@ -61,21 +61,31 @@ class SimpleHttpClient(object):
|
|||
self.agent = Agent(reactor, pool=pool)
|
||||
self.version_string = hs.version_string
|
||||
|
||||
def request(self, method, *args, **kwargs):
|
||||
def request(self, method, uri, *args, **kwargs):
|
||||
# A small wrapper around self.agent.request() so we can easily attach
|
||||
# counters to it
|
||||
outgoing_requests_counter.inc(method)
|
||||
d = preserve_context_over_fn(
|
||||
self.agent.request,
|
||||
method, *args, **kwargs
|
||||
method, uri, *args, **kwargs
|
||||
)
|
||||
|
||||
logger.info("Sending request %s %s", method, uri)
|
||||
|
||||
def _cb(response):
|
||||
incoming_responses_counter.inc(method, response.code)
|
||||
logger.info(
|
||||
"Received response to %s %s: %s",
|
||||
method, uri, response.code
|
||||
)
|
||||
return response
|
||||
|
||||
def _eb(failure):
|
||||
incoming_responses_counter.inc(method, "ERR")
|
||||
logger.info(
|
||||
"Error sending request to %s %s: %s %s",
|
||||
method, uri, failure.type, failure.getErrorMessage()
|
||||
)
|
||||
return failure
|
||||
|
||||
d.addCallbacks(_cb, _eb)
|
||||
|
@ -84,7 +94,9 @@ class SimpleHttpClient(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def post_urlencoded_get_json(self, uri, args={}):
|
||||
# TODO: Do we ever want to log message contents?
|
||||
logger.debug("post_urlencoded_get_json args: %s", args)
|
||||
|
||||
query_bytes = urllib.urlencode(args, True)
|
||||
|
||||
response = yield self.request(
|
||||
|
@ -97,7 +109,7 @@ class SimpleHttpClient(object):
|
|||
bodyProducer=FileBodyProducer(StringIO(query_bytes))
|
||||
)
|
||||
|
||||
body = yield readBody(response)
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
|
@ -105,7 +117,7 @@ class SimpleHttpClient(object):
|
|||
def post_json_get_json(self, uri, post_json):
|
||||
json_str = encode_canonical_json(post_json)
|
||||
|
||||
logger.info("HTTP POST %s -> %s", json_str, uri)
|
||||
logger.debug("HTTP POST %s -> %s", json_str, uri)
|
||||
|
||||
response = yield self.request(
|
||||
"POST",
|
||||
|
@ -116,7 +128,7 @@ class SimpleHttpClient(object):
|
|||
bodyProducer=FileBodyProducer(StringIO(json_str))
|
||||
)
|
||||
|
||||
body = yield readBody(response)
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
|
@ -149,7 +161,7 @@ class SimpleHttpClient(object):
|
|||
})
|
||||
)
|
||||
|
||||
body = yield readBody(response)
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
defer.returnValue(json.loads(body))
|
||||
|
@ -192,7 +204,7 @@ class SimpleHttpClient(object):
|
|||
bodyProducer=FileBodyProducer(StringIO(json_str))
|
||||
)
|
||||
|
||||
body = yield readBody(response)
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
defer.returnValue(json.loads(body))
|
||||
|
@ -226,7 +238,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
|
|||
)
|
||||
|
||||
try:
|
||||
body = yield readBody(response)
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
defer.returnValue(body)
|
||||
except PartialDownloadError as e:
|
||||
# twisted dislikes google's response, no content length.
|
||||
|
|
|
@ -35,11 +35,13 @@ from syutil.crypto.jsonsign import sign_json
|
|||
|
||||
import simplejson as json
|
||||
import logging
|
||||
import sys
|
||||
import urllib
|
||||
import urlparse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
outbound_logger = logging.getLogger("synapse.http.outbound")
|
||||
|
||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||
|
||||
|
@ -86,6 +88,7 @@ class MatrixFederationHttpClient(object):
|
|||
)
|
||||
self.clock = hs.get_clock()
|
||||
self.version_string = hs.version_string
|
||||
self._next_id = 1
|
||||
|
||||
def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
|
||||
return urlparse.urlunparse(
|
||||
|
@ -106,16 +109,12 @@ class MatrixFederationHttpClient(object):
|
|||
destination, path_bytes, param_bytes, query_bytes
|
||||
)
|
||||
|
||||
logger.info("Sending request to %s: %s %s",
|
||||
destination, method, url_bytes)
|
||||
txn_id = "%s-O-%s" % (method, self._next_id)
|
||||
self._next_id = (self._next_id + 1) % (sys.maxint - 1)
|
||||
|
||||
logger.debug(
|
||||
"Types: %s",
|
||||
[
|
||||
type(destination), type(method), type(path_bytes),
|
||||
type(param_bytes),
|
||||
type(query_bytes)
|
||||
]
|
||||
outbound_logger.info(
|
||||
"{%s} [%s] Sending request: %s %s",
|
||||
txn_id, destination, method, url_bytes
|
||||
)
|
||||
|
||||
# XXX: Would be much nicer to retry only at the transaction-layer
|
||||
|
@ -126,12 +125,15 @@ class MatrixFederationHttpClient(object):
|
|||
("", "", path_bytes, param_bytes, query_bytes, "")
|
||||
)
|
||||
|
||||
log_result = None
|
||||
try:
|
||||
while True:
|
||||
producer = None
|
||||
if body_callback:
|
||||
producer = body_callback(method, http_url_bytes, headers_dict)
|
||||
|
||||
try:
|
||||
def send_request():
|
||||
request_deferred = preserve_context_over_fn(
|
||||
self.agent.request,
|
||||
method,
|
||||
|
@ -140,12 +142,17 @@ class MatrixFederationHttpClient(object):
|
|||
producer
|
||||
)
|
||||
|
||||
response = yield self.clock.time_bound_deferred(
|
||||
|
||||
return self.clock.time_bound_deferred(
|
||||
request_deferred,
|
||||
time_out=timeout/1000. if timeout else 60,
|
||||
)
|
||||
|
||||
logger.debug("Got response to %s", method)
|
||||
response = yield preserve_context_over_fn(
|
||||
send_request,
|
||||
)
|
||||
|
||||
log_result = "%d %s" % (response.code, response.phrase,)
|
||||
break
|
||||
except Exception as e:
|
||||
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
|
||||
|
@ -154,10 +161,14 @@ class MatrixFederationHttpClient(object):
|
|||
destination,
|
||||
e
|
||||
)
|
||||
log_result = "DNS Lookup failed to %s with %s" % (
|
||||
destination, e
|
||||
)
|
||||
raise
|
||||
|
||||
logger.warn(
|
||||
"Sending request failed to %s: %s %s: %s - %s",
|
||||
"{%s} Sending request failed to %s: %s %s: %s - %s",
|
||||
txn_id,
|
||||
destination,
|
||||
method,
|
||||
url_bytes,
|
||||
|
@ -165,19 +176,21 @@ class MatrixFederationHttpClient(object):
|
|||
_flatten_response_never_received(e),
|
||||
)
|
||||
|
||||
log_result = "%s - %s" % (
|
||||
type(e).__name__, _flatten_response_never_received(e),
|
||||
)
|
||||
|
||||
if retries_left and not timeout:
|
||||
yield sleep(2 ** (5 - retries_left))
|
||||
retries_left -= 1
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
"Received response %d %s for %s: %s %s",
|
||||
response.code,
|
||||
response.phrase,
|
||||
finally:
|
||||
outbound_logger.info(
|
||||
"{%s} [%s] Result: %s",
|
||||
txn_id,
|
||||
destination,
|
||||
method,
|
||||
url_bytes
|
||||
log_result,
|
||||
)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
|
@ -185,7 +198,7 @@ class MatrixFederationHttpClient(object):
|
|||
else:
|
||||
# :'(
|
||||
# Update transactions table?
|
||||
body = yield readBody(response)
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
raise HttpResponseException(
|
||||
response.code, response.phrase, body
|
||||
)
|
||||
|
@ -265,10 +278,7 @@ class MatrixFederationHttpClient(object):
|
|||
"Content-Type not application/json"
|
||||
)
|
||||
|
||||
logger.debug("Getting resp body")
|
||||
body = yield readBody(response)
|
||||
logger.debug("Got resp body")
|
||||
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -311,9 +321,7 @@ class MatrixFederationHttpClient(object):
|
|||
"Content-Type not application/json"
|
||||
)
|
||||
|
||||
logger.debug("Getting resp body")
|
||||
body = yield readBody(response)
|
||||
logger.debug("Got resp body")
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
|
@ -371,9 +379,7 @@ class MatrixFederationHttpClient(object):
|
|||
"Content-Type not application/json"
|
||||
)
|
||||
|
||||
logger.debug("Getting resp body")
|
||||
body = yield readBody(response)
|
||||
logger.debug("Got resp body")
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
|
@ -416,7 +422,10 @@ class MatrixFederationHttpClient(object):
|
|||
headers = dict(response.headers.getAllRawHeaders())
|
||||
|
||||
try:
|
||||
length = yield _readBodyToFile(response, output_stream, max_size)
|
||||
length = yield preserve_context_over_fn(
|
||||
_readBodyToFile,
|
||||
response, output_stream, max_size
|
||||
)
|
||||
except:
|
||||
logger.exception("Failed to download body")
|
||||
raise
|
||||
|
|
|
@ -79,17 +79,11 @@ def request_handler(request_handler):
|
|||
_next_request_id += 1
|
||||
with LoggingContext(request_id) as request_context:
|
||||
request_context.request = request_id
|
||||
code = None
|
||||
start = self.clock.time_msec()
|
||||
with request.processing():
|
||||
try:
|
||||
logger.info(
|
||||
"Received request: %s %s",
|
||||
request.method, request.path
|
||||
)
|
||||
d = request_handler(self, request)
|
||||
with PreserveLoggingContext():
|
||||
yield d
|
||||
code = request.code
|
||||
except CodeMessageException as e:
|
||||
code = e.code
|
||||
if isinstance(e, SynapseError):
|
||||
|
@ -105,7 +99,6 @@ def request_handler(request_handler):
|
|||
version_string=self.version_string,
|
||||
)
|
||||
except:
|
||||
code = 500
|
||||
logger.exception(
|
||||
"Failed handle request %s.%s on %r: %r",
|
||||
request_handler.__module__,
|
||||
|
@ -119,13 +112,6 @@ def request_handler(request_handler):
|
|||
{"error": "Internal server error"},
|
||||
send_cors=True
|
||||
)
|
||||
finally:
|
||||
code = str(code) if code else "-"
|
||||
end = self.clock.time_msec()
|
||||
logger.info(
|
||||
"Processed request: %dms %s %s %s",
|
||||
end-start, code, request.method, request.path
|
||||
)
|
||||
return wrapped_request_handler
|
||||
|
||||
|
||||
|
@ -221,7 +207,7 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
incoming_requests_counter.inc(request.method, servlet_classname)
|
||||
|
||||
args = [
|
||||
urllib.unquote(u).decode("UTF-8") for u in m.groups()
|
||||
urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
|
||||
]
|
||||
|
||||
callback_return = yield callback(request, *args)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.async import run_on_reactor, ObservableDeferred
|
||||
from synapse.types import StreamToken
|
||||
import synapse.metrics
|
||||
|
||||
|
@ -45,21 +45,11 @@ class _NotificationListener(object):
|
|||
The events stream handler will have yielded to the deferred, so to
|
||||
notify the handler it is sufficient to resolve the deferred.
|
||||
"""
|
||||
__slots__ = ["deferred"]
|
||||
|
||||
def __init__(self, deferred):
|
||||
self.deferred = deferred
|
||||
|
||||
def notified(self):
|
||||
return self.deferred.called
|
||||
|
||||
def notify(self, token):
|
||||
""" Inform whoever is listening about the new events.
|
||||
"""
|
||||
try:
|
||||
self.deferred.callback(token)
|
||||
except defer.AlreadyCalledError:
|
||||
pass
|
||||
|
||||
|
||||
class _NotifierUserStream(object):
|
||||
"""This represents a user connected to the event stream.
|
||||
|
@ -75,11 +65,12 @@ class _NotifierUserStream(object):
|
|||
appservice=None):
|
||||
self.user = str(user)
|
||||
self.appservice = appservice
|
||||
self.listeners = set()
|
||||
self.rooms = set(rooms)
|
||||
self.current_token = current_token
|
||||
self.last_notified_ms = time_now_ms
|
||||
|
||||
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
||||
|
||||
def notify(self, stream_key, stream_id, time_now_ms):
|
||||
"""Notify any listeners for this user of a new event from an
|
||||
event source.
|
||||
|
@ -91,12 +82,10 @@ class _NotifierUserStream(object):
|
|||
self.current_token = self.current_token.copy_and_advance(
|
||||
stream_key, stream_id
|
||||
)
|
||||
if self.listeners:
|
||||
self.last_notified_ms = time_now_ms
|
||||
listeners = self.listeners
|
||||
self.listeners = set()
|
||||
for listener in listeners:
|
||||
listener.notify(self.current_token)
|
||||
noify_deferred = self.notify_deferred
|
||||
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
||||
noify_deferred.callback(self.current_token)
|
||||
|
||||
def remove(self, notifier):
|
||||
""" Remove this listener from all the indexes in the Notifier
|
||||
|
@ -114,6 +103,18 @@ class _NotifierUserStream(object):
|
|||
self.appservice, set()
|
||||
).discard(self)
|
||||
|
||||
def count_listeners(self):
|
||||
return len(self.notify_deferred.observers())
|
||||
|
||||
def new_listener(self, token):
|
||||
"""Returns a deferred that is resolved when there is a new token
|
||||
greater than the given token.
|
||||
"""
|
||||
if self.current_token.is_after(token):
|
||||
return _NotificationListener(defer.succeed(self.current_token))
|
||||
else:
|
||||
return _NotificationListener(self.notify_deferred.observe())
|
||||
|
||||
|
||||
class Notifier(object):
|
||||
""" This class is responsible for notifying any listeners when there are
|
||||
|
@ -158,7 +159,7 @@ class Notifier(object):
|
|||
for x in self.appservice_to_user_streams.values():
|
||||
all_user_streams |= x
|
||||
|
||||
return sum(len(stream.listeners) for stream in all_user_streams)
|
||||
return sum(stream.count_listeners() for stream in all_user_streams)
|
||||
metrics.register_callback("listeners", count_listeners)
|
||||
|
||||
metrics.register_callback(
|
||||
|
@ -220,16 +221,7 @@ class Notifier(object):
|
|||
event
|
||||
)
|
||||
|
||||
room_id = event.room_id
|
||||
|
||||
room_user_streams = self.room_to_user_streams.get(room_id, set())
|
||||
|
||||
user_streams = room_user_streams.copy()
|
||||
|
||||
for user in extra_users:
|
||||
user_stream = self.user_to_user_stream.get(str(user))
|
||||
if user_stream is not None:
|
||||
user_streams.add(user_stream)
|
||||
app_streams = set()
|
||||
|
||||
for appservice in self.appservice_to_user_streams:
|
||||
# TODO (kegan): Redundant appservice listener checks?
|
||||
|
@ -241,24 +233,20 @@ class Notifier(object):
|
|||
app_user_streams = self.appservice_to_user_streams.get(
|
||||
appservice, set()
|
||||
)
|
||||
user_streams |= app_user_streams
|
||||
app_streams |= app_user_streams
|
||||
|
||||
logger.debug("on_new_room_event listeners %s", user_streams)
|
||||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
for user_stream in user_streams:
|
||||
try:
|
||||
user_stream.notify(
|
||||
"room_key", "s%d" % (room_stream_id,), time_now_ms
|
||||
self.on_new_event(
|
||||
"room_key", room_stream_id,
|
||||
users=extra_users,
|
||||
rooms=[event.room_id],
|
||||
extra_streams=app_streams,
|
||||
)
|
||||
except:
|
||||
logger.exception("Failed to notify listener")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]):
|
||||
""" Used to inform listeners that something has happend
|
||||
presence/user event wise.
|
||||
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
|
||||
extra_streams=set()):
|
||||
""" Used to inform listeners that something has happend event wise.
|
||||
|
||||
Will wake up all listeners for the given users and rooms.
|
||||
"""
|
||||
|
@ -282,14 +270,10 @@ class Notifier(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_events(self, user, rooms, timeout, callback,
|
||||
from_token=StreamToken("s0", "0", "0")):
|
||||
from_token=StreamToken("s0", "0", "0", "0")):
|
||||
"""Wait until the callback returns a non empty response or the
|
||||
timeout fires.
|
||||
"""
|
||||
|
||||
deferred = defer.Deferred()
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
user = str(user)
|
||||
user_stream = self.user_to_user_stream.get(user)
|
||||
if user_stream is None:
|
||||
|
@ -302,55 +286,44 @@ class Notifier(object):
|
|||
rooms=rooms,
|
||||
appservice=appservice,
|
||||
current_token=current_token,
|
||||
time_now_ms=time_now_ms,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
)
|
||||
self._register_with_keys(user_stream)
|
||||
else:
|
||||
|
||||
result = None
|
||||
if timeout:
|
||||
# Will be set to a _NotificationListener that we'll be waiting on.
|
||||
# Allows us to cancel it.
|
||||
listener = None
|
||||
|
||||
def timed_out():
|
||||
if listener:
|
||||
listener.deferred.cancel()
|
||||
timer = self.clock.call_later(timeout/1000., timed_out)
|
||||
|
||||
prev_token = from_token
|
||||
while not result:
|
||||
try:
|
||||
current_token = user_stream.current_token
|
||||
|
||||
listener = [_NotificationListener(deferred)]
|
||||
|
||||
if timeout and not current_token.is_after(from_token):
|
||||
user_stream.listeners.add(listener[0])
|
||||
|
||||
if current_token.is_after(from_token):
|
||||
result = yield callback(from_token, current_token)
|
||||
else:
|
||||
result = None
|
||||
|
||||
timer = [None]
|
||||
|
||||
result = yield callback(prev_token, current_token)
|
||||
if result:
|
||||
user_stream.listeners.discard(listener[0])
|
||||
defer.returnValue(result)
|
||||
return
|
||||
break
|
||||
|
||||
if timeout:
|
||||
timed_out = [False]
|
||||
# Now we wait for the _NotifierUserStream to be told there
|
||||
# is a new token.
|
||||
# We need to supply the token we supplied to callback so
|
||||
# that we don't miss any current_token updates.
|
||||
prev_token = current_token
|
||||
listener = user_stream.new_listener(prev_token)
|
||||
yield listener.deferred
|
||||
except defer.CancelledError:
|
||||
break
|
||||
|
||||
def _timeout_listener():
|
||||
timed_out[0] = True
|
||||
timer[0] = None
|
||||
user_stream.listeners.discard(listener[0])
|
||||
listener[0].notify(current_token)
|
||||
|
||||
# We create multiple notification listeners so we have to manage
|
||||
# canceling the timeout ourselves.
|
||||
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
|
||||
|
||||
while not result and not timed_out[0]:
|
||||
new_token = yield deferred
|
||||
deferred = defer.Deferred()
|
||||
listener[0] = _NotificationListener(deferred)
|
||||
user_stream.listeners.add(listener[0])
|
||||
result = yield callback(current_token, new_token)
|
||||
current_token = new_token
|
||||
|
||||
if timer[0] is not None:
|
||||
try:
|
||||
self.clock.cancel_call_later(timer[0])
|
||||
except:
|
||||
logger.exception("Failed to cancel notifer timer")
|
||||
self.clock.cancel_call_later(timer, ignore_errs=True)
|
||||
else:
|
||||
current_token = user_stream.current_token
|
||||
result = yield callback(from_token, current_token)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
|
@ -368,6 +341,9 @@ class Notifier(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def check_for_updates(before_token, after_token):
|
||||
if not after_token.is_after(before_token):
|
||||
defer.returnValue(None)
|
||||
|
||||
events = []
|
||||
end_token = from_token
|
||||
for name, source in self.event_sources.sources.items():
|
||||
|
@ -376,10 +352,10 @@ class Notifier(object):
|
|||
after_id = getattr(after_token, keyname)
|
||||
if before_id == after_id:
|
||||
continue
|
||||
stuff, new_key = yield source.get_new_events_for_user(
|
||||
new_events, new_key = yield source.get_new_events_for_user(
|
||||
user, getattr(from_token, keyname), limit,
|
||||
)
|
||||
events.extend(stuff)
|
||||
events.extend(new_events)
|
||||
end_token = end_token.copy_and_replace(keyname, new_key)
|
||||
|
||||
if events:
|
||||
|
@ -402,7 +378,7 @@ class Notifier(object):
|
|||
expired_streams = []
|
||||
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
|
||||
for stream in self.user_to_user_stream.values():
|
||||
if stream.listeners:
|
||||
if stream.count_listeners():
|
||||
continue
|
||||
if stream.last_notified_ms < expire_before_ts:
|
||||
expired_streams.append(stream)
|
||||
|
|
|
@ -24,6 +24,7 @@ import baserules
|
|||
import logging
|
||||
import simplejson as json
|
||||
import re
|
||||
import random
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -256,12 +257,31 @@ class Pusher(object):
|
|||
logger.info("Pusher %s for user %s starting from token %s",
|
||||
self.pushkey, self.user_name, self.last_token)
|
||||
|
||||
wait = 0
|
||||
while self.alive:
|
||||
try:
|
||||
if wait > 0:
|
||||
yield synapse.util.async.sleep(wait)
|
||||
yield self.get_and_dispatch()
|
||||
wait = 0
|
||||
except:
|
||||
if wait == 0:
|
||||
wait = 1
|
||||
else:
|
||||
wait = min(wait * 2, 1800)
|
||||
logger.exception(
|
||||
"Exception in pusher loop for pushkey %s. Pausing for %ds",
|
||||
self.pushkey, wait
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_and_dispatch(self):
|
||||
from_tok = StreamToken.from_string(self.last_token)
|
||||
config = PaginationConfig(from_token=from_tok, limit='1')
|
||||
timeout = (300 + random.randint(-60, 60)) * 1000
|
||||
chunk = yield self.evStreamHandler.get_stream(
|
||||
self.user_name, config,
|
||||
timeout=100*365*24*60*60*1000, affect_presence=False
|
||||
timeout=timeout, affect_presence=False
|
||||
)
|
||||
|
||||
# limiting to 1 may get 1 event plus 1 presence event, so
|
||||
|
@ -273,10 +293,11 @@ class Pusher(object):
|
|||
break
|
||||
if not single_event:
|
||||
self.last_token = chunk['end']
|
||||
continue
|
||||
logger.debug("Event stream timeout for pushkey %s", self.pushkey)
|
||||
return
|
||||
|
||||
if not self.alive:
|
||||
continue
|
||||
return
|
||||
|
||||
processed = False
|
||||
actions = yield self._actions_for_event(single_event)
|
||||
|
@ -319,7 +340,7 @@ class Pusher(object):
|
|||
)
|
||||
|
||||
if not self.alive:
|
||||
continue
|
||||
return
|
||||
|
||||
if processed:
|
||||
self.backoff_delay = Pusher.INITIAL_BACKOFF
|
||||
|
|
|
@ -164,7 +164,7 @@ def make_base_append_underride_rules(user):
|
|||
]
|
||||
},
|
||||
{
|
||||
'rule_id': 'global/override/.m.rule.contains_display_name',
|
||||
'rule_id': 'global/underride/.m.rule.contains_display_name',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'contains_display_name'
|
||||
|
|
|
@ -31,6 +31,8 @@ REQUIREMENTS = {
|
|||
"pillow": ["PIL"],
|
||||
"pydenticon": ["pydenticon"],
|
||||
"ujson": ["ujson"],
|
||||
"blist": ["blist"],
|
||||
"pysaml2": ["saml2"],
|
||||
}
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
"web_client": {
|
||||
|
|
|
@ -20,14 +20,32 @@ from synapse.types import UserID
|
|||
from base import ClientV1RestServlet, client_path_pattern
|
||||
|
||||
import simplejson as json
|
||||
import urllib
|
||||
|
||||
import logging
|
||||
from saml2 import BINDING_HTTP_POST
|
||||
from saml2 import config
|
||||
from saml2.client import Saml2Client
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginRestServlet(ClientV1RestServlet):
|
||||
PATTERN = client_path_pattern("/login$")
|
||||
PASS_TYPE = "m.login.password"
|
||||
SAML2_TYPE = "m.login.saml2"
|
||||
|
||||
def __init__(self, hs):
|
||||
super(LoginRestServlet, self).__init__(hs)
|
||||
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
|
||||
self.saml2_enabled = hs.config.saml2_enabled
|
||||
|
||||
def on_GET(self, request):
|
||||
return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}]})
|
||||
flows = [{"type": LoginRestServlet.PASS_TYPE}]
|
||||
if self.saml2_enabled:
|
||||
flows.append({"type": LoginRestServlet.SAML2_TYPE})
|
||||
return (200, {"flows": flows})
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
@ -39,6 +57,16 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
|
||||
result = yield self.do_password_login(login_submission)
|
||||
defer.returnValue(result)
|
||||
elif self.saml2_enabled and (login_submission["type"] ==
|
||||
LoginRestServlet.SAML2_TYPE):
|
||||
relay_state = ""
|
||||
if "relay_state" in login_submission:
|
||||
relay_state = "&RelayState="+urllib.quote(
|
||||
login_submission["relay_state"])
|
||||
result = {
|
||||
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
|
||||
}
|
||||
defer.returnValue((200, result))
|
||||
else:
|
||||
raise SynapseError(400, "Bad login type.")
|
||||
except KeyError:
|
||||
|
@ -94,6 +122,49 @@ class PasswordResetRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
|
||||
|
||||
class SAML2RestServlet(ClientV1RestServlet):
|
||||
PATTERN = client_path_pattern("/login/saml2")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(SAML2RestServlet, self).__init__(hs)
|
||||
self.sp_config = hs.config.saml2_config_path
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
saml2_auth = None
|
||||
try:
|
||||
conf = config.SPConfig()
|
||||
conf.load_file(self.sp_config)
|
||||
SP = Saml2Client(conf)
|
||||
saml2_auth = SP.parse_authn_request_response(
|
||||
request.args['SAMLResponse'][0], BINDING_HTTP_POST)
|
||||
except Exception, e: # Not authenticated
|
||||
logger.exception(e)
|
||||
if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
|
||||
username = saml2_auth.name_id.text
|
||||
handler = self.handlers.registration_handler
|
||||
(user_id, token) = yield handler.register_saml2(username)
|
||||
# Forward to the RelayState callback along with ava
|
||||
if 'RelayState' in request.args:
|
||||
request.redirect(urllib.unquote(
|
||||
request.args['RelayState'][0]) +
|
||||
'?status=authenticated&access_token=' +
|
||||
token + '&user_id=' + user_id + '&ava=' +
|
||||
urllib.quote(json.dumps(saml2_auth.ava)))
|
||||
request.finish()
|
||||
defer.returnValue(None)
|
||||
defer.returnValue((200, {"status": "authenticated",
|
||||
"user_id": user_id, "token": token,
|
||||
"ava": saml2_auth.ava}))
|
||||
elif 'RelayState' in request.args:
|
||||
request.redirect(urllib.unquote(
|
||||
request.args['RelayState'][0]) +
|
||||
'?status=not_authenticated')
|
||||
request.finish()
|
||||
defer.returnValue(None)
|
||||
defer.returnValue((200, {"status": "not_authenticated"}))
|
||||
|
||||
|
||||
def _parse_json(request):
|
||||
try:
|
||||
content = json.loads(request.content.read())
|
||||
|
@ -106,4 +177,6 @@ def _parse_json(request):
|
|||
|
||||
def register_servlets(hs, http_server):
|
||||
LoginRestServlet(hs).register(http_server)
|
||||
if hs.config.saml2_enabled:
|
||||
SAML2RestServlet(hs).register(http_server)
|
||||
# TODO PasswordResetRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -412,6 +412,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
if "user_id" not in content:
|
||||
raise SynapseError(400, "Missing user_id key.")
|
||||
state_key = content["user_id"]
|
||||
# make sure it looks like a user ID; it'll throw if it's invalid.
|
||||
UserID.from_string(state_key)
|
||||
|
||||
if membership_action == "kick":
|
||||
membership_action = "leave"
|
||||
|
|
|
@ -39,10 +39,10 @@ class HttpTransactionStore(object):
|
|||
A tuple of (HTTP response code, response content) or None.
|
||||
"""
|
||||
try:
|
||||
logger.debug("get_response Key: %s TxnId: %s", key, txn_id)
|
||||
logger.debug("get_response TxnId: %s", txn_id)
|
||||
(last_txn_id, response) = self.transactions[key]
|
||||
if txn_id == last_txn_id:
|
||||
logger.info("get_response: Returning a response for %s", key)
|
||||
logger.info("get_response: Returning a response for %s", txn_id)
|
||||
return response
|
||||
except KeyError:
|
||||
pass
|
||||
|
@ -58,7 +58,7 @@ class HttpTransactionStore(object):
|
|||
txn_id (str): The transaction ID for this request.
|
||||
response (tuple): A tuple of (HTTP response code, response content)
|
||||
"""
|
||||
logger.debug("store_response Key: %s TxnId: %s", key, txn_id)
|
||||
logger.debug("store_response TxnId: %s", txn_id)
|
||||
self.transactions[key] = (txn_id, response)
|
||||
|
||||
def store_client_transaction(self, request, txn_id, response):
|
||||
|
|
|
@ -18,7 +18,9 @@ from . import (
|
|||
filter,
|
||||
account,
|
||||
register,
|
||||
auth
|
||||
auth,
|
||||
receipts,
|
||||
keys,
|
||||
)
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
|
@ -38,3 +40,5 @@ class ClientV2AlphaRestResource(JsonResource):
|
|||
account.register_servlets(hs, client_resource)
|
||||
register.register_servlets(hs, client_resource)
|
||||
auth.register_servlets(hs, client_resource)
|
||||
receipts.register_servlets(hs, client_resource)
|
||||
keys.register_servlets(hs, client_resource)
|
||||
|
|
|
@ -0,0 +1,276 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import RestServlet
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
from ._base import client_v2_pattern
|
||||
|
||||
import simplejson as json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeyUploadServlet(RestServlet):
|
||||
"""
|
||||
POST /keys/upload/<device_id> HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"device_keys": {
|
||||
"user_id": "<user_id>",
|
||||
"device_id": "<device_id>",
|
||||
"valid_until_ts": <millisecond_timestamp>,
|
||||
"algorithms": [
|
||||
"m.olm.curve25519-aes-sha256",
|
||||
]
|
||||
"keys": {
|
||||
"<algorithm>:<device_id>": "<key_base64>",
|
||||
},
|
||||
"signatures:" {
|
||||
"<user_id>" {
|
||||
"<algorithm>:<device_id>": "<signature_base64>"
|
||||
} } },
|
||||
"one_time_keys": {
|
||||
"<algorithm>:<key_id>": "<key_base64>"
|
||||
},
|
||||
}
|
||||
"""
|
||||
PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(KeyUploadServlet, self).__init__()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, device_id):
|
||||
auth_user, client_info = yield self.auth.get_user_by_req(request)
|
||||
user_id = auth_user.to_string()
|
||||
# TODO: Check that the device_id matches that in the authentication
|
||||
# or derive the device_id from the authentication instead.
|
||||
try:
|
||||
body = json.loads(request.content.read())
|
||||
except:
|
||||
raise SynapseError(400, "Invalid key JSON")
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
# TODO: Validate the JSON to make sure it has the right keys.
|
||||
device_keys = body.get("device_keys", None)
|
||||
if device_keys:
|
||||
logger.info(
|
||||
"Updating device_keys for device %r for user %r at %d",
|
||||
device_id, auth_user, time_now
|
||||
)
|
||||
# TODO: Sign the JSON with the server key
|
||||
yield self.store.set_e2e_device_keys(
|
||||
user_id, device_id, time_now,
|
||||
encode_canonical_json(device_keys)
|
||||
)
|
||||
|
||||
one_time_keys = body.get("one_time_keys", None)
|
||||
if one_time_keys:
|
||||
logger.info(
|
||||
"Adding %d one_time_keys for device %r for user %r at %d",
|
||||
len(one_time_keys), device_id, user_id, time_now
|
||||
)
|
||||
key_list = []
|
||||
for key_id, key_json in one_time_keys.items():
|
||||
algorithm, key_id = key_id.split(":")
|
||||
key_list.append((
|
||||
algorithm, key_id, encode_canonical_json(key_json)
|
||||
))
|
||||
|
||||
yield self.store.add_e2e_one_time_keys(
|
||||
user_id, device_id, time_now, key_list
|
||||
)
|
||||
|
||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, device_id):
|
||||
auth_user, client_info = yield self.auth.get_user_by_req(request)
|
||||
user_id = auth_user.to_string()
|
||||
|
||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||
|
||||
|
||||
class KeyQueryServlet(RestServlet):
|
||||
"""
|
||||
GET /keys/query/<user_id> HTTP/1.1
|
||||
|
||||
GET /keys/query/<user_id>/<device_id> HTTP/1.1
|
||||
|
||||
POST /keys/query HTTP/1.1
|
||||
Content-Type: application/json
|
||||
{
|
||||
"device_keys": {
|
||||
"<user_id>": ["<device_id>"]
|
||||
} }
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
{
|
||||
"device_keys": {
|
||||
"<user_id>": {
|
||||
"<device_id>": {
|
||||
"user_id": "<user_id>", // Duplicated to be signed
|
||||
"device_id": "<device_id>", // Duplicated to be signed
|
||||
"valid_until_ts": <millisecond_timestamp>,
|
||||
"algorithms": [ // List of supported algorithms
|
||||
"m.olm.curve25519-aes-sha256",
|
||||
],
|
||||
"keys": { // Must include a ed25519 signing key
|
||||
"<algorithm>:<key_id>": "<key_base64>",
|
||||
},
|
||||
"signatures:" {
|
||||
// Must be signed with device's ed25519 key
|
||||
"<user_id>/<device_id>": {
|
||||
"<algorithm>:<key_id>": "<signature_base64>"
|
||||
}
|
||||
// Must be signed by this server.
|
||||
"<server_name>": {
|
||||
"<algorithm>:<key_id>": "<signature_base64>"
|
||||
} } } } } }
|
||||
"""
|
||||
|
||||
PATTERN = client_v2_pattern(
|
||||
"/keys/query(?:"
|
||||
"/(?P<user_id>[^/]*)(?:"
|
||||
"/(?P<device_id>[^/]*)"
|
||||
")?"
|
||||
")?"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(KeyQueryServlet, self).__init__()
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id, device_id):
|
||||
logger.debug("onPOST")
|
||||
yield self.auth.get_user_by_req(request)
|
||||
try:
|
||||
body = json.loads(request.content.read())
|
||||
except:
|
||||
raise SynapseError(400, "Invalid key JSON")
|
||||
query = []
|
||||
for user_id, device_ids in body.get("device_keys", {}).items():
|
||||
if not device_ids:
|
||||
query.append((user_id, None))
|
||||
else:
|
||||
for device_id in device_ids:
|
||||
query.append((user_id, device_id))
|
||||
results = yield self.store.get_e2e_device_keys(query)
|
||||
defer.returnValue(self.json_result(request, results))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id, device_id):
|
||||
auth_user, client_info = yield self.auth.get_user_by_req(request)
|
||||
auth_user_id = auth_user.to_string()
|
||||
if not user_id:
|
||||
user_id = auth_user_id
|
||||
if not device_id:
|
||||
device_id = None
|
||||
# Returns a map of user_id->device_id->json_bytes.
|
||||
results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
|
||||
defer.returnValue(self.json_result(request, results))
|
||||
|
||||
def json_result(self, request, results):
|
||||
json_result = {}
|
||||
for user_id, device_keys in results.items():
|
||||
for device_id, json_bytes in device_keys.items():
|
||||
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
||||
json_bytes
|
||||
)
|
||||
return (200, {"device_keys": json_result})
|
||||
|
||||
|
||||
class OneTimeKeyServlet(RestServlet):
|
||||
"""
|
||||
GET /keys/claim/<user-id>/<device-id>/<algorithm> HTTP/1.1
|
||||
|
||||
POST /keys/claim HTTP/1.1
|
||||
{
|
||||
"one_time_keys": {
|
||||
"<user_id>": {
|
||||
"<device_id>": "<algorithm>"
|
||||
} } }
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
{
|
||||
"one_time_keys": {
|
||||
"<user_id>": {
|
||||
"<device_id>": {
|
||||
"<algorithm>:<key_id>": "<key_base64>"
|
||||
} } } }
|
||||
|
||||
"""
|
||||
PATTERN = client_v2_pattern(
|
||||
"/keys/claim(?:/?|(?:/"
|
||||
"(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
|
||||
")?)"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(OneTimeKeyServlet, self).__init__()
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id, device_id, algorithm):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
results = yield self.store.claim_e2e_one_time_keys(
|
||||
[(user_id, device_id, algorithm)]
|
||||
)
|
||||
defer.returnValue(self.json_result(request, results))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id, device_id, algorithm):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
try:
|
||||
body = json.loads(request.content.read())
|
||||
except:
|
||||
raise SynapseError(400, "Invalid key JSON")
|
||||
query = []
|
||||
for user_id, device_keys in body.get("one_time_keys", {}).items():
|
||||
for device_id, algorithm in device_keys.items():
|
||||
query.append((user_id, device_id, algorithm))
|
||||
results = yield self.store.claim_e2e_one_time_keys(query)
|
||||
defer.returnValue(self.json_result(request, results))
|
||||
|
||||
def json_result(self, request, results):
|
||||
json_result = {}
|
||||
for user_id, device_keys in results.items():
|
||||
for device_id, keys in device_keys.items():
|
||||
for key_id, json_bytes in keys.items():
|
||||
json_result.setdefault(user_id, {})[device_id] = {
|
||||
key_id: json.loads(json_bytes)
|
||||
}
|
||||
return (200, {"one_time_keys": json_result})
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
KeyUploadServlet(hs).register(http_server)
|
||||
KeyQueryServlet(hs).register(http_server)
|
||||
OneTimeKeyServlet(hs).register(http_server)
|
|
@ -0,0 +1,55 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet
|
||||
from ._base import client_v2_pattern
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReceiptRestServlet(RestServlet):
|
||||
PATTERN = client_v2_pattern(
|
||||
"/rooms/(?P<room_id>[^/]*)"
|
||||
"/receipt/(?P<receipt_type>[^/]*)"
|
||||
"/(?P<event_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ReceiptRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.receipts_handler = hs.get_handlers().receipts_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, receipt_type, event_id):
|
||||
user, client = yield self.auth.get_user_by_req(request)
|
||||
|
||||
yield self.receipts_handler.received_client_receipt(
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id=user.to_string(),
|
||||
event_id=event_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ReceiptRestServlet(hs).register(http_server)
|
|
@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
|
|||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_v2_pattern, parse_request_allow_empty
|
||||
from ._base import client_v2_pattern, parse_json_dict_from_request
|
||||
|
||||
import logging
|
||||
import hmac
|
||||
|
@ -55,21 +55,55 @@ class RegisterRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
yield run_on_reactor()
|
||||
body = parse_json_dict_from_request(request)
|
||||
|
||||
body = parse_request_allow_empty(request)
|
||||
if 'password' not in body:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
# we do basic sanity checks here because the auth layer will store these
|
||||
# in sessions. Pull out the username/password provided to us.
|
||||
desired_password = None
|
||||
if 'password' in body:
|
||||
if (not isinstance(body['password'], basestring) or
|
||||
len(body['password']) > 512):
|
||||
raise SynapseError(400, "Invalid password")
|
||||
desired_password = body["password"]
|
||||
|
||||
desired_username = None
|
||||
if 'username' in body:
|
||||
if (not isinstance(body['username'], basestring) or
|
||||
len(body['username']) > 512):
|
||||
raise SynapseError(400, "Invalid username")
|
||||
desired_username = body['username']
|
||||
yield self.registration_handler.check_username(desired_username)
|
||||
|
||||
is_using_shared_secret = False
|
||||
is_application_server = False
|
||||
|
||||
service = None
|
||||
appservice = None
|
||||
if 'access_token' in request.args:
|
||||
service = yield self.auth.get_appservice_by_req(request)
|
||||
appservice = yield self.auth.get_appservice_by_req(request)
|
||||
|
||||
# fork off as soon as possible for ASes and shared secret auth which
|
||||
# have completely different registration flows to normal users
|
||||
|
||||
# == Application Service Registration ==
|
||||
if appservice:
|
||||
result = yield self._do_appservice_registration(
|
||||
desired_username, request.args["access_token"][0]
|
||||
)
|
||||
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||
return
|
||||
|
||||
# == Shared Secret Registration == (e.g. create new user scripts)
|
||||
if 'mac' in body:
|
||||
# FIXME: Should we really be determining if this is shared secret
|
||||
# auth based purely on the 'mac' key?
|
||||
result = yield self._do_shared_secret_registration(
|
||||
desired_username, desired_password, body["mac"]
|
||||
)
|
||||
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||
return
|
||||
|
||||
# == Normal User Registration == (everyone else)
|
||||
if self.hs.config.disable_registration:
|
||||
raise SynapseError(403, "Registration has been disabled")
|
||||
|
||||
if desired_username is not None:
|
||||
yield self.registration_handler.check_username(desired_username)
|
||||
|
||||
if self.hs.config.enable_registration_captcha:
|
||||
flows = [
|
||||
|
@ -82,39 +116,20 @@ class RegisterRestServlet(RestServlet):
|
|||
[LoginType.EMAIL_IDENTITY]
|
||||
]
|
||||
|
||||
result = None
|
||||
if service:
|
||||
is_application_server = True
|
||||
params = body
|
||||
elif 'mac' in body:
|
||||
# Check registration-specific shared secret auth
|
||||
if 'username' not in body:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
self._check_shared_secret_auth(
|
||||
body['username'], body['mac']
|
||||
)
|
||||
is_using_shared_secret = True
|
||||
params = body
|
||||
else:
|
||||
authed, result, params = yield self.auth_handler.check_auth(
|
||||
flows, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
return
|
||||
|
||||
can_register = (
|
||||
not self.hs.config.disable_registration
|
||||
or is_application_server
|
||||
or is_using_shared_secret
|
||||
)
|
||||
if not can_register:
|
||||
raise SynapseError(403, "Registration has been disabled")
|
||||
|
||||
# NB: This may be from the auth handler and NOT from the POST
|
||||
if 'password' not in params:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
desired_username = params['username'] if 'username' in params else None
|
||||
new_password = params['password']
|
||||
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
|
||||
|
||||
desired_username = params.get("username", None)
|
||||
new_password = params.get("password", None)
|
||||
|
||||
(user_id, token) = yield self.registration_handler.register(
|
||||
localpart=desired_username,
|
||||
|
@ -147,18 +162,21 @@ class RegisterRestServlet(RestServlet):
|
|||
else:
|
||||
logger.info("bind_email not specified: not binding email")
|
||||
|
||||
result = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
|
||||
result = self._create_registration_details(user_id, token)
|
||||
defer.returnValue((200, result))
|
||||
|
||||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
def _check_shared_secret_auth(self, username, mac):
|
||||
@defer.inlineCallbacks
|
||||
def _do_appservice_registration(self, username, as_token):
|
||||
(user_id, token) = yield self.registration_handler.appservice_register(
|
||||
username, as_token
|
||||
)
|
||||
defer.returnValue(self._create_registration_details(user_id, token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_shared_secret_registration(self, username, password, mac):
|
||||
if not self.hs.config.registration_shared_secret:
|
||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
|
||||
|
@ -174,13 +192,23 @@ class RegisterRestServlet(RestServlet):
|
|||
digestmod=sha1,
|
||||
).hexdigest()
|
||||
|
||||
if compare_digest(want_mac, got_mac):
|
||||
return True
|
||||
else:
|
||||
if not compare_digest(want_mac, got_mac):
|
||||
raise SynapseError(
|
||||
403, "HMAC incorrect",
|
||||
)
|
||||
|
||||
(user_id, token) = yield self.registration_handler.register(
|
||||
localpart=username, password=password
|
||||
)
|
||||
defer.returnValue(self._create_registration_details(user_id, token))
|
||||
|
||||
def _create_registration_details(self, user_id, token):
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RegisterRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -15,20 +15,23 @@
|
|||
|
||||
from .thumbnailer import Thumbnailer
|
||||
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.http.server import respond_with_json
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import (
|
||||
cs_error, Codes, SynapseError
|
||||
)
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet import defer, threads
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.protocols.basic import FileSender
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.stringutils import is_ascii
|
||||
|
||||
import os
|
||||
|
||||
import cgi
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -36,8 +39,13 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
def parse_media_id(request):
|
||||
try:
|
||||
server_name, media_id = request.postpath
|
||||
return (server_name, media_id)
|
||||
# This allows users to append e.g. /test.png to the URL. Useful for
|
||||
# clients that parse the URL to see content type.
|
||||
server_name, media_id = request.postpath[:2]
|
||||
if len(request.postpath) > 2 and is_ascii(request.postpath[-1]):
|
||||
return server_name, media_id, request.postpath[-1]
|
||||
else:
|
||||
return server_name, media_id, None
|
||||
except:
|
||||
raise SynapseError(
|
||||
404,
|
||||
|
@ -52,7 +60,7 @@ class BaseMediaResource(Resource):
|
|||
def __init__(self, hs, filepaths):
|
||||
Resource.__init__(self)
|
||||
self.auth = hs.get_auth()
|
||||
self.client = hs.get_http_client()
|
||||
self.client = MatrixFederationHttpClient(hs)
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.hostname
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -127,12 +135,21 @@ class BaseMediaResource(Resource):
|
|||
media_type = headers["Content-Type"][0]
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
content_disposition = headers.get("Content-Disposition", None)
|
||||
if content_disposition:
|
||||
_, params = cgi.parse_header(content_disposition[0],)
|
||||
upload_name = params.get("filename", None)
|
||||
if upload_name and not is_ascii(upload_name):
|
||||
upload_name = None
|
||||
else:
|
||||
upload_name = None
|
||||
|
||||
yield self.store.store_cached_remote_media(
|
||||
origin=server_name,
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name=None,
|
||||
upload_name=upload_name,
|
||||
media_length=length,
|
||||
filesystem_id=file_id,
|
||||
)
|
||||
|
@ -143,7 +160,7 @@ class BaseMediaResource(Resource):
|
|||
media_info = {
|
||||
"media_type": media_type,
|
||||
"media_length": length,
|
||||
"upload_name": None,
|
||||
"upload_name": upload_name,
|
||||
"created_ts": time_now_ms,
|
||||
"filesystem_id": file_id,
|
||||
}
|
||||
|
@ -156,11 +173,16 @@ class BaseMediaResource(Resource):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_with_file(self, request, media_type, file_path,
|
||||
file_size=None):
|
||||
file_size=None, upload_name=None):
|
||||
logger.debug("Responding with %r", file_path)
|
||||
|
||||
if os.path.isfile(file_path):
|
||||
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
|
||||
if upload_name:
|
||||
request.setHeader(
|
||||
b"Content-Disposition",
|
||||
b"inline; filename=%s" % (upload_name.encode("utf-8"),),
|
||||
)
|
||||
|
||||
# cache for at least a day.
|
||||
# XXX: we might want to turn this off for data we don't want to
|
||||
|
@ -222,6 +244,9 @@ class BaseMediaResource(Resource):
|
|||
)
|
||||
return
|
||||
|
||||
local_thumbnails = []
|
||||
|
||||
def generate_thumbnails():
|
||||
scales = set()
|
||||
crops = set()
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
|
@ -240,9 +265,10 @@ class BaseMediaResource(Resource):
|
|||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
yield self.store.store_local_thumbnail(
|
||||
|
||||
local_thumbnails.append((
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
))
|
||||
|
||||
for t_width, t_height, t_type in crops:
|
||||
if (t_width, t_height, t_type) in scales:
|
||||
|
@ -256,9 +282,14 @@ class BaseMediaResource(Resource):
|
|||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
yield self.store.store_local_thumbnail(
|
||||
local_thumbnails.append((
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
))
|
||||
|
||||
yield threads.deferToThread(generate_thumbnails)
|
||||
|
||||
for l in local_thumbnails:
|
||||
yield self.store.store_local_thumbnail(*l)
|
||||
|
||||
defer.returnValue({
|
||||
"width": m_width,
|
||||
|
@ -273,11 +304,14 @@ class BaseMediaResource(Resource):
|
|||
if not requirements:
|
||||
return
|
||||
|
||||
remote_thumbnails = []
|
||||
|
||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
m_width = thumbnailer.width
|
||||
m_height = thumbnailer.height
|
||||
|
||||
def generate_thumbnails():
|
||||
if m_width * m_height >= self.max_image_pixels:
|
||||
logger.info(
|
||||
"Image too large to thumbnail %r x %r > %r",
|
||||
|
@ -303,10 +337,10 @@ class BaseMediaResource(Resource):
|
|||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
remote_thumbnails.append([
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
])
|
||||
|
||||
for t_width, t_height, t_type in crops:
|
||||
if (t_width, t_height, t_type) in scales:
|
||||
|
@ -320,10 +354,15 @@ class BaseMediaResource(Resource):
|
|||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
remote_thumbnails.append([
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
])
|
||||
|
||||
yield threads.deferToThread(generate_thumbnails)
|
||||
|
||||
for r in remote_thumbnails:
|
||||
yield self.store.store_remote_media_thumbnail(*r)
|
||||
|
||||
defer.returnValue({
|
||||
"width": m_width,
|
||||
|
|
|
@ -32,14 +32,16 @@ class DownloadResource(BaseMediaResource):
|
|||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
server_name, media_id = parse_media_id(request)
|
||||
server_name, media_id, name = parse_media_id(request)
|
||||
if server_name == self.server_name:
|
||||
yield self._respond_local_file(request, media_id)
|
||||
yield self._respond_local_file(request, media_id, name)
|
||||
else:
|
||||
yield self._respond_remote_file(request, server_name, media_id)
|
||||
yield self._respond_remote_file(
|
||||
request, server_name, media_id, name
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_local_file(self, request, media_id):
|
||||
def _respond_local_file(self, request, media_id, name):
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
if not media_info:
|
||||
self._respond_404(request)
|
||||
|
@ -47,24 +49,28 @@ class DownloadResource(BaseMediaResource):
|
|||
|
||||
media_type = media_info["media_type"]
|
||||
media_length = media_info["media_length"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
file_path = self.filepaths.local_media_filepath(media_id)
|
||||
|
||||
yield self._respond_with_file(
|
||||
request, media_type, file_path, media_length
|
||||
request, media_type, file_path, media_length,
|
||||
upload_name=upload_name,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_remote_file(self, request, server_name, media_id):
|
||||
def _respond_remote_file(self, request, server_name, media_id, name):
|
||||
media_info = yield self._get_remote_media(server_name, media_id)
|
||||
|
||||
media_type = media_info["media_type"]
|
||||
media_length = media_info["media_length"]
|
||||
filesystem_id = media_info["filesystem_id"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
|
||||
file_path = self.filepaths.remote_media_filepath(
|
||||
server_name, filesystem_id
|
||||
)
|
||||
|
||||
yield self._respond_with_file(
|
||||
request, media_type, file_path, media_length
|
||||
request, media_type, file_path, media_length,
|
||||
upload_name=upload_name,
|
||||
)
|
||||
|
|
|
@ -36,7 +36,7 @@ class ThumbnailResource(BaseMediaResource):
|
|||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
server_name, media_id = parse_media_id(request)
|
||||
server_name, media_id, _ = parse_media_id(request)
|
||||
width = parse_integer(request, "width")
|
||||
height = parse_integer(request, "height")
|
||||
method = parse_string(request, "method", "scale")
|
||||
|
@ -162,11 +162,12 @@ class ThumbnailResource(BaseMediaResource):
|
|||
t_method = info["thumbnail_method"]
|
||||
if t_method == "scale" or t_method == "crop":
|
||||
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
||||
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
|
||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||
type_quality = desired_type != info["thumbnail_type"]
|
||||
length_quality = info["thumbnail_length"]
|
||||
info_list.append((
|
||||
aspect_quality, size_quality, type_quality,
|
||||
aspect_quality, min_quality, size_quality, type_quality,
|
||||
length_quality, info
|
||||
))
|
||||
if info_list:
|
||||
|
|
|
@ -82,7 +82,7 @@ class Thumbnailer(object):
|
|||
|
||||
def save_image(self, output_image, output_type, output_path):
|
||||
output_bytes_io = BytesIO()
|
||||
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=70)
|
||||
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
|
||||
output_bytes = output_bytes_io.getvalue()
|
||||
with open(output_path, "wb") as output_file:
|
||||
output_file.write(output_bytes)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
from synapse.http.server import respond_with_json, request_handler
|
||||
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.util.stringutils import random_string, is_ascii
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
@ -84,6 +84,12 @@ class UploadResource(BaseMediaResource):
|
|||
code=413,
|
||||
)
|
||||
|
||||
upload_name = request.args.get("filename", None)
|
||||
if upload_name:
|
||||
upload_name = upload_name[0]
|
||||
if upload_name and not is_ascii(upload_name):
|
||||
raise SynapseError(400, "filename must be ascii")
|
||||
|
||||
headers = request.requestHeaders
|
||||
|
||||
if headers.hasHeader("Content-Type"):
|
||||
|
@ -99,7 +105,7 @@ class UploadResource(BaseMediaResource):
|
|||
# TODO(markjh): parse content-dispostion
|
||||
|
||||
content_uri = yield self.create_content(
|
||||
media_type, None, request.content.read(),
|
||||
media_type, upload_name, request.content.read(),
|
||||
content_length, auth_user
|
||||
)
|
||||
|
||||
|
|
|
@ -132,16 +132,8 @@ class BaseHomeServer(object):
|
|||
setattr(BaseHomeServer, "get_%s" % (depname), _get)
|
||||
|
||||
def get_ip_from_request(self, request):
|
||||
# May be an X-Forwarding-For header depending on config
|
||||
ip_addr = request.getClientIP()
|
||||
if self.config.captcha_ip_origin_is_x_forwarded:
|
||||
# use the header
|
||||
if request.requestHeaders.hasHeader("X-Forwarded-For"):
|
||||
ip_addr = request.requestHeaders.getRawHeaders(
|
||||
"X-Forwarded-For"
|
||||
)[0]
|
||||
|
||||
return ip_addr
|
||||
# X-Forwarded-For is handled by our custom request type.
|
||||
return request.getClientIP()
|
||||
|
||||
def is_mine(self, domain_specific_string):
|
||||
return domain_specific_string.domain == self.hostname
|
||||
|
|
|
@ -106,7 +106,7 @@ class StateHandler(object):
|
|||
defer.returnValue(state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def compute_event_context(self, event, old_state=None):
|
||||
def compute_event_context(self, event, old_state=None, outlier=False):
|
||||
""" Fills out the context with the `current state` of the graph. The
|
||||
`current state` here is defined to be the state of the event graph
|
||||
just before the event - i.e. it never includes `event`
|
||||
|
@ -119,9 +119,23 @@ class StateHandler(object):
|
|||
Returns:
|
||||
an EventContext
|
||||
"""
|
||||
yield run_on_reactor()
|
||||
|
||||
context = EventContext()
|
||||
|
||||
yield run_on_reactor()
|
||||
if outlier:
|
||||
# If this is an outlier, then we know it shouldn't have any current
|
||||
# state. Certainly store.get_current_state won't return any, and
|
||||
# persisting the event won't store the state group.
|
||||
if old_state:
|
||||
context.current_state = {
|
||||
(s.type, s.state_key): s for s in old_state
|
||||
}
|
||||
else:
|
||||
context.current_state = {}
|
||||
context.prev_state_events = []
|
||||
context.state_group = None
|
||||
defer.returnValue(context)
|
||||
|
||||
if old_state:
|
||||
context.current_state = {
|
||||
|
@ -155,10 +169,6 @@ class StateHandler(object):
|
|||
context.current_state = curr_state
|
||||
context.state_group = group if not event.is_state() else None
|
||||
|
||||
prev_state = yield self.store.add_event_hashes(
|
||||
prev_state
|
||||
)
|
||||
|
||||
if event.is_state():
|
||||
key = (event.type, event.state_key)
|
||||
if key in context.current_state:
|
||||
|
|
|
@ -37,6 +37,9 @@ from .rejections import RejectionsStore
|
|||
from .state import StateStore
|
||||
from .signatures import SignatureStore
|
||||
from .filtering import FilteringStore
|
||||
from .end_to_end_keys import EndToEndKeyStore
|
||||
|
||||
from .receipts import ReceiptsStore
|
||||
|
||||
|
||||
import fnmatch
|
||||
|
@ -51,7 +54,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# Remember to update this number every time a change is made to database
|
||||
# schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 19
|
||||
SCHEMA_VERSION = 21
|
||||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
@ -74,6 +77,8 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
PushRuleStore,
|
||||
ApplicationServiceTransactionStore,
|
||||
EventsStore,
|
||||
ReceiptsStore,
|
||||
EndToEndKeyStore,
|
||||
):
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -94,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
key = (user.to_string(), access_token, device_id, ip)
|
||||
|
||||
try:
|
||||
last_seen = self.client_ip_last_seen.get(*key)
|
||||
last_seen = self.client_ip_last_seen.get(key)
|
||||
except KeyError:
|
||||
last_seen = None
|
||||
|
||||
|
@ -102,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||
defer.returnValue(None)
|
||||
|
||||
self.client_ip_last_seen.prefill(*key + (now,))
|
||||
self.client_ip_last_seen.prefill(key, now)
|
||||
|
||||
# It's safe not to lock here: a) no unique constraint,
|
||||
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
|
||||
|
@ -348,7 +353,12 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
|||
module_name, absolute_path, python_file
|
||||
)
|
||||
logger.debug("Running script %s", relative_path)
|
||||
module.run_upgrade(cur)
|
||||
module.run_upgrade(cur, database_engine)
|
||||
elif ext == ".pyc":
|
||||
# Sometimes .pyc files turn up anyway even though we've
|
||||
# disabled their generation; e.g. from distribution package
|
||||
# installers. Silently skip it
|
||||
pass
|
||||
elif ext == ".sql":
|
||||
# A plain old .sql file, just read and execute it
|
||||
logger.debug("Applying schema %s", relative_path)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import logging
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
||||
from synapse.util.lrucache import LruCache
|
||||
|
@ -27,6 +28,7 @@ from twisted.internet import defer
|
|||
from collections import namedtuple, OrderedDict
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
|
@ -55,9 +57,12 @@ cache_counter = metrics.register_cache(
|
|||
)
|
||||
|
||||
|
||||
_CacheSentinel = object()
|
||||
|
||||
|
||||
class Cache(object):
|
||||
|
||||
def __init__(self, name, max_entries=1000, keylen=1, lru=False):
|
||||
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
|
||||
if lru:
|
||||
self.cache = LruCache(max_size=max_entries)
|
||||
self.max_entries = None
|
||||
|
@ -81,45 +86,44 @@ class Cache(object):
|
|||
"Cache objects can only be accessed from the main thread"
|
||||
)
|
||||
|
||||
def get(self, *keyargs):
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
|
||||
if keyargs in self.cache:
|
||||
def get(self, key, default=_CacheSentinel):
|
||||
val = self.cache.get(key, _CacheSentinel)
|
||||
if val is not _CacheSentinel:
|
||||
cache_counter.inc_hits(self.name)
|
||||
return self.cache[keyargs]
|
||||
return val
|
||||
|
||||
cache_counter.inc_misses(self.name)
|
||||
raise KeyError()
|
||||
|
||||
def update(self, sequence, *args):
|
||||
if default is _CacheSentinel:
|
||||
raise KeyError()
|
||||
else:
|
||||
return default
|
||||
|
||||
def update(self, sequence, key, value):
|
||||
self.check_thread()
|
||||
if self.sequence == sequence:
|
||||
# Only update the cache if the caches sequence number matches the
|
||||
# number that the cache had before the SELECT was started (SYN-369)
|
||||
self.prefill(*args)
|
||||
|
||||
def prefill(self, *args): # because I can't *keyargs, value
|
||||
keyargs = args[:-1]
|
||||
value = args[-1]
|
||||
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
self.prefill(key, value)
|
||||
|
||||
def prefill(self, key, value):
|
||||
if self.max_entries is not None:
|
||||
while len(self.cache) >= self.max_entries:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
self.cache[keyargs] = value
|
||||
self.cache[key] = value
|
||||
|
||||
def invalidate(self, *keyargs):
|
||||
def invalidate(self, key):
|
||||
self.check_thread()
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
if not isinstance(key, tuple):
|
||||
raise TypeError(
|
||||
"The cache key must be a tuple not %r" % (type(key),)
|
||||
)
|
||||
|
||||
# Increment the sequence number so that any SELECT statements that
|
||||
# raced with the INSERT don't update the cache (SYN-369)
|
||||
self.sequence += 1
|
||||
self.cache.pop(keyargs, None)
|
||||
self.cache.pop(key, None)
|
||||
|
||||
def invalidate_all(self):
|
||||
self.check_thread()
|
||||
|
@ -127,9 +131,12 @@ class Cache(object):
|
|||
self.cache.clear()
|
||||
|
||||
|
||||
def cached(max_entries=1000, num_args=1, lru=False):
|
||||
class CacheDescriptor(object):
|
||||
""" A method decorator that applies a memoizing cache around the function.
|
||||
|
||||
This caches deferreds, rather than the results themselves. Deferreds that
|
||||
fail are removed from the cache.
|
||||
|
||||
The function is presumed to take zero or more arguments, which are used in
|
||||
a tuple as the key for the cache. Hits are served directly from the cache;
|
||||
misses use the function body to generate the value.
|
||||
|
@ -141,47 +148,108 @@ def cached(max_entries=1000, num_args=1, lru=False):
|
|||
which can be used to insert values into the cache specifically, without
|
||||
calling the calculation function.
|
||||
"""
|
||||
def wrap(orig):
|
||||
cache = Cache(
|
||||
name=orig.__name__,
|
||||
max_entries=max_entries,
|
||||
keylen=num_args,
|
||||
lru=lru,
|
||||
def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
|
||||
inlineCallbacks=False):
|
||||
self.orig = orig
|
||||
|
||||
if inlineCallbacks:
|
||||
self.function_to_call = defer.inlineCallbacks(orig)
|
||||
else:
|
||||
self.function_to_call = orig
|
||||
|
||||
self.max_entries = max_entries
|
||||
self.num_args = num_args
|
||||
self.lru = lru
|
||||
|
||||
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
|
||||
|
||||
if len(self.arg_names) < self.num_args:
|
||||
raise Exception(
|
||||
"Not enough explicit positional arguments to key off of for %r."
|
||||
" (@cached cannot key off of *args or **kwars)"
|
||||
% (orig.__name__,)
|
||||
)
|
||||
|
||||
@functools.wraps(orig)
|
||||
@defer.inlineCallbacks
|
||||
def wrapped(self, *keyargs):
|
||||
self.cache = Cache(
|
||||
name=self.orig.__name__,
|
||||
max_entries=self.max_entries,
|
||||
keylen=self.num_args,
|
||||
lru=self.lru,
|
||||
)
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def wrapped(*args, **kwargs):
|
||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
||||
try:
|
||||
cached_result = cache.get(*keyargs)
|
||||
cached_result_d = self.cache.get(cache_key)
|
||||
|
||||
observer = cached_result_d.observe()
|
||||
if DEBUG_CACHES:
|
||||
actual_result = yield orig(self, *keyargs)
|
||||
@defer.inlineCallbacks
|
||||
def check_result(cached_result):
|
||||
actual_result = yield self.function_to_call(obj, *args, **kwargs)
|
||||
if actual_result != cached_result:
|
||||
logger.error(
|
||||
"Stale cache entry %s%r: cached: %r, actual %r",
|
||||
orig.__name__, keyargs,
|
||||
self.orig.__name__, cache_key,
|
||||
cached_result, actual_result,
|
||||
)
|
||||
raise ValueError("Stale cache entry")
|
||||
defer.returnValue(cached_result)
|
||||
observer.addCallback(check_result)
|
||||
|
||||
return observer
|
||||
except KeyError:
|
||||
# Get the sequence number of the cache before reading from the
|
||||
# database so that we can tell if the cache is invalidated
|
||||
# while the SELECT is executing (SYN-369)
|
||||
sequence = cache.sequence
|
||||
sequence = self.cache.sequence
|
||||
|
||||
ret = yield orig(self, *keyargs)
|
||||
ret = defer.maybeDeferred(
|
||||
self.function_to_call,
|
||||
obj, *args, **kwargs
|
||||
)
|
||||
|
||||
cache.update(sequence, *keyargs + (ret,))
|
||||
def onErr(f):
|
||||
self.cache.invalidate(cache_key)
|
||||
return f
|
||||
|
||||
defer.returnValue(ret)
|
||||
ret.addErrback(onErr)
|
||||
|
||||
ret = ObservableDeferred(ret, consumeErrors=True)
|
||||
self.cache.update(sequence, cache_key, ret)
|
||||
|
||||
return ret.observe()
|
||||
|
||||
wrapped.invalidate = self.cache.invalidate
|
||||
wrapped.invalidate_all = self.cache.invalidate_all
|
||||
wrapped.prefill = self.cache.prefill
|
||||
|
||||
obj.__dict__[self.orig.__name__] = wrapped
|
||||
|
||||
wrapped.invalidate = cache.invalidate
|
||||
wrapped.invalidate_all = cache.invalidate_all
|
||||
wrapped.prefill = cache.prefill
|
||||
return wrapped
|
||||
|
||||
return wrap
|
||||
|
||||
def cached(max_entries=1000, num_args=1, lru=True):
|
||||
return lambda orig: CacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
num_args=num_args,
|
||||
lru=lru
|
||||
)
|
||||
|
||||
|
||||
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
|
||||
return lambda orig: CacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
num_args=num_args,
|
||||
lru=lru,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
|
||||
|
||||
class LoggingTransaction(object):
|
||||
|
@ -312,13 +380,14 @@ class SQLBaseStore(object):
|
|||
|
||||
self.database_engine = hs.database_engine
|
||||
|
||||
self._stream_id_gen = StreamIdGenerator()
|
||||
self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
|
||||
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
||||
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
||||
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
||||
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
||||
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
||||
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
||||
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
|
||||
|
||||
def start_profiling(self):
|
||||
self._previous_loop_ts = self._clock.time_msec()
|
||||
|
|
|
@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore):
|
|||
},
|
||||
desc="create_room_alias_association",
|
||||
)
|
||||
self.get_aliases_for_room.invalidate(room_id)
|
||||
self.get_aliases_for_room.invalidate((room_id,))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_room_alias(self, room_alias):
|
||||
|
@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore):
|
|||
room_alias,
|
||||
)
|
||||
|
||||
self.get_aliases_for_room.invalidate(room_id)
|
||||
self.get_aliases_for_room.invalidate((room_id,))
|
||||
defer.returnValue(room_id)
|
||||
|
||||
def _delete_room_alias_txn(self, txn, room_alias):
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from _base import SQLBaseStore
|
||||
|
||||
|
||||
class EndToEndKeyStore(SQLBaseStore):
|
||||
def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes):
|
||||
return self._simple_upsert(
|
||||
table="e2e_device_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
values={
|
||||
"ts_added_ms": time_now,
|
||||
"key_json": json_bytes,
|
||||
}
|
||||
)
|
||||
|
||||
def get_e2e_device_keys(self, query_list):
|
||||
"""Fetch a list of device keys.
|
||||
Args:
|
||||
query_list(list): List of pairs of user_ids and device_ids.
|
||||
Returns:
|
||||
Dict mapping from user-id to dict mapping from device_id to
|
||||
key json byte strings.
|
||||
"""
|
||||
def _get_e2e_device_keys(txn):
|
||||
result = {}
|
||||
for user_id, device_id in query_list:
|
||||
user_result = result.setdefault(user_id, {})
|
||||
keyvalues = {"user_id": user_id}
|
||||
if device_id:
|
||||
keyvalues["device_id"] = device_id
|
||||
rows = self._simple_select_list_txn(
|
||||
txn, table="e2e_device_keys_json",
|
||||
keyvalues=keyvalues,
|
||||
retcols=["device_id", "key_json"]
|
||||
)
|
||||
for row in rows:
|
||||
user_result[row["device_id"]] = row["key_json"]
|
||||
return result
|
||||
return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
|
||||
|
||||
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
|
||||
def _add_e2e_one_time_keys(txn):
|
||||
for (algorithm, key_id, json_bytes) in key_list:
|
||||
self._simple_upsert_txn(
|
||||
txn, table="e2e_one_time_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm,
|
||||
"key_id": key_id,
|
||||
},
|
||||
values={
|
||||
"ts_added_ms": time_now,
|
||||
"key_json": json_bytes,
|
||||
}
|
||||
)
|
||||
return self.runInteraction(
|
||||
"add_e2e_one_time_keys", _add_e2e_one_time_keys
|
||||
)
|
||||
|
||||
def count_e2e_one_time_keys(self, user_id, device_id):
|
||||
""" Count the number of one time keys the server has for a device
|
||||
Returns:
|
||||
Dict mapping from algorithm to number of keys for that algorithm.
|
||||
"""
|
||||
def _count_e2e_one_time_keys(txn):
|
||||
sql = (
|
||||
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
|
||||
" WHERE user_id = ? AND device_id = ?"
|
||||
" GROUP BY algorithm"
|
||||
)
|
||||
txn.execute(sql, (user_id, device_id))
|
||||
result = {}
|
||||
for algorithm, key_count in txn.fetchall():
|
||||
result[algorithm] = key_count
|
||||
return result
|
||||
return self.runInteraction(
|
||||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||
)
|
||||
|
||||
def claim_e2e_one_time_keys(self, query_list):
|
||||
"""Take a list of one time keys out of the database"""
|
||||
def _claim_e2e_one_time_keys(txn):
|
||||
sql = (
|
||||
"SELECT key_id, key_json FROM e2e_one_time_keys_json"
|
||||
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||
" LIMIT 1"
|
||||
)
|
||||
result = {}
|
||||
delete = []
|
||||
for user_id, device_id, algorithm in query_list:
|
||||
user_result = result.setdefault(user_id, {})
|
||||
device_result = user_result.setdefault(device_id, {})
|
||||
txn.execute(sql, (user_id, device_id, algorithm))
|
||||
for key_id, key_json in txn.fetchall():
|
||||
device_result[algorithm + ":" + key_id] = key_json
|
||||
delete.append((user_id, device_id, algorithm, key_id))
|
||||
sql = (
|
||||
"DELETE FROM e2e_one_time_keys_json"
|
||||
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||
" AND key_id = ?"
|
||||
)
|
||||
for user_id, device_id, algorithm, key_id in delete:
|
||||
txn.execute(sql, (user_id, device_id, algorithm, key_id))
|
||||
return result
|
||||
return self.runInteraction(
|
||||
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
||||
)
|
|
@ -49,14 +49,22 @@ class EventFederationStore(SQLBaseStore):
|
|||
results = set()
|
||||
|
||||
base_sql = (
|
||||
"SELECT auth_id FROM event_auth WHERE event_id = ?"
|
||||
"SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
|
||||
)
|
||||
|
||||
front = set(event_ids)
|
||||
while front:
|
||||
new_front = set()
|
||||
for f in front:
|
||||
txn.execute(base_sql, (f,))
|
||||
front_list = list(front)
|
||||
chunks = [
|
||||
front_list[x:x+100]
|
||||
for x in xrange(0, len(front), 100)
|
||||
]
|
||||
for chunk in chunks:
|
||||
txn.execute(
|
||||
base_sql % (",".join(["?"] * len(chunk)),),
|
||||
chunk
|
||||
)
|
||||
new_front.update([r[0] for r in txn.fetchall()])
|
||||
|
||||
new_front -= results
|
||||
|
@ -274,8 +282,7 @@ class EventFederationStore(SQLBaseStore):
|
|||
},
|
||||
)
|
||||
|
||||
def _handle_prev_events(self, txn, outlier, event_id, prev_events,
|
||||
room_id):
|
||||
def _handle_mult_prev_events(self, txn, events):
|
||||
"""
|
||||
For the given event, update the event edges table and forward and
|
||||
backward extremities tables.
|
||||
|
@ -285,45 +292,47 @@ class EventFederationStore(SQLBaseStore):
|
|||
table="event_edges",
|
||||
values=[
|
||||
{
|
||||
"event_id": event_id,
|
||||
"event_id": ev.event_id,
|
||||
"prev_event_id": e_id,
|
||||
"room_id": room_id,
|
||||
"room_id": ev.room_id,
|
||||
"is_state": False,
|
||||
}
|
||||
for e_id, _ in prev_events
|
||||
for ev in events
|
||||
for e_id, _ in ev.prev_events
|
||||
],
|
||||
)
|
||||
|
||||
# Update the extremities table if this is not an outlier.
|
||||
if not outlier:
|
||||
for e_id, _ in prev_events:
|
||||
# TODO (erikj): This could be done as a bulk insert
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="event_forward_extremities",
|
||||
keyvalues={
|
||||
"event_id": e_id,
|
||||
"room_id": room_id,
|
||||
}
|
||||
events_by_room = {}
|
||||
for ev in events:
|
||||
events_by_room.setdefault(ev.room_id, []).append(ev)
|
||||
|
||||
for room_id, room_events in events_by_room.items():
|
||||
prevs = [
|
||||
e_id for ev in room_events for e_id, _ in ev.prev_events
|
||||
if not ev.internal_metadata.is_outlier()
|
||||
]
|
||||
if prevs:
|
||||
txn.execute(
|
||||
"DELETE FROM event_forward_extremities"
|
||||
" WHERE room_id = ?"
|
||||
" AND event_id in (%s)" % (
|
||||
",".join(["?"] * len(prevs)),
|
||||
),
|
||||
[room_id] + prevs,
|
||||
)
|
||||
|
||||
# We only insert as a forward extremity the new event if there are
|
||||
# no other events that reference it as a prev event
|
||||
query = (
|
||||
"SELECT 1 FROM event_edges WHERE prev_event_id = ?"
|
||||
"INSERT INTO event_forward_extremities (event_id, room_id)"
|
||||
" SELECT ?, ? WHERE NOT EXISTS ("
|
||||
" SELECT 1 FROM event_edges WHERE prev_event_id = ?"
|
||||
" )"
|
||||
)
|
||||
|
||||
txn.execute(query, (event_id,))
|
||||
|
||||
if not txn.fetchone():
|
||||
query = (
|
||||
"INSERT INTO event_forward_extremities"
|
||||
" (event_id, room_id)"
|
||||
" VALUES (?, ?)"
|
||||
txn.executemany(
|
||||
query,
|
||||
[(ev.event_id, ev.room_id, ev.event_id) for ev in events]
|
||||
)
|
||||
|
||||
txn.execute(query, (event_id, room_id))
|
||||
|
||||
query = (
|
||||
"INSERT INTO event_backward_extremities (event_id, room_id)"
|
||||
" SELECT ?, ? WHERE NOT EXISTS ("
|
||||
|
@ -337,18 +346,23 @@ class EventFederationStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
txn.executemany(query, [
|
||||
(e_id, room_id, e_id, room_id, e_id, room_id, False)
|
||||
for e_id, _ in prev_events
|
||||
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
|
||||
for ev in events for e_id, _ in ev.prev_events
|
||||
if not ev.internal_metadata.is_outlier()
|
||||
])
|
||||
|
||||
query = (
|
||||
"DELETE FROM event_backward_extremities"
|
||||
" WHERE event_id = ? AND room_id = ?"
|
||||
)
|
||||
txn.execute(query, (event_id, room_id))
|
||||
txn.executemany(
|
||||
query,
|
||||
[(ev.event_id, ev.room_id) for ev in events]
|
||||
)
|
||||
|
||||
for room_id in events_by_room:
|
||||
txn.call_after(
|
||||
self.get_latest_event_ids_in_room.invalidate, room_id
|
||||
self.get_latest_event_ids_in_room.invalidate, (room_id,)
|
||||
)
|
||||
|
||||
def get_backfill_events(self, room_id, event_list, limit):
|
||||
|
@ -400,9 +414,11 @@ class EventFederationStore(SQLBaseStore):
|
|||
keyvalues={
|
||||
"event_id": event_id,
|
||||
},
|
||||
retcol="depth"
|
||||
retcol="depth",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if depth:
|
||||
queue.put((-depth, event_id))
|
||||
|
||||
while not queue.empty() and len(event_results) < limit:
|
||||
|
@ -489,4 +505,4 @@ class EventFederationStore(SQLBaseStore):
|
|||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||
|
||||
txn.execute(query, (room_id,))
|
||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
|
||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
|
||||
|
|
|
@ -23,9 +23,7 @@ from synapse.events.utils import prune_event
|
|||
from synapse.util.logcontext import preserve_context_over_deferred
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
|
||||
from syutil.base64util import decode_base64
|
||||
from syutil.jsonutil import encode_json
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
@ -46,6 +44,48 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
|
|||
|
||||
|
||||
class EventsStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
def persist_events(self, events_and_contexts, backfilled=False,
|
||||
is_new_state=True):
|
||||
if not events_and_contexts:
|
||||
return
|
||||
|
||||
if backfilled:
|
||||
if not self.min_token_deferred.called:
|
||||
yield self.min_token_deferred
|
||||
start = self.min_token - 1
|
||||
self.min_token -= len(events_and_contexts) + 1
|
||||
stream_orderings = range(start, self.min_token, -1)
|
||||
|
||||
@contextmanager
|
||||
def stream_ordering_manager():
|
||||
yield stream_orderings
|
||||
stream_ordering_manager = stream_ordering_manager()
|
||||
else:
|
||||
stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
|
||||
self, len(events_and_contexts)
|
||||
)
|
||||
|
||||
with stream_ordering_manager as stream_orderings:
|
||||
for (event, _), stream in zip(events_and_contexts, stream_orderings):
|
||||
event.internal_metadata.stream_ordering = stream
|
||||
|
||||
chunks = [
|
||||
events_and_contexts[x:x+100]
|
||||
for x in xrange(0, len(events_and_contexts), 100)
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
# We can't easily parallelize these since different chunks
|
||||
# might contain the same event. :(
|
||||
yield self.runInteraction(
|
||||
"persist_events",
|
||||
self._persist_events_txn,
|
||||
events_and_contexts=chunk,
|
||||
backfilled=backfilled,
|
||||
is_new_state=is_new_state,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def persist_event(self, event, context, backfilled=False,
|
||||
|
@ -67,13 +107,13 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
try:
|
||||
with stream_ordering_manager as stream_ordering:
|
||||
event.internal_metadata.stream_ordering = stream_ordering
|
||||
yield self.runInteraction(
|
||||
"persist_event",
|
||||
self._persist_event_txn,
|
||||
event=event,
|
||||
context=context,
|
||||
backfilled=backfilled,
|
||||
stream_ordering=stream_ordering,
|
||||
is_new_state=is_new_state,
|
||||
current_state=current_state,
|
||||
)
|
||||
|
@ -116,19 +156,14 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
@log_function
|
||||
def _persist_event_txn(self, txn, event, context, backfilled,
|
||||
stream_ordering=None, is_new_state=True,
|
||||
current_state=None):
|
||||
|
||||
# Remove the any existing cache entries for the event_id
|
||||
txn.call_after(self._invalidate_get_event_cache, event.event_id)
|
||||
|
||||
is_new_state=True, current_state=None):
|
||||
# We purposefully do this first since if we include a `current_state`
|
||||
# key, we *want* to update the `current_state_events` table
|
||||
if current_state:
|
||||
txn.call_after(self.get_current_state_for_key.invalidate_all)
|
||||
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
||||
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
||||
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||
txn.call_after(self.get_room_name_and_aliases, event.room_id)
|
||||
|
||||
self._simple_delete_txn(
|
||||
|
@ -149,21 +184,72 @@ class EventsStore(SQLBaseStore):
|
|||
}
|
||||
)
|
||||
|
||||
outlier = event.internal_metadata.is_outlier()
|
||||
|
||||
if not outlier:
|
||||
self._update_min_depth_for_room_txn(
|
||||
return self._persist_events_txn(
|
||||
txn,
|
||||
event.room_id,
|
||||
event.depth
|
||||
[(event, context)],
|
||||
backfilled=backfilled,
|
||||
is_new_state=is_new_state,
|
||||
)
|
||||
|
||||
have_persisted = self._simple_select_one_txn(
|
||||
txn,
|
||||
table="events",
|
||||
keyvalues={"event_id": event.event_id},
|
||||
retcols=["event_id", "outlier"],
|
||||
allow_none=True,
|
||||
@log_function
|
||||
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
|
||||
is_new_state=True):
|
||||
|
||||
# Remove the any existing cache entries for the event_ids
|
||||
for event, _ in events_and_contexts:
|
||||
txn.call_after(self._invalidate_get_event_cache, event.event_id)
|
||||
|
||||
depth_updates = {}
|
||||
for event, _ in events_and_contexts:
|
||||
if event.internal_metadata.is_outlier():
|
||||
continue
|
||||
depth_updates[event.room_id] = max(
|
||||
event.depth, depth_updates.get(event.room_id, event.depth)
|
||||
)
|
||||
|
||||
for room_id, depth in depth_updates.items():
|
||||
self._update_min_depth_for_room_txn(txn, room_id, depth)
|
||||
|
||||
txn.execute(
|
||||
"SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
|
||||
",".join(["?"] * len(events_and_contexts)),
|
||||
),
|
||||
[event.event_id for event, _ in events_and_contexts]
|
||||
)
|
||||
have_persisted = {
|
||||
event_id: outlier
|
||||
for event_id, outlier in txn.fetchall()
|
||||
}
|
||||
|
||||
event_map = {}
|
||||
to_remove = set()
|
||||
for event, context in events_and_contexts:
|
||||
# Handle the case of the list including the same event multiple
|
||||
# times. The tricky thing here is when they differ by whether
|
||||
# they are an outlier.
|
||||
if event.event_id in event_map:
|
||||
other = event_map[event.event_id]
|
||||
|
||||
if not other.internal_metadata.is_outlier():
|
||||
to_remove.add(event)
|
||||
continue
|
||||
elif not event.internal_metadata.is_outlier():
|
||||
to_remove.add(event)
|
||||
continue
|
||||
else:
|
||||
to_remove.add(other)
|
||||
|
||||
event_map[event.event_id] = event
|
||||
|
||||
if event.event_id not in have_persisted:
|
||||
continue
|
||||
|
||||
to_remove.add(event)
|
||||
|
||||
outlier_persisted = have_persisted[event.event_id]
|
||||
if not event.internal_metadata.is_outlier() and outlier_persisted:
|
||||
self._store_state_groups_txn(
|
||||
txn, event, context,
|
||||
)
|
||||
|
||||
metadata_json = encode_json(
|
||||
|
@ -171,16 +257,6 @@ class EventsStore(SQLBaseStore):
|
|||
using_frozen_dicts=USE_FROZEN_DICTS
|
||||
).decode("UTF-8")
|
||||
|
||||
# If we have already persisted this event, we don't need to do any
|
||||
# more processing.
|
||||
# The processing above must be done on every call to persist event,
|
||||
# since they might not have happened on previous calls. For example,
|
||||
# if we are persisting an event that we had persisted as an outlier,
|
||||
# but is no longer one.
|
||||
if have_persisted:
|
||||
if not outlier and have_persisted["outlier"]:
|
||||
self._store_state_groups_txn(txn, event, context)
|
||||
|
||||
sql = (
|
||||
"UPDATE event_json SET internal_metadata = ?"
|
||||
" WHERE event_id = ?"
|
||||
|
@ -198,29 +274,45 @@ class EventsStore(SQLBaseStore):
|
|||
sql,
|
||||
(False, event.event_id,)
|
||||
)
|
||||
return
|
||||
|
||||
if not outlier:
|
||||
self._store_state_groups_txn(txn, event, context)
|
||||
|
||||
self._handle_prev_events(
|
||||
txn,
|
||||
outlier=outlier,
|
||||
event_id=event.event_id,
|
||||
prev_events=event.prev_events,
|
||||
room_id=event.room_id,
|
||||
events_and_contexts = filter(
|
||||
lambda ec: ec[0] not in to_remove,
|
||||
events_and_contexts
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
self._store_room_member_txn(txn, event)
|
||||
elif event.type == EventTypes.Name:
|
||||
if not events_and_contexts:
|
||||
return
|
||||
|
||||
self._store_mult_state_groups_txn(txn, [
|
||||
(event, context)
|
||||
for event, context in events_and_contexts
|
||||
if not event.internal_metadata.is_outlier()
|
||||
])
|
||||
|
||||
self._handle_mult_prev_events(
|
||||
txn,
|
||||
events=[event for event, _ in events_and_contexts],
|
||||
)
|
||||
|
||||
for event, _ in events_and_contexts:
|
||||
if event.type == EventTypes.Name:
|
||||
self._store_room_name_txn(txn, event)
|
||||
elif event.type == EventTypes.Topic:
|
||||
self._store_room_topic_txn(txn, event)
|
||||
elif event.type == EventTypes.Redaction:
|
||||
self._store_redaction(txn, event)
|
||||
|
||||
event_dict = {
|
||||
self._store_room_members_txn(
|
||||
txn,
|
||||
[
|
||||
event
|
||||
for event, _ in events_and_contexts
|
||||
if event.type == EventTypes.Member
|
||||
]
|
||||
)
|
||||
|
||||
def event_dict(event):
|
||||
return {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
if k not in [
|
||||
|
@ -229,63 +321,44 @@ class EventsStore(SQLBaseStore):
|
|||
]
|
||||
}
|
||||
|
||||
self._simple_insert_txn(
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
values={
|
||||
values=[
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"internal_metadata": metadata_json,
|
||||
"json": encode_json(
|
||||
event_dict, using_frozen_dicts=USE_FROZEN_DICTS
|
||||
"internal_metadata": encode_json(
|
||||
event.internal_metadata.get_dict(),
|
||||
using_frozen_dicts=USE_FROZEN_DICTS
|
||||
).decode("UTF-8"),
|
||||
},
|
||||
"json": encode_json(
|
||||
event_dict(event), using_frozen_dicts=USE_FROZEN_DICTS
|
||||
).decode("UTF-8"),
|
||||
}
|
||||
for event, _ in events_and_contexts
|
||||
],
|
||||
)
|
||||
|
||||
content = encode_json(
|
||||
event.content, using_frozen_dicts=USE_FROZEN_DICTS
|
||||
).decode("UTF-8")
|
||||
|
||||
vals = {
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="events",
|
||||
values=[
|
||||
{
|
||||
"stream_ordering": event.internal_metadata.stream_ordering,
|
||||
"topological_ordering": event.depth,
|
||||
"event_id": event.event_id,
|
||||
"type": event.type,
|
||||
"room_id": event.room_id,
|
||||
"content": content,
|
||||
"processed": True,
|
||||
"outlier": outlier,
|
||||
"depth": event.depth,
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"processed": True,
|
||||
"outlier": event.internal_metadata.is_outlier(),
|
||||
"content": encode_json(
|
||||
event.content, using_frozen_dicts=USE_FROZEN_DICTS
|
||||
).decode("UTF-8"),
|
||||
}
|
||||
|
||||
unrec = {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
if k not in vals.keys() and k not in [
|
||||
"redacted",
|
||||
"redacted_because",
|
||||
"signatures",
|
||||
"hashes",
|
||||
"prev_events",
|
||||
]
|
||||
}
|
||||
|
||||
vals["unrecognized_keys"] = encode_json(
|
||||
unrec, using_frozen_dicts=USE_FROZEN_DICTS
|
||||
).decode("UTF-8")
|
||||
|
||||
sql = (
|
||||
"INSERT INTO events"
|
||||
" (stream_ordering, topological_ordering, event_id, type,"
|
||||
" room_id, content, processed, outlier, depth)"
|
||||
" VALUES (?,?,?,?,?,?,?,?,?)"
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
stream_ordering, event.depth, event.event_id, event.type,
|
||||
event.room_id, content, True, outlier, event.depth
|
||||
)
|
||||
for event, _ in events_and_contexts
|
||||
],
|
||||
)
|
||||
|
||||
if context.rejected:
|
||||
|
@ -293,20 +366,6 @@ class EventsStore(SQLBaseStore):
|
|||
txn, event.event_id, context.rejected
|
||||
)
|
||||
|
||||
for hash_alg, hash_base64 in event.hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_event_content_hash_txn(
|
||||
txn, event.event_id, hash_alg, hash_bytes,
|
||||
)
|
||||
|
||||
for prev_event_id, prev_hashes in event.prev_events:
|
||||
for alg, hash_base64 in prev_hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_prev_event_hash_txn(
|
||||
txn, event.event_id, prev_event_id, alg,
|
||||
hash_bytes
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
|
@ -316,16 +375,22 @@ class EventsStore(SQLBaseStore):
|
|||
"room_id": event.room_id,
|
||||
"auth_id": auth_id,
|
||||
}
|
||||
for event, _ in events_and_contexts
|
||||
for auth_id, _ in event.auth_events
|
||||
],
|
||||
)
|
||||
|
||||
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
|
||||
self._store_event_reference_hash_txn(
|
||||
txn, event.event_id, ref_alg, ref_hash_bytes
|
||||
self._store_event_reference_hashes_txn(
|
||||
txn, [event for event, _ in events_and_contexts]
|
||||
)
|
||||
|
||||
if event.is_state():
|
||||
state_events_and_contexts = filter(
|
||||
lambda i: i[0].is_state(),
|
||||
events_and_contexts,
|
||||
)
|
||||
|
||||
state_values = []
|
||||
for event, context in state_events_and_contexts:
|
||||
vals = {
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
|
@ -337,10 +402,12 @@ class EventsStore(SQLBaseStore):
|
|||
if hasattr(event, "replaces_state"):
|
||||
vals["prev_state"] = event.replaces_state
|
||||
|
||||
self._simple_insert_txn(
|
||||
state_values.append(vals)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
"state_events",
|
||||
vals,
|
||||
table="state_events",
|
||||
values=state_values,
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
|
@ -349,25 +416,27 @@ class EventsStore(SQLBaseStore):
|
|||
values=[
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"prev_event_id": e_id,
|
||||
"prev_event_id": prev_id,
|
||||
"room_id": event.room_id,
|
||||
"is_state": True,
|
||||
}
|
||||
for e_id, h in event.prev_state
|
||||
for event, _ in state_events_and_contexts
|
||||
for prev_id, _ in event.prev_state
|
||||
],
|
||||
)
|
||||
|
||||
if is_new_state and not context.rejected:
|
||||
if is_new_state:
|
||||
for event, _ in state_events_and_contexts:
|
||||
if not context.rejected:
|
||||
txn.call_after(
|
||||
self.get_current_state_for_key.invalidate,
|
||||
event.room_id, event.type, event.state_key
|
||||
(event.room_id, event.type, event.state_key,)
|
||||
)
|
||||
|
||||
if (event.type == EventTypes.Name
|
||||
or event.type == EventTypes.Aliases):
|
||||
if event.type in [EventTypes.Name, EventTypes.Aliases]:
|
||||
txn.call_after(
|
||||
self.get_room_name_and_aliases.invalidate,
|
||||
event.room_id
|
||||
(event.room_id,)
|
||||
)
|
||||
|
||||
self._simple_upsert_txn(
|
||||
|
@ -498,8 +567,9 @@ class EventsStore(SQLBaseStore):
|
|||
def _invalidate_get_event_cache(self, event_id):
|
||||
for check_redacted in (False, True):
|
||||
for get_prev_content in (False, True):
|
||||
self._get_event_cache.invalidate(event_id, check_redacted,
|
||||
get_prev_content)
|
||||
self._get_event_cache.invalidate(
|
||||
(event_id, check_redacted, get_prev_content)
|
||||
)
|
||||
|
||||
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False):
|
||||
|
@ -520,7 +590,7 @@ class EventsStore(SQLBaseStore):
|
|||
for event_id in events:
|
||||
try:
|
||||
ret = self._get_event_cache.get(
|
||||
event_id, check_redacted, get_prev_content
|
||||
(event_id, check_redacted, get_prev_content,)
|
||||
)
|
||||
|
||||
if allow_rejected or not ret.rejected_reason:
|
||||
|
@ -736,7 +806,8 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
because = yield self.get_event(
|
||||
redaction_id,
|
||||
check_redacted=False
|
||||
check_redacted=False,
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if because:
|
||||
|
@ -746,12 +817,13 @@ class EventsStore(SQLBaseStore):
|
|||
prev = yield self.get_event(
|
||||
ev.unsigned["replaces_state"],
|
||||
get_prev_content=False,
|
||||
allow_none=True,
|
||||
)
|
||||
if prev:
|
||||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||
|
||||
self._get_event_cache.prefill(
|
||||
ev.event_id, check_redacted, get_prev_content, ev
|
||||
(ev.event_id, check_redacted, get_prev_content), ev
|
||||
)
|
||||
|
||||
defer.returnValue(ev)
|
||||
|
@ -808,7 +880,7 @@ class EventsStore(SQLBaseStore):
|
|||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||
|
||||
self._get_event_cache.prefill(
|
||||
ev.event_id, check_redacted, get_prev_content, ev
|
||||
(ev.event_id, check_redacted, get_prev_content), ev
|
||||
)
|
||||
|
||||
return ev
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from _base import SQLBaseStore
|
||||
from _base import SQLBaseStore, cachedInlineCallbacks
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -71,6 +71,24 @@ class KeyStore(SQLBaseStore):
|
|||
desc="store_server_certificate",
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_all_server_verify_keys(self, server_name):
|
||||
rows = yield self._simple_select_list(
|
||||
table="server_signature_keys",
|
||||
keyvalues={
|
||||
"server_name": server_name,
|
||||
},
|
||||
retcols=["key_id", "verify_key"],
|
||||
desc="get_all_server_verify_keys",
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
row["key_id"]: decode_verify_key_bytes(
|
||||
row["key_id"], str(row["verify_key"])
|
||||
)
|
||||
for row in rows
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_keys(self, server_name, key_ids):
|
||||
"""Retrieve the NACL verification key for a given server for the given
|
||||
|
@ -81,24 +99,14 @@ class KeyStore(SQLBaseStore):
|
|||
Returns:
|
||||
(list of VerifyKey): The verification keys.
|
||||
"""
|
||||
sql = (
|
||||
"SELECT key_id, verify_key FROM server_signature_keys"
|
||||
" WHERE server_name = ?"
|
||||
" AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
|
||||
)
|
||||
|
||||
rows = yield self._execute_and_decode(
|
||||
"get_server_verify_keys", sql, server_name, *key_ids
|
||||
)
|
||||
|
||||
keys = []
|
||||
for row in rows:
|
||||
key_id = row["key_id"]
|
||||
key_bytes = row["verify_key"]
|
||||
key = decode_verify_key_bytes(key_id, str(key_bytes))
|
||||
keys.append(key)
|
||||
defer.returnValue(keys)
|
||||
keys = yield self.get_all_server_verify_keys(server_name)
|
||||
defer.returnValue({
|
||||
k: keys[k]
|
||||
for k in key_ids
|
||||
if k in keys and keys[k]
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_server_verify_key(self, server_name, from_server, time_now_ms,
|
||||
verify_key):
|
||||
"""Stores a NACL verification key for the given server.
|
||||
|
@ -109,7 +117,7 @@ class KeyStore(SQLBaseStore):
|
|||
ts_now_ms (int): The time now in milliseconds
|
||||
verification_key (VerifyKey): The NACL verify key.
|
||||
"""
|
||||
return self._simple_upsert(
|
||||
yield self._simple_upsert(
|
||||
table="server_signature_keys",
|
||||
keyvalues={
|
||||
"server_name": server_name,
|
||||
|
@ -123,6 +131,8 @@ class KeyStore(SQLBaseStore):
|
|||
desc="store_server_verify_key",
|
||||
)
|
||||
|
||||
self.get_all_server_verify_keys.invalidate((server_name,))
|
||||
|
||||
def store_server_keys_json(self, server_name, key_id, from_server,
|
||||
ts_now_ms, ts_expires_ms, key_json_bytes):
|
||||
"""Stores the JSON bytes for a set of keys from a server
|
||||
|
@ -152,6 +162,7 @@ class KeyStore(SQLBaseStore):
|
|||
"ts_valid_until_ms": ts_expires_ms,
|
||||
"key_json": buffer(key_json_bytes),
|
||||
},
|
||||
desc="store_server_keys_json",
|
||||
)
|
||||
|
||||
def get_server_keys_json(self, server_keys):
|
||||
|
|
|
@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore):
|
|||
updatevalues={"accepted": True},
|
||||
desc="set_presence_list_accepted",
|
||||
)
|
||||
self.get_presence_list_accepted.invalidate(observer_localpart)
|
||||
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
||||
defer.returnValue(result)
|
||||
|
||||
def get_presence_list(self, observer_localpart, accepted=None):
|
||||
|
@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore):
|
|||
"observed_user_id": observed_userid},
|
||||
desc="del_presence_list",
|
||||
)
|
||||
self.get_presence_list_accepted.invalidate(observer_localpart)
|
||||
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore, cached
|
||||
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class PushRuleStore(SQLBaseStore):
|
||||
@cached()
|
||||
@defer.inlineCallbacks
|
||||
@cachedInlineCallbacks()
|
||||
def get_push_rules_for_user(self, user_name):
|
||||
rows = yield self._simple_select_list(
|
||||
table=PushRuleTable.table_name,
|
||||
|
@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue(rows)
|
||||
|
||||
@cached()
|
||||
@defer.inlineCallbacks
|
||||
@cachedInlineCallbacks()
|
||||
def get_push_rules_enabled_for_user(self, user_name):
|
||||
results = yield self._simple_select_list(
|
||||
table=PushRuleEnableTable.table_name,
|
||||
|
@ -153,11 +151,11 @@ class PushRuleStore(SQLBaseStore):
|
|||
txn.execute(sql, (user_name, priority_class, new_rule_priority))
|
||||
|
||||
txn.call_after(
|
||||
self.get_push_rules_for_user.invalidate, user_name
|
||||
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||
)
|
||||
|
||||
txn.call_after(
|
||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
||||
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
|
@ -189,10 +187,10 @@ class PushRuleStore(SQLBaseStore):
|
|||
new_rule['priority'] = new_prio
|
||||
|
||||
txn.call_after(
|
||||
self.get_push_rules_for_user.invalidate, user_name
|
||||
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
||||
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
|
@ -218,8 +216,8 @@ class PushRuleStore(SQLBaseStore):
|
|||
desc="delete_push_rule",
|
||||
)
|
||||
|
||||
self.get_push_rules_for_user.invalidate(user_name)
|
||||
self.get_push_rules_enabled_for_user.invalidate(user_name)
|
||||
self.get_push_rules_for_user.invalidate((user_name,))
|
||||
self.get_push_rules_enabled_for_user.invalidate((user_name,))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_push_rule_enabled(self, user_name, rule_id, enabled):
|
||||
|
@ -240,10 +238,10 @@ class PushRuleStore(SQLBaseStore):
|
|||
{'id': new_id},
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_push_rules_for_user.invalidate, user_name
|
||||
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
||||
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,347 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014, 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util import unwrapFirstError
|
||||
|
||||
from blist import sorteddict
|
||||
import logging
|
||||
import ujson as json
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReceiptsStore(SQLBaseStore):
|
||||
def __init__(self, hs):
|
||||
super(ReceiptsStore, self).__init__(hs)
|
||||
|
||||
self._receipts_stream_cache = _RoomStreamChangeCache()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
"""Get receipts for multiple rooms for sending to clients.
|
||||
|
||||
Args:
|
||||
room_ids (list): List of room_ids.
|
||||
to_key (int): Max stream id to fetch receipts upto.
|
||||
from_key (int): Min stream id to fetch receipts from. None fetches
|
||||
from the start.
|
||||
|
||||
Returns:
|
||||
list: A list of receipts.
|
||||
"""
|
||||
room_ids = set(room_ids)
|
||||
|
||||
if from_key:
|
||||
room_ids = yield self._receipts_stream_cache.get_rooms_changed(
|
||||
self, room_ids, from_key
|
||||
)
|
||||
|
||||
results = yield defer.gatherResults(
|
||||
[
|
||||
self.get_linearized_receipts_for_room(
|
||||
room_id, to_key, from_key=from_key
|
||||
)
|
||||
for room_id in room_ids
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
defer.returnValue([ev for res in results for ev in res])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||
"""Get receipts for a single room for sending to clients.
|
||||
|
||||
Args:
|
||||
room_ids (str): The room id.
|
||||
to_key (int): Max stream id to fetch receipts upto.
|
||||
from_key (int): Min stream id to fetch receipts from. None fetches
|
||||
from the start.
|
||||
|
||||
Returns:
|
||||
list: A list of receipts.
|
||||
"""
|
||||
def f(txn):
|
||||
if from_key:
|
||||
sql = (
|
||||
"SELECT * FROM receipts_linearized WHERE"
|
||||
" room_id = ? AND stream_id > ? AND stream_id <= ?"
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(room_id, from_key, to_key)
|
||||
)
|
||||
else:
|
||||
sql = (
|
||||
"SELECT * FROM receipts_linearized WHERE"
|
||||
" room_id = ? AND stream_id <= ?"
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(room_id, to_key)
|
||||
)
|
||||
|
||||
rows = self.cursor_to_dict(txn)
|
||||
|
||||
return rows
|
||||
|
||||
rows = yield self.runInteraction(
|
||||
"get_linearized_receipts_for_room", f
|
||||
)
|
||||
|
||||
if not rows:
|
||||
defer.returnValue([])
|
||||
|
||||
content = {}
|
||||
for row in rows:
|
||||
content.setdefault(
|
||||
row["event_id"], {}
|
||||
).setdefault(
|
||||
row["receipt_type"], {}
|
||||
)[row["user_id"]] = json.loads(row["data"])
|
||||
|
||||
defer.returnValue([{
|
||||
"type": "m.receipt",
|
||||
"room_id": room_id,
|
||||
"content": content,
|
||||
}])
|
||||
|
||||
def get_max_receipt_stream_id(self):
|
||||
return self._receipts_id_gen.get_max_token(self)
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_graph_receipts_for_room(self, room_id):
|
||||
"""Get receipts for sending to remote servers.
|
||||
"""
|
||||
rows = yield self._simple_select_list(
|
||||
table="receipts_graph",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=["receipt_type", "user_id", "event_id"],
|
||||
desc="get_linearized_receipts_for_room",
|
||||
)
|
||||
|
||||
result = {}
|
||||
for row in rows:
|
||||
result.setdefault(
|
||||
row["user_id"], {}
|
||||
).setdefault(
|
||||
row["receipt_type"], []
|
||||
).append(row["event_id"])
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
|
||||
user_id, event_id, data, stream_id):
|
||||
|
||||
# We don't want to clobber receipts for more recent events, so we
|
||||
# have to compare orderings of existing receipts
|
||||
sql = (
|
||||
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||
" INNER JOIN receipts_linearized as r USING (event_id, room_id)"
|
||||
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (room_id, receipt_type, user_id))
|
||||
results = txn.fetchall()
|
||||
|
||||
if results:
|
||||
res = self._simple_select_one_txn(
|
||||
txn,
|
||||
table="events",
|
||||
retcols=["topological_ordering", "stream_ordering"],
|
||||
keyvalues={"event_id": event_id},
|
||||
)
|
||||
topological_ordering = int(res["topological_ordering"])
|
||||
stream_ordering = int(res["stream_ordering"])
|
||||
|
||||
for to, so, _ in results:
|
||||
if int(to) > topological_ordering:
|
||||
return False
|
||||
elif int(to) == topological_ordering and int(so) >= stream_ordering:
|
||||
return False
|
||||
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="receipts_linearized",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="receipts_linearized",
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
"event_id": event_id,
|
||||
"data": json.dumps(data),
|
||||
}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
|
||||
"""Insert a receipt, either from local client or remote server.
|
||||
|
||||
Automatically does conversion between linearized and graph
|
||||
representations.
|
||||
"""
|
||||
if not event_ids:
|
||||
return
|
||||
|
||||
if len(event_ids) == 1:
|
||||
linearized_event_id = event_ids[0]
|
||||
else:
|
||||
# we need to points in graph -> linearized form.
|
||||
# TODO: Make this better.
|
||||
def graph_to_linear(txn):
|
||||
query = (
|
||||
"SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
|
||||
" SELECT max(stream_ordering) WHERE event_id IN (%s)"
|
||||
")"
|
||||
) % (",".join(["?"] * len(event_ids)))
|
||||
|
||||
txn.execute(query, [room_id] + event_ids)
|
||||
rows = txn.fetchall()
|
||||
if rows:
|
||||
return rows[0][0]
|
||||
else:
|
||||
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
|
||||
|
||||
linearized_event_id = yield self.runInteraction(
|
||||
"insert_receipt_conv", graph_to_linear
|
||||
)
|
||||
|
||||
stream_id_manager = yield self._receipts_id_gen.get_next(self)
|
||||
with stream_id_manager as stream_id:
|
||||
yield self._receipts_stream_cache.room_has_changed(
|
||||
self, room_id, stream_id
|
||||
)
|
||||
have_persisted = yield self.runInteraction(
|
||||
"insert_linearized_receipt",
|
||||
self.insert_linearized_receipt_txn,
|
||||
room_id, receipt_type, user_id, linearized_event_id,
|
||||
data,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
|
||||
if not have_persisted:
|
||||
defer.returnValue(None)
|
||||
|
||||
yield self.insert_graph_receipt(
|
||||
room_id, receipt_type, user_id, event_ids, data
|
||||
)
|
||||
|
||||
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
|
||||
defer.returnValue((stream_id, max_persisted_id))
|
||||
|
||||
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
|
||||
data):
|
||||
return self.runInteraction(
|
||||
"insert_graph_receipt",
|
||||
self.insert_graph_receipt_txn,
|
||||
room_id, receipt_type, user_id, event_ids, data
|
||||
)
|
||||
|
||||
def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
|
||||
user_id, event_ids, data):
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="receipts_graph",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="receipts_graph",
|
||||
values={
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
"event_ids": json.dumps(event_ids),
|
||||
"data": json.dumps(data),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class _RoomStreamChangeCache(object):
|
||||
"""Keeps track of the stream_id of the latest change in rooms.
|
||||
|
||||
Given a list of rooms and stream key, it will give a subset of rooms that
|
||||
may have changed since that key. If the key is too old then the cache
|
||||
will simply return all rooms.
|
||||
"""
|
||||
def __init__(self, size_of_cache=10000):
|
||||
self._size_of_cache = size_of_cache
|
||||
self._room_to_key = {}
|
||||
self._cache = sorteddict()
|
||||
self._earliest_key = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_changed(self, store, room_ids, key):
|
||||
"""Returns subset of room ids that have had new receipts since the
|
||||
given key. If the key is too old it will just return the given list.
|
||||
"""
|
||||
if key > (yield self._get_earliest_key(store)):
|
||||
keys = self._cache.keys()
|
||||
i = keys.bisect_right(key)
|
||||
|
||||
result = set(
|
||||
self._cache[k] for k in keys[i:]
|
||||
).intersection(room_ids)
|
||||
else:
|
||||
result = room_ids
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def room_has_changed(self, store, room_id, key):
|
||||
"""Informs the cache that the room has been changed at the given key.
|
||||
"""
|
||||
if key > (yield self._get_earliest_key(store)):
|
||||
old_key = self._room_to_key.get(room_id, None)
|
||||
if old_key:
|
||||
key = max(key, old_key)
|
||||
self._cache.pop(old_key, None)
|
||||
self._cache[key] = room_id
|
||||
|
||||
while len(self._cache) > self._size_of_cache:
|
||||
k, r = self._cache.popitem()
|
||||
self._earliest_key = max(k, self._earliest_key)
|
||||
self._room_to_key.pop(r, None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_earliest_key(self, store):
|
||||
if self._earliest_key is None:
|
||||
self._earliest_key = yield store.get_max_receipt_stream_id()
|
||||
self._earliest_key = int(self._earliest_key)
|
||||
|
||||
defer.returnValue(self._earliest_key)
|
|
@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore):
|
|||
user_id
|
||||
)
|
||||
for r in rows:
|
||||
self.get_user_by_token.invalidate(r)
|
||||
self.get_user_by_token.invalidate((r,))
|
||||
|
||||
@cached()
|
||||
def get_user_by_token(self, token):
|
||||
|
|
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.errors import StoreError
|
||||
|
||||
from ._base import SQLBaseStore, cached
|
||||
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||
|
||||
import collections
|
||||
import logging
|
||||
|
@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
|
|||
}
|
||||
)
|
||||
|
||||
@cached()
|
||||
@defer.inlineCallbacks
|
||||
@cachedInlineCallbacks()
|
||||
def get_room_name_and_aliases(self, room_id):
|
||||
def f(txn):
|
||||
sql = (
|
||||
|
|
|
@ -35,38 +35,28 @@ RoomsForUser = namedtuple(
|
|||
|
||||
class RoomMemberStore(SQLBaseStore):
|
||||
|
||||
def _store_room_member_txn(self, txn, event):
|
||||
def _store_room_members_txn(self, txn, events):
|
||||
"""Store a room member in the database.
|
||||
"""
|
||||
try:
|
||||
target_user_id = event.state_key
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to parse target_user_id=%s", target_user_id
|
||||
)
|
||||
raise
|
||||
|
||||
logger.debug(
|
||||
"_store_room_member_txn: target_user_id=%s, membership=%s",
|
||||
target_user_id,
|
||||
event.membership,
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
"room_memberships",
|
||||
table="room_memberships",
|
||||
values=[
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"user_id": target_user_id,
|
||||
"user_id": event.state_key,
|
||||
"sender": event.user_id,
|
||||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
}
|
||||
for event in events
|
||||
]
|
||||
)
|
||||
|
||||
txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
||||
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
|
||||
for event in events:
|
||||
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||
|
||||
def get_room_member(self, user_id, room_id):
|
||||
"""Retrieve the current state of a room member.
|
||||
|
@ -88,7 +78,7 @@ class RoomMemberStore(SQLBaseStore):
|
|||
lambda events: events[0] if events else None
|
||||
)
|
||||
|
||||
@cached()
|
||||
@cached(max_entries=5000)
|
||||
def get_users_in_room(self, room_id):
|
||||
def f(txn):
|
||||
|
||||
|
@ -164,7 +154,7 @@ class RoomMemberStore(SQLBaseStore):
|
|||
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
|
||||
]
|
||||
|
||||
@cached()
|
||||
@cached(max_entries=5000)
|
||||
def get_joined_hosts_for_room(self, room_id):
|
||||
return self.runInteraction(
|
||||
"get_joined_hosts_for_room",
|
||||
|
|
|
@ -18,7 +18,7 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_upgrade(cur):
|
||||
def run_upgrade(cur, *args, **kwargs):
|
||||
cur.execute("SELECT id, regex FROM application_services_regex")
|
||||
for row in cur.fetchall():
|
||||
try:
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
SELECT 1;
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
Main purpose of this upgrade is to change the unique key on the
|
||||
pushers table again (it was missed when the v16 full schema was
|
||||
made) but this also changes the pushkey and data columns to text.
|
||||
When selecting a bytea column into a text column, postgres inserts
|
||||
the hex encoded data, and there's no portable way of getting the
|
||||
UTF-8 bytes, so we have to do it in Python.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||
logger.info("Porting pushers table...")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS pushers2 (
|
||||
id BIGINT PRIMARY KEY,
|
||||
user_name TEXT NOT NULL,
|
||||
access_token BIGINT DEFAULT NULL,
|
||||
profile_tag VARCHAR(32) NOT NULL,
|
||||
kind VARCHAR(8) NOT NULL,
|
||||
app_id VARCHAR(64) NOT NULL,
|
||||
app_display_name VARCHAR(64) NOT NULL,
|
||||
device_display_name VARCHAR(128) NOT NULL,
|
||||
pushkey TEXT NOT NULL,
|
||||
ts BIGINT NOT NULL,
|
||||
lang VARCHAR(8),
|
||||
data TEXT,
|
||||
last_token TEXT,
|
||||
last_success BIGINT,
|
||||
failing_since BIGINT,
|
||||
UNIQUE (app_id, pushkey, user_name)
|
||||
)
|
||||
""")
|
||||
cur.execute("""SELECT
|
||||
id, user_name, access_token, profile_tag, kind,
|
||||
app_id, app_display_name, device_display_name,
|
||||
pushkey, ts, lang, data, last_token, last_success,
|
||||
failing_since
|
||||
FROM pushers
|
||||
""")
|
||||
count = 0
|
||||
for row in cur.fetchall():
|
||||
row = list(row)
|
||||
row[8] = bytes(row[8]).decode("utf-8")
|
||||
row[11] = bytes(row[11]).decode("utf-8")
|
||||
cur.execute(database_engine.convert_param_style("""
|
||||
INSERT into pushers2 (
|
||||
id, user_name, access_token, profile_tag, kind,
|
||||
app_id, app_display_name, device_display_name,
|
||||
pushkey, ts, lang, data, last_token, last_success,
|
||||
failing_since
|
||||
) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
|
||||
row
|
||||
)
|
||||
count += 1
|
||||
cur.execute("DROP TABLE pushers")
|
||||
cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
|
||||
logger.info("Moved %d pushers to new table", count)
|
|
@ -0,0 +1,34 @@
|
|||
/* Copyright 2015 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS e2e_device_keys_json (
|
||||
user_id TEXT NOT NULL, -- The user these keys are for.
|
||||
device_id TEXT NOT NULL, -- Which of the user's devices these keys are for.
|
||||
ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded.
|
||||
key_json TEXT NOT NULL, -- The keys for the device as a JSON blob.
|
||||
CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id)
|
||||
);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json (
|
||||
user_id TEXT NOT NULL, -- The user this one-time key is for.
|
||||
device_id TEXT NOT NULL, -- The device this one-time key is for.
|
||||
algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for.
|
||||
key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
|
||||
ts_added_ms BIGINT NOT NULL, -- When this key was uploaded.
|
||||
key_json TEXT NOT NULL, -- The key as a JSON blob.
|
||||
CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id)
|
||||
);
|
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2015 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS receipts_graph(
|
||||
room_id TEXT NOT NULL,
|
||||
receipt_type TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
event_ids TEXT NOT NULL,
|
||||
data TEXT NOT NULL,
|
||||
CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS receipts_linearized (
|
||||
stream_id BIGINT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
receipt_type TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
data TEXT NOT NULL,
|
||||
CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id)
|
||||
);
|
||||
|
||||
CREATE INDEX receipts_linearized_id ON receipts_linearized(
|
||||
stream_id
|
||||
);
|
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
|||
from _base import SQLBaseStore
|
||||
|
||||
from syutil.base64util import encode_base64
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
|
||||
|
||||
class SignatureStore(SQLBaseStore):
|
||||
|
@ -101,23 +102,26 @@ class SignatureStore(SQLBaseStore):
|
|||
txn.execute(query, (event_id, ))
|
||||
return {k: v for k, v in txn.fetchall()}
|
||||
|
||||
def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
|
||||
hash_bytes):
|
||||
def _store_event_reference_hashes_txn(self, txn, events):
|
||||
"""Store a hash for a PDU
|
||||
Args:
|
||||
txn (cursor):
|
||||
event_id (str): Id for the Event.
|
||||
algorithm (str): Hashing algorithm.
|
||||
hash_bytes (bytes): Hash function output bytes.
|
||||
events (list): list of Events.
|
||||
"""
|
||||
self._simple_insert_txn(
|
||||
|
||||
vals = []
|
||||
for event in events:
|
||||
ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
|
||||
vals.append({
|
||||
"event_id": event.event_id,
|
||||
"algorithm": ref_alg,
|
||||
"hash": buffer(ref_hash_bytes),
|
||||
})
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
"event_reference_hashes",
|
||||
{
|
||||
"event_id": event_id,
|
||||
"algorithm": algorithm,
|
||||
"hash": buffer(hash_bytes),
|
||||
},
|
||||
table="event_reference_hashes",
|
||||
values=vals,
|
||||
)
|
||||
|
||||
def _get_event_signatures_txn(self, txn, event_id):
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore, cached
|
||||
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -81,31 +81,41 @@ class StateStore(SQLBaseStore):
|
|||
f,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def c(vals):
|
||||
vals[:] = yield self._get_events(vals, get_prev_content=False)
|
||||
|
||||
yield defer.gatherResults(
|
||||
state_list = yield defer.gatherResults(
|
||||
[
|
||||
c(vals)
|
||||
for vals in states.values()
|
||||
self._fetch_events_for_group(group, vals)
|
||||
for group, vals in states.items()
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
|
||||
defer.returnValue(states)
|
||||
defer.returnValue(dict(state_list))
|
||||
|
||||
def _fetch_events_for_group(self, key, events):
|
||||
return self._get_events(
|
||||
events, get_prev_content=False
|
||||
).addCallback(
|
||||
lambda evs: (key, evs)
|
||||
)
|
||||
|
||||
def _store_state_groups_txn(self, txn, event, context):
|
||||
return self._store_mult_state_groups_txn(txn, [(event, context)])
|
||||
|
||||
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
|
||||
state_groups = {}
|
||||
for event, context in events_and_contexts:
|
||||
if context.current_state is None:
|
||||
return
|
||||
continue
|
||||
|
||||
if context.state_group is not None:
|
||||
state_groups[event.event_id] = context.state_group
|
||||
continue
|
||||
|
||||
state_events = dict(context.current_state)
|
||||
|
||||
if event.is_state():
|
||||
state_events[(event.type, event.state_key)] = event
|
||||
|
||||
state_group = context.state_group
|
||||
if not state_group:
|
||||
state_group = self._state_groups_id_gen.get_next_txn(txn)
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
|
@ -131,14 +141,19 @@ class StateStore(SQLBaseStore):
|
|||
for state in state_events.values()
|
||||
],
|
||||
)
|
||||
state_groups[event.event_id] = state_group
|
||||
|
||||
self._simple_insert_txn(
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_to_state_groups",
|
||||
values={
|
||||
"state_group": state_group,
|
||||
values=[
|
||||
{
|
||||
"state_group": state_groups[event.event_id],
|
||||
"event_id": event.event_id,
|
||||
},
|
||||
}
|
||||
for event, context in events_and_contexts
|
||||
if context.current_state is not None
|
||||
],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -173,8 +188,7 @@ class StateStore(SQLBaseStore):
|
|||
events = yield self._get_events(event_ids, get_prev_content=False)
|
||||
defer.returnValue(events)
|
||||
|
||||
@cached(num_args=3)
|
||||
@defer.inlineCallbacks
|
||||
@cachedInlineCallbacks(num_args=3)
|
||||
def get_current_state_for_key(self, room_id, event_type, state_key):
|
||||
def f(txn):
|
||||
sql = (
|
||||
|
@ -190,6 +204,65 @@ class StateStore(SQLBaseStore):
|
|||
events = yield self._get_events(event_ids, get_prev_content=False)
|
||||
defer.returnValue(events)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_events(self, room_id, event_ids):
|
||||
def f(txn):
|
||||
groups = set()
|
||||
event_to_group = {}
|
||||
for event_id in event_ids:
|
||||
# TODO: Remove this loop.
|
||||
group = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="event_to_state_groups",
|
||||
keyvalues={"event_id": event_id},
|
||||
retcol="state_group",
|
||||
allow_none=True,
|
||||
)
|
||||
if group:
|
||||
event_to_group[event_id] = group
|
||||
groups.add(group)
|
||||
|
||||
group_to_state_ids = {}
|
||||
for group in groups:
|
||||
state_ids = self._simple_select_onecol_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
keyvalues={"state_group": group},
|
||||
retcol="event_id",
|
||||
)
|
||||
|
||||
group_to_state_ids[group] = state_ids
|
||||
|
||||
return event_to_group, group_to_state_ids
|
||||
|
||||
res = yield self.runInteraction(
|
||||
"annotate_events_with_state_groups",
|
||||
f,
|
||||
)
|
||||
|
||||
event_to_group, group_to_state_ids = res
|
||||
|
||||
state_list = yield defer.gatherResults(
|
||||
[
|
||||
self._fetch_events_for_group(group, vals)
|
||||
for group, vals in group_to_state_ids.items()
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
|
||||
state_dict = {
|
||||
group: {
|
||||
(ev.type, ev.state_key): ev
|
||||
for ev in state
|
||||
}
|
||||
for group, state in state_list
|
||||
}
|
||||
|
||||
defer.returnValue([
|
||||
state_dict.get(event_to_group.get(event, None), None)
|
||||
for event in event_ids
|
||||
])
|
||||
|
||||
|
||||
def _make_group_id(clock):
|
||||
return str(int(clock.time_msec())) + random_string(5)
|
||||
|
|
|
@ -72,7 +72,10 @@ class StreamIdGenerator(object):
|
|||
with stream_id_gen.get_next_txn(txn) as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
def __init__(self):
|
||||
def __init__(self, table, column):
|
||||
self.table = table
|
||||
self.column = column
|
||||
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._current_max = None
|
||||
|
@ -107,6 +110,37 @@ class StreamIdGenerator(object):
|
|||
|
||||
defer.returnValue(manager())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_next_mult(self, store, n):
|
||||
"""
|
||||
Usage:
|
||||
with yield stream_id_gen.get_next(store, n) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
if not self._current_max:
|
||||
yield store.runInteraction(
|
||||
"_compute_current_max",
|
||||
self._get_or_compute_current_max,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
next_ids = range(self._current_max + 1, self._current_max + n + 1)
|
||||
self._current_max += n
|
||||
|
||||
for next_id in next_ids:
|
||||
self._unfinished_ids.append(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_ids
|
||||
finally:
|
||||
with self._lock:
|
||||
for next_id in next_ids:
|
||||
self._unfinished_ids.remove(next_id)
|
||||
|
||||
defer.returnValue(manager())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_max_token(self, store):
|
||||
"""Returns the maximum stream id such that all stream ids less than or
|
||||
|
@ -126,7 +160,7 @@ class StreamIdGenerator(object):
|
|||
|
||||
def _get_or_compute_current_max(self, txn):
|
||||
with self._lock:
|
||||
txn.execute("SELECT MAX(stream_ordering) FROM events")
|
||||
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
|
||||
rows = txn.fetchall()
|
||||
val, = rows[0]
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from synapse.types import StreamToken
|
|||
from synapse.handlers.presence import PresenceEventSource
|
||||
from synapse.handlers.room import RoomEventSource
|
||||
from synapse.handlers.typing import TypingNotificationEventSource
|
||||
from synapse.handlers.receipts import ReceiptEventSource
|
||||
|
||||
|
||||
class NullSource(object):
|
||||
|
@ -43,6 +44,7 @@ class EventSources(object):
|
|||
"room": RoomEventSource,
|
||||
"presence": PresenceEventSource,
|
||||
"typing": TypingNotificationEventSource,
|
||||
"receipt": ReceiptEventSource,
|
||||
}
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -62,7 +64,10 @@ class EventSources(object):
|
|||
),
|
||||
typing_key=(
|
||||
yield self.sources["typing"].get_current_key()
|
||||
)
|
||||
),
|
||||
receipt_key=(
|
||||
yield self.sources["receipt"].get_current_key()
|
||||
),
|
||||
)
|
||||
defer.returnValue(token)
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ class EventID(DomainSpecificString):
|
|||
class StreamToken(
|
||||
namedtuple(
|
||||
"Token",
|
||||
("room_key", "presence_key", "typing_key")
|
||||
("room_key", "presence_key", "typing_key", "receipt_key")
|
||||
)
|
||||
):
|
||||
_SEPARATOR = "_"
|
||||
|
@ -109,6 +109,9 @@ class StreamToken(
|
|||
def from_string(cls, string):
|
||||
try:
|
||||
keys = string.split(cls._SEPARATOR)
|
||||
if len(keys) == len(cls._fields) - 1:
|
||||
# i.e. old token from before receipt_key
|
||||
keys.append("0")
|
||||
return cls(*keys)
|
||||
except:
|
||||
raise SynapseError(400, "Invalid Token")
|
||||
|
@ -131,6 +134,7 @@ class StreamToken(
|
|||
(other_token.room_stream_id < self.room_stream_id)
|
||||
or (int(other_token.presence_key) < int(self.presence_key))
|
||||
or (int(other_token.typing_key) < int(self.typing_key))
|
||||
or (int(other_token.receipt_key) < int(self.receipt_key))
|
||||
)
|
||||
|
||||
def copy_and_advance(self, key, new_value):
|
||||
|
@ -174,7 +178,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||
|
||||
Live tokens start with an "s" followed by the "stream_ordering" id of the
|
||||
event it comes after. Historic tokens start with a "t" followed by the
|
||||
"topological_ordering" id of the event it comes after, follewed by "-",
|
||||
"topological_ordering" id of the event it comes after, followed by "-",
|
||||
followed by the "stream_ordering" id of the event it comes after.
|
||||
"""
|
||||
__slots__ = []
|
||||
|
@ -207,4 +211,5 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||
return "s%d" % (self.stream,)
|
||||
|
||||
|
||||
# token_id is the primary key ID of the access token, not the access token itself.
|
||||
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
|
||||
|
|
|
@ -91,8 +91,12 @@ class Clock(object):
|
|||
with PreserveLoggingContext():
|
||||
return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
|
||||
|
||||
def cancel_call_later(self, timer):
|
||||
def cancel_call_later(self, timer, ignore_errs=False):
|
||||
try:
|
||||
timer.cancel()
|
||||
except:
|
||||
if not ignore_errs:
|
||||
raise
|
||||
|
||||
def time_bound_deferred(self, given_deferred, time_out):
|
||||
if given_deferred.called:
|
||||
|
|
|
@ -38,6 +38,9 @@ class ObservableDeferred(object):
|
|||
deferred.
|
||||
|
||||
If consumeErrors is true errors will be captured from the origin deferred.
|
||||
|
||||
Cancelling or otherwise resolving an observer will not affect the original
|
||||
ObservableDeferred.
|
||||
"""
|
||||
|
||||
__slots__ = ["_deferred", "_observers", "_result"]
|
||||
|
@ -45,10 +48,10 @@ class ObservableDeferred(object):
|
|||
def __init__(self, deferred, consumeErrors=False):
|
||||
object.__setattr__(self, "_deferred", deferred)
|
||||
object.__setattr__(self, "_result", None)
|
||||
object.__setattr__(self, "_observers", [])
|
||||
object.__setattr__(self, "_observers", set())
|
||||
|
||||
def callback(r):
|
||||
self._result = (True, r)
|
||||
object.__setattr__(self, "_result", (True, r))
|
||||
while self._observers:
|
||||
try:
|
||||
self._observers.pop().callback(r)
|
||||
|
@ -57,7 +60,7 @@ class ObservableDeferred(object):
|
|||
return r
|
||||
|
||||
def errback(f):
|
||||
self._result = (False, f)
|
||||
object.__setattr__(self, "_result", (False, f))
|
||||
while self._observers:
|
||||
try:
|
||||
self._observers.pop().errback(f)
|
||||
|
@ -74,14 +77,28 @@ class ObservableDeferred(object):
|
|||
def observe(self):
|
||||
if not self._result:
|
||||
d = defer.Deferred()
|
||||
self._observers.append(d)
|
||||
|
||||
def remove(r):
|
||||
self._observers.discard(d)
|
||||
return r
|
||||
d.addBoth(remove)
|
||||
|
||||
self._observers.add(d)
|
||||
return d
|
||||
else:
|
||||
success, res = self._result
|
||||
return defer.succeed(res) if success else defer.fail(res)
|
||||
|
||||
def observers(self):
|
||||
return self._observers
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._deferred, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
setattr(self._deferred, name, value)
|
||||
|
||||
def __repr__(self):
|
||||
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
|
||||
id(self), self._result, self._deferred,
|
||||
)
|
||||
|
|
|
@ -140,6 +140,37 @@ class PreserveLoggingContext(object):
|
|||
)
|
||||
|
||||
|
||||
class _PreservingContextDeferred(defer.Deferred):
|
||||
"""A deferred that ensures that all callbacks and errbacks are called with
|
||||
the given logging context.
|
||||
"""
|
||||
def __init__(self, context):
|
||||
self._log_context = context
|
||||
defer.Deferred.__init__(self)
|
||||
|
||||
def addCallbacks(self, callback, errback=None,
|
||||
callbackArgs=None, callbackKeywords=None,
|
||||
errbackArgs=None, errbackKeywords=None):
|
||||
callback = self._wrap_callback(callback)
|
||||
errback = self._wrap_callback(errback)
|
||||
return defer.Deferred.addCallbacks(
|
||||
self, callback,
|
||||
errback=errback,
|
||||
callbackArgs=callbackArgs,
|
||||
callbackKeywords=callbackKeywords,
|
||||
errbackArgs=errbackArgs,
|
||||
errbackKeywords=errbackKeywords,
|
||||
)
|
||||
|
||||
def _wrap_callback(self, f):
|
||||
def g(res, *args, **kwargs):
|
||||
with PreserveLoggingContext():
|
||||
LoggingContext.thread_local.current_context = self._log_context
|
||||
res = f(res, *args, **kwargs)
|
||||
return res
|
||||
return g
|
||||
|
||||
|
||||
def preserve_context_over_fn(fn, *args, **kwargs):
|
||||
"""Takes a function and invokes it with the given arguments, but removes
|
||||
and restores the current logging context while doing so.
|
||||
|
@ -160,24 +191,7 @@ def preserve_context_over_deferred(deferred):
|
|||
"""Given a deferred wrap it such that any callbacks added later to it will
|
||||
be invoked with the current context.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
|
||||
current_context = LoggingContext.current_context()
|
||||
|
||||
def cb(res):
|
||||
with PreserveLoggingContext():
|
||||
LoggingContext.thread_local.current_context = current_context
|
||||
res = d.callback(res)
|
||||
return res
|
||||
|
||||
def eb(failure):
|
||||
with PreserveLoggingContext():
|
||||
LoggingContext.thread_local.current_context = current_context
|
||||
res = d.errback(failure)
|
||||
return res
|
||||
|
||||
if deferred.called:
|
||||
return deferred
|
||||
|
||||
deferred.addCallbacks(cb, eb)
|
||||
d = _PreservingContextDeferred(current_context)
|
||||
deferred.chainDeferred(d)
|
||||
return d
|
||||
|
|
|
@ -33,3 +33,12 @@ def random_string_with_symbols(length):
|
|||
return ''.join(
|
||||
random.choice(_string_with_symbols) for _ in xrange(length)
|
||||
)
|
||||
|
||||
|
||||
def is_ascii(s):
|
||||
try:
|
||||
s.encode("ascii")
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
|
|
@ -57,6 +57,49 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
interested_service, event
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_query_user_exists_unknown_user(self):
|
||||
user_id = "@someone:anywhere"
|
||||
services = [self._mkservice(is_interested=True)]
|
||||
services[0].is_interested_in_user = Mock(return_value=True)
|
||||
self.mock_store.get_app_services = Mock(return_value=services)
|
||||
self.mock_store.get_user_by_id = Mock(return_value=None)
|
||||
|
||||
event = Mock(
|
||||
sender=user_id,
|
||||
type="m.room.message",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.mock_as_api.push = Mock()
|
||||
self.mock_as_api.query_user = Mock()
|
||||
yield self.handler.notify_interested_services(event)
|
||||
self.mock_as_api.query_user.assert_called_once_with(
|
||||
services[0], user_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_query_user_exists_known_user(self):
|
||||
user_id = "@someone:anywhere"
|
||||
services = [self._mkservice(is_interested=True)]
|
||||
services[0].is_interested_in_user = Mock(return_value=True)
|
||||
self.mock_store.get_app_services = Mock(return_value=services)
|
||||
self.mock_store.get_user_by_id = Mock(return_value={
|
||||
"name": user_id
|
||||
})
|
||||
|
||||
event = Mock(
|
||||
sender=user_id,
|
||||
type="m.room.message",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.mock_as_api.push = Mock()
|
||||
self.mock_as_api.query_user = Mock()
|
||||
yield self.handler.notify_interested_services(event)
|
||||
self.assertFalse(
|
||||
self.mock_as_api.query_user.called,
|
||||
"query_user called when it shouldn't have been."
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_query_room_alias_exists(self):
|
||||
room_alias_str = "#foo:bar"
|
||||
|
|
|
@ -100,7 +100,7 @@ class FederationTestCase(unittest.TestCase):
|
|||
return defer.succeed({})
|
||||
self.datastore.have_events.side_effect = have_events
|
||||
|
||||
def annotate(ev, old_state=None):
|
||||
def annotate(ev, old_state=None, outlier=False):
|
||||
context = Mock()
|
||||
context.current_state = {}
|
||||
context.auth_events = {}
|
||||
|
@ -120,7 +120,7 @@ class FederationTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
self.state_handler.compute_event_context.assert_called_once_with(
|
||||
ANY, old_state=None,
|
||||
ANY, old_state=None, outlier=False
|
||||
)
|
||||
|
||||
self.auth.check.assert_called_once_with(ANY, auth_events={})
|
||||
|
|
|
@ -42,6 +42,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
"get_room",
|
||||
"store_room",
|
||||
"get_latest_events_in_room",
|
||||
"add_event_hashes",
|
||||
]),
|
||||
resource_for_federation=NonCallableMock(),
|
||||
http_client=NonCallableMock(spec_set=[]),
|
||||
|
@ -88,6 +89,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
|
||||
self.datastore.persist_event.return_value = (1,1)
|
||||
self.datastore.add_event_hashes.return_value = []
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_invite(self):
|
||||
|
|
|
@ -66,8 +66,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
|
||||
self.mock_federation_resource = MockHttpResource()
|
||||
|
||||
mock_notifier = Mock(spec=["on_new_user_event"])
|
||||
self.on_new_user_event = mock_notifier.on_new_user_event
|
||||
mock_notifier = Mock(spec=["on_new_event"])
|
||||
self.on_new_event = mock_notifier.on_new_event
|
||||
|
||||
self.auth = Mock(spec=[])
|
||||
|
||||
|
@ -182,7 +182,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
timeout=20000,
|
||||
)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
self.on_new_event.assert_has_calls([
|
||||
call('typing_key', 1, rooms=[self.room_id]),
|
||||
])
|
||||
|
||||
|
@ -245,7 +245,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
self.on_new_event.assert_has_calls([
|
||||
call('typing_key', 1, rooms=[self.room_id]),
|
||||
])
|
||||
|
||||
|
@ -299,7 +299,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
room_id=self.room_id,
|
||||
)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
self.on_new_event.assert_has_calls([
|
||||
call('typing_key', 1, rooms=[self.room_id]),
|
||||
])
|
||||
|
||||
|
@ -331,10 +331,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
timeout=10000,
|
||||
)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
self.on_new_event.assert_has_calls([
|
||||
call('typing_key', 1, rooms=[self.room_id]),
|
||||
])
|
||||
self.on_new_user_event.reset_mock()
|
||||
self.on_new_event.reset_mock()
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
||||
|
@ -351,7 +351,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
|
||||
self.clock.advance_time(11)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
self.on_new_event.assert_has_calls([
|
||||
call('typing_key', 2, rooms=[self.room_id]),
|
||||
])
|
||||
|
||||
|
@ -377,10 +377,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
timeout=10000,
|
||||
)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
self.on_new_event.assert_has_calls([
|
||||
call('typing_key', 3, rooms=[self.room_id]),
|
||||
])
|
||||
self.on_new_user_event.reset_mock()
|
||||
self.on_new_event.reset_mock()
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 3)
|
||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
||||
|
|
|
@ -183,7 +183,17 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
|||
)
|
||||
self.assertEquals(200, code, msg=str(response))
|
||||
|
||||
self.assertEquals(0, len(response["chunk"]))
|
||||
# We may get a presence event for ourselves down
|
||||
self.assertEquals(
|
||||
0,
|
||||
len([
|
||||
c for c in response["chunk"]
|
||||
if not (
|
||||
c.get("type") == "m.presence"
|
||||
and c["content"].get("user_id") == self.user_id
|
||||
)
|
||||
])
|
||||
)
|
||||
|
||||
# joined room (expect all content for room)
|
||||
yield self.join(room=room_id, user=self.user_id, tok=self.token)
|
||||
|
|
|
@ -357,7 +357,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
|||
# all be ours
|
||||
|
||||
# I'll already get my own presence state change
|
||||
self.assertEquals({"start": "0_1_0", "end": "0_1_0", "chunk": []},
|
||||
self.assertEquals({"start": "0_1_0_0", "end": "0_1_0_0", "chunk": []},
|
||||
response
|
||||
)
|
||||
|
||||
|
@ -376,7 +376,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
|||
"/events?from=s0_1_0&timeout=0", None)
|
||||
|
||||
self.assertEquals(200, code)
|
||||
self.assertEquals({"start": "s0_1_0", "end": "s0_2_0", "chunk": [
|
||||
self.assertEquals({"start": "s0_1_0_0", "end": "s0_2_0_0", "chunk": [
|
||||
{"type": "m.presence",
|
||||
"content": {
|
||||
"user_id": "@banana:test",
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||
from synapse.api.errors import SynapseError
|
||||
from twisted.internet import defer
|
||||
from mock import Mock, MagicMock
|
||||
from tests import unittest
|
||||
import json
|
||||
|
||||
|
||||
class RegisterRestServletTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# do the dance to hook up request data to self.request_data
|
||||
self.request_data = ""
|
||||
self.request = Mock(
|
||||
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
|
||||
)
|
||||
self.request.args = {}
|
||||
|
||||
self.appservice = None
|
||||
self.auth = Mock(get_appservice_by_req=Mock(
|
||||
side_effect=lambda x: defer.succeed(self.appservice))
|
||||
)
|
||||
|
||||
self.auth_result = (False, None, None)
|
||||
self.auth_handler = Mock(
|
||||
check_auth=Mock(side_effect=lambda x,y,z: self.auth_result)
|
||||
)
|
||||
self.registration_handler = Mock()
|
||||
self.identity_handler = Mock()
|
||||
self.login_handler = Mock()
|
||||
|
||||
# do the dance to hook it up to the hs global
|
||||
self.handlers = Mock(
|
||||
auth_handler=self.auth_handler,
|
||||
registration_handler=self.registration_handler,
|
||||
identity_handler=self.identity_handler,
|
||||
login_handler=self.login_handler
|
||||
)
|
||||
self.hs = Mock()
|
||||
self.hs.hostname = "superbig~testing~thing.com"
|
||||
self.hs.get_auth = Mock(return_value=self.auth)
|
||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||
self.hs.config.disable_registration = False
|
||||
|
||||
# init the thing we're testing
|
||||
self.servlet = RegisterRestServlet(self.hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_appservice_registration_valid(self):
|
||||
user_id = "@kermit:muppet"
|
||||
token = "kermits_access_token"
|
||||
self.request.args = {
|
||||
"access_token": "i_am_an_app_service"
|
||||
}
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit"
|
||||
})
|
||||
self.appservice = {
|
||||
"id": "1234"
|
||||
}
|
||||
self.registration_handler.appservice_register = Mock(
|
||||
return_value=(user_id, token)
|
||||
)
|
||||
result = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(result, (200, {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname
|
||||
}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_appservice_registration_invalid(self):
|
||||
self.request.args = {
|
||||
"access_token": "i_am_an_app_service"
|
||||
}
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit"
|
||||
})
|
||||
self.appservice = None # no application service exists
|
||||
result = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(result, (401, None))
|
||||
|
||||
def test_POST_bad_password(self):
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit",
|
||||
"password": 666
|
||||
})
|
||||
d = self.servlet.on_POST(self.request)
|
||||
return self.assertFailure(d, SynapseError)
|
||||
|
||||
def test_POST_bad_username(self):
|
||||
self.request_data = json.dumps({
|
||||
"username": 777,
|
||||
"password": "monkey"
|
||||
})
|
||||
d = self.servlet.on_POST(self.request)
|
||||
return self.assertFailure(d, SynapseError)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_user_valid(self):
|
||||
user_id = "@kermit:muppet"
|
||||
token = "kermits_access_token"
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
})
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (True, None, {
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
})
|
||||
self.registration_handler.register = Mock(return_value=(user_id, token))
|
||||
|
||||
result = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(result, (200, {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname
|
||||
}))
|
||||
|
||||
def test_POST_disabled_registration(self):
|
||||
self.hs.config.disable_registration = True
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
})
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (True, None, {
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
})
|
||||
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
||||
d = self.servlet.on_POST(self.request)
|
||||
return self.assertFailure(d, SynapseError)
|
|
@ -17,6 +17,8 @@
|
|||
from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
|
||||
from synapse.storage._base import Cache, cached
|
||||
|
||||
|
||||
|
@ -40,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
|
|||
self.assertEquals(self.cache.get("foo"), 123)
|
||||
|
||||
def test_invalidate(self):
|
||||
self.cache.prefill("foo", 123)
|
||||
self.cache.invalidate("foo")
|
||||
self.cache.prefill(("foo",), 123)
|
||||
self.cache.invalidate(("foo",))
|
||||
|
||||
failed = False
|
||||
try:
|
||||
self.cache.get("foo")
|
||||
self.cache.get(("foo",))
|
||||
except KeyError:
|
||||
failed = True
|
||||
|
||||
|
@ -96,87 +98,102 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_passthrough(self):
|
||||
class A(object):
|
||||
@cached()
|
||||
def func(self, key):
|
||||
return key
|
||||
|
||||
self.assertEquals((yield func(self, "foo")), "foo")
|
||||
self.assertEquals((yield func(self, "bar")), "bar")
|
||||
a = A()
|
||||
|
||||
self.assertEquals((yield a.func("foo")), "foo")
|
||||
self.assertEquals((yield a.func("bar")), "bar")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_hit(self):
|
||||
callcount = [0]
|
||||
|
||||
class A(object):
|
||||
@cached()
|
||||
def func(self, key):
|
||||
callcount[0] += 1
|
||||
return key
|
||||
|
||||
yield func(self, "foo")
|
||||
a = A()
|
||||
yield a.func("foo")
|
||||
|
||||
self.assertEquals(callcount[0], 1)
|
||||
|
||||
self.assertEquals((yield func(self, "foo")), "foo")
|
||||
self.assertEquals((yield a.func("foo")), "foo")
|
||||
self.assertEquals(callcount[0], 1)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_invalidate(self):
|
||||
callcount = [0]
|
||||
|
||||
class A(object):
|
||||
@cached()
|
||||
def func(self, key):
|
||||
callcount[0] += 1
|
||||
return key
|
||||
|
||||
yield func(self, "foo")
|
||||
a = A()
|
||||
yield a.func("foo")
|
||||
|
||||
self.assertEquals(callcount[0], 1)
|
||||
|
||||
func.invalidate("foo")
|
||||
a.func.invalidate(("foo",))
|
||||
|
||||
yield func(self, "foo")
|
||||
yield a.func("foo")
|
||||
|
||||
self.assertEquals(callcount[0], 2)
|
||||
|
||||
def test_invalidate_missing(self):
|
||||
class A(object):
|
||||
@cached()
|
||||
def func(self, key):
|
||||
return key
|
||||
|
||||
func.invalidate("what")
|
||||
A().func.invalidate(("what",))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_max_entries(self):
|
||||
callcount = [0]
|
||||
|
||||
class A(object):
|
||||
@cached(max_entries=10)
|
||||
def func(self, key):
|
||||
callcount[0] += 1
|
||||
return key
|
||||
|
||||
for k in range(0,12):
|
||||
yield func(self, k)
|
||||
a = A()
|
||||
|
||||
for k in range(0, 12):
|
||||
yield a.func(k)
|
||||
|
||||
self.assertEquals(callcount[0], 12)
|
||||
|
||||
# There must have been at least 2 evictions, meaning if we calculate
|
||||
# all 12 values again, we must get called at least 2 more times
|
||||
for k in range(0,12):
|
||||
yield func(self, k)
|
||||
yield a.func(k)
|
||||
|
||||
self.assertTrue(callcount[0] >= 14,
|
||||
msg="Expected callcount >= 14, got %d" % (callcount[0]))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_prefill(self):
|
||||
callcount = [0]
|
||||
|
||||
d = defer.succeed(123)
|
||||
|
||||
class A(object):
|
||||
@cached()
|
||||
def func(self, key):
|
||||
callcount[0] += 1
|
||||
return key
|
||||
return d
|
||||
|
||||
func.prefill("foo", 123)
|
||||
a = A()
|
||||
|
||||
self.assertEquals((yield func(self, "foo")), 123)
|
||||
a.func.prefill(("foo",), ObservableDeferred(d))
|
||||
|
||||
self.assertEquals(a.func("foo").result, d.result)
|
||||
self.assertEquals(callcount[0], 0)
|
||||
|
|
|
@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||
(yield self.store.get_user_by_id(self.user_id))
|
||||
)
|
||||
|
||||
result = yield self.store.get_user_by_token(self.tokens[1])
|
||||
result = yield self.store.get_user_by_token(self.tokens[0])
|
||||
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
|
|
|
@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase):
|
|||
yield d
|
||||
self.assertTrue(d.called)
|
||||
|
||||
observers[0].assert_called_once("Go")
|
||||
observers[1].assert_called_once("Go")
|
||||
observers[0].assert_called_once_with("Go")
|
||||
observers[1].assert_called_once_with("Go")
|
||||
|
||||
self.assertEquals(mock_logger.warning.call_count, 1)
|
||||
self.assertIsInstance(mock_logger.warning.call_args[0][0],
|
||||
|
|
|
@ -114,6 +114,8 @@ class MockHttpResource(HttpServer):
|
|||
mock_request.method = http_method
|
||||
mock_request.uri = path
|
||||
|
||||
mock_request.getClientIP.return_value = "-"
|
||||
|
||||
mock_request.requestHeaders.getRawHeaders.return_value=[
|
||||
"X-Matrix origin=test,key=,sig="
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue