Merge branch 'develop' of https://github.com/matrix-org/synapse into cohort_analytics

This commit is contained in:
Neil Johnson 2018-05-14 09:31:42 +01:00
commit 977765bde2
120 changed files with 2388 additions and 1183 deletions

5
.dockerignore Normal file
View File

@ -0,0 +1,5 @@
Dockerfile
.travis.yml
.gitignore
demo/etc
tox.ini

1
.gitignore vendored
View File

@ -32,6 +32,7 @@ demo/media_store.*
demo/etc demo/etc
uploads uploads
cache
.idea/ .idea/
media_store/ media_store/

View File

@ -1,14 +1,22 @@
sudo: false sudo: false
language: python language: python
python: 2.7
# tell travis to cache ~/.cache/pip # tell travis to cache ~/.cache/pip
cache: pip cache: pip
env: matrix:
- TOX_ENV=packaging include:
- TOX_ENV=pep8 - python: 2.7
- TOX_ENV=py27 env: TOX_ENV=packaging
- python: 2.7
env: TOX_ENV=pep8
- python: 2.7
env: TOX_ENV=py27
- python: 3.6
env: TOX_ENV=py36
install: install:
- pip install tox - pip install tox

View File

@ -60,3 +60,6 @@ Niklas Riekenbrauck <nikriek at gmail dot.com>
Christoph Witzany <christoph at web.crofting.com> Christoph Witzany <christoph at web.crofting.com>
* Add LDAP support for authentication * Add LDAP support for authentication
Pierre Jaury <pierre at jaury.eu>
* Docker packaging

View File

@ -1,9 +1,55 @@
Changes in synapse <unreleased>
===============================
Potentially breaking change:
* Make Client-Server API return 401 for invalid token (PR #3161).
This changes the Client-server spec to return a 401 error code instead of 403
when the access token is unrecognised. This is the behaviour required by the
specification, but some clients may be relying on the old, incorrect
behaviour.
Thanks to @NotAFile for fixing this.
Changes in synapse v0.28.1 (2018-05-01)
=======================================
SECURITY UPDATE
* Clamp the allowed values of event depth received over federation to be
[0, 2^63 - 1]. This mitigates an attack where malicious events
injected with depth = 2^63 - 1 render rooms unusable. Depth is used to
determine the cosmetic ordering of events within a room, and so the ordering
of events in such a room will default to using stream_ordering rather than depth
(topological_ordering).
This is a temporary solution to mitigate abuse in the wild, whilst a long term solution
is being implemented to improve how the depth parameter is used.
Full details at
https://docs.google.com/document/d/1I3fi2S-XnpO45qrpCsowZv8P8dHcNZ4fsBsbOW7KABI
* Pin Twisted to <18.4 until we stop using the private _OpenSSLECCurve API.
Changes in synapse v0.28.0 (2018-04-26)
=======================================
Bug Fixes:
* Fix quarantine media admin API and search reindex (PR #3130)
* Fix media admin APIs (PR #3134)
Changes in synapse v0.28.0-rc1 (2018-04-24) Changes in synapse v0.28.0-rc1 (2018-04-24)
=========================================== ===========================================
Minor performance improvement to federation sending and bug fixes. Minor performance improvement to federation sending and bug fixes.
(Note: This release does not include state resolutions discussed in matrix live) (Note: This release does not include the delta state resolution implementation discussed in matrix live)
Features: Features:
@ -16,8 +62,7 @@ Changes:
* move handling of auto_join_rooms to RegisterHandler (PR #2996) Thanks to @krombel! * move handling of auto_join_rooms to RegisterHandler (PR #2996) Thanks to @krombel!
* Improve handling of SRV records for federation connections (PR #3016) Thanks to @silkeh! * Improve handling of SRV records for federation connections (PR #3016) Thanks to @silkeh!
* Document the behaviour of ResponseCache (PR #3059) * Document the behaviour of ResponseCache (PR #3059)
* Preparation for py3 (PR #3061, #3073, #3074, #3075, #3103, #3104, #3106, #3107 * Preparation for py3 (PR #3061, #3073, #3074, #3075, #3103, #3104, #3106, #3107, #3109, #3110) Thanks to @NotAFile!
#3109, #3110) Thanks to @NotAFile!
* update prometheus dashboard to use new metric names (PR #3069) Thanks to @krombel! * update prometheus dashboard to use new metric names (PR #3069) Thanks to @krombel!
* use python3-compatible prints (PR #3074) Thanks to @NotAFile! * use python3-compatible prints (PR #3074) Thanks to @NotAFile!
* Send federation events concurrently (PR #3078) * Send federation events concurrently (PR #3078)

19
Dockerfile Normal file
View File

@ -0,0 +1,19 @@
FROM docker.io/python:2-alpine3.7
RUN apk add --no-cache --virtual .nacl_deps su-exec build-base libffi-dev zlib-dev libressl-dev libjpeg-turbo-dev linux-headers postgresql-dev
COPY . /synapse
# A wheel cache may be provided in ./cache for faster build
RUN cd /synapse \
&& pip install --upgrade pip setuptools psycopg2 \
&& mkdir -p /synapse/cache \
&& pip install -f /synapse/cache --upgrade --process-dependency-links . \
&& mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \
&& rm -rf setup.py setup.cfg synapse
VOLUME ["/data"]
EXPOSE 8008/tcp 8448/tcp
ENTRYPOINT ["/start.py"]

View File

@ -25,6 +25,8 @@ recursive-include synapse/static *.js
exclude jenkins.sh exclude jenkins.sh
exclude jenkins*.sh exclude jenkins*.sh
exclude jenkins* exclude jenkins*
exclude Dockerfile
exclude .dockerignore
recursive-exclude jenkins *.sh recursive-exclude jenkins *.sh
prune .github prune .github

148
contrib/docker/README.md Normal file
View File

@ -0,0 +1,148 @@
# Synapse Docker
This Docker image will run Synapse as a single process. It does not provide any
database server or TURN server that you should run separately.
If you run a Postgres server, you should simply have it in the same Compose
project or set the proper environment variables and the image will automatically
use that server.
## Build
Build the docker image with the `docker build` command from the root of the synapse repository.
```
docker build -t docker.io/matrixdotorg/synapse .
```
The `-t` option sets the image tag. Official images are tagged `matrixdotorg/synapse:<version>` where `<version>` is the same as the release tag in the synapse git repository.
You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project.
## Run
This image is designed to run either with an automatically generated configuration
file or with a custom configuration that requires manual edition.
### Automated configuration
It is recommended that you use Docker Compose to run your containers, including
this image and a Postgres server. A sample ``docker-compose.yml`` is provided,
including example labels for reverse proxying and other artifacts.
Read the section about environment variables and set at least mandatory variables,
then run the server:
```
docker-compose up -d
```
### Manual configuration
A sample ``docker-compose.yml`` is provided, including example labels for
reverse proxying and other artifacts.
Specify a ``SYNAPSE_CONFIG_PATH``, preferably to a persistent path,
to use manual configuration. To generate a fresh ``homeserver.yaml``, simply run:
```
docker-compose run --rm -e SYNAPSE_SERVER_NAME=my.matrix.host synapse generate
```
Then, customize your configuration and run the server:
```
docker-compose up -d
```
### Without Compose
If you do not wish to use Compose, you may still run this image using plain
Docker commands. Note that the following is just a guideline and you may need
to add parameters to the docker run command to account for the network situation
with your postgres database.
```
docker run \
-d \
--name synapse \
-v ${DATA_PATH}:/data \
-e SYNAPSE_SERVER_NAME=my.matrix.host \
-e SYNAPSE_REPORT_STATS=yes \
docker.io/matrixdotorg/synapse:latest
```
## Volumes
The image expects a single volume, located at ``/data``, that will hold:
* temporary files during uploads;
* uploaded media and thumbnails;
* the SQLite database if you do not configure postgres;
* the appservices configuration.
You are free to use separate volumes depending on storage endpoints at your
disposal. For instance, ``/data/media`` coud be stored on a large but low
performance hdd storage while other files could be stored on high performance
endpoints.
In order to setup an application service, simply create an ``appservices``
directory in the data volume and write the application service Yaml
configuration file there. Multiple application services are supported.
## Environment
Unless you specify a custom path for the configuration file, a very generic
file will be generated, based on the following environment settings.
These are a good starting point for setting up your own deployment.
Global settings:
* ``UID``, the user id Synapse will run as [default 991]
* ``GID``, the group id Synapse will run as [default 991]
* ``SYNAPSE_CONFIG_PATH``, path to a custom config file
If ``SYNAPSE_CONFIG_PATH`` is set, you should generate a configuration file
then customize it manually. No other environment variable is required.
Otherwise, a dynamic configuration file will be used. The following environment
variables are available for configuration:
* ``SYNAPSE_SERVER_NAME`` (mandatory), the current server public hostname.
* ``SYNAPSE_REPORT_STATS``, (mandatory, ``yes`` or ``no``), enable anonymous
statistics reporting back to the Matrix project which helps us to get funding.
* ``SYNAPSE_MACAROON_SECRET_KEY`` (mandatory) secret for signing access tokens
to the server, set this to a proper random key.
* ``SYNAPSE_NO_TLS``, set this variable to disable TLS in Synapse (use this if
you run your own TLS-capable reverse proxy).
* ``SYNAPSE_ENABLE_REGISTRATION``, set this variable to enable registration on
the Synapse instance.
* ``SYNAPSE_ALLOW_GUEST``, set this variable to allow guest joining this server.
* ``SYNAPSE_EVENT_CACHE_SIZE``, the event cache size [default `10K`].
* ``SYNAPSE_CACHE_FACTOR``, the cache factor [default `0.5`].
* ``SYNAPSE_RECAPTCHA_PUBLIC_KEY``, set this variable to the recaptcha public
key in order to enable recaptcha upon registration.
* ``SYNAPSE_RECAPTCHA_PRIVATE_KEY``, set this variable to the recaptcha private
key in order to enable recaptcha upon registration.
* ``SYNAPSE_TURN_URIS``, set this variable to the coma-separated list of TURN
uris to enable TURN for this homeserver.
* ``SYNAPSE_TURN_SECRET``, set this to the TURN shared secret if required.
Shared secrets, that will be initialized to random values if not set:
* ``SYNAPSE_REGISTRATION_SHARED_SECRET``, secret for registrering users if
registration is disable.
Database specific values (will use SQLite if not set):
* `POSTGRES_DB` - The database name for the synapse postgres database. [default: `synapse`]
* `POSTGRES_HOST` - The host of the postgres database if you wish to use postgresql instead of sqlite3. [default: `db` which is useful when using a container on the same docker network in a compose file where the postgres service is called `db`]
* `POSTGRES_PASSWORD` - The password for the synapse postgres database. **If this is set then postgres will be used instead of sqlite3.** [default: none] **NOTE**: You are highly encouraged to use postgresql! Please use the compose file to make it easier to deploy.
* `POSTGRES_USER` - The user for the synapse postgres database. [default: `matrix`]
Mail server specific values (will not send emails if not set):
* ``SYNAPSE_SMTP_HOST``, hostname to the mail server.
* ``SYNAPSE_SMTP_PORT``, TCP port for accessing the mail server [default ``25``].
* ``SYNAPSE_SMTP_USER``, username for authenticating against the mail server if any.
* ``SYNAPSE_SMTP_PASSWORD``, password for authenticating against the mail server if any.

View File

@ -0,0 +1,219 @@
# vim:ft=yaml
## TLS ##
tls_certificate_path: "/data/{{ SYNAPSE_SERVER_NAME }}.tls.crt"
tls_private_key_path: "/data/{{ SYNAPSE_SERVER_NAME }}.tls.key"
tls_dh_params_path: "/data/{{ SYNAPSE_SERVER_NAME }}.tls.dh"
no_tls: {{ "True" if SYNAPSE_NO_TLS else "False" }}
tls_fingerprints: []
## Server ##
server_name: "{{ SYNAPSE_SERVER_NAME }}"
pid_file: /homeserver.pid
web_client: False
soft_file_limit: 0
## Ports ##
listeners:
{% if not SYNAPSE_NO_TLS %}
-
port: 8448
bind_addresses: ['0.0.0.0']
type: http
tls: true
x_forwarded: false
resources:
- names: [client]
compress: true
- names: [federation] # Federation APIs
compress: false
{% endif %}
- port: 8008
tls: false
bind_addresses: ['0.0.0.0']
type: http
x_forwarded: false
resources:
- names: [client]
compress: true
- names: [federation]
compress: false
## Database ##
{% if POSTGRES_PASSWORD %}
database:
name: "psycopg2"
args:
user: "{{ POSTGRES_USER or "synapse" }}"
password: "{{ POSTGRES_PASSWORD }}"
database: "{{ POSTGRES_DB or "synapse" }}"
host: "{{ POSTGRES_HOST or "db" }}"
port: "{{ POSTGRES_PORT or "5432" }}"
cp_min: 5
cp_max: 10
{% else %}
database:
name: "sqlite3"
args:
database: "/data/homeserver.db"
{% endif %}
## Performance ##
event_cache_size: "{{ SYNAPSE_EVENT_CACHE_SIZE or "10K" }}"
verbose: 0
log_file: "/data/homeserver.log"
log_config: "/compiled/log.config"
## Ratelimiting ##
rc_messages_per_second: 0.2
rc_message_burst_count: 10.0
federation_rc_window_size: 1000
federation_rc_sleep_limit: 10
federation_rc_sleep_delay: 500
federation_rc_reject_limit: 50
federation_rc_concurrent: 3
## Files ##
media_store_path: "/data/media"
uploads_path: "/data/uploads"
max_upload_size: "10M"
max_image_pixels: "32M"
dynamic_thumbnails: false
# List of thumbnail to precalculate when an image is uploaded.
thumbnail_sizes:
- width: 32
height: 32
method: crop
- width: 96
height: 96
method: crop
- width: 320
height: 240
method: scale
- width: 640
height: 480
method: scale
- width: 800
height: 600
method: scale
url_preview_enabled: False
max_spider_size: "10M"
## Captcha ##
{% if SYNAPSE_RECAPTCHA_PUBLIC_KEY %}
recaptcha_public_key: "{{ SYNAPSE_RECAPTCHA_PUBLIC_KEY }}"
recaptcha_private_key: "{{ SYNAPSE_RECAPTCHA_PRIVATE_KEY }}"
enable_registration_captcha: True
recaptcha_siteverify_api: "https://www.google.com/recaptcha/api/siteverify"
{% else %}
recaptcha_public_key: "YOUR_PUBLIC_KEY"
recaptcha_private_key: "YOUR_PRIVATE_KEY"
enable_registration_captcha: False
recaptcha_siteverify_api: "https://www.google.com/recaptcha/api/siteverify"
{% endif %}
## Turn ##
{% if SYNAPSE_TURN_URIS %}
turn_uris:
{% for uri in SYNAPSE_TURN_URIS.split(',') %} - "{{ uri }}"
{% endfor %}
turn_shared_secret: "{{ SYNAPSE_TURN_SECRET }}"
turn_user_lifetime: "1h"
turn_allow_guests: True
{% else %}
turn_uris: []
turn_shared_secret: "YOUR_SHARED_SECRET"
turn_user_lifetime: "1h"
turn_allow_guests: True
{% endif %}
## Registration ##
enable_registration: {{ "True" if SYNAPSE_ENABLE_REGISTRATION else "False" }}
registration_shared_secret: "{{ SYNAPSE_REGISTRATION_SHARED_SECRET }}"
bcrypt_rounds: 12
allow_guest_access: {{ "True" if SYNAPSE_ALLOW_GUEST else "False" }}
enable_group_creation: true
# The list of identity servers trusted to verify third party
# identifiers by this server.
trusted_third_party_id_servers:
- matrix.org
- vector.im
- riot.im
## Metrics ###
{% if SYNAPSE_REPORT_STATS.lower() == "yes" %}
enable_metrics: True
report_stats: True
{% else %}
enable_metrics: False
report_stats: False
{% endif %}
## API Configuration ##
room_invite_state_types:
- "m.room.join_rules"
- "m.room.canonical_alias"
- "m.room.avatar"
- "m.room.name"
{% if SYNAPSE_APPSERVICES %}
app_service_config_files:
{% for appservice in SYNAPSE_APPSERVICES %} - "{{ appservice }}"
{% endfor %}
{% else %}
app_service_config_files: []
{% endif %}
macaroon_secret_key: "{{ SYNAPSE_MACAROON_SECRET_KEY }}"
expire_access_token: False
## Signing Keys ##
signing_key_path: "/data/{{ SYNAPSE_SERVER_NAME }}.signing.key"
old_signing_keys: {}
key_refresh_interval: "1d" # 1 Day.
# The trusted servers to download signing keys from.
perspectives:
servers:
"matrix.org":
verify_keys:
"ed25519:auto":
key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
password_config:
enabled: true
{% if SYNAPSE_SMTP_HOST %}
email:
enable_notifs: false
smtp_host: "{{ SYNAPSE_SMTP_HOST }}"
smtp_port: {{ SYNAPSE_SMTP_PORT or "25" }}
smtp_user: "{{ SYNAPSE_SMTP_USER }}"
smtp_pass: "{{ SYNAPSE_SMTP_PASSWORD }}"
require_transport_security: False
notif_from: "{{ SYNAPSE_SMTP_FROM or "hostmaster@" + SYNAPSE_SERVER_NAME }}"
app_name: Matrix
template_dir: res/templates
notif_template_html: notif_mail.html
notif_template_text: notif_mail.txt
notif_for_new_users: True
riot_base_url: "https://{{ SYNAPSE_SERVER_NAME }}"
{% endif %}

View File

@ -0,0 +1,29 @@
version: 1
formatters:
precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s- %(message)s'
filters:
context:
(): synapse.util.logcontext.LoggingContextFilter
request: ""
handlers:
console:
class: logging.StreamHandler
formatter: precise
filters: [context]
loggers:
synapse:
level: {{ SYNAPSE_LOG_LEVEL or "WARNING" }}
synapse.storage.SQL:
# beware: increasing this to DEBUG will make synapse log sensitive
# information such as access tokens.
level: {{ SYNAPSE_LOG_LEVEL or "WARNING" }}
root:
level: {{ SYNAPSE_LOG_LEVEL or "WARNING" }}
handlers: [console]

View File

@ -0,0 +1,49 @@
# This compose file is compatible with Compose itself, it might need some
# adjustments to run properly with stack.
version: '3'
services:
synapse:
image: docker.io/matrixdotorg/synapse:latest
# Since snyapse does not retry to connect to the database, restart upon
# failure
restart: unless-stopped
# See the readme for a full documentation of the environment settings
environment:
- SYNAPSE_SERVER_NAME=my.matrix.host
- SYNAPSE_REPORT_STATS=no
- SYNAPSE_ENABLE_REGISTRATION=yes
- SYNAPSE_LOG_LEVEL=INFO
- POSTGRES_PASSWORD=changeme
volumes:
# You may either store all the files in a local folder
- ./files:/data
# .. or you may split this between different storage points
# - ./files:/data
# - /path/to/ssd:/data/uploads
# - /path/to/large_hdd:/data/media
depends_on:
- db
# In order to expose Synapse, remove one of the following, you might for
# instance expose the TLS port directly:
ports:
- 8448:8448/tcp
# ... or use a reverse proxy, here is an example for traefik:
labels:
- traefik.enable=true
- traefik.frontend.rule=Host:my.matrix.Host
- traefik.port=8448
db:
image: docker.io/postgres:10-alpine
# Change that password, of course!
environment:
- POSTGRES_USER=synapse
- POSTGRES_PASSWORD=changeme
volumes:
# You may store the database tables in a local folder..
- ./schemas:/var/lib/postgresql/data
# .. or store them on some high performance storage for better results
# - /path/to/ssd/storage:/var/lib/postfesql/data

66
contrib/docker/start.py Executable file
View File

@ -0,0 +1,66 @@
#!/usr/local/bin/python
import jinja2
import os
import sys
import subprocess
import glob
# Utility functions
convert = lambda src, dst, environ: open(dst, "w").write(jinja2.Template(open(src).read()).render(**environ))
def check_arguments(environ, args):
for argument in args:
if argument not in environ:
print("Environment variable %s is mandatory, exiting." % argument)
sys.exit(2)
def generate_secrets(environ, secrets):
for name, secret in secrets.items():
if secret not in environ:
filename = "/data/%s.%s.key" % (environ["SYNAPSE_SERVER_NAME"], name)
if os.path.exists(filename):
with open(filename) as handle: value = handle.read()
else:
print("Generating a random secret for {}".format(name))
value = os.urandom(32).encode("hex")
with open(filename, "w") as handle: handle.write(value)
environ[secret] = value
# Prepare the configuration
mode = sys.argv[1] if len(sys.argv) > 1 else None
environ = os.environ.copy()
ownership = "{}:{}".format(environ.get("UID", 991), environ.get("GID", 991))
args = ["python", "-m", "synapse.app.homeserver"]
# In generate mode, generate a configuration, missing keys, then exit
if mode == "generate":
check_arguments(environ, ("SYNAPSE_SERVER_NAME", "SYNAPSE_REPORT_STATS", "SYNAPSE_CONFIG_PATH"))
args += [
"--server-name", environ["SYNAPSE_SERVER_NAME"],
"--report-stats", environ["SYNAPSE_REPORT_STATS"],
"--config-path", environ["SYNAPSE_CONFIG_PATH"],
"--generate-config"
]
os.execv("/usr/local/bin/python", args)
# In normal mode, generate missing keys if any, then run synapse
else:
# Parse the configuration file
if "SYNAPSE_CONFIG_PATH" in environ:
args += ["--config-path", environ["SYNAPSE_CONFIG_PATH"]]
else:
check_arguments(environ, ("SYNAPSE_SERVER_NAME", "SYNAPSE_REPORT_STATS"))
generate_secrets(environ, {
"registration": "SYNAPSE_REGISTRATION_SHARED_SECRET",
"macaroon": "SYNAPSE_MACAROON_SECRET_KEY"
})
environ["SYNAPSE_APPSERVICES"] = glob.glob("/data/appservices/*.yaml")
if not os.path.exists("/compiled"): os.mkdir("/compiled")
convert("/conf/homeserver.yaml", "/compiled/homeserver.yaml", environ)
convert("/conf/log.config", "/compiled/log.config", environ)
subprocess.check_output(["chown", "-R", ownership, "/data"])
args += ["--config-path", "/compiled/homeserver.yaml"]
# Generate missing keys and start synapse
subprocess.check_output(args + ["--generate-keys"])
os.execv("/sbin/su-exec", ["su-exec", ownership] + args)

View File

@ -1,5 +1,7 @@
#! /bin/bash #! /bin/bash
set -eux
cd "`dirname $0`/.." cd "`dirname $0`/.."
TOX_DIR=$WORKSPACE/.tox TOX_DIR=$WORKSPACE/.tox
@ -14,7 +16,20 @@ fi
tox -e py27 --notest -v tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin TOX_BIN=$TOX_DIR/py27/bin
$TOX_BIN/pip install setuptools
# cryptography 2.2 requires setuptools >= 18.5.
#
# older versions of virtualenv (?) give us a virtualenv with the same version
# of setuptools as is installed on the system python (and tox runs virtualenv
# under python3, so we get the version of setuptools that is installed on that).
#
# anyway, make sure that we have a recent enough setuptools.
$TOX_BIN/pip install 'setuptools>=18.5'
# we also need a semi-recent version of pip, because old ones fail to install
# the "enum34" dependency of cryptography.
$TOX_BIN/pip install 'pip>=10'
{ python synapse/python_dependencies.py { python synapse/python_dependencies.py
echo lxml psycopg2 echo lxml psycopg2
} | xargs $TOX_BIN/pip install } | xargs $TOX_BIN/pip install

View File

@ -6,9 +6,19 @@
## Do not run it lightly. ## Do not run it lightly.
set -e
if [ "$1" == "-h" ] || [ "$1" == "" ]; then
echo "Call with ROOM_ID as first option and then pipe it into the database. So for instance you might run"
echo " nuke-room-from-db.sh <room_id> | sqlite3 homeserver.db"
echo "or"
echo " nuke-room-from-db.sh <room_id> | psql --dbname=synapse"
exit
fi
ROOMID="$1" ROOMID="$1"
sqlite3 homeserver.db <<EOF cat <<EOF
DELETE FROM event_forward_extremities WHERE room_id = '$ROOMID'; DELETE FROM event_forward_extremities WHERE room_id = '$ROOMID';
DELETE FROM event_backward_extremities WHERE room_id = '$ROOMID'; DELETE FROM event_backward_extremities WHERE room_id = '$ROOMID';
DELETE FROM event_edges WHERE room_id = '$ROOMID'; DELETE FROM event_edges WHERE room_id = '$ROOMID';
@ -29,7 +39,7 @@ DELETE FROM state_groups WHERE room_id = '$ROOMID';
DELETE FROM state_groups_state WHERE room_id = '$ROOMID'; DELETE FROM state_groups_state WHERE room_id = '$ROOMID';
DELETE FROM receipts_graph WHERE room_id = '$ROOMID'; DELETE FROM receipts_graph WHERE room_id = '$ROOMID';
DELETE FROM receipts_linearized WHERE room_id = '$ROOMID'; DELETE FROM receipts_linearized WHERE room_id = '$ROOMID';
DELETE FROM event_search_content WHERE c1room_id = '$ROOMID'; DELETE FROM event_search WHERE room_id = '$ROOMID';
DELETE FROM guest_access WHERE room_id = '$ROOMID'; DELETE FROM guest_access WHERE room_id = '$ROOMID';
DELETE FROM history_visibility WHERE room_id = '$ROOMID'; DELETE FROM history_visibility WHERE room_id = '$ROOMID';
DELETE FROM room_tags WHERE room_id = '$ROOMID'; DELETE FROM room_tags WHERE room_id = '$ROOMID';

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.28.0-rc1" __version__ = "0.28.1"

View File

@ -16,6 +16,9 @@
"""Contains constants from the specification.""" """Contains constants from the specification."""
# the "depth" field on events is limited to 2**63 - 1
MAX_DEPTH = 2**63 - 1
class Membership(object): class Membership(object):

View File

@ -32,10 +32,10 @@ from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import reactor from twisted.internet import reactor, defer
from twisted.web.resource import NoResource from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.appservice") logger = logging.getLogger("synapse.app.appservice")
@ -74,6 +74,7 @@ class AppserviceServer(HomeServer):
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
self.version_string,
) )
) )
@ -112,9 +113,14 @@ class ASReplicationHandler(ReplicationClientHandler):
if stream_name == "events": if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering() max_stream_id = self.store.get_room_max_stream_ordering()
preserve_fn( run_in_background(self._notify_app_services, max_stream_id)
self.appservice_handler.notify_interested_services
)(max_stream_id) @defer.inlineCallbacks
def _notify_app_services(self, room_stream_id):
try:
yield self.appservice_handler.notify_interested_services(room_stream_id)
except Exception:
logger.exception("Error notifying application services of event")
def start(config_options): def start(config_options):

View File

@ -98,6 +98,7 @@ class ClientReaderServer(HomeServer):
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
self.version_string,
) )
) )

View File

@ -114,6 +114,7 @@ class EventCreatorServer(HomeServer):
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
self.version_string,
) )
) )

View File

@ -87,6 +87,7 @@ class FederationReaderServer(HomeServer):
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
self.version_string,
) )
) )

View File

@ -38,7 +38,7 @@ from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -101,6 +101,7 @@ class FederationSenderServer(HomeServer):
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
self.version_string,
) )
) )
@ -229,7 +230,7 @@ class FederationSenderHandler(object):
# presence, typing, etc. # presence, typing, etc.
if stream_name == "federation": if stream_name == "federation":
send_queue.process_rows_for_federation(self.federation_sender, rows) send_queue.process_rows_for_federation(self.federation_sender, rows)
preserve_fn(self.update_token)(token) run_in_background(self.update_token, token)
# We also need to poke the federation sender when new events happen # We also need to poke the federation sender when new events happen
elif stream_name == "events": elif stream_name == "events":
@ -237,19 +238,22 @@ class FederationSenderHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def update_token(self, token): def update_token(self, token):
self.federation_position = token try:
self.federation_position = token
# We linearize here to ensure we don't have races updating the token # We linearize here to ensure we don't have races updating the token
with (yield self._fed_position_linearizer.queue(None)): with (yield self._fed_position_linearizer.queue(None)):
if self._last_ack < self.federation_position: if self._last_ack < self.federation_position:
yield self.store.update_federation_out_pos( yield self.store.update_federation_out_pos(
"federation", self.federation_position "federation", self.federation_position
) )
# We ACK this token over replication so that the master can drop # We ACK this token over replication so that the master can drop
# its in memory queues # its in memory queues
self.replication_client.send_federation_ack(self.federation_position) self.replication_client.send_federation_ack(self.federation_position)
self._last_ack = self.federation_position self._last_ack = self.federation_position
except Exception:
logger.exception("Error updating federation stream position")
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -152,6 +152,7 @@ class FrontendProxyServer(HomeServer):
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
self.version_string,
) )
) )

View File

@ -140,6 +140,7 @@ class SynapseHomeServer(HomeServer):
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
self.version_string,
), ),
self.tls_server_context_factory, self.tls_server_context_factory,
) )
@ -153,6 +154,7 @@ class SynapseHomeServer(HomeServer):
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
self.version_string,
) )
) )
logger.info("Synapse now listening on port %d", port) logger.info("Synapse now listening on port %d", port)

View File

@ -94,6 +94,7 @@ class MediaRepositoryServer(HomeServer):
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
self.version_string,
) )
) )

View File

@ -33,7 +33,7 @@ from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -104,6 +104,7 @@ class PusherServer(HomeServer):
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
self.version_string,
) )
) )
@ -140,24 +141,27 @@ class PusherReplicationHandler(ReplicationClientHandler):
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
preserve_fn(self.poke_pushers)(stream_name, token, rows) run_in_background(self.poke_pushers, stream_name, token, rows)
@defer.inlineCallbacks @defer.inlineCallbacks
def poke_pushers(self, stream_name, token, rows): def poke_pushers(self, stream_name, token, rows):
if stream_name == "pushers": try:
for row in rows: if stream_name == "pushers":
if row.deleted: for row in rows:
yield self.stop_pusher(row.user_id, row.app_id, row.pushkey) if row.deleted:
else: yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
yield self.start_pusher(row.user_id, row.app_id, row.pushkey) else:
elif stream_name == "events": yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
yield self.pusher_pool.on_new_notifications( elif stream_name == "events":
token, token, yield self.pusher_pool.on_new_notifications(
) token, token,
elif stream_name == "receipts": )
yield self.pusher_pool.on_new_receipts( elif stream_name == "receipts":
token, token, set(row.room_id for row in rows) yield self.pusher_pool.on_new_receipts(
) token, token, set(row.room_id for row in rows)
)
except Exception:
logger.exception("Error poking pushers")
def stop_pusher(self, user_id, app_id, pushkey): def stop_pusher(self, user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey) key = "%s:%s" % (app_id, pushkey)

View File

@ -51,7 +51,7 @@ from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberStore
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
@ -281,6 +281,7 @@ class SynchrotronServer(HomeServer):
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
self.version_string,
) )
) )
@ -327,8 +328,7 @@ class SyncReplicationHandler(ReplicationClientHandler):
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.process_and_notify, stream_name, token, rows)
preserve_fn(self.process_and_notify)(stream_name, token, rows)
def get_streams_to_replicate(self): def get_streams_to_replicate(self):
args = super(SyncReplicationHandler, self).get_streams_to_replicate() args = super(SyncReplicationHandler, self).get_streams_to_replicate()
@ -340,55 +340,58 @@ class SyncReplicationHandler(ReplicationClientHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def process_and_notify(self, stream_name, token, rows): def process_and_notify(self, stream_name, token, rows):
if stream_name == "events": try:
# We shouldn't get multiple rows per token for events stream, so if stream_name == "events":
# we don't need to optimise this for multiple rows. # We shouldn't get multiple rows per token for events stream, so
for row in rows: # we don't need to optimise this for multiple rows.
event = yield self.store.get_event(row.event_id) for row in rows:
extra_users = () event = yield self.store.get_event(row.event_id)
if event.type == EventTypes.Member: extra_users = ()
extra_users = (event.state_key,) if event.type == EventTypes.Member:
max_token = self.store.get_room_max_stream_ordering() extra_users = (event.state_key,)
self.notifier.on_new_room_event( max_token = self.store.get_room_max_stream_ordering()
event, token, max_token, extra_users self.notifier.on_new_room_event(
) event, token, max_token, extra_users
elif stream_name == "push_rules": )
self.notifier.on_new_event( elif stream_name == "push_rules":
"push_rules_key", token, users=[row.user_id for row in rows],
)
elif stream_name in ("account_data", "tag_account_data",):
self.notifier.on_new_event(
"account_data_key", token, users=[row.user_id for row in rows],
)
elif stream_name == "receipts":
self.notifier.on_new_event(
"receipt_key", token, rooms=[row.room_id for row in rows],
)
elif stream_name == "typing":
self.typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event(
"typing_key", token, rooms=[row.room_id for row in rows],
)
elif stream_name == "to_device":
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
self.notifier.on_new_event( self.notifier.on_new_event(
"to_device_key", token, users=entities, "push_rules_key", token, users=[row.user_id for row in rows],
) )
elif stream_name == "device_lists": elif stream_name in ("account_data", "tag_account_data",):
all_room_ids = set() self.notifier.on_new_event(
for row in rows: "account_data_key", token, users=[row.user_id for row in rows],
room_ids = yield self.store.get_rooms_for_user(row.user_id) )
all_room_ids.update(room_ids) elif stream_name == "receipts":
self.notifier.on_new_event( self.notifier.on_new_event(
"device_list_key", token, rooms=all_room_ids, "receipt_key", token, rooms=[row.room_id for row in rows],
) )
elif stream_name == "presence": elif stream_name == "typing":
yield self.presence_handler.process_replication_rows(token, rows) self.typing_handler.process_replication_rows(token, rows)
elif stream_name == "receipts": self.notifier.on_new_event(
self.notifier.on_new_event( "typing_key", token, rooms=[row.room_id for row in rows],
"groups_key", token, users=[row.user_id for row in rows], )
) elif stream_name == "to_device":
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
self.notifier.on_new_event(
"to_device_key", token, users=entities,
)
elif stream_name == "device_lists":
all_room_ids = set()
for row in rows:
room_ids = yield self.store.get_rooms_for_user(row.user_id)
all_room_ids.update(room_ids)
self.notifier.on_new_event(
"device_list_key", token, rooms=all_room_ids,
)
elif stream_name == "presence":
yield self.presence_handler.process_replication_rows(token, rows)
elif stream_name == "receipts":
self.notifier.on_new_event(
"groups_key", token, users=[row.user_id for row in rows],
)
except Exception:
logger.exception("Error processing replication")
def start(config_options): def start(config_options):

View File

@ -39,10 +39,10 @@ from synapse.storage.engines import create_engine
from synapse.storage.user_directory import UserDirectoryStore from synapse.storage.user_directory import UserDirectoryStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import reactor from twisted.internet import reactor, defer
from twisted.web.resource import NoResource from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.user_dir") logger = logging.getLogger("synapse.app.user_dir")
@ -126,6 +126,7 @@ class UserDirectoryServer(HomeServer):
site_tag, site_tag,
listener_config, listener_config,
root_resource, root_resource,
self.version_string,
) )
) )
@ -164,7 +165,14 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
stream_name, token, rows stream_name, token, rows
) )
if stream_name == "current_state_deltas": if stream_name == "current_state_deltas":
preserve_fn(self.user_directory.notify_new_event)() run_in_background(self._notify_directory)
@defer.inlineCallbacks
def _notify_directory(self):
try:
yield self.user_directory.notify_new_event()
except Exception:
logger.exception("Error notifiying user directory of state update")
def start(config_options): def start(config_options):

View File

@ -51,7 +51,7 @@ components.
from twisted.internet import defer from twisted.internet import defer
from synapse.appservice import ApplicationServiceState from synapse.appservice import ApplicationServiceState
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
import logging import logging
@ -106,7 +106,7 @@ class _ServiceQueuer(object):
def enqueue(self, service, event): def enqueue(self, service, event):
# if this service isn't being sent something # if this service isn't being sent something
self.queued_events.setdefault(service.id, []).append(event) self.queued_events.setdefault(service.id, []).append(event)
preserve_fn(self._send_request)(service) run_in_background(self._send_request, service)
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_request(self, service): def _send_request(self, service):
@ -152,10 +152,10 @@ class _TransactionController(object):
if sent: if sent:
yield txn.complete(self.store) yield txn.complete(self.store)
else: else:
preserve_fn(self._start_recoverer)(service) run_in_background(self._start_recoverer, service)
except Exception as e: except Exception:
logger.exception(e) logger.exception("Error creating appservice transaction")
preserve_fn(self._start_recoverer)(service) run_in_background(self._start_recoverer, service)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_recovered(self, recoverer): def on_recovered(self, recoverer):
@ -176,17 +176,20 @@ class _TransactionController(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _start_recoverer(self, service): def _start_recoverer(self, service):
yield self.store.set_appservice_state( try:
service, yield self.store.set_appservice_state(
ApplicationServiceState.DOWN service,
) ApplicationServiceState.DOWN
logger.info( )
"Application service falling behind. Starting recoverer. AS ID %s", logger.info(
service.id "Application service falling behind. Starting recoverer. AS ID %s",
) service.id
recoverer = self.recoverer_fn(service, self.on_recovered) )
self.add_recoverers([recoverer]) recoverer = self.recoverer_fn(service, self.on_recovered)
recoverer.recover() self.add_recoverers([recoverer])
recoverer.recover()
except Exception:
logger.exception("Error starting AS recoverer")
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_service_up(self, service): def _is_service_up(self, service):

View File

@ -281,15 +281,15 @@ class Config(object):
) )
if not cls.path_exists(config_dir_path): if not cls.path_exists(config_dir_path):
os.makedirs(config_dir_path) os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file: with open(config_path, "w") as config_file:
config_bytes, config = obj.generate_config( config_str, config = obj.generate_config(
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name, server_name=server_name,
report_stats=(config_args.report_stats == "yes"), report_stats=(config_args.report_stats == "yes"),
is_generating_file=True is_generating_file=True
) )
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
config_file.write(config_bytes) config_file.write(config_str)
print(( print((
"A config file has been generated in %r for server name" "A config file has been generated in %r for server name"
" %r with corresponding SSL keys and self-signed" " %r with corresponding SSL keys and self-signed"

View File

@ -17,11 +17,11 @@ from ._base import Config, ConfigError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.types import UserID from synapse.types import UserID
import urllib
import yaml import yaml
import logging import logging
from six import string_types from six import string_types
from six.moves.urllib import parse as urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -105,7 +105,7 @@ def _load_appservice(hostname, as_info, config_filename):
) )
localpart = as_info["sender_localpart"] localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart: if urlparse.quote(localpart) != localpart:
raise ValueError( raise ValueError(
"sender_localpart needs characters which are not URL encoded." "sender_localpart needs characters which are not URL encoded."
) )

View File

@ -117,7 +117,7 @@ class LoggingConfig(Config):
log_config = config.get("log_config") log_config = config.get("log_config")
if log_config and not os.path.exists(log_config): if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log") log_file = self.abspath("homeserver.log")
with open(log_config, "wb") as log_config_file: with open(log_config, "w") as log_config_file:
log_config_file.write( log_config_file.write(
DEFAULT_LOG_CONFIG.substitute(log_file=log_file) DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
) )

View File

@ -133,7 +133,7 @@ class TlsConfig(Config):
tls_dh_params_path = config["tls_dh_params_path"] tls_dh_params_path = config["tls_dh_params_path"]
if not self.path_exists(tls_private_key_path): if not self.path_exists(tls_private_key_path):
with open(tls_private_key_path, "w") as private_key_file: with open(tls_private_key_path, "wb") as private_key_file:
tls_private_key = crypto.PKey() tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048) tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
private_key_pem = crypto.dump_privatekey( private_key_pem = crypto.dump_privatekey(
@ -148,7 +148,7 @@ class TlsConfig(Config):
) )
if not self.path_exists(tls_certificate_path): if not self.path_exists(tls_certificate_path):
with open(tls_certificate_path, "w") as certificate_file: with open(tls_certificate_path, "wb") as certificate_file:
cert = crypto.X509() cert = crypto.X509()
subject = cert.get_subject() subject = cert.get_subject()
subject.CN = config["server_name"] subject.CN = config["server_name"]

View File

@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
from twisted.internet import ssl from twisted.internet import ssl
from OpenSSL import SSL from OpenSSL import SSL, crypto
from twisted.internet._sslverify import _OpenSSLECCurve, _defaultCurveName from twisted.internet._sslverify import _defaultCurveName
import logging import logging
@ -32,8 +32,9 @@ class ServerContextFactory(ssl.ContextFactory):
@staticmethod @staticmethod
def configure_context(context, config): def configure_context(context, config):
try: try:
_ecCurve = _OpenSSLECCurve(_defaultCurveName) _ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
_ecCurve.addECKeyToContext(context) context.set_tmp_ecdh(_ecCurve)
except Exception: except Exception:
logger.exception("Failed to enable elliptic curve for TLS") logger.exception("Failed to enable elliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)

View File

@ -19,7 +19,8 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util import unwrapFirstError, logcontext from synapse.util import unwrapFirstError, logcontext
from synapse.util.logcontext import ( from synapse.util.logcontext import (
PreserveLoggingContext, PreserveLoggingContext,
preserve_fn preserve_fn,
run_in_background,
) )
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -127,7 +128,7 @@ class Keyring(object):
verify_requests.append(verify_request) verify_requests.append(verify_request)
preserve_fn(self._start_key_lookups)(verify_requests) run_in_background(self._start_key_lookups, verify_requests)
# Pass those keys to handle_key_deferred so that the json object # Pass those keys to handle_key_deferred so that the json object
# signatures can be verified # signatures can be verified
@ -146,53 +147,56 @@ class Keyring(object):
verify_requests (List[VerifyKeyRequest]): verify_requests (List[VerifyKeyRequest]):
""" """
# create a deferred for each server we're going to look up the keys try:
# for; we'll resolve them once we have completed our lookups. # create a deferred for each server we're going to look up the keys
# These will be passed into wait_for_previous_lookups to block # for; we'll resolve them once we have completed our lookups.
# any other lookups until we have finished. # These will be passed into wait_for_previous_lookups to block
# The deferreds are called with no logcontext. # any other lookups until we have finished.
server_to_deferred = { # The deferreds are called with no logcontext.
rq.server_name: defer.Deferred() server_to_deferred = {
for rq in verify_requests rq.server_name: defer.Deferred()
} for rq in verify_requests
}
# We want to wait for any previous lookups to complete before # We want to wait for any previous lookups to complete before
# proceeding. # proceeding.
yield self.wait_for_previous_lookups( yield self.wait_for_previous_lookups(
[rq.server_name for rq in verify_requests], [rq.server_name for rq in verify_requests],
server_to_deferred, server_to_deferred,
)
# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
#
# map from server name to a set of request ids
server_to_request_ids = {}
for verify_request in verify_requests:
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
def remove_deferreds(res, verify_request):
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
for verify_request in verify_requests:
verify_request.deferred.addBoth(
remove_deferreds, verify_request,
) )
# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
#
# map from server name to a set of request ids
server_to_request_ids = {}
for verify_request in verify_requests:
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
def remove_deferreds(res, verify_request):
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
for verify_request in verify_requests:
verify_request.deferred.addBoth(
remove_deferreds, verify_request,
)
except Exception:
logger.exception("Error starting key lookups")
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred): def wait_for_previous_lookups(self, server_names, server_to_deferred):
"""Waits for any previous key lookups for the given servers to finish. """Waits for any previous key lookups for the given servers to finish.
@ -313,7 +317,7 @@ class Keyring(object):
if not verify_request.deferred.called: if not verify_request.deferred.called:
verify_request.deferred.errback(err) verify_request.deferred.errback(err)
preserve_fn(do_iterations)().addErrback(on_err) run_in_background(do_iterations).addErrback(on_err)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
@ -329,8 +333,9 @@ class Keyring(object):
""" """
res = yield logcontext.make_deferred_yieldable(defer.gatherResults( res = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self.store.get_server_verify_keys)( run_in_background(
server_name, key_ids self.store.get_server_verify_keys,
server_name, key_ids,
).addCallback(lambda ks, server: (server, ks), server_name) ).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
@ -358,7 +363,7 @@ class Keyring(object):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults( results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(get_key)(p_name, p_keys) run_in_background(get_key, p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items() for p_name, p_keys in self.perspective_servers.items()
], ],
consumeErrors=True, consumeErrors=True,
@ -398,7 +403,7 @@ class Keyring(object):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults( results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(get_key)(server_name, key_ids) run_in_background(get_key, server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
@ -481,7 +486,8 @@ class Keyring(object):
yield logcontext.make_deferred_yieldable(defer.gatherResults( yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self.store_keys)( run_in_background(
self.store_keys,
server_name=server_name, server_name=server_name,
from_server=perspective_name, from_server=perspective_name,
verify_keys=response_keys, verify_keys=response_keys,
@ -539,7 +545,8 @@ class Keyring(object):
yield logcontext.make_deferred_yieldable(defer.gatherResults( yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self.store_keys)( run_in_background(
self.store_keys,
server_name=key_server_name, server_name=key_server_name,
from_server=server_name, from_server=server_name,
verify_keys=verify_keys, verify_keys=verify_keys,
@ -615,7 +622,8 @@ class Keyring(object):
yield logcontext.make_deferred_yieldable(defer.gatherResults( yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self.store.store_server_keys_json)( run_in_background(
self.store.store_server_keys_json,
server_name=server_name, server_name=server_name,
key_id=key_id, key_id=key_id,
from_server=server_name, from_server=server_name,
@ -716,7 +724,8 @@ class Keyring(object):
# TODO(markjh): Store whether the keys have expired. # TODO(markjh): Store whether the keys have expired.
return logcontext.make_deferred_yieldable(defer.gatherResults( return logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self.store.store_server_verify_key)( run_in_background(
self.store.store_server_verify_key,
server_name, server_name, key.time_added, key server_name, server_name, key.time_added, key
) )
for key_id, key in verify_keys.items() for key_id, key in verify_keys.items()

View File

@ -47,14 +47,26 @@ class _EventInternalMetadata(object):
def _event_dict_property(key): def _event_dict_property(key):
# We want to be able to use hasattr with the event dict properties.
# However, (on python3) hasattr expects AttributeError to be raised. Hence,
# we need to transform the KeyError into an AttributeError
def getter(self): def getter(self):
return self._event_dict[key] try:
return self._event_dict[key]
except KeyError:
raise AttributeError(key)
def setter(self, v): def setter(self, v):
self._event_dict[key] = v try:
self._event_dict[key] = v
except KeyError:
raise AttributeError(key)
def delete(self): def delete(self):
del self._event_dict[key] try:
del self._event_dict[key]
except KeyError:
raise AttributeError(key)
return property( return property(
getter, getter,

View File

@ -14,7 +14,10 @@
# limitations under the License. # limitations under the License.
import logging import logging
from synapse.api.errors import SynapseError import six
from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import SynapseError, Codes
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
@ -190,11 +193,23 @@ def event_from_pdu_json(pdu_json, outlier=False):
FrozenEvent FrozenEvent
Raises: Raises:
SynapseError: if the pdu is missing required fields SynapseError: if the pdu is missing required fields or is otherwise
not a valid matrix event
""" """
# we could probably enforce a bunch of other fields here (room_id, sender, # we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc) # origin, etc etc)
assert_params_in_request(pdu_json, ('event_id', 'type')) assert_params_in_request(pdu_json, ('event_id', 'type', 'depth'))
depth = pdu_json['depth']
if not isinstance(depth, six.integer_types):
raise SynapseError(400, "Depth %r not an intger" % (depth, ),
Codes.BAD_JSON)
if depth < 0:
raise SynapseError(400, "Depth too small", Codes.BAD_JSON)
elif depth > MAX_DEPTH:
raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
event = FrozenEvent( event = FrozenEvent(
pdu_json pdu_json
) )

View File

@ -19,6 +19,8 @@ import itertools
import logging import logging
import random import random
from six.moves import range
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
@ -33,7 +35,7 @@ from synapse.federation.federation_base import (
import synapse.metrics import synapse.metrics
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -413,11 +415,12 @@ class FederationClient(FederationBase):
batch_size = 20 batch_size = 20
missing_events = list(missing_events) missing_events = list(missing_events)
for i in xrange(0, len(missing_events), batch_size): for i in range(0, len(missing_events), batch_size):
batch = set(missing_events[i:i + batch_size]) batch = set(missing_events[i:i + batch_size])
deferreds = [ deferreds = [
preserve_fn(self.get_pdu)( run_in_background(
self.get_pdu,
destinations=random_server_list(), destinations=random_server_list(),
event_id=e_id, event_id=e_id,
) )

View File

@ -323,6 +323,8 @@ class TransactionQueue(object):
break break
yield self._process_presence_inner(states_map.values()) yield self._process_presence_inner(states_map.values())
except Exception:
logger.exception("Error sending presence states to servers")
finally: finally:
self._processing_pending_presence = False self._processing_pending_presence = False

View File

@ -25,7 +25,7 @@ from synapse.http.servlet import (
) )
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import run_in_background
from synapse.types import ThirdPartyInstanceID, get_domain_from_id from synapse.types import ThirdPartyInstanceID, get_domain_from_id
import functools import functools
@ -152,11 +152,18 @@ class Authenticator(object):
# alive # alive
retry_timings = yield self.store.get_destination_retry_timings(origin) retry_timings = yield self.store.get_destination_retry_timings(origin)
if retry_timings and retry_timings["retry_last_ts"]: if retry_timings and retry_timings["retry_last_ts"]:
logger.info("Marking origin %r as up", origin) run_in_background(self._reset_retry_timings, origin)
preserve_fn(self.store.set_destination_retry_timings)(origin, 0, 0)
defer.returnValue(origin) defer.returnValue(origin)
@defer.inlineCallbacks
def _reset_retry_timings(self, origin):
try:
logger.info("Marking origin %r as up", origin)
yield self.store.set_destination_retry_timings(origin, 0, 0)
except Exception:
logger.exception("Error resetting retry timings on %s", origin)
class BaseFederationServlet(object): class BaseFederationServlet(object):
REQUIRE_AUTH = True REQUIRE_AUTH = True

View File

@ -74,8 +74,6 @@ class Transaction(JsonEncodedObject):
"previous_ids", "previous_ids",
"pdus", "pdus",
"edus", "edus",
"transaction_id",
"destination",
"pdu_failures", "pdu_failures",
] ]

View File

@ -42,7 +42,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import run_in_background
from signedjson.sign import sign_json from signedjson.sign import sign_json
@ -165,31 +165,35 @@ class GroupAttestionRenewer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _renew_attestation(group_id, user_id): def _renew_attestation(group_id, user_id):
if not self.is_mine_id(group_id): try:
destination = get_domain_from_id(group_id) if not self.is_mine_id(group_id):
elif not self.is_mine_id(user_id): destination = get_domain_from_id(group_id)
destination = get_domain_from_id(user_id) elif not self.is_mine_id(user_id):
else: destination = get_domain_from_id(user_id)
logger.warn( else:
"Incorrectly trying to do attestations for user: %r in %r", logger.warn(
user_id, group_id, "Incorrectly trying to do attestations for user: %r in %r",
user_id, group_id,
)
yield self.store.remove_attestation_renewal(group_id, user_id)
return
attestation = self.attestations.create_attestation(group_id, user_id)
yield self.transport_client.renew_group_attestation(
destination, group_id, user_id,
content={"attestation": attestation},
) )
yield self.store.remove_attestation_renewal(group_id, user_id)
return
attestation = self.attestations.create_attestation(group_id, user_id) yield self.store.update_attestation_renewal(
group_id, user_id, attestation
yield self.transport_client.renew_group_attestation( )
destination, group_id, user_id, except Exception:
content={"attestation": attestation}, logger.exception("Error renewing attestation of %r in %r",
) user_id, group_id)
yield self.store.update_attestation_renewal(
group_id, user_id, attestation
)
for row in rows: for row in rows:
group_id = row["group_id"] group_id = row["group_id"]
user_id = row["user_id"] user_id = row["user_id"]
preserve_fn(_renew_attestation)(group_id, user_id) run_in_background(_renew_attestation, group_id, user_id)

View File

@ -19,7 +19,7 @@ import synapse
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.logcontext import ( from synapse.util.logcontext import (
make_deferred_yieldable, preserve_fn, run_in_background, make_deferred_yieldable, run_in_background,
) )
import logging import logging
@ -111,9 +111,7 @@ class ApplicationServicesHandler(object):
# Fork off pushes to these services # Fork off pushes to these services
for service in services: for service in services:
preserve_fn(self.scheduler.submit_event_for_as)( self.scheduler.submit_event_for_as(service, event)
service, event
)
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_room_events(events): def handle_room_events(events):
@ -198,7 +196,10 @@ class ApplicationServicesHandler(object):
services = yield self._get_services_for_3pn(protocol) services = yield self._get_services_for_3pn(protocol)
results = yield make_deferred_yieldable(defer.DeferredList([ results = yield make_deferred_yieldable(defer.DeferredList([
preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields) run_in_background(
self.appservice_api.query_3pe,
service, kind, protocol, fields,
)
for service in services for service in services
], consumeErrors=True)) ], consumeErrors=True))
@ -259,11 +260,15 @@ class ApplicationServicesHandler(object):
event based on the service regex. event based on the service regex.
""" """
services = self.store.get_app_services() services = self.store.get_app_services()
interested_list = [
s for s in services if ( # we can't use a list comprehension here. Since python 3, list
yield s.is_interested(event, self.store) # comprehensions use a generator internally. This means you can't yield
) # inside of a list comprehension anymore.
] interested_list = []
for s in services:
if (yield s.is_interested(event, self.store)):
interested_list.append(s)
defer.returnValue(interested_list) defer.returnValue(interested_list)
def _get_services_for_user(self, user_id): def _get_services_for_user(self, user_id):

View File

@ -24,7 +24,7 @@ from synapse.api.errors import (
SynapseError, CodeMessageException, FederationDeniedError, SynapseError, CodeMessageException, FederationDeniedError,
) )
from synapse.types import get_domain_from_id, UserID from synapse.types import get_domain_from_id, UserID
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -139,9 +139,9 @@ class E2eKeysHandler(object):
failures[destination] = _exception_to_failure(e) failures[destination] = _exception_to_failure(e)
yield make_deferred_yieldable(defer.gatherResults([ yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(do_remote_query)(destination) run_in_background(do_remote_query, destination)
for destination in remote_queries_not_in_cache for destination in remote_queries_not_in_cache
])) ], consumeErrors=True))
defer.returnValue({ defer.returnValue({
"device_keys": results, "failures": failures, "device_keys": results, "failures": failures,
@ -242,9 +242,9 @@ class E2eKeysHandler(object):
failures[destination] = _exception_to_failure(e) failures[destination] = _exception_to_failure(e)
yield make_deferred_yieldable(defer.gatherResults([ yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(claim_client_keys)(destination) run_in_background(claim_client_keys, destination)
for destination in remote_queries for destination in remote_queries
])) ], consumeErrors=True))
logger.info( logger.info(
"Claimed one-time-keys: %s", "Claimed one-time-keys: %s",

View File

@ -16,12 +16,14 @@
"""Contains handlers for federation events.""" """Contains handlers for federation events."""
import httplib
import itertools import itertools
import logging import logging
import sys
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
import six
from six.moves import http_client
from twisted.internet import defer from twisted.internet import defer
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
@ -637,7 +639,8 @@ class FederationHandler(BaseHandler):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults( results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
logcontext.preserve_fn(self.replication_layer.get_pdu)( logcontext.run_in_background(
self.replication_layer.get_pdu,
[dest], [dest],
event_id, event_id,
outlier=True, outlier=True,
@ -887,7 +890,7 @@ class FederationHandler(BaseHandler):
logger.warn("Rejecting event %s which has %i prev_events", logger.warn("Rejecting event %s which has %i prev_events",
ev.event_id, len(ev.prev_events)) ev.event_id, len(ev.prev_events))
raise SynapseError( raise SynapseError(
httplib.BAD_REQUEST, http_client.BAD_REQUEST,
"Too many prev_events", "Too many prev_events",
) )
@ -895,7 +898,7 @@ class FederationHandler(BaseHandler):
logger.warn("Rejecting event %s which has %i auth_events", logger.warn("Rejecting event %s which has %i auth_events",
ev.event_id, len(ev.auth_events)) ev.event_id, len(ev.auth_events))
raise SynapseError( raise SynapseError(
httplib.BAD_REQUEST, http_client.BAD_REQUEST,
"Too many auth_events", "Too many auth_events",
) )
@ -1023,7 +1026,7 @@ class FederationHandler(BaseHandler):
# lots of requests for missing prev_events which we do actually # lots of requests for missing prev_events which we do actually
# have. Hence we fire off the deferred, but don't wait for it. # have. Hence we fire off the deferred, but don't wait for it.
logcontext.preserve_fn(self._handle_queued_pdus)(room_queue) logcontext.run_in_background(self._handle_queued_pdus, room_queue)
defer.returnValue(True) defer.returnValue(True)
@ -1513,18 +1516,21 @@ class FederationHandler(BaseHandler):
backfilled=backfilled, backfilled=backfilled,
) )
except: # noqa: E722, as we reraise the exception this is fine. except: # noqa: E722, as we reraise the exception this is fine.
# Ensure that we actually remove the entries in the push actions tp, value, tb = sys.exc_info()
# staging area
logcontext.preserve_fn( logcontext.run_in_background(
self.store.remove_push_actions_from_staging self.store.remove_push_actions_from_staging,
)(event.event_id) event.event_id,
raise )
six.reraise(tp, value, tb)
if not backfilled: if not backfilled:
# this intentionally does not yield: we don't care about the result # this intentionally does not yield: we don't care about the result
# and don't need to wait for it. # and don't need to wait for it.
logcontext.preserve_fn(self.pusher_pool.on_new_notifications)( logcontext.run_in_background(
event_stream_id, max_stream_id self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id,
) )
defer.returnValue((context, event_stream_id, max_stream_id)) defer.returnValue((context, event_stream_id, max_stream_id))
@ -1538,7 +1544,8 @@ class FederationHandler(BaseHandler):
""" """
contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
logcontext.preserve_fn(self._prep_event)( logcontext.run_in_background(
self._prep_event,
origin, origin,
ev_info["event"], ev_info["event"],
state=ev_info.get("state"), state=ev_info.get("state"),
@ -1867,7 +1874,8 @@ class FederationHandler(BaseHandler):
different_events = yield logcontext.make_deferred_yieldable( different_events = yield logcontext.make_deferred_yieldable(
defer.gatherResults([ defer.gatherResults([
logcontext.preserve_fn(self.store.get_event)( logcontext.run_in_background(
self.store.get_event,
d, d,
allow_none=True, allow_none=True,
allow_rejected=False, allow_rejected=False,

View File

@ -27,7 +27,7 @@ from synapse.types import (
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -166,7 +166,8 @@ class InitialSyncHandler(BaseHandler):
(messages, token), current_state = yield make_deferred_yieldable( (messages, token), current_state = yield make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
preserve_fn(self.store.get_recent_events_for_room)( run_in_background(
self.store.get_recent_events_for_room,
event.room_id, event.room_id,
limit=limit, limit=limit,
end_token=room_end_token, end_token=room_end_token,
@ -180,8 +181,8 @@ class InitialSyncHandler(BaseHandler):
self.store, user_id, messages self.store, user_id, messages
) )
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token)
end_token = now_token.copy_and_replace("room_key", token[1]) end_token = now_token.copy_and_replace("room_key", room_end_token)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
d["messages"] = { d["messages"] = {
@ -324,8 +325,8 @@ class InitialSyncHandler(BaseHandler):
self.store, user_id, messages, is_peeking=is_peeking self.store, user_id, messages, is_peeking=is_peeking
) )
start_token = StreamToken.START.copy_and_replace("room_key", token[0]) start_token = StreamToken.START.copy_and_replace("room_key", token)
end_token = StreamToken.START.copy_and_replace("room_key", token[1]) end_token = StreamToken.START.copy_and_replace("room_key", stream_token)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -391,9 +392,10 @@ class InitialSyncHandler(BaseHandler):
presence, receipts, (messages, token) = yield defer.gatherResults( presence, receipts, (messages, token) = yield defer.gatherResults(
[ [
preserve_fn(get_presence)(), run_in_background(get_presence),
preserve_fn(get_receipts)(), run_in_background(get_receipts),
preserve_fn(self.store.get_recent_events_for_room)( run_in_background(
self.store.get_recent_events_for_room,
room_id, room_id,
limit=limit, limit=limit,
end_token=now_token.room_key, end_token=now_token.room_key,
@ -406,8 +408,8 @@ class InitialSyncHandler(BaseHandler):
self.store, user_id, messages, is_peeking=is_peeking, self.store, user_id, messages, is_peeking=is_peeking,
) )
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token)
end_token = now_token.copy_and_replace("room_key", token[1]) end_token = now_token
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View File

@ -13,10 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import simplejson
import sys
from canonicaljson import encode_canonical_json
import six
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership, MAX_DEPTH
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
@ -25,7 +31,7 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken, UserID, RoomAlias, RoomStreamToken,
) )
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
from synapse.util.logcontext import preserve_fn, run_in_background from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -34,11 +40,6 @@ from synapse.replication.http.send_event import send_event_to_master
from ._base import BaseHandler from ._base import BaseHandler
from canonicaljson import encode_canonical_json
import logging
import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -624,6 +625,10 @@ class EventCreationHandler(object):
if prev_events_and_hashes: if prev_events_and_hashes:
depth = max([d for _, _, d in prev_events_and_hashes]) + 1 depth = max([d for _, _, d in prev_events_and_hashes]) + 1
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(depth, MAX_DEPTH)
else: else:
depth = 1 depth = 1
@ -729,8 +734,14 @@ class EventCreationHandler(object):
except: # noqa: E722, as we reraise the exception this is fine. except: # noqa: E722, as we reraise the exception this is fine.
# Ensure that we actually remove the entries in the push actions # Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them. # staging area, if we calculated them.
preserve_fn(self.store.remove_push_actions_from_staging)(event.event_id) tp, value, tb = sys.exc_info()
raise
run_in_background(
self.store.remove_push_actions_from_staging,
event.event_id,
)
six.reraise(tp, value, tb)
@defer.inlineCallbacks @defer.inlineCallbacks
def persist_and_notify_client_event( def persist_and_notify_client_event(
@ -850,22 +861,33 @@ class EventCreationHandler(object):
# this intentionally does not yield: we don't care about the result # this intentionally does not yield: we don't care about the result
# and don't need to wait for it. # and don't need to wait for it.
preserve_fn(self.pusher_pool.on_new_notifications)( run_in_background(
self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id event_stream_id, max_stream_id
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _notify(): def _notify():
yield run_on_reactor() yield run_on_reactor()
self.notifier.on_new_room_event( try:
event, event_stream_id, max_stream_id, self.notifier.on_new_room_event(
extra_users=extra_users event, event_stream_id, max_stream_id,
) extra_users=extra_users
)
except Exception:
logger.exception("Error notifying about new room event")
preserve_fn(_notify)() run_in_background(_notify)
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
# We don't want to block sending messages on any presence code. This # We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while. # matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(requester.user) run_in_background(self._bump_active_time, requester.user)
@defer.inlineCallbacks
def _bump_active_time(self, user):
try:
presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user)
except Exception:
logger.exception("Error bumping presence active time")

View File

@ -31,7 +31,7 @@ from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import run_in_background
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer from synapse.util.wheel_timer import WheelTimer
@ -254,6 +254,14 @@ class PresenceHandler(object):
logger.info("Finished _persist_unpersisted_changes") logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks
def _update_states_and_catch_exception(self, new_states):
try:
res = yield self._update_states(new_states)
defer.returnValue(res)
except Exception:
logger.exception("Error updating presence")
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_states(self, new_states): def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes """Updates presence of users. Sets the appropriate timeouts. Pokes
@ -364,7 +372,7 @@ class PresenceHandler(object):
now=now, now=now,
) )
preserve_fn(self._update_states)(changes) run_in_background(self._update_states_and_catch_exception, changes)
except Exception: except Exception:
logger.exception("Exception in _handle_timeouts loop") logger.exception("Exception in _handle_timeouts loop")
@ -422,20 +430,23 @@ class PresenceHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _end(): def _end():
if affect_presence: try:
self.user_to_num_current_syncs[user_id] -= 1 self.user_to_num_current_syncs[user_id] -= 1
prev_state = yield self.current_state_for_user(user_id) prev_state = yield self.current_state_for_user(user_id)
yield self._update_states([prev_state.copy_and_replace( yield self._update_states([prev_state.copy_and_replace(
last_user_sync_ts=self.clock.time_msec(), last_user_sync_ts=self.clock.time_msec(),
)]) )])
except Exception:
logger.exception("Error updating presence after sync")
@contextmanager @contextmanager
def _user_syncing(): def _user_syncing():
try: try:
yield yield
finally: finally:
preserve_fn(_end)() if affect_presence:
run_in_background(_end)
defer.returnValue(_user_syncing()) defer.returnValue(_user_syncing())

View File

@ -135,37 +135,40 @@ class ReceiptsHandler(BaseHandler):
"""Given a list of receipts, works out which remote servers should be """Given a list of receipts, works out which remote servers should be
poked and pokes them. poked and pokes them.
""" """
# TODO: Some of this stuff should be coallesced. try:
for receipt in receipts: # TODO: Some of this stuff should be coallesced.
room_id = receipt["room_id"] for receipt in receipts:
receipt_type = receipt["receipt_type"] room_id = receipt["room_id"]
user_id = receipt["user_id"] receipt_type = receipt["receipt_type"]
event_ids = receipt["event_ids"] user_id = receipt["user_id"]
data = receipt["data"] event_ids = receipt["event_ids"]
data = receipt["data"]
users = yield self.state.get_current_user_in_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
remotedomains = set(get_domain_from_id(u) for u in users) remotedomains = set(get_domain_from_id(u) for u in users)
remotedomains = remotedomains.copy() remotedomains = remotedomains.copy()
remotedomains.discard(self.server_name) remotedomains.discard(self.server_name)
logger.debug("Sending receipt to: %r", remotedomains) logger.debug("Sending receipt to: %r", remotedomains)
for domain in remotedomains: for domain in remotedomains:
self.federation.send_edu( self.federation.send_edu(
destination=domain, destination=domain,
edu_type="m.receipt", edu_type="m.receipt",
content={ content={
room_id: { room_id: {
receipt_type: { receipt_type: {
user_id: { user_id: {
"event_ids": event_ids, "event_ids": event_ids,
"data": data, "data": data,
}
} }
} },
}, },
}, key=(room_id, receipt_type, user_id),
key=(room_id, receipt_type, user_id), )
) except Exception:
logger.exception("Error pushing receipts to remote servers")
@defer.inlineCallbacks @defer.inlineCallbacks
def get_receipts_for_room(self, room_id, to_key): def get_receipts_for_room(self, room_id, to_key):

View File

@ -15,6 +15,8 @@
from twisted.internet import defer from twisted.internet import defer
from six.moves import range
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import ( from synapse.api.constants import (
@ -200,7 +202,7 @@ class RoomListHandler(BaseHandler):
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1 step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
chunk = [] chunk = []
for i in xrange(0, len(rooms_to_scan), step): for i in range(0, len(rooms_to_scan), step):
batch = rooms_to_scan[i:i + step] batch = rooms_to_scan[i:i + step]
logger.info("Processing %i rooms for result", len(batch)) logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute( yield concurrently_execute(

View File

@ -354,12 +354,24 @@ class SyncHandler(object):
since_key = since_token.room_key since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat: while limited and len(recents) < timeline_limit and max_repeat:
events, end_key = yield self.store.get_room_events_stream_for_room( # If we have a since_key then we are trying to get any events
room_id, # that have happened since `since_key` up to `end_key`, so we
limit=load_limit + 1, # can just use `get_room_events_stream_for_room`.
from_key=since_key, # Otherwise, we want to return the last N events in the room
to_key=end_key, # in toplogical ordering.
) if since_key:
events, end_key = yield self.store.get_room_events_stream_for_room(
room_id,
limit=load_limit + 1,
from_key=since_key,
to_key=end_key,
)
else:
events, end_key = yield self.store.get_recent_events_for_room(
room_id,
limit=load_limit + 1,
end_token=end_key,
)
loaded_recents = sync_config.filter_collection.filter_room_timeline( loaded_recents = sync_config.filter_collection.filter_room_timeline(
events events
) )
@ -429,7 +441,7 @@ class SyncHandler(object):
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
""" """
last_events, token = yield self.store.get_recent_events_for_room( last_events, _ = yield self.store.get_recent_events_for_room(
room_id, end_token=stream_position.room_key, limit=1, room_id, end_token=stream_position.room_key, limit=1,
) )

View File

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
@ -97,7 +97,8 @@ class TypingHandler(object):
if self.hs.is_mine_id(member.user_id): if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None) last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
preserve_fn(self._push_remote)( run_in_background(
self._push_remote,
member=member, member=member,
typing=True typing=True
) )
@ -196,7 +197,7 @@ class TypingHandler(object):
def _push_update(self, member, typing): def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id): if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users. # Only send updates for changes to our own users.
preserve_fn(self._push_remote)(member, typing) run_in_background(self._push_remote, member, typing)
self._push_update_local( self._push_update_local(
member=member, member=member,
@ -205,28 +206,31 @@ class TypingHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_remote(self, member, typing): def _push_remote(self, member, typing):
users = yield self.state.get_current_user_in_room(member.room_id) try:
self._member_last_federation_poke[member] = self.clock.time_msec() users = yield self.state.get_current_user_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec() now = self.clock.time_msec()
self.wheel_timer.insert( self.wheel_timer.insert(
now=now, now=now,
obj=member, obj=member,
then=now + FEDERATION_PING_INTERVAL, then=now + FEDERATION_PING_INTERVAL,
) )
for domain in set(get_domain_from_id(u) for u in users): for domain in set(get_domain_from_id(u) for u in users):
if domain != self.server_name: if domain != self.server_name:
self.federation.send_edu( self.federation.send_edu(
destination=domain, destination=domain,
edu_type="m.typing", edu_type="m.typing",
content={ content={
"room_id": member.room_id, "room_id": member.room_id,
"user_id": member.user_id, "user_id": member.user_id,
"typing": typing, "typing": typing,
}, },
key=member, key=member,
) )
except Exception:
logger.exception("Error pushing typing notif to remotes")
@defer.inlineCallbacks @defer.inlineCallbacks
def _recv_edu(self, origin, content): def _recv_edu(self, origin, content):

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,3 +13,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet.defer import CancelledError
from twisted.python import failure
from synapse.api.errors import SynapseError
class RequestTimedOutError(SynapseError):
"""Exception representing timeout of an outbound request"""
def __init__(self):
super(RequestTimedOutError, self).__init__(504, "Timed out")
def cancelled_to_request_timed_out_error(value, timeout):
"""Turns CancelledErrors into RequestTimedOutErrors.
For use with async.add_timeout_to_deferred
"""
if isinstance(value, failure.Failure):
value.trap(CancelledError)
raise RequestTimedOutError()
return value

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.http.server import wrap_request_handler from synapse.http.server import wrap_json_request_handler
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@ -42,14 +42,13 @@ class AdditionalResource(Resource):
Resource.__init__(self) Resource.__init__(self)
self._handler = handler self._handler = handler
# these are required by the request_handler wrapper # required by the request_handler wrapper
self.version_string = hs.version_string
self.clock = hs.get_clock() self.clock = hs.get_clock()
def render(self, request): def render(self, request):
self._async_render(request) self._async_render(request)
return NOT_DONE_YET return NOT_DONE_YET
@wrap_request_handler @wrap_json_request_handler
def _async_render(self, request): def _async_render(self, request):
return self._handler(request) return self._handler(request)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,9 +19,10 @@ from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
) )
from synapse.http import cancelled_to_request_timed_out_error
from synapse.util.async import add_timeout_to_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
from synapse.util import logcontext
import synapse.metrics import synapse.metrics
from synapse.http.endpoint import SpiderEndpoint from synapse.http.endpoint import SpiderEndpoint
@ -38,7 +40,7 @@ from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
from StringIO import StringIO from six import StringIO
import simplejson as json import simplejson as json
import logging import logging
@ -95,21 +97,17 @@ class SimpleHttpClient(object):
# counters to it # counters to it
outgoing_requests_counter.inc(method) outgoing_requests_counter.inc(method)
def send_request():
request_deferred = self.agent.request(
method, uri, *args, **kwargs
)
return self.clock.time_bound_deferred(
request_deferred,
time_out=60,
)
logger.info("Sending request %s %s", method, uri) logger.info("Sending request %s %s", method, uri)
try: try:
with logcontext.PreserveLoggingContext(): request_deferred = self.agent.request(
response = yield send_request() method, uri, *args, **kwargs
)
add_timeout_to_deferred(
request_deferred,
60, cancelled_to_request_timed_out_error,
)
response = yield make_deferred_yieldable(request_deferred)
incoming_responses_counter.inc(method, response.code) incoming_responses_counter.inc(method, response.code)
logger.info( logger.info(
@ -509,7 +507,7 @@ class SpiderHttpClient(SimpleHttpClient):
reactor, reactor,
SpiderEndpointFactory(hs) SpiderEndpointFactory(hs)
) )
), [('gzip', GzipDecoder)] ), [(b'gzip', GzipDecoder)]
) )
# We could look like Chrome: # We could look like Chrome:
# self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)

View File

@ -115,10 +115,15 @@ class _WrappedConnection(object):
if time.time() - self.last_request >= 2.5 * 60: if time.time() - self.last_request >= 2.5 * 60:
self.abort() self.abort()
# Abort the underlying TLS connection. The abort() method calls # Abort the underlying TLS connection. The abort() method calls
# loseConnection() on the underlying TLS connection which tries to # loseConnection() on the TLS connection which tries to
# shutdown the connection cleanly. We call abortConnection() # shutdown the connection cleanly. We call abortConnection()
# since that will promptly close the underlying TCP connection. # since that will promptly close the TLS connection.
self.transport.abortConnection() #
# In Twisted >18.4; the TLS connection will be None if it has closed
# which will make abortConnection() throw. Check that the TLS connection
# is not None before trying to close it.
if self.transport.getHandle() is not None:
self.transport.abortConnection()
def request(self, request): def request(self, request):
self.last_request = time.time() self.last_request = time.time()
@ -286,7 +291,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
if (len(answers) == 1 if (len(answers) == 1
and answers[0].type == dns.SRV and answers[0].type == dns.SRV
and answers[0].payload and answers[0].payload
and answers[0].payload.target == dns.Name('.')): and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name) raise ConnectError("Service %s unavailable" % service_name)
for answer in answers: for answer in answers:

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,17 +13,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.util.retryutils
from twisted.internet import defer, reactor, protocol from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, HTTPConnectionPool, Agent from twisted.web.client import readBody, HTTPConnectionPool, Agent
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
from synapse.http import cancelled_to_request_timed_out_error
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep
from synapse.util import logcontext
import synapse.metrics import synapse.metrics
from synapse.util.async import sleep, add_timeout_to_deferred
from synapse.util import logcontext
from synapse.util.logcontext import make_deferred_yieldable
import synapse.util.retryutils
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -38,8 +41,7 @@ import logging
import random import random
import sys import sys
import urllib import urllib
import urlparse from six.moves.urllib import parse as urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
outbound_logger = logging.getLogger("synapse.http.outbound") outbound_logger = logging.getLogger("synapse.http.outbound")
@ -184,21 +186,20 @@ class MatrixFederationHttpClient(object):
producer = body_callback(method, http_url_bytes, headers_dict) producer = body_callback(method, http_url_bytes, headers_dict)
try: try:
def send_request(): request_deferred = self.agent.request(
request_deferred = self.agent.request( method,
method, url_bytes,
url_bytes, Headers(headers_dict),
Headers(headers_dict), producer
producer )
) add_timeout_to_deferred(
request_deferred,
return self.clock.time_bound_deferred( timeout / 1000. if timeout else 60,
request_deferred, cancelled_to_request_timed_out_error,
time_out=timeout / 1000. if timeout else 60, )
) response = yield make_deferred_yieldable(
request_deferred,
with logcontext.PreserveLoggingContext(): )
response = yield send_request()
log_result = "%d %s" % (response.code, response.phrase,) log_result = "%d %s" % (response.code, response.phrase,)
break break

View File

@ -0,0 +1,149 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import synapse.metrics
from synapse.util.logcontext import LoggingContext
logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for("synapse.http.server")
# total number of responses served, split by method/servlet/tag
response_count = metrics.register_counter(
"response_count",
labels=["method", "servlet", "tag"],
alternative_names=(
# the following are all deprecated aliases for the same metric
metrics.name_prefix + x for x in (
"_requests",
"_response_time:count",
"_response_ru_utime:count",
"_response_ru_stime:count",
"_response_db_txn_count:count",
"_response_db_txn_duration:count",
)
)
)
requests_counter = metrics.register_counter(
"requests_received",
labels=["method", "servlet", ],
)
outgoing_responses_counter = metrics.register_counter(
"responses",
labels=["method", "code"],
)
response_timer = metrics.register_counter(
"response_time_seconds",
labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_time:total",
),
)
response_ru_utime = metrics.register_counter(
"response_ru_utime_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_ru_utime:total",
),
)
response_ru_stime = metrics.register_counter(
"response_ru_stime_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_ru_stime:total",
),
)
response_db_txn_count = metrics.register_counter(
"response_db_txn_count", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_db_txn_count:total",
),
)
# seconds spent waiting for db txns, excluding scheduling time, when processing
# this request
response_db_txn_duration = metrics.register_counter(
"response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_db_txn_duration:total",
),
)
# seconds spent waiting for a db connection, when processing this request
response_db_sched_duration = metrics.register_counter(
"response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
)
# size in bytes of the response written
response_size = metrics.register_counter(
"response_size", labels=["method", "servlet", "tag"]
)
class RequestMetrics(object):
def start(self, time_msec, name):
self.start = time_msec
self.start_context = LoggingContext.current_context()
self.name = name
def stop(self, time_msec, request):
context = LoggingContext.current_context()
tag = ""
if context:
tag = context.tag
if context != self.start_context:
logger.warn(
"Context have unexpectedly changed %r, %r",
context, self.start_context
)
return
outgoing_responses_counter.inc(request.method, str(request.code))
response_count.inc(request.method, self.name, tag)
response_timer.inc_by(
time_msec - self.start, request.method,
self.name, tag
)
ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(
ru_utime, request.method, self.name, tag
)
response_ru_stime.inc_by(
ru_stime, request.method, self.name, tag
)
response_db_txn_count.inc_by(
context.db_txn_count, request.method, self.name, tag
)
response_db_txn_duration.inc_by(
context.db_txn_duration_ms / 1000., request.method, self.name, tag
)
response_db_sched_duration.inc_by(
context.db_sched_duration_ms / 1000., request.method, self.name, tag
)
response_size.inc_by(request.sentLength, request.method, self.name, tag)

View File

@ -18,6 +18,9 @@
from synapse.api.errors import ( from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError, Codes cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError, Codes
) )
from synapse.http.request_metrics import (
requests_counter,
)
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -41,178 +44,103 @@ import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
# total number of responses served, split by method/servlet/tag def wrap_json_request_handler(h):
response_count = metrics.register_counter( """Wraps a request handler method with exception handling.
"response_count",
labels=["method", "servlet", "tag"],
alternative_names=(
# the following are all deprecated aliases for the same metric
metrics.name_prefix + x for x in (
"_requests",
"_response_time:count",
"_response_ru_utime:count",
"_response_ru_stime:count",
"_response_db_txn_count:count",
"_response_db_txn_duration:count",
)
)
)
requests_counter = metrics.register_counter( Also adds logging as per wrap_request_handler_with_logging.
"requests_received",
labels=["method", "servlet", ],
)
outgoing_responses_counter = metrics.register_counter( The handler method must have a signature of "handle_foo(self, request)",
"responses", where "self" must have a "clock" attribute (and "request" must be a
labels=["method", "code"], SynapseRequest).
)
response_timer = metrics.register_counter( The handler must return a deferred. If the deferred succeeds we assume that
"response_time_seconds",
labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_time:total",
),
)
response_ru_utime = metrics.register_counter(
"response_ru_utime_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_ru_utime:total",
),
)
response_ru_stime = metrics.register_counter(
"response_ru_stime_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_ru_stime:total",
),
)
response_db_txn_count = metrics.register_counter(
"response_db_txn_count", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_db_txn_count:total",
),
)
# seconds spent waiting for db txns, excluding scheduling time, when processing
# this request
response_db_txn_duration = metrics.register_counter(
"response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_db_txn_duration:total",
),
)
# seconds spent waiting for a db connection, when processing this request
response_db_sched_duration = metrics.register_counter(
"response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
)
# size in bytes of the response written
response_size = metrics.register_counter(
"response_size", labels=["method", "servlet", "tag"]
)
_next_request_id = 0
def request_handler(include_metrics=False):
"""Decorator for ``wrap_request_handler``"""
return lambda request_handler: wrap_request_handler(request_handler, include_metrics)
def wrap_request_handler(request_handler, include_metrics=False):
"""Wraps a method that acts as a request handler with the necessary logging
and exception handling.
The method must have a signature of "handle_foo(self, request)". The
argument "self" must have "version_string" and "clock" attributes. The
argument "request" must be a twisted HTTP request.
The method must return a deferred. If the deferred succeeds we assume that
a response has been sent. If the deferred fails with a SynapseError we use a response has been sent. If the deferred fails with a SynapseError we use
it to send a JSON response with the appropriate HTTP reponse code. If the it to send a JSON response with the appropriate HTTP reponse code. If the
deferred fails with any other type of error we send a 500 reponse. deferred fails with any other type of error we send a 500 reponse.
We insert a unique request-id into the logging context for this request and
log the response and duration for this request.
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def wrapped_request_handler(self, request): def wrapped_request_handler(self, request):
global _next_request_id try:
request_id = "%s-%s" % (request.method, _next_request_id) yield h(self, request)
_next_request_id += 1 except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
else:
logger.exception(e)
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
except Exception:
# failure.Failure() fishes the original Failure out
# of our stack, and thus gives us a sensible stack
# trace.
f = failure.Failure()
logger.error(
"Failed handle request via %r: %r: %s",
h,
request,
f.getTraceback().rstrip(),
)
respond_with_json(
request,
500,
{
"error": "Internal server error",
"errcode": Codes.UNKNOWN,
},
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
return wrap_request_handler_with_logging(wrapped_request_handler)
def wrap_request_handler_with_logging(h):
"""Wraps a request handler to provide logging and metrics
The handler method must have a signature of "handle_foo(self, request)",
where "self" must have a "clock" attribute (and "request" must be a
SynapseRequest).
As well as calling `request.processing` (which will log the response and
duration for this request), the wrapped request handler will insert the
request id into the logging context.
"""
@defer.inlineCallbacks
def wrapped_request_handler(self, request):
"""
Args:
self:
request (synapse.http.site.SynapseRequest):
"""
request_id = request.get_request_id()
with LoggingContext(request_id) as request_context: with LoggingContext(request_id) as request_context:
request_context.request = request_id
with Measure(self.clock, "wrapped_request_handler"): with Measure(self.clock, "wrapped_request_handler"):
request_metrics = RequestMetrics()
# we start the request metrics timer here with an initial stab # we start the request metrics timer here with an initial stab
# at the servlet name. For most requests that name will be # at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render # JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet. # will update it once it picks a servlet.
servlet_name = self.__class__.__name__ servlet_name = self.__class__.__name__
request_metrics.start(self.clock, name=servlet_name) with request.processing(servlet_name):
with PreserveLoggingContext(request_context):
d = h(self, request)
request_context.request = request_id # record the arrival of the request *after*
with request.processing(): # dispatching to the handler, so that the handler
try: # can update the servlet name in the request
with PreserveLoggingContext(request_context): # metrics
if include_metrics: requests_counter.inc(request.method,
yield request_handler(self, request, request_metrics) request.request_metrics.name)
else: yield d
requests_counter.inc(request.method, servlet_name)
yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except Exception:
# failure.Failure() fishes the original Failure out
# of our stack, and thus gives us a sensible stack
# trace.
f = failure.Failure()
logger.error(
"Failed handle request %s.%s on %r: %r: %s",
request_handler.__module__,
request_handler.__name__,
self,
request,
f.getTraceback().rstrip(),
)
respond_with_json(
request,
500,
{
"error": "Internal server error",
"errcode": Codes.UNKNOWN,
},
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
finally:
try:
request_metrics.stop(
self.clock, request
)
except Exception as e:
logger.warn("Failed to stop metrics: %r", e)
return wrapped_request_handler return wrapped_request_handler
@ -262,7 +190,6 @@ class JsonResource(HttpServer, resource.Resource):
self.canonical_json = canonical_json self.canonical_json = canonical_json
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.path_regexs = {} self.path_regexs = {}
self.version_string = hs.version_string
self.hs = hs self.hs = hs
def register_paths(self, method, path_patterns, callback): def register_paths(self, method, path_patterns, callback):
@ -278,13 +205,9 @@ class JsonResource(HttpServer, resource.Resource):
self._async_render(request) self._async_render(request)
return server.NOT_DONE_YET return server.NOT_DONE_YET
# Disable metric reporting because _async_render does its own metrics. @wrap_json_request_handler
# It does its own metric reporting because _async_render dispatches to
# a callback and it's the class name of that callback we want to report
# against rather than the JsonResource itself.
@request_handler(include_metrics=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render(self, request, request_metrics): def _async_render(self, request):
""" This gets called from render() every time someone sends us a request. """ This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
@ -296,9 +219,7 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname = servlet_instance.__class__.__name__ servlet_classname = servlet_instance.__class__.__name__
else: else:
servlet_classname = "%r" % callback servlet_classname = "%r" % callback
request.request_metrics.name = servlet_classname
request_metrics.name = servlet_classname
requests_counter.inc(request.method, servlet_classname)
# Now trigger the callback. If it returns a response, we send it # Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper # here. If it throws an exception, that is handled by the wrapper
@ -345,15 +266,12 @@ class JsonResource(HttpServer, resource.Resource):
def _send_response(self, request, code, response_json_object, def _send_response(self, request, code, response_json_object,
response_code_message=None): response_code_message=None):
outgoing_responses_counter.inc(request.method, str(code))
# TODO: Only enable CORS for the requests that need it. # TODO: Only enable CORS for the requests that need it.
respond_with_json( respond_with_json(
request, code, response_json_object, request, code, response_json_object,
send_cors=True, send_cors=True,
response_code_message=response_code_message, response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request), pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
canonical_json=self.canonical_json, canonical_json=self.canonical_json,
) )
@ -386,54 +304,6 @@ def _unrecognised_request_handler(request):
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
class RequestMetrics(object):
def start(self, clock, name):
self.start = clock.time_msec()
self.start_context = LoggingContext.current_context()
self.name = name
def stop(self, clock, request):
context = LoggingContext.current_context()
tag = ""
if context:
tag = context.tag
if context != self.start_context:
logger.warn(
"Context have unexpectedly changed %r, %r",
context, self.start_context
)
return
response_count.inc(request.method, self.name, tag)
response_timer.inc_by(
clock.time_msec() - self.start, request.method,
self.name, tag
)
ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(
ru_utime, request.method, self.name, tag
)
response_ru_stime.inc_by(
ru_stime, request.method, self.name, tag
)
response_db_txn_count.inc_by(
context.db_txn_count, request.method, self.name, tag
)
response_db_txn_duration.inc_by(
context.db_txn_duration_ms / 1000., request.method, self.name, tag
)
response_db_sched_duration.inc_by(
context.db_sched_duration_ms / 1000., request.method, self.name, tag
)
response_size.inc_by(request.sentLength, request.method, self.name, tag)
class RootRedirect(resource.Resource): class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path.""" """Redirects the root '/' path to another path."""
@ -452,7 +322,7 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False, def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False, response_code_message=None, pretty_print=False,
version_string="", canonical_json=True): canonical_json=True):
# could alternatively use request.notifyFinish() and flip a flag when # could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like # the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste. # a waste.
@ -474,12 +344,11 @@ def respond_with_json(request, code, json_object, send_cors=False,
request, code, json_bytes, request, code, json_bytes,
send_cors=send_cors, send_cors=send_cors,
response_code_message=response_code_message, response_code_message=response_code_message,
version_string=version_string
) )
def respond_with_json_bytes(request, code, json_bytes, send_cors=False, def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
version_string="", response_code_message=None): response_code_message=None):
"""Sends encoded JSON in response to the given request. """Sends encoded JSON in response to the given request.
Args: Args:
@ -493,7 +362,6 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.setResponseCode(code, message=response_code_message) request.setResponseCode(code, message=response_code_message)
request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Content-Type", b"application/json")
request.setHeader(b"Server", version_string)
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate") request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
@ -546,6 +414,6 @@ def _request_user_agent_is_curl(request):
b"User-Agent", default=[] b"User-Agent", default=[]
) )
for user_agent in user_agents: for user_agent in user_agents:
if "curl" in user_agent: if b"curl" in user_agent:
return True return True
return False return False

View File

@ -12,24 +12,48 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.logcontext import LoggingContext
from twisted.web.server import Site, Request
import contextlib import contextlib
import logging import logging
import re import re
import time import time
from twisted.web.server import Site, Request
from synapse.http.request_metrics import RequestMetrics
from synapse.util.logcontext import LoggingContext
logger = logging.getLogger(__name__)
ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
_next_request_seq = 0
class SynapseRequest(Request): class SynapseRequest(Request):
"""Class which encapsulates an HTTP request to synapse.
All of the requests processed in synapse are of this type.
It extends twisted's twisted.web.server.Request, and adds:
* Unique request ID
* Redaction of access_token query-params in __repr__
* Logging at start and end
* Metrics to record CPU, wallclock and DB time by endpoint.
It provides a method `processing` which should be called by the Resource
which is handling the request, and returns a context manager.
"""
def __init__(self, site, *args, **kw): def __init__(self, site, *args, **kw):
Request.__init__(self, *args, **kw) Request.__init__(self, *args, **kw)
self.site = site self.site = site
self.authenticated_entity = None self.authenticated_entity = None
self.start_time = 0 self.start_time = 0
global _next_request_seq
self.request_seq = _next_request_seq
_next_request_seq += 1
def __repr__(self): def __repr__(self):
# We overwrite this so that we don't log ``access_token`` # We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % ( return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
@ -41,6 +65,9 @@ class SynapseRequest(Request):
self.site.site_tag, self.site.site_tag,
) )
def get_request_id(self):
return "%s-%i" % (self.method, self.request_seq)
def get_redacted_uri(self): def get_redacted_uri(self):
return ACCESS_TOKEN_RE.sub( return ACCESS_TOKEN_RE.sub(
br'\1<redacted>\3', br'\1<redacted>\3',
@ -50,7 +77,16 @@ class SynapseRequest(Request):
def get_user_agent(self): def get_user_agent(self):
return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
def started_processing(self): def render(self, resrc):
# override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string)
return Request.render(self, resrc)
def _started_processing(self, servlet_name):
self.start_time = int(time.time() * 1000)
self.request_metrics = RequestMetrics()
self.request_metrics.start(self.start_time, name=servlet_name)
self.site.access_logger.info( self.site.access_logger.info(
"%s - %s - Received request: %s %s", "%s - %s - Received request: %s %s",
self.getClientIP(), self.getClientIP(),
@ -58,10 +94,8 @@ class SynapseRequest(Request):
self.method, self.method,
self.get_redacted_uri() self.get_redacted_uri()
) )
self.start_time = int(time.time() * 1000)
def finished_processing(self):
def _finished_processing(self):
try: try:
context = LoggingContext.current_context() context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage() ru_utime, ru_stime = context.get_resource_usage()
@ -72,6 +106,8 @@ class SynapseRequest(Request):
ru_utime, ru_stime = (0, 0) ru_utime, ru_stime = (0, 0)
db_txn_count, db_txn_duration_ms = (0, 0) db_txn_count, db_txn_duration_ms = (0, 0)
end_time = int(time.time() * 1000)
self.site.access_logger.info( self.site.access_logger.info(
"%s - %s - {%s}" "%s - %s - {%s}"
" Processed request: %dms (%dms, %dms) (%dms/%dms/%d)" " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
@ -79,7 +115,7 @@ class SynapseRequest(Request):
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
self.authenticated_entity, self.authenticated_entity,
int(time.time() * 1000) - self.start_time, end_time - self.start_time,
int(ru_utime * 1000), int(ru_utime * 1000),
int(ru_stime * 1000), int(ru_stime * 1000),
db_sched_duration_ms, db_sched_duration_ms,
@ -93,11 +129,38 @@ class SynapseRequest(Request):
self.get_user_agent(), self.get_user_agent(),
) )
try:
self.request_metrics.stop(end_time, self)
except Exception as e:
logger.warn("Failed to stop metrics: %r", e)
@contextlib.contextmanager @contextlib.contextmanager
def processing(self): def processing(self, servlet_name):
self.started_processing() """Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is:
@defer.inlineCallbacks
def handle_request(request):
with request.processing("FooServlet"):
yield really_handle_the_request()
This will log the request's arrival. Once the context manager is
closed, the completion of the request will be logged, and the various
metrics will be updated.
Args:
servlet_name (str): the name of the servlet which will be
processing this request. This is used in the metrics.
It is possible to update this afterwards by updating
self.request_metrics.servlet_name.
"""
# TODO: we should probably just move this into render() and finish(),
# to save having to call a separate method.
self._started_processing(servlet_name)
yield yield
self.finished_processing() self._finished_processing()
class XForwardedForRequest(SynapseRequest): class XForwardedForRequest(SynapseRequest):
@ -135,7 +198,8 @@ class SynapseSite(Site):
Subclass of a twisted http Site that does access logging with python's Subclass of a twisted http Site that does access logging with python's
standard logging standard logging
""" """
def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs): def __init__(self, logger_name, site_tag, config, resource,
server_version_string, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs) Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag self.site_tag = site_tag
@ -143,6 +207,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False) proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied) self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name) self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string
def log(self, request): def log(self, request):
pass pass

View File

@ -16,6 +16,7 @@
from itertools import chain from itertools import chain
import logging import logging
import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,8 +57,7 @@ class BaseMetric(object):
return not len(self.labels) return not len(self.labels)
def _render_labelvalue(self, value): def _render_labelvalue(self, value):
# TODO: escape backslashes, quotes and newlines return '"%s"' % (_escape_label_value(value),)
return '"%s"' % (value)
def _render_key(self, values): def _render_key(self, values):
if self.is_scalar(): if self.is_scalar():
@ -71,7 +71,8 @@ class BaseMetric(object):
"""Render this metric for a single set of labels """Render this metric for a single set of labels
Args: Args:
label_values (list[str]): values for each of the labels label_values (list[object]): values for each of the labels,
(which get stringified).
value: value of the metric at with these labels value: value of the metric at with these labels
Returns: Returns:
@ -299,3 +300,29 @@ class MemoryUsageMetric(object):
"process_psutil_rss:total %d" % sum_rss, "process_psutil_rss:total %d" % sum_rss,
"process_psutil_rss:count %d" % len_rss, "process_psutil_rss:count %d" % len_rss,
] ]
def _escape_character(m):
"""Replaces a single character with its escape sequence.
Args:
m (re.MatchObject): A match object whose first group is the single
character to replace
Returns:
str
"""
c = m.group(1)
if c == "\\":
return "\\\\"
elif c == "\"":
return "\\\""
elif c == "\n":
return "\\n"
return c
def _escape_label_value(value):
"""Takes a label value and escapes quotes, newlines and backslashes
"""
return re.sub(r"([\n\"\\])", _escape_character, str(value))

View File

@ -14,14 +14,17 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.util import DeferredTimedOutError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import ObservableDeferred from synapse.util.async import (
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn ObservableDeferred, add_timeout_to_deferred,
DeferredTimeoutError,
)
from synapse.util.logcontext import PreserveLoggingContext, run_in_background
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -251,9 +254,7 @@ class Notifier(object):
def _on_new_room_event(self, event, room_stream_id, extra_users=[]): def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
preserve_fn(self.appservice_handler.notify_interested_services)( run_in_background(self._notify_app_services, room_stream_id)
room_stream_id
)
if self.federation_sender: if self.federation_sender:
self.federation_sender.notify_new_events(room_stream_id) self.federation_sender.notify_new_events(room_stream_id)
@ -267,6 +268,13 @@ class Notifier(object):
rooms=[event.room_id], rooms=[event.room_id],
) )
@defer.inlineCallbacks
def _notify_app_services(self, room_stream_id):
try:
yield self.appservice_handler.notify_interested_services(room_stream_id)
except Exception:
logger.exception("Error notifying application services of event")
def on_new_event(self, stream_key, new_token, users=[], rooms=[]): def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise. """ Used to inform listeners that something has happend event wise.
@ -331,11 +339,12 @@ class Notifier(object):
# Now we wait for the _NotifierUserStream to be told there # Now we wait for the _NotifierUserStream to be told there
# is a new token. # is a new token.
listener = user_stream.new_listener(prev_token) listener = user_stream.new_listener(prev_token)
add_timeout_to_deferred(
listener.deferred,
(end_time - now) / 1000.,
)
with PreserveLoggingContext(): with PreserveLoggingContext():
yield self.clock.time_bound_deferred( yield listener.deferred
listener.deferred,
time_out=(end_time - now) / 1000.
)
current_token = user_stream.current_token current_token = user_stream.current_token
@ -346,7 +355,7 @@ class Notifier(object):
# Update the prev_token to the current_token since nothing # Update the prev_token to the current_token since nothing
# has happened between the old prev_token and the current_token # has happened between the old prev_token and the current_token
prev_token = current_token prev_token = current_token
except DeferredTimedOutError: except DeferredTimeoutError:
break break
except defer.CancelledError: except defer.CancelledError:
break break
@ -551,13 +560,14 @@ class Notifier(object):
if end_time <= now: if end_time <= now:
break break
add_timeout_to_deferred(
listener.deferred.addTimeout,
(end_time - now) / 1000.,
)
try: try:
with PreserveLoggingContext(): with PreserveLoggingContext():
yield self.clock.time_bound_deferred( yield listener.deferred
listener.deferred, except DeferredTimeoutError:
time_out=(end_time - now) / 1000.
)
except DeferredTimedOutError:
break break
except defer.CancelledError: except defer.CancelledError:
break break

View File

@ -77,10 +77,13 @@ class EmailPusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_started(self): def on_started(self):
if self.mailer is not None: if self.mailer is not None:
self.throttle_params = yield self.store.get_throttle_params_by_room( try:
self.pusher_id self.throttle_params = yield self.store.get_throttle_params_by_room(
) self.pusher_id
yield self._process() )
yield self._process()
except Exception:
logger.exception("Error starting email pusher")
def on_stop(self): def on_stop(self):
if self.timed_call: if self.timed_call:

View File

@ -18,8 +18,8 @@ import logging
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import push_rule_evaluator from . import push_rule_evaluator
import push_tools from . import push_tools
import synapse import synapse
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -94,7 +94,10 @@ class HttpPusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_started(self): def on_started(self):
yield self._process() try:
yield self._process()
except Exception:
logger.exception("Error starting http pusher")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering): def on_new_notifications(self, min_stream_ordering, max_stream_ordering):

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from httppusher import HttpPusher from .httppusher import HttpPusher
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
from .pusher import PusherFactory from synapse.push.pusher import PusherFactory
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -137,12 +137,15 @@ class PusherPool:
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
deferreds.append( deferreds.append(
preserve_fn(p.on_new_notifications)( run_in_background(
min_stream_id, max_stream_id p.on_new_notifications,
min_stream_id, max_stream_id,
) )
) )
yield make_deferred_yieldable(defer.gatherResults(deferreds)) yield make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True),
)
except Exception: except Exception:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
@ -164,10 +167,15 @@ class PusherPool:
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
deferreds.append( deferreds.append(
preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id) run_in_background(
p.on_new_receipts,
min_stream_id, max_stream_id,
)
) )
yield make_deferred_yieldable(defer.gatherResults(deferreds)) yield make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True),
)
except Exception: except Exception:
logger.exception("Exception in pusher on_new_receipts") logger.exception("Exception in pusher on_new_receipts")
@ -207,7 +215,7 @@ class PusherPool:
if appid_pushkey in byuser: if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop() byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p byuser[appid_pushkey] = p
preserve_fn(p.on_started)() run_in_background(p.on_started)
logger.info("Started pushers") logger.info("Started pushers")

View File

@ -1,5 +1,6 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,6 +19,18 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# this dict maps from python package name to a list of modules we expect it to
# provide.
#
# the key is a "requirement specifier", as used as a parameter to `pip
# install`[1], or an `install_requires` argument to `setuptools.setup` [2].
#
# the value is a sequence of strings; each entry should be the name of the
# python module, optionally followed by a version assertion which can be either
# ">=<ver>" or "==<ver>".
#
# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
# [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies
REQUIREMENTS = { REQUIREMENTS = {
"jsonschema>=2.5.1": ["jsonschema>=2.5.1"], "jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
"frozendict>=0.4": ["frozendict"], "frozendict>=0.4": ["frozendict"],
@ -27,7 +40,10 @@ REQUIREMENTS = {
"pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"], "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=16.0.0": ["twisted>=16.0.0"], "Twisted>=16.0.0": ["twisted>=16.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
# We use crypto.get_elliptic_curve which is only supported in >=0.15
"pyopenssl>=0.15": ["OpenSSL>=0.15"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],
"pyasn1": ["pyasn1"], "pyasn1": ["pyasn1"],
"daemonize": ["daemonize"], "daemonize": ["daemonize"],
@ -39,6 +55,7 @@ REQUIREMENTS = {
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"], "msgpack-python>=0.3.0": ["msgpack"],
"phonenumbers>=8.2.0": ["phonenumbers"], "phonenumbers>=8.2.0": ["phonenumbers"],
"six": ["six"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {

View File

@ -53,12 +53,12 @@ from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure from twisted.python.failure import Failure
from commands import ( from .commands import (
COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS,
ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand, ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand,
NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand, NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand,
) )
from streams import STREAMS_MAP from .streams import STREAMS_MAP
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.metrics.metric import CounterMetric from synapse.metrics.metric import CounterMetric

View File

@ -18,8 +18,8 @@
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from streams import STREAMS_MAP, FederationStream from .streams import STREAMS_MAP, FederationStream
from protocol import ServerReplicationStreamProtocol from .protocol import ServerReplicationStreamProtocol
from synapse.util.metrics import Measure, measure_func from synapse.util.metrics import Measure, measure_func

View File

@ -168,11 +168,24 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
yield self.store.find_first_stream_ordering_after_ts(ts) yield self.store.find_first_stream_ordering_after_ts(ts)
) )
(_, depth, _) = ( room_event_after_stream_ordering = (
yield self.store.get_room_event_after_stream_ordering( yield self.store.get_room_event_after_stream_ordering(
room_id, stream_ordering, room_id, stream_ordering,
) )
) )
if room_event_after_stream_ordering:
(_, depth, _) = room_event_after_stream_ordering
else:
logger.warn(
"[purge] purging events not possible: No event found "
"(received_ts %i => stream_ordering %i)",
ts, stream_ordering,
)
raise SynapseError(
404,
"there is no event to be purged",
errcode=Codes.NOT_FOUND,
)
logger.info( logger.info(
"[purge] purging up to depth %i (received_ts %i => " "[purge] purging up to depth %i (received_ts %i => "
"stream_ordering %i)", "stream_ordering %i)",

View File

@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
"""A base Synapse REST Servlet for the client version 1 API. """A base Synapse REST Servlet for the client version 1 API.
""" """
# This subclass was presumably created to allow the auth for the v1
# protocol version to be different, however this behaviour was removed.
# it may no longer be necessary
def __init__(self, hs): def __init__(self, hs):
""" """
Args: Args:
@ -59,5 +63,5 @@ class ClientV1RestServlet(RestServlet):
""" """
self.hs = hs self.hs = hs
self.builder_factory = hs.get_event_builder_factory() self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_v1auth() self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock()) self.txns = HttpTransactionCache(hs.get_clock())

View File

@ -25,7 +25,7 @@ from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json import simplejson as json
import urllib import urllib
import urlparse from six.moves.urllib import parse as urlparse
import logging import logging
from saml2 import BINDING_HTTP_POST from saml2 import BINDING_HTTP_POST

View File

@ -150,7 +150,7 @@ class PushersRemoveRestServlet(RestServlet):
super(RestServlet, self).__init__() super(RestServlet, self).__init__()
self.hs = hs self.hs = hs
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.auth = hs.get_v1auth() self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool() self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks @defer.inlineCallbacks
@ -176,7 +176,6 @@ class PushersRemoveRestServlet(RestServlet):
request.setResponseCode(200) request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % ( request.setHeader(b"Content-Length", b"%d" % (
len(PushersRemoveRestServlet.SUCCESS_HTML), len(PushersRemoveRestServlet.SUCCESS_HTML),
)) ))

View File

@ -30,6 +30,8 @@ from hashlib import sha1
import hmac import hmac
import logging import logging
from six import string_types
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -333,11 +335,11 @@ class RegisterRestServlet(ClientV1RestServlet):
def _do_shared_secret(self, request, register_json, session): def _do_shared_secret(self, request, register_json, session):
yield run_on_reactor() yield run_on_reactor()
if not isinstance(register_json.get("mac", None), basestring): if not isinstance(register_json.get("mac", None), string_types):
raise SynapseError(400, "Expected mac.") raise SynapseError(400, "Expected mac.")
if not isinstance(register_json.get("user", None), basestring): if not isinstance(register_json.get("user", None), string_types):
raise SynapseError(400, "Expected 'user' key.") raise SynapseError(400, "Expected 'user' key.")
if not isinstance(register_json.get("password", None), basestring): if not isinstance(register_json.get("password", None), string_types):
raise SynapseError(400, "Expected 'password' key.") raise SynapseError(400, "Expected 'password' key.")
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
@ -358,14 +360,14 @@ class RegisterRestServlet(ClientV1RestServlet):
got_mac = str(register_json["mac"]) got_mac = str(register_json["mac"])
want_mac = hmac.new( want_mac = hmac.new(
key=self.hs.config.registration_shared_secret, key=self.hs.config.registration_shared_secret.encode(),
digestmod=sha1, digestmod=sha1,
) )
want_mac.update(user) want_mac.update(user)
want_mac.update("\x00") want_mac.update(b"\x00")
want_mac.update(password) want_mac.update(password)
want_mac.update("\x00") want_mac.update(b"\x00")
want_mac.update("admin" if admin else "notadmin") want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest() want_mac = want_mac.hexdigest()
if compare_digest(want_mac, got_mac): if compare_digest(want_mac, got_mac):

View File

@ -28,8 +28,9 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer parse_json_object_from_request, parse_string, parse_integer
) )
from six.moves.urllib import parse as urlparse
import logging import logging
import urllib
import simplejson as json import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -433,7 +434,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None) filter_bytes = request.args.get("filter", None)
if filter_bytes: if filter_bytes:
filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8") filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8")
event_filter = Filter(json.loads(filter_json)) event_filter = Filter(json.loads(filter_json))
else: else:
event_filter = None event_filter = None
@ -718,8 +719,8 @@ class RoomTypingRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id) room_id = urlparse.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id)) target_user = UserID.from_string(urlparse.unquote(user_id))
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)

View File

@ -129,7 +129,6 @@ class AuthRestServlet(RestServlet):
html_bytes = html.encode("utf8") html_bytes = html.encode("utf8")
request.setResponseCode(200) request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes) request.write(html_bytes)
@ -175,7 +174,6 @@ class AuthRestServlet(RestServlet):
html_bytes = html.encode("utf8") html_bytes = html.encode("utf8")
request.setResponseCode(200) request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes) request.write(html_bytes)

View File

@ -88,7 +88,7 @@ class NotificationsServlet(RestServlet):
pa["topological_ordering"], pa["stream_ordering"] pa["topological_ordering"], pa["stream_ordering"]
) )
returned_push_actions.append(returned_pa) returned_push_actions.append(returned_pa)
next_token = pa["stream_ordering"] next_token = str(pa["stream_ordering"])
defer.returnValue((200, { defer.returnValue((200, {
"notifications": returned_push_actions, "notifications": returned_push_actions,

View File

@ -35,6 +35,8 @@ from hashlib import sha1
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from six import string_types
# We ought to be using hmac.compare_digest() but on older pythons it doesn't # We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison # exist. It's a _really minor_ security flaw to use plain string comparison
@ -210,14 +212,14 @@ class RegisterRestServlet(RestServlet):
# in sessions. Pull out the username/password provided to us. # in sessions. Pull out the username/password provided to us.
desired_password = None desired_password = None
if 'password' in body: if 'password' in body:
if (not isinstance(body['password'], basestring) or if (not isinstance(body['password'], string_types) or
len(body['password']) > 512): len(body['password']) > 512):
raise SynapseError(400, "Invalid password") raise SynapseError(400, "Invalid password")
desired_password = body["password"] desired_password = body["password"]
desired_username = None desired_username = None
if 'username' in body: if 'username' in body:
if (not isinstance(body['username'], basestring) or if (not isinstance(body['username'], string_types) or
len(body['username']) > 512): len(body['username']) > 512):
raise SynapseError(400, "Invalid username") raise SynapseError(400, "Invalid username")
desired_username = body['username'] desired_username = body['username']
@ -243,7 +245,7 @@ class RegisterRestServlet(RestServlet):
access_token = get_access_token_from_request(request) access_token = get_access_token_from_request(request)
if isinstance(desired_username, basestring): if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration( result = yield self._do_appservice_registration(
desired_username, access_token, body desired_username, access_token, body
) )
@ -464,7 +466,7 @@ class RegisterRestServlet(RestServlet):
# includes the password and admin flag in the hashed text. Why are # includes the password and admin flag in the hashed text. Why are
# these different? # these different?
want_mac = hmac.new( want_mac = hmac.new(
key=self.hs.config.registration_shared_secret, key=self.hs.config.registration_shared_secret.encode(),
msg=user, msg=user,
digestmod=sha1, digestmod=sha1,
).hexdigest() ).hexdigest()

View File

@ -49,7 +49,6 @@ class LocalKey(Resource):
""" """
def __init__(self, hs): def __init__(self, hs):
self.version_string = hs.version_string
self.response_body = encode_canonical_json( self.response_body = encode_canonical_json(
self.response_json_object(hs.config) self.response_json_object(hs.config)
) )
@ -84,7 +83,6 @@ class LocalKey(Resource):
def render_GET(self, request): def render_GET(self, request):
return respond_with_json_bytes( return respond_with_json_bytes(
request, 200, self.response_body, request, 200, self.response_body,
version_string=self.version_string
) )
def getChild(self, name, request): def getChild(self, name, request):

View File

@ -63,7 +63,6 @@ class LocalKey(Resource):
isLeaf = True isLeaf = True
def __init__(self, hs): def __init__(self, hs):
self.version_string = hs.version_string
self.config = hs.config self.config = hs.config
self.clock = hs.clock self.clock = hs.clock
self.update_response_body(self.clock.time_msec()) self.update_response_body(self.clock.time_msec())
@ -115,5 +114,4 @@ class LocalKey(Resource):
self.update_response_body(time_now) self.update_response_body(time_now)
return respond_with_json_bytes( return respond_with_json_bytes(
request, 200, self.response_body, request, 200, self.response_body,
version_string=self.version_string
) )

View File

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.http.server import request_handler, respond_with_json_bytes from synapse.http.server import (
respond_with_json_bytes, wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.crypto.keyring import KeyLookupError from synapse.crypto.keyring import KeyLookupError
@ -91,7 +93,6 @@ class RemoteKey(Resource):
def __init__(self, hs): def __init__(self, hs):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.version_string = hs.version_string
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@ -99,7 +100,7 @@ class RemoteKey(Resource):
self.async_render_GET(request) self.async_render_GET(request)
return NOT_DONE_YET return NOT_DONE_YET
@request_handler() @wrap_json_request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def async_render_GET(self, request): def async_render_GET(self, request):
if len(request.postpath) == 1: if len(request.postpath) == 1:
@ -124,7 +125,7 @@ class RemoteKey(Resource):
self.async_render_POST(request) self.async_render_POST(request)
return NOT_DONE_YET return NOT_DONE_YET
@request_handler() @wrap_json_request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def async_render_POST(self, request): def async_render_POST(self, request):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -240,5 +241,4 @@ class RemoteKey(Resource):
respond_with_json_bytes( respond_with_json_bytes(
request, 200, result_io.getvalue(), request, 200, result_io.getvalue(),
version_string=self.version_string
) )

View File

@ -28,7 +28,7 @@ import os
import logging import logging
import urllib import urllib
import urlparse from six.moves.urllib import parse as urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -143,6 +143,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
respond_404(request) respond_404(request)
return return
logger.debug("Responding to media request with responder %s")
add_file_headers(request, media_type, file_size, upload_name) add_file_headers(request, media_type, file_size, upload_name)
with responder: with responder:
yield responder.write_to_consumer(request) yield responder.write_to_consumer(request)

View File

@ -12,17 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.http.servlet
from ._base import parse_media_id, respond_404
from twisted.web.resource import Resource
from synapse.http.server import request_handler, set_cors_headers
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
import logging import logging
from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.http.server import (
set_cors_headers,
wrap_json_request_handler,
)
import synapse.http.servlet
from ._base import parse_media_id, respond_404
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,15 +37,14 @@ class DownloadResource(Resource):
self.media_repo = media_repo self.media_repo = media_repo
self.server_name = hs.hostname self.server_name = hs.hostname
# Both of these are expected by @request_handler() # this is expected by @wrap_json_request_handler
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.version_string = hs.version_string
def render_GET(self, request): def render_GET(self, request):
self._async_render_GET(request) self._async_render_GET(request)
return NOT_DONE_YET return NOT_DONE_YET
@request_handler() @wrap_json_request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
set_cors_headers(request) set_cors_headers(request)

View File

@ -47,7 +47,7 @@ import shutil
import cgi import cgi
import logging import logging
import urlparse from six.moves.urllib import parse as urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -255,7 +255,9 @@ class FileResponder(Responder):
self.open_file = open_file self.open_file = open_file
def write_to_consumer(self, consumer): def write_to_consumer(self, consumer):
return FileSender().beginFileTransfer(self.open_file, consumer) return make_deferred_yieldable(
FileSender().beginFileTransfer(self.open_file, consumer)
)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close() self.open_file.close()

View File

@ -35,13 +35,14 @@ from ._base import FileInfo
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, Codes, SynapseError, Codes,
) )
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.http.client import SpiderHttpClient from synapse.http.client import SpiderHttpClient
from synapse.http.server import ( from synapse.http.server import (
request_handler, respond_with_json_bytes, respond_with_json_bytes,
respond_with_json, respond_with_json,
wrap_json_request_handler,
) )
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
@ -57,7 +58,6 @@ class PreviewUrlResource(Resource):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.version_string = hs.version_string
self.filepaths = media_repo.filepaths self.filepaths = media_repo.filepaths
self.max_spider_size = hs.config.max_spider_size self.max_spider_size = hs.config.max_spider_size
self.server_name = hs.hostname self.server_name = hs.hostname
@ -90,7 +90,7 @@ class PreviewUrlResource(Resource):
self._async_render_GET(request) self._async_render_GET(request)
return NOT_DONE_YET return NOT_DONE_YET
@request_handler() @wrap_json_request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
@ -144,7 +144,8 @@ class PreviewUrlResource(Resource):
observable = self._cache.get(url) observable = self._cache.get(url)
if not observable: if not observable:
download = preserve_fn(self._do_preview)( download = run_in_background(
self._do_preview,
url, requester.user, ts, url, requester.user, ts,
) )
observable = ObservableDeferred( observable = ObservableDeferred(

View File

@ -18,7 +18,7 @@ from twisted.internet import defer, threads
from .media_storage import FileResponder from .media_storage import FileResponder
from synapse.config._base import Config from synapse.config._base import Config
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import run_in_background
import logging import logging
import os import os
@ -87,7 +87,12 @@ class StorageProviderWrapper(StorageProvider):
return self.backend.store_file(path, file_info) return self.backend.store_file(path, file_info)
else: else:
# TODO: Handle errors. # TODO: Handle errors.
preserve_fn(self.backend.store_file)(path, file_info) def store():
try:
return self.backend.store_file(path, file_info)
except Exception:
logger.exception("Error storing file")
run_in_background(store)
return defer.succeed(None) return defer.succeed(None)
def fetch(self, path, file_info): def fetch(self, path, file_info):

View File

@ -14,18 +14,21 @@
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.http.server import (
set_cors_headers,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_string
from ._base import ( from ._base import (
parse_media_id, respond_404, respond_with_file, FileInfo, FileInfo, parse_media_id, respond_404, respond_with_file,
respond_with_responder, respond_with_responder,
) )
from twisted.web.resource import Resource
from synapse.http.servlet import parse_string, parse_integer
from synapse.http.server import request_handler, set_cors_headers
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,14 +44,13 @@ class ThumbnailResource(Resource):
self.media_storage = media_storage self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname self.server_name = hs.hostname
self.version_string = hs.version_string
self.clock = hs.get_clock() self.clock = hs.get_clock()
def render_GET(self, request): def render_GET(self, request):
self._async_render_GET(request) self._async_render_GET(request)
return NOT_DONE_YET return NOT_DONE_YET
@request_handler() @wrap_json_request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
set_cors_headers(request) set_cors_headers(request)

View File

@ -13,16 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.http.server import respond_with_json, request_handler import logging
from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import (
from twisted.web.server import NOT_DONE_YET respond_with_json,
from twisted.internet import defer wrap_json_request_handler,
)
from twisted.web.resource import Resource
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,7 +41,6 @@ class UploadResource(Resource):
self.server_name = hs.hostname self.server_name = hs.hostname
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.max_upload_size = hs.config.max_upload_size self.max_upload_size = hs.config.max_upload_size
self.version_string = hs.version_string
self.clock = hs.get_clock() self.clock = hs.get_clock()
def render_POST(self, request): def render_POST(self, request):
@ -51,7 +51,7 @@ class UploadResource(Resource):
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET return NOT_DONE_YET
@request_handler() @wrap_json_request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_POST(self, request): def _async_render_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
@ -81,15 +81,15 @@ class UploadResource(Resource):
headers = request.requestHeaders headers = request.requestHeaders
if headers.hasHeader("Content-Type"): if headers.hasHeader("Content-Type"):
media_type = headers.getRawHeaders("Content-Type")[0] media_type = headers.getRawHeaders(b"Content-Type")[0]
else: else:
raise SynapseError( raise SynapseError(
msg="Upload request missing 'Content-Type'", msg="Upload request missing 'Content-Type'",
code=400, code=400,
) )
# if headers.hasHeader("Content-Disposition"): # if headers.hasHeader(b"Content-Disposition"):
# disposition = headers.getRawHeaders("Content-Disposition")[0] # disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion # TODO(markjh): parse content-dispostion
content_uri = yield self.media_repo.create_content( content_uri = yield self.media_repo.create_content(

View File

@ -105,7 +105,6 @@ class HomeServer(object):
'federation_client', 'federation_client',
'federation_server', 'federation_server',
'handlers', 'handlers',
'v1auth',
'auth', 'auth',
'state_handler', 'state_handler',
'state_resolution_handler', 'state_resolution_handler',
@ -225,15 +224,6 @@ class HomeServer(object):
def build_simple_http_client(self): def build_simple_http_client(self):
return SimpleHttpClient(self) return SimpleHttpClient(self)
def build_v1auth(self):
orf = Auth(self)
# Matrix spec makes no reference to what HTTP status code is returned,
# but the V1 API uses 403 where it means 401, and the webclient
# relies on this behaviour, so V1 gets its own copy of the auth
# with backwards compat behaviour.
orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403
return orf
def build_state_handler(self): def build_state_handler(self):
return StateHandler(self) return StateHandler(self)

View File

@ -448,6 +448,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"add_push_actions_to_staging", _add_push_actions_to_staging_txn "add_push_actions_to_staging", _add_push_actions_to_staging_txn
) )
@defer.inlineCallbacks
def remove_push_actions_from_staging(self, event_id): def remove_push_actions_from_staging(self, event_id):
"""Called if we failed to persist the event to ensure that stale push """Called if we failed to persist the event to ensure that stale push
actions don't build up in the DB actions don't build up in the DB
@ -456,13 +457,22 @@ class EventPushActionsWorkerStore(SQLBaseStore):
event_id (str) event_id (str)
""" """
return self._simple_delete( try:
table="event_push_actions_staging", res = yield self._simple_delete(
keyvalues={ table="event_push_actions_staging",
"event_id": event_id, keyvalues={
}, "event_id": event_id,
desc="remove_push_actions_from_staging", },
) desc="remove_push_actions_from_staging",
)
defer.returnValue(res)
except Exception:
# this method is called from an exception handler, so propagating
# another exception here really isn't helpful - there's nothing
# the caller can do about it. Just log the exception and move on.
logger.exception(
"Error removing push actions after event persistence failure",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_stream_orderings_for_times(self): def _find_stream_orderings_for_times(self):

View File

@ -22,7 +22,6 @@ import logging
import simplejson as json import simplejson as json
from twisted.internet import defer from twisted.internet import defer
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
@ -425,7 +424,9 @@ class EventsStore(EventsWorkerStore):
) )
current_state = yield self._get_new_state_after_events( current_state = yield self._get_new_state_after_events(
room_id, room_id,
ev_ctx_rm, new_latest_event_ids, ev_ctx_rm,
latest_event_ids,
new_latest_event_ids,
) )
if current_state is not None: if current_state is not None:
current_state_for_room[room_id] = current_state current_state_for_room[room_id] = current_state
@ -513,7 +514,8 @@ class EventsStore(EventsWorkerStore):
defer.returnValue(new_latest_event_ids) defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids): def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
new_latest_event_ids):
"""Calculate the current state dict after adding some new events to """Calculate the current state dict after adding some new events to
a room a room
@ -524,6 +526,9 @@ class EventsStore(EventsWorkerStore):
events_context (list[(EventBase, EventContext)]): events_context (list[(EventBase, EventContext)]):
events and contexts which are being added to the room events and contexts which are being added to the room
old_latest_event_ids (iterable[str]):
the old forward extremities for the room.
new_latest_event_ids (iterable[str]): new_latest_event_ids (iterable[str]):
the new forward extremities for the room. the new forward extremities for the room.
@ -534,64 +539,89 @@ class EventsStore(EventsWorkerStore):
""" """
if not new_latest_event_ids: if not new_latest_event_ids:
defer.returnValue({}) return
# map from state_group to ((type, key) -> event_id) state map # map from state_group to ((type, key) -> event_id) state map
state_groups = {} state_groups_map = {}
missing_event_ids = [] for ev, ctx in events_context:
was_updated = False if ctx.state_group is None:
# I don't think this can happen, but let's double-check
raise Exception(
"Context for new extremity event %s has no state "
"group" % (ev.event_id, ),
)
if ctx.state_group in state_groups_map:
continue
state_groups_map[ctx.state_group] = ctx.current_state_ids
# We need to map the event_ids to their state groups. First, let's
# check if the event is one we're persisting, in which case we can
# pull the state group from its context.
# Otherwise we need to pull the state group from the database.
# Set of events we need to fetch groups for. (We know none of the old
# extremities are going to be in events_context).
missing_event_ids = set(old_latest_event_ids)
event_id_to_state_group = {}
for event_id in new_latest_event_ids: for event_id in new_latest_event_ids:
# First search in the list of new events we're adding, # First search in the list of new events we're adding.
# and then use the current state from that
for ev, ctx in events_context: for ev, ctx in events_context:
if event_id == ev.event_id: if event_id == ev.event_id:
if ctx.current_state_ids is None: event_id_to_state_group[event_id] = ctx.state_group
raise Exception("Unknown current state")
if ctx.state_group is None:
# I don't think this can happen, but let's double-check
raise Exception(
"Context for new extremity event %s has no state "
"group" % (event_id, ),
)
# If we've already seen the state group don't bother adding
# it to the state sets again
if ctx.state_group not in state_groups:
state_groups[ctx.state_group] = ctx.current_state_ids
if ctx.delta_ids or hasattr(ev, "state_key"):
was_updated = True
break break
else: else:
# If we couldn't find it, then we'll need to pull # If we couldn't find it, then we'll need to pull
# the state from the database # the state from the database
was_updated = True missing_event_ids.add(event_id)
missing_event_ids.append(event_id)
if not was_updated:
return
if missing_event_ids: if missing_event_ids:
# Now pull out the state for any missing events from DB # Now pull out the state groups for any missing events from DB
event_to_groups = yield self._get_state_group_for_events( event_to_groups = yield self._get_state_group_for_events(
missing_event_ids, missing_event_ids,
) )
event_id_to_state_group.update(event_to_groups)
groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys()) # State groups of old_latest_event_ids
old_state_groups = set(
event_id_to_state_group[evid] for evid in old_latest_event_ids
)
if groups: # State groups of new_latest_event_ids
group_to_state = yield self._get_state_for_groups(groups) new_state_groups = set(
state_groups.update(group_to_state) event_id_to_state_group[evid] for evid in new_latest_event_ids
)
if len(state_groups) == 1: # If they old and new groups are the same then we don't need to do
# anything.
if old_state_groups == new_state_groups:
return
# Now that we have calculated new_state_groups we need to get
# their state IDs so we can resolve to a single state set.
missing_state = new_state_groups - set(state_groups_map)
if missing_state:
group_to_state = yield self._get_state_for_groups(missing_state)
state_groups_map.update(group_to_state)
if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current # If there is only one state group, then we know what the current
# state is. # state is.
defer.returnValue(state_groups.values()[0]) defer.returnValue(state_groups_map[new_state_groups.pop()])
# Ok, we need to defer to the state handler to resolve our state sets.
def get_events(ev_ids): def get_events(ev_ids):
return self.get_events( return self.get_events(
ev_ids, get_prev_content=False, check_redacted=False, ev_ids, get_prev_content=False, check_redacted=False,
) )
state_groups = {
sg: state_groups_map[sg] for sg in new_state_groups
}
events_map = {ev.event_id: ev for ev, _ in events_context} events_map = {ev.event_id: ev for ev, _ in events_context}
logger.debug("calling resolve_state_groups from preserve_events") logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups( res = yield self._state_resolution_handler.resolve_state_groups(

View File

@ -20,7 +20,7 @@ from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util.logcontext import ( from synapse.util.logcontext import (
preserve_fn, PreserveLoggingContext, make_deferred_yieldable PreserveLoggingContext, make_deferred_yieldable, run_in_background,
) )
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
@ -319,7 +319,8 @@ class EventsWorkerStore(SQLBaseStore):
res = yield make_deferred_yieldable(defer.gatherResults( res = yield make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self._get_event_from_row)( run_in_background(
self._get_event_from_row,
row["internal_metadata"], row["json"], row["redacts"], row["internal_metadata"], row["json"], row["redacts"],
rejected_reason=row["rejects"], rejected_reason=row["rejects"],
) )

View File

@ -22,6 +22,8 @@ from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from six.moves import range
class RegistrationWorkerStore(SQLBaseStore): class RegistrationWorkerStore(SQLBaseStore):
@cached() @cached()
@ -469,7 +471,7 @@ class RegistrationStore(RegistrationWorkerStore,
match = regex.search(user_id) match = regex.search(user_id)
if match: if match:
found.add(int(match.group(1))) found.add(int(match.group(1)))
for i in xrange(len(found) + 1): for i in range(len(found) + 1):
if i not in found: if i not in found:
return i return i

View File

@ -530,7 +530,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
# Convert the IDs to MXC URIs # Convert the IDs to MXC URIs
for media_id in local_mxcs: for media_id in local_mxcs:
local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id)) local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
for hostname, media_id in remote_mxcs: for hostname, media_id in remote_mxcs:
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
@ -595,7 +595,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
while next_token: while next_token:
sql = """ sql = """
SELECT stream_ordering, json FROM events SELECT stream_ordering, json FROM events
JOIN event_json USING (event_id) JOIN event_json USING (room_id, event_id)
WHERE room_id = ? WHERE room_id = ?
AND stream_ordering < ? AND stream_ordering < ?
AND contains_url = ? AND outlier = ? AND contains_url = ? AND outlier = ?
@ -619,7 +619,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
if matches: if matches:
hostname = matches.group(1) hostname = matches.group(1)
media_id = matches.group(2) media_id = matches.group(2)
if hostname == self.hostname: if hostname == self.hs.hostname:
local_media_mxcs.append(media_id) local_media_mxcs.append(media_id)
else: else:
remote_media_mxcs.append((hostname, media_id)) remote_media_mxcs.append((hostname, media_id))

View File

@ -14,6 +14,8 @@
import logging import logging
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from six.moves import range
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,7 +60,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
for as_id, user_ids in owned.items(): for as_id, user_ids in owned.items():
n = 100 n = 100
user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n)) user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n))
for chunk in user_chunks: for chunk in user_chunks:
cur.execute( cur.execute(
database_engine.convert_param_style( database_engine.convert_param_style(

View File

@ -0,0 +1,57 @@
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.engines import PostgresEngine
from synapse.storage.prepare_database import get_statements
FIX_INDEXES = """
-- rebuild indexes as uniques
DROP INDEX groups_invites_g_idx;
CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id);
DROP INDEX groups_users_g_idx;
CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id);
-- rename other indexes to actually match their table names..
DROP INDEX groups_users_u_idx;
CREATE INDEX group_users_u_idx ON group_users(user_id);
DROP INDEX groups_invites_u_idx;
CREATE INDEX group_invites_u_idx ON group_invites(user_id);
DROP INDEX groups_rooms_g_idx;
CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
DROP INDEX groups_rooms_r_idx;
CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
"""
def run_create(cur, database_engine, *args, **kwargs):
rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
# remove duplicates from group_users & group_invites tables
cur.execute("""
DELETE FROM group_users WHERE %s NOT IN (
SELECT min(%s) FROM group_users GROUP BY group_id, user_id
);
""" % (rowid, rowid))
cur.execute("""
DELETE FROM group_invites WHERE %s NOT IN (
SELECT min(%s) FROM group_invites GROUP BY group_id, user_id
);
""" % (rowid, rowid))
for statement in get_statements(FIX_INDEXES.splitlines()):
cur.execute(statement)
def run_upgrade(*args, **kwargs):
pass

View File

@ -77,7 +77,7 @@ class SearchStore(BackgroundUpdateStore):
sql = ( sql = (
"SELECT stream_ordering, event_id, room_id, type, json, " "SELECT stream_ordering, event_id, room_id, type, json, "
" origin_server_ts FROM events" " origin_server_ts FROM events"
" JOIN event_json USING (event_id)" " JOIN event_json USING (room_id, event_id)"
" WHERE ? <= stream_ordering AND stream_ordering < ?" " WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)" " AND (%s)"
" ORDER BY stream_ordering DESC" " ORDER BY stream_ordering DESC"

View File

@ -38,15 +38,17 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.events import EventsWorkerStore from synapse.storage.events import EventsWorkerStore
from synapse.util.caches.descriptors import cached
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine
import abc import abc
import logging import logging
from six.moves import range
from collections import namedtuple
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,6 +60,12 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological" _TOPOLOGICAL_TOKEN = "topological"
# Used as return values for pagination APIs
_EventDictReturn = namedtuple("_EventDictReturn", (
"event_id", "topological_ordering", "stream_ordering",
))
def lower_bound(token, engine, inclusive=False): def lower_bound(token, engine, inclusive=False):
inclusive = "=" if inclusive else "" inclusive = "=" if inclusive else ""
if token.topological is None: if token.topological is None:
@ -196,13 +204,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {} results = {}
room_ids = list(room_ids) room_ids = list(room_ids)
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)):
res = yield make_deferred_yieldable(defer.gatherResults([ res = yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(self.get_room_events_stream_for_room)( run_in_background(
self.get_room_events_stream_for_room,
room_id, from_key, to_key, limit, order=order, room_id, from_key, to_key, limit, order=order,
) )
for room_id in rm_ids for room_id in rm_ids
])) ], consumeErrors=True))
results.update(dict(zip(rm_ids, res))) results.update(dict(zip(rm_ids, res)))
defer.returnValue(results) defer.returnValue(results)
@ -224,54 +233,55 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
order='DESC'): order='DESC'):
# Note: If from_key is None then we return in topological order. This
# is because in that case we're using this as a "get the last few messages
# in a room" function, rather than "get new messages since last sync"
if from_key is not None:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
else:
from_id = None
to_id = RoomStreamToken.parse_stream_token(to_key).stream
"""Get new room events in stream ordering since `from_key`.
Args:
room_id (str)
from_key (str): Token from which no events are returned before
to_key (str): Token from which no events are returned after. (This
is typically the current stream token)
limit (int): Maximum number of events to return
order (str): Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
events (in ascending order) and the token from the start of
the chunk of events returned.
"""
if from_key == to_key: if from_key == to_key:
defer.returnValue(([], from_key)) defer.returnValue(([], from_key))
if from_id: from_id = RoomStreamToken.parse_stream_token(from_key).stream
has_changed = yield self._events_stream_cache.has_entity_changed( to_id = RoomStreamToken.parse_stream_token(to_key).stream
room_id, from_id
)
if not has_changed: has_changed = yield self._events_stream_cache.has_entity_changed(
defer.returnValue(([], from_key)) room_id, from_id
)
if not has_changed:
defer.returnValue(([], from_key))
def f(txn): def f(txn):
if from_id is not None: sql = (
sql = ( "SELECT event_id, stream_ordering FROM events WHERE"
"SELECT event_id, stream_ordering FROM events WHERE" " room_id = ?"
" room_id = ?" " AND not outlier"
" AND not outlier" " AND stream_ordering > ? AND stream_ordering <= ?"
" AND stream_ordering > ? AND stream_ordering <= ?" " ORDER BY stream_ordering %s LIMIT ?"
" ORDER BY stream_ordering %s LIMIT ?" ) % (order,)
) % (order,) txn.execute(sql, (room_id, from_id, to_id, limit))
txn.execute(sql, (room_id, from_id, to_id, limit))
else:
sql = (
"SELECT event_id, stream_ordering FROM events WHERE"
" room_id = ?"
" AND not outlier"
" AND stream_ordering <= ?"
" ORDER BY topological_ordering %s, stream_ordering %s LIMIT ?"
) % (order, order,)
txn.execute(sql, (room_id, to_id, limit))
rows = self.cursor_to_dict(txn)
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows return rows
rows = yield self.runInteraction("get_room_events_stream_for_room", f) rows = yield self.runInteraction("get_room_events_stream_for_room", f)
ret = yield self._get_events( ret = yield self._get_events(
[r["event_id"] for r in rows], [r.event_id for r in rows],
get_prev_content=True get_prev_content=True
) )
@ -281,7 +291,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse() ret.reverse()
if rows: if rows:
key = "s%d" % min(r["stream_ordering"] for r in rows) key = "s%d" % min(r.stream_ordering for r in rows)
else: else:
# Assume we didn't get anything because there was nothing to # Assume we didn't get anything because there was nothing to
# get. # get.
@ -291,10 +301,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_membership_changes_for_user(self, user_id, from_key, to_key): def get_membership_changes_for_user(self, user_id, from_key, to_key):
if from_key is not None: from_id = RoomStreamToken.parse_stream_token(from_key).stream
from_id = RoomStreamToken.parse_stream_token(from_key).stream
else:
from_id = None
to_id = RoomStreamToken.parse_stream_token(to_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key: if from_key == to_key:
@ -308,34 +315,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue([]) defer.returnValue([])
def f(txn): def f(txn):
if from_id is not None: sql = (
sql = ( "SELECT m.event_id, stream_ordering FROM events AS e,"
"SELECT m.event_id, stream_ordering FROM events AS e," " room_memberships AS m"
" room_memberships AS m" " WHERE e.event_id = m.event_id"
" WHERE e.event_id = m.event_id" " AND m.user_id = ?"
" AND m.user_id = ?" " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
" AND e.stream_ordering > ? AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC"
" ORDER BY e.stream_ordering ASC" )
) txn.execute(sql, (user_id, from_id, to_id,))
txn.execute(sql, (user_id, from_id, to_id,))
else: rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
sql = (
"SELECT m.event_id, stream_ordering FROM events AS e,"
" room_memberships AS m"
" WHERE e.event_id = m.event_id"
" AND m.user_id = ?"
" AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
)
txn.execute(sql, (user_id, to_id,))
rows = self.cursor_to_dict(txn)
return rows return rows
rows = yield self.runInteraction("get_membership_changes_for_user", f) rows = yield self.runInteraction("get_membership_changes_for_user", f)
ret = yield self._get_events( ret = yield self._get_events(
[r["event_id"] for r in rows], [r.event_id for r in rows],
get_prev_content=True get_prev_content=True
) )
@ -344,14 +341,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): def get_recent_events_for_room(self, room_id, limit, end_token):
"""Get the most recent events in the room in topological ordering.
Args:
room_id (str)
limit (int)
end_token (str): The stream token representing now.
Returns:
Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
events and a token pointing to the start of the returned
events.
The events returned are in ascending order.
"""
rows, token = yield self.get_recent_event_ids_for_room( rows, token = yield self.get_recent_event_ids_for_room(
room_id, limit, end_token, from_token room_id, limit, end_token,
) )
logger.debug("stream before") logger.debug("stream before")
events = yield self._get_events( events = yield self._get_events(
[r["event_id"] for r in rows], [r.event_id for r in rows],
get_prev_content=True get_prev_content=True
) )
logger.debug("stream after") logger.debug("stream after")
@ -360,61 +371,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue((events, token)) defer.returnValue((events, token))
@cached(num_args=4) @defer.inlineCallbacks
def get_recent_event_ids_for_room(self, room_id, limit, end_token, from_token=None): def get_recent_event_ids_for_room(self, room_id, limit, end_token):
end_token = RoomStreamToken.parse_stream_token(end_token) """Get the most recent events in the room in topological ordering.
if from_token is None: Args:
sql = ( room_id (str)
"SELECT stream_ordering, topological_ordering, event_id" limit (int)
" FROM events" end_token (str): The stream token representing now.
" WHERE room_id = ? AND stream_ordering <= ? AND outlier = ?"
" ORDER BY topological_ordering DESC, stream_ordering DESC"
" LIMIT ?"
)
else:
from_token = RoomStreamToken.parse_stream_token(from_token)
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
" WHERE room_id = ? AND stream_ordering > ?"
" AND stream_ordering <= ? AND outlier = ?"
" ORDER BY topological_ordering DESC, stream_ordering DESC"
" LIMIT ?"
)
def get_recent_events_for_room_txn(txn): Returns:
if from_token is None: Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
txn.execute(sql, (room_id, end_token.stream, False, limit,)) _EventDictReturn and a token pointing to the start of the returned
else: events.
txn.execute(sql, ( The events returned are in ascending order.
room_id, from_token.stream, end_token.stream, False, limit """
)) # Allow a zero limit here, and no-op.
if limit == 0:
defer.returnValue(([], end_token))
rows = self.cursor_to_dict(txn) end_token = RoomStreamToken.parse(end_token)
rows.reverse() # As we selected with reverse ordering rows, token = yield self.runInteraction(
"get_recent_event_ids_for_room", self._paginate_room_events_txn,
if rows: room_id, from_token=end_token, limit=limit,
# Tokens are positions between events.
# This token points *after* the last event in the chunk.
# We need it to point to the event before it in the chunk
# since we are going backwards so we subtract one from the
# stream part.
topo = rows[0]["topological_ordering"]
toke = rows[0]["stream_ordering"] - 1
start_token = str(RoomStreamToken(topo, toke))
token = (start_token, str(end_token))
else:
token = (str(end_token), str(end_token))
return rows, token
return self.runInteraction(
"get_recent_events_for_room", get_recent_events_for_room_txn
) )
# We want to return the results in ascending order.
rows.reverse()
defer.returnValue((rows, token))
def get_room_event_after_stream_ordering(self, room_id, stream_ordering): def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
"""Gets details of the first event in a room at or after a stream ordering """Gets details of the first event in a room at or after a stream ordering
@ -517,10 +504,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@staticmethod @staticmethod
def _set_before_and_after(events, rows, topo_order=True): def _set_before_and_after(events, rows, topo_order=True):
"""Inserts ordering information to events' internal metadata from
the DB rows.
Args:
events (list[FrozenEvent])
rows (list[_EventDictReturn])
topo_order (bool): Whether the events were ordered topologically
or by stream ordering. If true then all rows should have a non
null topological_ordering.
"""
for event, row in zip(events, rows): for event, row in zip(events, rows):
stream = row["stream_ordering"] stream = row.stream_ordering
if topo_order: if topo_order and row.topological_ordering:
topo = event.depth topo = row.topological_ordering
else: else:
topo = None topo = None
internal = event.internal_metadata internal = event.internal_metadata
@ -592,87 +589,27 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"], retcols=["stream_ordering", "topological_ordering"],
) )
token = RoomStreamToken( # Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
results["topological_ordering"] - 1,
results["stream_ordering"],
)
after_token = RoomStreamToken(
results["topological_ordering"], results["topological_ordering"],
results["stream_ordering"], results["stream_ordering"],
) )
if isinstance(self.database_engine, Sqlite3Engine): rows, start_token = self._paginate_room_events_txn(
# SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)`` txn, room_id, before_token, direction='b', limit=before_limit,
# So we give pass it to SQLite3 as the UNION ALL of the two queries. )
events_before = [r.event_id for r in rows]
query_before = ( rows, end_token = self._paginate_room_events_txn(
"SELECT topological_ordering, stream_ordering, event_id FROM events" txn, room_id, after_token, direction='f', limit=after_limit,
" WHERE room_id = ? AND topological_ordering < ?" )
" UNION ALL" events_after = [r.event_id for r in rows]
" SELECT topological_ordering, stream_ordering, event_id FROM events"
" WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?"
" ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
)
before_args = (
room_id, token.topological,
room_id, token.topological, token.stream,
before_limit,
)
query_after = (
"SELECT topological_ordering, stream_ordering, event_id FROM events"
" WHERE room_id = ? AND topological_ordering > ?"
" UNION ALL"
" SELECT topological_ordering, stream_ordering, event_id FROM events"
" WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?"
" ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
)
after_args = (
room_id, token.topological,
room_id, token.topological, token.stream,
after_limit,
)
else:
query_before = (
"SELECT topological_ordering, stream_ordering, event_id FROM events"
" WHERE room_id = ? AND %s"
" ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
) % (upper_bound(token, self.database_engine, inclusive=False),)
before_args = (room_id, before_limit)
query_after = (
"SELECT topological_ordering, stream_ordering, event_id FROM events"
" WHERE room_id = ? AND %s"
" ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
) % (lower_bound(token, self.database_engine, inclusive=False),)
after_args = (room_id, after_limit)
txn.execute(query_before, before_args)
rows = self.cursor_to_dict(txn)
events_before = [r["event_id"] for r in rows]
if rows:
start_token = str(RoomStreamToken(
rows[0]["topological_ordering"],
rows[0]["stream_ordering"] - 1,
))
else:
start_token = str(RoomStreamToken(
token.topological,
token.stream - 1,
))
txn.execute(query_after, after_args)
rows = self.cursor_to_dict(txn)
events_after = [r["event_id"] for r in rows]
if rows:
end_token = str(RoomStreamToken(
rows[-1]["topological_ordering"],
rows[-1]["stream_ordering"],
))
else:
end_token = str(token)
return { return {
"before": { "before": {
@ -735,17 +672,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def has_room_changed_since(self, room_id, stream_id): def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id) return self._events_stream_cache.has_entity_changed(room_id, stream_id)
def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None,
direction='b', limit=-1, event_filter=None):
"""Returns list of events before or after a given token.
class StreamStore(StreamWorkerStore): Args:
def get_room_max_stream_ordering(self): txn
return self._stream_id_gen.get_current_token() room_id (str)
from_token (RoomStreamToken): The token used to stream from
to_token (RoomStreamToken|None): A token which if given limits the
results to only those before
direction(char): Either 'b' or 'f' to indicate whether we are
paginating forwards or backwards from `from_key`.
limit (int): The maximum number of events to return. Zero or less
means no limit.
event_filter (Filter|None): If provided filters the events to
those that match the filter.
def get_room_min_stream_ordering(self): Returns:
return self._backfill_id_gen.get_current_token() Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
as a list of _EventDictReturn and a token that points to the end
@defer.inlineCallbacks of the result set.
def paginate_room_events(self, room_id, from_key, to_key=None, """
direction='b', limit=-1, event_filter=None):
# Tokens really represent positions between elements, but we use # Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence # the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities. # we have a bit of asymmetry when it comes to equalities.
@ -753,20 +701,20 @@ class StreamStore(StreamWorkerStore):
if direction == 'b': if direction == 'b':
order = "DESC" order = "DESC"
bounds = upper_bound( bounds = upper_bound(
RoomStreamToken.parse(from_key), self.database_engine from_token, self.database_engine
) )
if to_key: if to_token:
bounds = "%s AND %s" % (bounds, lower_bound( bounds = "%s AND %s" % (bounds, lower_bound(
RoomStreamToken.parse(to_key), self.database_engine to_token, self.database_engine
)) ))
else: else:
order = "ASC" order = "ASC"
bounds = lower_bound( bounds = lower_bound(
RoomStreamToken.parse(from_key), self.database_engine from_token, self.database_engine
) )
if to_key: if to_token:
bounds = "%s AND %s" % (bounds, upper_bound( bounds = "%s AND %s" % (bounds, upper_bound(
RoomStreamToken.parse(to_key), self.database_engine to_token, self.database_engine
)) ))
filter_clause, filter_args = filter_to_clause(event_filter) filter_clause, filter_args = filter_to_clause(event_filter)
@ -782,7 +730,8 @@ class StreamStore(StreamWorkerStore):
limit_str = "" limit_str = ""
sql = ( sql = (
"SELECT * FROM events" "SELECT event_id, topological_ordering, stream_ordering"
" FROM events"
" WHERE outlier = ? AND room_id = ? AND %(bounds)s" " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s," " ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s %(limit)s" " stream_ordering %(order)s %(limit)s"
@ -792,35 +741,72 @@ class StreamStore(StreamWorkerStore):
"limit": limit_str "limit": limit_str
} }
def f(txn): txn.execute(sql, args)
txn.execute(sql, args)
rows = self.cursor_to_dict(txn) rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
if rows: if rows:
topo = rows[-1]["topological_ordering"] topo = rows[-1].topological_ordering
toke = rows[-1]["stream_ordering"] toke = rows[-1].stream_ordering
if direction == 'b': if direction == 'b':
# Tokens are positions between events. # Tokens are positions between events.
# This token points *after* the last event in the chunk. # This token points *after* the last event in the chunk.
# We need it to point to the event before it in the chunk # We need it to point to the event before it in the chunk
# when we are going backwards so we subtract one from the # when we are going backwards so we subtract one from the
# stream part. # stream part.
toke -= 1 toke -= 1
next_token = str(RoomStreamToken(topo, toke)) next_token = RoomStreamToken(topo, toke)
else: else:
# TODO (erikj): We should work out what to do here instead. # TODO (erikj): We should work out what to do here instead.
next_token = to_key if to_key else from_key next_token = to_token if to_token else from_token
return rows, next_token, return rows, str(next_token),
rows, token = yield self.runInteraction("paginate_room_events", f) @defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1, event_filter=None):
"""Returns list of events before or after a given token.
Args:
room_id (str)
from_key (str): The token used to stream from
to_key (str|None): A token which if given limits the results to
only those before
direction(char): Either 'b' or 'f' to indicate whether we are
paginating forwards or backwards from `from_key`.
limit (int): The maximum number of events to return. Zero or less
means no limit.
event_filter (Filter|None): If provided filters the events to
those that match the filter.
Returns:
tuple[list[dict], str]: Returns the results as a list of dicts and
a token that points to the end of the result set. The dicts have
the keys "event_id", "topological_ordering" and "stream_orderign".
"""
from_key = RoomStreamToken.parse(from_key)
if to_key:
to_key = RoomStreamToken.parse(to_key)
rows, token = yield self.runInteraction(
"paginate_room_events", self._paginate_room_events_txn,
room_id, from_key, to_key, direction, limit, event_filter,
)
events = yield self._get_events( events = yield self._get_events(
[r["event_id"] for r in rows], [r.event_id for r in rows],
get_prev_content=True get_prev_content=True
) )
self._set_before_and_after(events, rows) self._set_before_and_after(events, rows)
defer.returnValue((events, token)) defer.returnValue((events, token))
class StreamStore(StreamWorkerStore):
def get_room_max_stream_ordering(self):
return self._stream_id_gen.get_current_token()
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()

View File

@ -22,6 +22,8 @@ from twisted.internet import defer
import simplejson as json import simplejson as json
import logging import logging
from six.moves import range
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -98,7 +100,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
batch_size = 50 batch_size = 50
results = [] results = []
for i in xrange(0, len(tag_ids), batch_size): for i in range(0, len(tag_ids), batch_size):
tags = yield self.runInteraction( tags = yield self.runInteraction(
"get_all_updated_tag_content", "get_all_updated_tag_content",
get_tag_content, get_tag_content,

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api.errors import SynapseError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task from twisted.internet import defer, reactor, task
@ -24,11 +23,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeferredTimedOutError(SynapseError):
def __init__(self):
super(DeferredTimedOutError, self).__init__(504, "Timed out")
def unwrapFirstError(failure): def unwrapFirstError(failure):
# defer.gatherResults and DeferredLists wrap failures. # defer.gatherResults and DeferredLists wrap failures.
failure.trap(defer.FirstError) failure.trap(defer.FirstError)
@ -85,53 +79,3 @@ class Clock(object):
except Exception: except Exception:
if not ignore_errs: if not ignore_errs:
raise raise
def time_bound_deferred(self, given_deferred, time_out):
if given_deferred.called:
return given_deferred
ret_deferred = defer.Deferred()
def timed_out_fn():
e = DeferredTimedOutError()
try:
ret_deferred.errback(e)
except Exception:
pass
try:
given_deferred.cancel()
except Exception:
pass
timer = None
def cancel(res):
try:
self.cancel_call_later(timer)
except Exception:
pass
return res
ret_deferred.addBoth(cancel)
def success(res):
try:
ret_deferred.callback(res)
except Exception:
pass
return res
def err(res):
try:
ret_deferred.errback(res)
except Exception:
pass
given_deferred.addCallbacks(callback=success, errback=err)
timer = self.call_later(time_out, timed_out_fn)
return ret_deferred

Some files were not shown because too many files have changed in this diff Show More