Merge branch 'develop' of github.com:matrix-org/synapse into anoa/dm_room_upgrade
This commit is contained in:
commit
821b65aeb5
|
@ -1,11 +1,7 @@
|
||||||
[run]
|
[run]
|
||||||
branch = True
|
branch = True
|
||||||
parallel = True
|
parallel = True
|
||||||
source = synapse
|
include = synapse/*
|
||||||
|
|
||||||
[paths]
|
|
||||||
source=
|
|
||||||
coverage
|
|
||||||
|
|
||||||
[report]
|
[report]
|
||||||
precision = 2
|
precision = 2
|
||||||
|
|
|
@ -25,9 +25,9 @@ homeserver*.pid
|
||||||
*.tls.dh
|
*.tls.dh
|
||||||
*.tls.key
|
*.tls.key
|
||||||
|
|
||||||
.coverage
|
.coverage*
|
||||||
.coverage.*
|
coverage.*
|
||||||
!.coverage.rc
|
!.coveragerc
|
||||||
htmlcov
|
htmlcov
|
||||||
|
|
||||||
demo/*/*.db
|
demo/*/*.db
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.
|
|
@ -0,0 +1 @@
|
||||||
|
Synapse can now automatically provision TLS certificates via ACME (the protocol used by CAs like Let's Encrypt).
|
|
@ -0,0 +1 @@
|
||||||
|
Fix bug when rejecting remote invites
|
|
@ -0,0 +1 @@
|
||||||
|
Search now includes results from predecessor rooms after a room upgrade.
|
|
@ -0,0 +1 @@
|
||||||
|
Config option to disable requesting MSISDN on registration.
|
|
@ -0,0 +1 @@
|
||||||
|
Move SRV logic into the Agent layer
|
|
@ -0,0 +1 @@
|
||||||
|
Apply a unique index to the user_ips table, preventing duplicates.
|
|
@ -0,0 +1 @@
|
||||||
|
debian package: symlink to explicit python version
|
|
@ -0,0 +1 @@
|
||||||
|
Apply a unique index to the user_ips table, preventing duplicates.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix None guard in calling config.server.is_threepid_reserved
|
|
@ -0,0 +1 @@
|
||||||
|
Add infrastructure to support different event formats
|
|
@ -0,0 +1 @@
|
||||||
|
Generate the debian config during build
|
|
@ -0,0 +1 @@
|
||||||
|
Add a metric for tracking event stream position of the user directory.
|
|
@ -0,0 +1 @@
|
||||||
|
Add infrastructure to support different event formats
|
|
@ -0,0 +1 @@
|
||||||
|
Add infrastructure to support different event formats
|
|
@ -0,0 +1 @@
|
||||||
|
Don't send IP addresses as SNI
|
|
@ -0,0 +1 @@
|
||||||
|
Clarify documentation for the `public_baseurl` config param
|
|
@ -0,0 +1 @@
|
||||||
|
Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix UnboundLocalError in post_urlencoded_get_json
|
|
@ -0,0 +1 @@
|
||||||
|
Add a timeout to filtered room directory queries.
|
|
@ -0,0 +1 @@
|
||||||
|
Move SRV logic into the Agent layer
|
|
@ -6,7 +6,16 @@
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
export DH_VIRTUALENV_INSTALL_ROOT=/opt/venvs
|
export DH_VIRTUALENV_INSTALL_ROOT=/opt/venvs
|
||||||
SNAKE=/usr/bin/python3
|
|
||||||
|
# make sure that the virtualenv links to the specific version of python, by
|
||||||
|
# dereferencing the python3 symlink.
|
||||||
|
#
|
||||||
|
# Otherwise, if somebody tries to install (say) the stretch package on buster,
|
||||||
|
# they will get a confusing error about "No module named 'synapse'", because
|
||||||
|
# python won't look in the right directory. At least this way, the error will
|
||||||
|
# be a *bit* more obvious.
|
||||||
|
#
|
||||||
|
SNAKE=`readlink -e /usr/bin/python3`
|
||||||
|
|
||||||
# try to set the CFLAGS so any compiled C extensions are compiled with the most
|
# try to set the CFLAGS so any compiled C extensions are compiled with the most
|
||||||
# generic as possible x64 instructions, so that compiling it on a new Intel chip
|
# generic as possible x64 instructions, so that compiling it on a new Intel chip
|
||||||
|
@ -36,6 +45,10 @@ dh_virtualenv \
|
||||||
--extra-pip-arg="--compile" \
|
--extra-pip-arg="--compile" \
|
||||||
--extras="all"
|
--extras="all"
|
||||||
|
|
||||||
|
PACKAGE_BUILD_DIR="debian/matrix-synapse-py3"
|
||||||
|
VIRTUALENV_DIR="${PACKAGE_BUILD_DIR}${DH_VIRTUALENV_INSTALL_ROOT}/matrix-synapse"
|
||||||
|
TARGET_PYTHON="${VIRTUALENV_DIR}/bin/python"
|
||||||
|
|
||||||
# we copy the tests to a temporary directory so that we can put them on the
|
# we copy the tests to a temporary directory so that we can put them on the
|
||||||
# PYTHONPATH without putting the uninstalled synapse on the pythonpath.
|
# PYTHONPATH without putting the uninstalled synapse on the pythonpath.
|
||||||
tmpdir=`mktemp -d`
|
tmpdir=`mktemp -d`
|
||||||
|
@ -44,5 +57,35 @@ trap "rm -r $tmpdir" EXIT
|
||||||
cp -r tests "$tmpdir"
|
cp -r tests "$tmpdir"
|
||||||
|
|
||||||
PYTHONPATH="$tmpdir" \
|
PYTHONPATH="$tmpdir" \
|
||||||
debian/matrix-synapse-py3/opt/venvs/matrix-synapse/bin/python \
|
"${TARGET_PYTHON}" -B -m twisted.trial --reporter=text -j2 tests
|
||||||
-B -m twisted.trial --reporter=text -j2 tests
|
|
||||||
|
# build the config file
|
||||||
|
"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_config" \
|
||||||
|
--config-dir="/etc/matrix-synapse" \
|
||||||
|
--data-dir="/var/lib/matrix-synapse" |
|
||||||
|
perl -pe '
|
||||||
|
# tweak the paths to the tls certs and signing keys
|
||||||
|
/^tls_.*_path:/ and s/SERVERNAME/homeserver/;
|
||||||
|
/^signing_key_path:/ and s/SERVERNAME/homeserver/;
|
||||||
|
|
||||||
|
# tweak the pid file location
|
||||||
|
/^pid_file:/ and s#:.*#: "/var/run/matrix-synapse.pid"#;
|
||||||
|
|
||||||
|
# tweak the path to the log config
|
||||||
|
/^log_config:/ and s/SERVERNAME\.log\.config/log.yaml/;
|
||||||
|
|
||||||
|
# tweak the path to the media store
|
||||||
|
/^media_store_path:/ and s#/media_store#/media#;
|
||||||
|
|
||||||
|
# remove the server_name setting, which is set in a separate file
|
||||||
|
/^server_name:/ and $_ = "#\n# This is set in /etc/matrix-synapse/conf.d/server_name.yaml for Debian installations.\n# $_";
|
||||||
|
|
||||||
|
# remove the report_stats setting, which is set in a separate file
|
||||||
|
/^# report_stats:/ and $_ = "";
|
||||||
|
|
||||||
|
' > "${PACKAGE_BUILD_DIR}/etc/matrix-synapse/homeserver.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
# add a dependency on the right version of python to substvars.
|
||||||
|
PYPKG=`basename $SNAKE`
|
||||||
|
echo "synapse:pydepends=$PYPKG" >> debian/matrix-synapse-py3.substvars
|
||||||
|
|
|
@ -27,8 +27,8 @@ Depends:
|
||||||
adduser,
|
adduser,
|
||||||
debconf,
|
debconf,
|
||||||
python3-distutils|libpython3-stdlib (<< 3.6),
|
python3-distutils|libpython3-stdlib (<< 3.6),
|
||||||
python3,
|
|
||||||
${misc:Depends},
|
${misc:Depends},
|
||||||
|
${synapse:pydepends},
|
||||||
# some of our scripts use perl, but none of them are important,
|
# some of our scripts use perl, but none of them are important,
|
||||||
# so we put perl:Depends in Suggests rather than Depends.
|
# so we put perl:Depends in Suggests rather than Depends.
|
||||||
Suggests:
|
Suggests:
|
||||||
|
|
|
@ -1,614 +0,0 @@
|
||||||
# vim:ft=yaml
|
|
||||||
# PEM encoded X509 certificate for TLS.
|
|
||||||
# You can replace the self-signed certificate that synapse
|
|
||||||
# autogenerates on launch with your own SSL certificate + key pair
|
|
||||||
# if you like. Any required intermediary certificates can be
|
|
||||||
# appended after the primary certificate in hierarchical order.
|
|
||||||
tls_certificate_path: "/etc/matrix-synapse/homeserver.tls.crt"
|
|
||||||
|
|
||||||
# PEM encoded private key for TLS
|
|
||||||
tls_private_key_path: "/etc/matrix-synapse/homeserver.tls.key"
|
|
||||||
|
|
||||||
# Don't bind to the https port
|
|
||||||
no_tls: False
|
|
||||||
|
|
||||||
# List of allowed TLS fingerprints for this server to publish along
|
|
||||||
# with the signing keys for this server. Other matrix servers that
|
|
||||||
# make HTTPS requests to this server will check that the TLS
|
|
||||||
# certificates returned by this server match one of the fingerprints.
|
|
||||||
#
|
|
||||||
# Synapse automatically adds the fingerprint of its own certificate
|
|
||||||
# to the list. So if federation traffic is handled directly by synapse
|
|
||||||
# then no modification to the list is required.
|
|
||||||
#
|
|
||||||
# If synapse is run behind a load balancer that handles the TLS then it
|
|
||||||
# will be necessary to add the fingerprints of the certificates used by
|
|
||||||
# the loadbalancers to this list if they are different to the one
|
|
||||||
# synapse is using.
|
|
||||||
#
|
|
||||||
# Homeservers are permitted to cache the list of TLS fingerprints
|
|
||||||
# returned in the key responses up to the "valid_until_ts" returned in
|
|
||||||
# key. It may be necessary to publish the fingerprints of a new
|
|
||||||
# certificate and wait until the "valid_until_ts" of the previous key
|
|
||||||
# responses have passed before deploying it.
|
|
||||||
#
|
|
||||||
# You can calculate a fingerprint from a given TLS listener via:
|
|
||||||
# openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
|
|
||||||
# openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
|
|
||||||
# or by checking matrix.org/federationtester/api/report?server_name=$host
|
|
||||||
#
|
|
||||||
tls_fingerprints: []
|
|
||||||
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
|
|
||||||
|
|
||||||
|
|
||||||
## Server ##
|
|
||||||
|
|
||||||
# When running as a daemon, the file to store the pid in
|
|
||||||
pid_file: "/var/run/matrix-synapse.pid"
|
|
||||||
|
|
||||||
# CPU affinity mask. Setting this restricts the CPUs on which the
|
|
||||||
# process will be scheduled. It is represented as a bitmask, with the
|
|
||||||
# lowest order bit corresponding to the first logical CPU and the
|
|
||||||
# highest order bit corresponding to the last logical CPU. Not all CPUs
|
|
||||||
# may exist on a given system but a mask may specify more CPUs than are
|
|
||||||
# present.
|
|
||||||
#
|
|
||||||
# For example:
|
|
||||||
# 0x00000001 is processor #0,
|
|
||||||
# 0x00000003 is processors #0 and #1,
|
|
||||||
# 0xFFFFFFFF is all processors (#0 through #31).
|
|
||||||
#
|
|
||||||
# Pinning a Python process to a single CPU is desirable, because Python
|
|
||||||
# is inherently single-threaded due to the GIL, and can suffer a
|
|
||||||
# 30-40% slowdown due to cache blow-out and thread context switching
|
|
||||||
# if the scheduler happens to schedule the underlying threads across
|
|
||||||
# different cores. See
|
|
||||||
# https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
|
|
||||||
#
|
|
||||||
# cpu_affinity: 0xFFFFFFFF
|
|
||||||
|
|
||||||
# The path to the web client which will be served at /_matrix/client/
|
|
||||||
# if 'webclient' is configured under the 'listeners' configuration.
|
|
||||||
#
|
|
||||||
# web_client_location: "/path/to/web/root"
|
|
||||||
|
|
||||||
# The public-facing base URL for the client API (not including _matrix/...)
|
|
||||||
# public_baseurl: https://example.com:8448/
|
|
||||||
|
|
||||||
# Set the soft limit on the number of file descriptors synapse can use
|
|
||||||
# Zero is used to indicate synapse should set the soft limit to the
|
|
||||||
# hard limit.
|
|
||||||
soft_file_limit: 0
|
|
||||||
|
|
||||||
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
|
|
||||||
# gc_thresholds: [700, 10, 10]
|
|
||||||
|
|
||||||
# Set the limit on the returned events in the timeline in the get
|
|
||||||
# and sync operations. The default value is -1, means no upper limit.
|
|
||||||
# filter_timeline_limit: 5000
|
|
||||||
|
|
||||||
# Whether room invites to users on this server should be blocked
|
|
||||||
# (except those sent by local server admins). The default is False.
|
|
||||||
# block_non_admin_invites: True
|
|
||||||
|
|
||||||
# Restrict federation to the following whitelist of domains.
|
|
||||||
# N.B. we recommend also firewalling your federation listener to limit
|
|
||||||
# inbound federation traffic as early as possible, rather than relying
|
|
||||||
# purely on this application-layer restriction. If not specified, the
|
|
||||||
# default is to whitelist everything.
|
|
||||||
#
|
|
||||||
# federation_domain_whitelist:
|
|
||||||
# - lon.example.com
|
|
||||||
# - nyc.example.com
|
|
||||||
# - syd.example.com
|
|
||||||
|
|
||||||
# List of ports that Synapse should listen on, their purpose and their
|
|
||||||
# configuration.
|
|
||||||
listeners:
|
|
||||||
# Main HTTPS listener
|
|
||||||
# For when matrix traffic is sent directly to synapse.
|
|
||||||
-
|
|
||||||
# The port to listen for HTTPS requests on.
|
|
||||||
port: 8448
|
|
||||||
|
|
||||||
# Local addresses to listen on.
|
|
||||||
# On Linux and Mac OS, `::` will listen on all IPv4 and IPv6
|
|
||||||
# addresses by default. For most other OSes, this will only listen
|
|
||||||
# on IPv6.
|
|
||||||
bind_addresses:
|
|
||||||
- '::'
|
|
||||||
- '0.0.0.0'
|
|
||||||
|
|
||||||
# This is a 'http' listener, allows us to specify 'resources'.
|
|
||||||
type: http
|
|
||||||
|
|
||||||
tls: true
|
|
||||||
|
|
||||||
# Use the X-Forwarded-For (XFF) header as the client IP and not the
|
|
||||||
# actual client IP.
|
|
||||||
x_forwarded: false
|
|
||||||
|
|
||||||
# List of HTTP resources to serve on this listener.
|
|
||||||
resources:
|
|
||||||
-
|
|
||||||
# List of resources to host on this listener.
|
|
||||||
names:
|
|
||||||
- client # The client-server APIs, both v1 and v2
|
|
||||||
- webclient # The bundled webclient.
|
|
||||||
|
|
||||||
# Should synapse compress HTTP responses to clients that support it?
|
|
||||||
# This should be disabled if running synapse behind a load balancer
|
|
||||||
# that can do automatic compression.
|
|
||||||
compress: true
|
|
||||||
|
|
||||||
- names: [federation] # Federation APIs
|
|
||||||
compress: false
|
|
||||||
|
|
||||||
# optional list of additional endpoints which can be loaded via
|
|
||||||
# dynamic modules
|
|
||||||
# additional_resources:
|
|
||||||
# "/_matrix/my/custom/endpoint":
|
|
||||||
# module: my_module.CustomRequestHandler
|
|
||||||
# config: {}
|
|
||||||
|
|
||||||
# Unsecure HTTP listener,
|
|
||||||
# For when matrix traffic passes through loadbalancer that unwraps TLS.
|
|
||||||
- port: 8008
|
|
||||||
tls: false
|
|
||||||
bind_addresses: ['::', '0.0.0.0']
|
|
||||||
type: http
|
|
||||||
|
|
||||||
x_forwarded: false
|
|
||||||
|
|
||||||
resources:
|
|
||||||
- names: [client, webclient]
|
|
||||||
compress: true
|
|
||||||
- names: [federation]
|
|
||||||
compress: false
|
|
||||||
|
|
||||||
# Turn on the twisted ssh manhole service on localhost on the given
|
|
||||||
# port.
|
|
||||||
# - port: 9000
|
|
||||||
# bind_addresses: ['::1', '127.0.0.1']
|
|
||||||
# type: manhole
|
|
||||||
|
|
||||||
|
|
||||||
# Database configuration
|
|
||||||
database:
|
|
||||||
# The database engine name
|
|
||||||
name: "sqlite3"
|
|
||||||
# Arguments to pass to the engine
|
|
||||||
args:
|
|
||||||
# Path to the database
|
|
||||||
database: "/var/lib/matrix-synapse/homeserver.db"
|
|
||||||
|
|
||||||
# Number of events to cache in memory.
|
|
||||||
event_cache_size: "10K"
|
|
||||||
|
|
||||||
|
|
||||||
# A yaml python logging config file
|
|
||||||
log_config: "/etc/matrix-synapse/log.yaml"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Ratelimiting ##
|
|
||||||
|
|
||||||
# Number of messages a client can send per second
|
|
||||||
rc_messages_per_second: 0.2
|
|
||||||
|
|
||||||
# Number of message a client can send before being throttled
|
|
||||||
rc_message_burst_count: 10.0
|
|
||||||
|
|
||||||
# The federation window size in milliseconds
|
|
||||||
federation_rc_window_size: 1000
|
|
||||||
|
|
||||||
# The number of federation requests from a single server in a window
|
|
||||||
# before the server will delay processing the request.
|
|
||||||
federation_rc_sleep_limit: 10
|
|
||||||
|
|
||||||
# The duration in milliseconds to delay processing events from
|
|
||||||
# remote servers by if they go over the sleep limit.
|
|
||||||
federation_rc_sleep_delay: 500
|
|
||||||
|
|
||||||
# The maximum number of concurrent federation requests allowed
|
|
||||||
# from a single server
|
|
||||||
federation_rc_reject_limit: 50
|
|
||||||
|
|
||||||
# The number of federation requests to concurrently process from a
|
|
||||||
# single server
|
|
||||||
federation_rc_concurrent: 3
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Directory where uploaded images and attachments are stored.
|
|
||||||
media_store_path: "/var/lib/matrix-synapse/media"
|
|
||||||
|
|
||||||
# Media storage providers allow media to be stored in different
|
|
||||||
# locations.
|
|
||||||
# media_storage_providers:
|
|
||||||
# - module: file_system
|
|
||||||
# # Whether to write new local files.
|
|
||||||
# store_local: false
|
|
||||||
# # Whether to write new remote media
|
|
||||||
# store_remote: false
|
|
||||||
# # Whether to block upload requests waiting for write to this
|
|
||||||
# # provider to complete
|
|
||||||
# store_synchronous: false
|
|
||||||
# config:
|
|
||||||
# directory: /mnt/some/other/directory
|
|
||||||
|
|
||||||
# Directory where in-progress uploads are stored.
|
|
||||||
uploads_path: "/var/lib/matrix-synapse/uploads"
|
|
||||||
|
|
||||||
# The largest allowed upload size in bytes
|
|
||||||
max_upload_size: "10M"
|
|
||||||
|
|
||||||
# Maximum number of pixels that will be thumbnailed
|
|
||||||
max_image_pixels: "32M"
|
|
||||||
|
|
||||||
# Whether to generate new thumbnails on the fly to precisely match
|
|
||||||
# the resolution requested by the client. If true then whenever
|
|
||||||
# a new resolution is requested by the client the server will
|
|
||||||
# generate a new thumbnail. If false the server will pick a thumbnail
|
|
||||||
# from a precalculated list.
|
|
||||||
dynamic_thumbnails: false
|
|
||||||
|
|
||||||
# List of thumbnail to precalculate when an image is uploaded.
|
|
||||||
thumbnail_sizes:
|
|
||||||
- width: 32
|
|
||||||
height: 32
|
|
||||||
method: crop
|
|
||||||
- width: 96
|
|
||||||
height: 96
|
|
||||||
method: crop
|
|
||||||
- width: 320
|
|
||||||
height: 240
|
|
||||||
method: scale
|
|
||||||
- width: 640
|
|
||||||
height: 480
|
|
||||||
method: scale
|
|
||||||
- width: 800
|
|
||||||
height: 600
|
|
||||||
method: scale
|
|
||||||
|
|
||||||
# Is the preview URL API enabled? If enabled, you *must* specify
|
|
||||||
# an explicit url_preview_ip_range_blacklist of IPs that the spider is
|
|
||||||
# denied from accessing.
|
|
||||||
url_preview_enabled: False
|
|
||||||
|
|
||||||
# List of IP address CIDR ranges that the URL preview spider is denied
|
|
||||||
# from accessing. There are no defaults: you must explicitly
|
|
||||||
# specify a list for URL previewing to work. You should specify any
|
|
||||||
# internal services in your network that you do not want synapse to try
|
|
||||||
# to connect to, otherwise anyone in any Matrix room could cause your
|
|
||||||
# synapse to issue arbitrary GET requests to your internal services,
|
|
||||||
# causing serious security issues.
|
|
||||||
#
|
|
||||||
# url_preview_ip_range_blacklist:
|
|
||||||
# - '127.0.0.0/8'
|
|
||||||
# - '10.0.0.0/8'
|
|
||||||
# - '172.16.0.0/12'
|
|
||||||
# - '192.168.0.0/16'
|
|
||||||
# - '100.64.0.0/10'
|
|
||||||
# - '169.254.0.0/16'
|
|
||||||
#
|
|
||||||
# List of IP address CIDR ranges that the URL preview spider is allowed
|
|
||||||
# to access even if they are specified in url_preview_ip_range_blacklist.
|
|
||||||
# This is useful for specifying exceptions to wide-ranging blacklisted
|
|
||||||
# target IP ranges - e.g. for enabling URL previews for a specific private
|
|
||||||
# website only visible in your network.
|
|
||||||
#
|
|
||||||
# url_preview_ip_range_whitelist:
|
|
||||||
# - '192.168.1.1'
|
|
||||||
|
|
||||||
# Optional list of URL matches that the URL preview spider is
|
|
||||||
# denied from accessing. You should use url_preview_ip_range_blacklist
|
|
||||||
# in preference to this, otherwise someone could define a public DNS
|
|
||||||
# entry that points to a private IP address and circumvent the blacklist.
|
|
||||||
# This is more useful if you know there is an entire shape of URL that
|
|
||||||
# you know that will never want synapse to try to spider.
|
|
||||||
#
|
|
||||||
# Each list entry is a dictionary of url component attributes as returned
|
|
||||||
# by urlparse.urlsplit as applied to the absolute form of the URL. See
|
|
||||||
# https://docs.python.org/2/library/urlparse.html#urlparse.urlsplit
|
|
||||||
# The values of the dictionary are treated as an filename match pattern
|
|
||||||
# applied to that component of URLs, unless they start with a ^ in which
|
|
||||||
# case they are treated as a regular expression match. If all the
|
|
||||||
# specified component matches for a given list item succeed, the URL is
|
|
||||||
# blacklisted.
|
|
||||||
#
|
|
||||||
# url_preview_url_blacklist:
|
|
||||||
# # blacklist any URL with a username in its URI
|
|
||||||
# - username: '*'
|
|
||||||
#
|
|
||||||
# # blacklist all *.google.com URLs
|
|
||||||
# - netloc: 'google.com'
|
|
||||||
# - netloc: '*.google.com'
|
|
||||||
#
|
|
||||||
# # blacklist all plain HTTP URLs
|
|
||||||
# - scheme: 'http'
|
|
||||||
#
|
|
||||||
# # blacklist http(s)://www.acme.com/foo
|
|
||||||
# - netloc: 'www.acme.com'
|
|
||||||
# path: '/foo'
|
|
||||||
#
|
|
||||||
# # blacklist any URL with a literal IPv4 address
|
|
||||||
# - netloc: '^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$'
|
|
||||||
|
|
||||||
# The largest allowed URL preview spidering size in bytes
|
|
||||||
max_spider_size: "10M"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Captcha ##
|
|
||||||
# See docs/CAPTCHA_SETUP for full details of configuring this.
|
|
||||||
|
|
||||||
# This Home Server's ReCAPTCHA public key.
|
|
||||||
recaptcha_public_key: "YOUR_PUBLIC_KEY"
|
|
||||||
|
|
||||||
# This Home Server's ReCAPTCHA private key.
|
|
||||||
recaptcha_private_key: "YOUR_PRIVATE_KEY"
|
|
||||||
|
|
||||||
# Enables ReCaptcha checks when registering, preventing signup
|
|
||||||
# unless a captcha is answered. Requires a valid ReCaptcha
|
|
||||||
# public/private key.
|
|
||||||
enable_registration_captcha: False
|
|
||||||
|
|
||||||
# A secret key used to bypass the captcha test entirely.
|
|
||||||
#captcha_bypass_secret: "YOUR_SECRET_HERE"
|
|
||||||
|
|
||||||
# The API endpoint to use for verifying m.login.recaptcha responses.
|
|
||||||
recaptcha_siteverify_api: "https://www.google.com/recaptcha/api/siteverify"
|
|
||||||
|
|
||||||
|
|
||||||
## Turn ##
|
|
||||||
|
|
||||||
# The public URIs of the TURN server to give to clients
|
|
||||||
turn_uris: []
|
|
||||||
|
|
||||||
# The shared secret used to compute passwords for the TURN server
|
|
||||||
turn_shared_secret: "YOUR_SHARED_SECRET"
|
|
||||||
|
|
||||||
# The Username and password if the TURN server needs them and
|
|
||||||
# does not use a token
|
|
||||||
#turn_username: "TURNSERVER_USERNAME"
|
|
||||||
#turn_password: "TURNSERVER_PASSWORD"
|
|
||||||
|
|
||||||
# How long generated TURN credentials last
|
|
||||||
turn_user_lifetime: "1h"
|
|
||||||
|
|
||||||
# Whether guests should be allowed to use the TURN server.
|
|
||||||
# This defaults to True, otherwise VoIP will be unreliable for guests.
|
|
||||||
# However, it does introduce a slight security risk as it allows users to
|
|
||||||
# connect to arbitrary endpoints without having first signed up for a
|
|
||||||
# valid account (e.g. by passing a CAPTCHA).
|
|
||||||
turn_allow_guests: False
|
|
||||||
|
|
||||||
|
|
||||||
## Registration ##
|
|
||||||
|
|
||||||
# Enable registration for new users.
|
|
||||||
enable_registration: False
|
|
||||||
|
|
||||||
# The user must provide all of the below types of 3PID when registering.
|
|
||||||
#
|
|
||||||
# registrations_require_3pid:
|
|
||||||
# - email
|
|
||||||
# - msisdn
|
|
||||||
|
|
||||||
# Mandate that users are only allowed to associate certain formats of
|
|
||||||
# 3PIDs with accounts on this server.
|
|
||||||
#
|
|
||||||
# allowed_local_3pids:
|
|
||||||
# - medium: email
|
|
||||||
# pattern: ".*@matrix\.org"
|
|
||||||
# - medium: email
|
|
||||||
# pattern: ".*@vector\.im"
|
|
||||||
# - medium: msisdn
|
|
||||||
# pattern: "\+44"
|
|
||||||
|
|
||||||
# If set, allows registration by anyone who also has the shared
|
|
||||||
# secret, even if registration is otherwise disabled.
|
|
||||||
# registration_shared_secret: <PRIVATE STRING>
|
|
||||||
|
|
||||||
# Set the number of bcrypt rounds used to generate password hash.
|
|
||||||
# Larger numbers increase the work factor needed to generate the hash.
|
|
||||||
# The default number is 12 (which equates to 2^12 rounds).
|
|
||||||
# N.B. that increasing this will exponentially increase the time required
|
|
||||||
# to register or login - e.g. 24 => 2^24 rounds which will take >20 mins.
|
|
||||||
bcrypt_rounds: 12
|
|
||||||
|
|
||||||
# Allows users to register as guests without a password/email/etc, and
|
|
||||||
# participate in rooms hosted on this server which have been made
|
|
||||||
# accessible to anonymous users.
|
|
||||||
allow_guest_access: False
|
|
||||||
|
|
||||||
# The list of identity servers trusted to verify third party
|
|
||||||
# identifiers by this server.
|
|
||||||
trusted_third_party_id_servers:
|
|
||||||
- matrix.org
|
|
||||||
- vector.im
|
|
||||||
- riot.im
|
|
||||||
|
|
||||||
# Users who register on this homeserver will automatically be joined
|
|
||||||
# to these rooms
|
|
||||||
#auto_join_rooms:
|
|
||||||
# - "#example:example.com"
|
|
||||||
|
|
||||||
|
|
||||||
## Metrics ###
|
|
||||||
|
|
||||||
# Enable collection and rendering of performance metrics
|
|
||||||
enable_metrics: False
|
|
||||||
|
|
||||||
## API Configuration ##
|
|
||||||
|
|
||||||
# A list of event types that will be included in the room_invite_state
|
|
||||||
room_invite_state_types:
|
|
||||||
- "m.room.join_rules"
|
|
||||||
- "m.room.canonical_alias"
|
|
||||||
- "m.room.avatar"
|
|
||||||
- "m.room.name"
|
|
||||||
|
|
||||||
|
|
||||||
# A list of application service config file to use
|
|
||||||
app_service_config_files: []
|
|
||||||
|
|
||||||
|
|
||||||
# macaroon_secret_key: <PRIVATE STRING>
|
|
||||||
|
|
||||||
# Used to enable access token expiration.
|
|
||||||
expire_access_token: False
|
|
||||||
|
|
||||||
## Signing Keys ##
|
|
||||||
|
|
||||||
# Path to the signing key to sign messages with
|
|
||||||
signing_key_path: "/etc/matrix-synapse/homeserver.signing.key"
|
|
||||||
|
|
||||||
# The keys that the server used to sign messages with but won't use
|
|
||||||
# to sign new messages. E.g. it has lost its private key
|
|
||||||
old_signing_keys: {}
|
|
||||||
# "ed25519:auto":
|
|
||||||
# # Base64 encoded public key
|
|
||||||
# key: "The public part of your old signing key."
|
|
||||||
# # Millisecond POSIX timestamp when the key expired.
|
|
||||||
# expired_ts: 123456789123
|
|
||||||
|
|
||||||
# How long key response published by this server is valid for.
|
|
||||||
# Used to set the valid_until_ts in /key/v2 APIs.
|
|
||||||
# Determines how quickly servers will query to check which keys
|
|
||||||
# are still valid.
|
|
||||||
key_refresh_interval: "1d" # 1 Day.
|
|
||||||
|
|
||||||
# The trusted servers to download signing keys from.
|
|
||||||
perspectives:
|
|
||||||
servers:
|
|
||||||
"matrix.org":
|
|
||||||
verify_keys:
|
|
||||||
"ed25519:auto":
|
|
||||||
key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Enable SAML2 for registration and login. Uses pysaml2
|
|
||||||
# config_path: Path to the sp_conf.py configuration file
|
|
||||||
# idp_redirect_url: Identity provider URL which will redirect
|
|
||||||
# the user back to /login/saml2 with proper info.
|
|
||||||
# See pysaml2 docs for format of config.
|
|
||||||
#saml2_config:
|
|
||||||
# enabled: true
|
|
||||||
# config_path: "/home/erikj/git/synapse/sp_conf.py"
|
|
||||||
# idp_redirect_url: "http://test/idp"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Enable CAS for registration and login.
|
|
||||||
#cas_config:
|
|
||||||
# enabled: true
|
|
||||||
# server_url: "https://cas-server.com"
|
|
||||||
# service_url: "https://homeserver.domain.com:8448"
|
|
||||||
# #required_attributes:
|
|
||||||
# # name: value
|
|
||||||
|
|
||||||
|
|
||||||
# The JWT needs to contain a globally unique "sub" (subject) claim.
|
|
||||||
#
|
|
||||||
# jwt_config:
|
|
||||||
# enabled: true
|
|
||||||
# secret: "a secret"
|
|
||||||
# algorithm: "HS256"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Enable password for login.
|
|
||||||
password_config:
|
|
||||||
enabled: true
|
|
||||||
# Uncomment and change to a secret random string for extra security.
|
|
||||||
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
|
|
||||||
#pepper: ""
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Enable sending emails for notification events
|
|
||||||
# Defining a custom URL for Riot is only needed if email notifications
|
|
||||||
# should contain links to a self-hosted installation of Riot; when set
|
|
||||||
# the "app_name" setting is ignored.
|
|
||||||
#
|
|
||||||
# If your SMTP server requires authentication, the optional smtp_user &
|
|
||||||
# smtp_pass variables should be used
|
|
||||||
#
|
|
||||||
#email:
|
|
||||||
# enable_notifs: false
|
|
||||||
# smtp_host: "localhost"
|
|
||||||
# smtp_port: 25
|
|
||||||
# smtp_user: "exampleusername"
|
|
||||||
# smtp_pass: "examplepassword"
|
|
||||||
# require_transport_security: False
|
|
||||||
# notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
|
|
||||||
# app_name: Matrix
|
|
||||||
# template_dir: res/templates
|
|
||||||
# notif_template_html: notif_mail.html
|
|
||||||
# notif_template_text: notif_mail.txt
|
|
||||||
# notif_for_new_users: True
|
|
||||||
# riot_base_url: "http://localhost/riot"
|
|
||||||
|
|
||||||
|
|
||||||
# password_providers:
|
|
||||||
# - module: "ldap_auth_provider.LdapAuthProvider"
|
|
||||||
# config:
|
|
||||||
# enabled: true
|
|
||||||
# uri: "ldap://ldap.example.com:389"
|
|
||||||
# start_tls: true
|
|
||||||
# base: "ou=users,dc=example,dc=com"
|
|
||||||
# attributes:
|
|
||||||
# uid: "cn"
|
|
||||||
# mail: "email"
|
|
||||||
# name: "givenName"
|
|
||||||
# #bind_dn:
|
|
||||||
# #bind_password:
|
|
||||||
# #filter: "(objectClass=posixAccount)"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Clients requesting push notifications can either have the body of
|
|
||||||
# the message sent in the notification poke along with other details
|
|
||||||
# like the sender, or just the event ID and room ID (`event_id_only`).
|
|
||||||
# If clients choose the former, this option controls whether the
|
|
||||||
# notification request includes the content of the event (other details
|
|
||||||
# like the sender are still included). For `event_id_only` push, it
|
|
||||||
# has no effect.
|
|
||||||
|
|
||||||
# For modern android devices the notification content will still appear
|
|
||||||
# because it is loaded by the app. iPhone, however will send a
|
|
||||||
# notification saying only that a message arrived and who it came from.
|
|
||||||
#
|
|
||||||
#push:
|
|
||||||
# include_content: true
|
|
||||||
|
|
||||||
|
|
||||||
# spam_checker:
|
|
||||||
# module: "my_custom_project.SuperSpamChecker"
|
|
||||||
# config:
|
|
||||||
# example_option: 'things'
|
|
||||||
|
|
||||||
|
|
||||||
# Whether to allow non server admins to create groups on this server
|
|
||||||
enable_group_creation: false
|
|
||||||
|
|
||||||
# If enabled, non server admins can only create groups with local parts
|
|
||||||
# starting with this prefix
|
|
||||||
# group_creation_prefix: "unofficial/"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# User Directory configuration
|
|
||||||
#
|
|
||||||
# 'search_all_users' defines whether to search all users visible to your HS
|
|
||||||
# when searching the user directory, rather than limiting to users visible
|
|
||||||
# in public rooms. Defaults to false. If you set it True, you'll have to run
|
|
||||||
# UPDATE user_directory_stream_pos SET stream_id = NULL;
|
|
||||||
# on your database to tell it to rebuild the user_directory search indexes.
|
|
||||||
#
|
|
||||||
#user_directory:
|
|
||||||
# search_all_users: false
|
|
|
@ -1,2 +1 @@
|
||||||
debian/homeserver.yaml etc/matrix-synapse
|
|
||||||
debian/log.yaml etc/matrix-synapse
|
debian/log.yaml etc/matrix-synapse
|
||||||
|
|
|
@ -10,12 +10,12 @@
|
||||||
# can be passed on the commandline for debugging.
|
# can be passed on the commandline for debugging.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
DISTS = (
|
DISTS = (
|
||||||
"debian:stretch",
|
"debian:stretch",
|
||||||
|
|
|
@ -819,7 +819,9 @@ class Auth(object):
|
||||||
elif threepid:
|
elif threepid:
|
||||||
# If the user does not exist yet, but is signing up with a
|
# If the user does not exist yet, but is signing up with a
|
||||||
# reserved threepid then pass auth check
|
# reserved threepid then pass auth check
|
||||||
if is_threepid_reserved(self.hs.config, threepid):
|
if is_threepid_reserved(
|
||||||
|
self.hs.config.mau_limits_reserved_threepids, threepid
|
||||||
|
):
|
||||||
return
|
return
|
||||||
# Else if there is no room in the MAU bucket, bail
|
# Else if there is no room in the MAU bucket, bail
|
||||||
current_mau = yield self.store.get_monthly_active_count()
|
current_mau = yield self.store.get_monthly_active_count()
|
||||||
|
|
|
@ -120,6 +120,19 @@ KNOWN_ROOM_VERSIONS = {
|
||||||
RoomVersions.STATE_V2_TEST,
|
RoomVersions.STATE_V2_TEST,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EventFormatVersions(object):
|
||||||
|
"""This is an internal enum for tracking the version of the event format,
|
||||||
|
independently from the room version.
|
||||||
|
"""
|
||||||
|
V1 = 1
|
||||||
|
|
||||||
|
|
||||||
|
KNOWN_EVENT_FORMAT_VERSIONS = {
|
||||||
|
EventFormatVersions.V1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
ServerNoticeMsgType = "m.server_notice"
|
ServerNoticeMsgType = "m.server_notice"
|
||||||
ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
|
ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
|
||||||
|
|
||||||
|
|
|
@ -444,6 +444,20 @@ class Filter(object):
|
||||||
def include_redundant_members(self):
|
def include_redundant_members(self):
|
||||||
return self.filter_json.get("include_redundant_members", False)
|
return self.filter_json.get("include_redundant_members", False)
|
||||||
|
|
||||||
|
def with_room_ids(self, room_ids):
|
||||||
|
"""Returns a new filter with the given room IDs appended.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_ids (iterable[unicode]): The room_ids to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
filter: A new filter including the given rooms and the old
|
||||||
|
filter's rooms.
|
||||||
|
"""
|
||||||
|
newFilter = Filter(self.filter_json)
|
||||||
|
newFilter.rooms += room_ids
|
||||||
|
return newFilter
|
||||||
|
|
||||||
|
|
||||||
def _matches_wildcard(actual_value, filter_value):
|
def _matches_wildcard(actual_value, filter_value):
|
||||||
if filter_value.endswith("*"):
|
if filter_value.endswith("*"):
|
||||||
|
|
|
@ -13,10 +13,12 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
from six import iteritems
|
from six import iteritems
|
||||||
|
|
||||||
|
@ -324,17 +326,12 @@ def setup(config_options):
|
||||||
|
|
||||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
|
||||||
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
|
||||||
|
|
||||||
database_engine = create_engine(config.database_config)
|
database_engine = create_engine(config.database_config)
|
||||||
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
|
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
|
||||||
|
|
||||||
hs = SynapseHomeServer(
|
hs = SynapseHomeServer(
|
||||||
config.server_name,
|
config.server_name,
|
||||||
db_config=config.database_config,
|
db_config=config.database_config,
|
||||||
tls_server_context_factory=tls_server_context_factory,
|
|
||||||
tls_client_options_factory=tls_client_options_factory,
|
|
||||||
config=config,
|
config=config,
|
||||||
version_string="Synapse/" + get_version_string(synapse),
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
database_engine=database_engine,
|
database_engine=database_engine,
|
||||||
|
@ -361,12 +358,53 @@ def setup(config_options):
|
||||||
logger.info("Database prepared in %s.", config.database_config['name'])
|
logger.info("Database prepared in %s.", config.database_config['name'])
|
||||||
|
|
||||||
hs.setup()
|
hs.setup()
|
||||||
hs.start_listening()
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def start():
|
def start():
|
||||||
hs.get_pusherpool().start()
|
try:
|
||||||
hs.get_datastore().start_profiling()
|
# Check if the certificate is still valid.
|
||||||
hs.get_datastore().start_doing_background_updates()
|
cert_days_remaining = hs.config.is_disk_cert_valid()
|
||||||
|
|
||||||
|
if hs.config.acme_enabled:
|
||||||
|
# If ACME is enabled, we might need to provision a certificate
|
||||||
|
# before starting.
|
||||||
|
acme = hs.get_acme_handler()
|
||||||
|
|
||||||
|
# Start up the webservices which we will respond to ACME
|
||||||
|
# challenges with.
|
||||||
|
yield acme.start_listening()
|
||||||
|
|
||||||
|
# We want to reprovision if cert_days_remaining is None (meaning no
|
||||||
|
# certificate exists), or the days remaining number it returns
|
||||||
|
# is less than our re-registration threshold.
|
||||||
|
if (cert_days_remaining is None) or (
|
||||||
|
not cert_days_remaining > hs.config.acme_reprovision_threshold
|
||||||
|
):
|
||||||
|
yield acme.provision_certificate()
|
||||||
|
|
||||||
|
# Read the certificate from disk and build the context factories for
|
||||||
|
# TLS.
|
||||||
|
hs.config.read_certificate_from_disk()
|
||||||
|
hs.tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
|
||||||
|
config
|
||||||
|
)
|
||||||
|
|
||||||
|
# It is now safe to start your Synapse.
|
||||||
|
hs.start_listening()
|
||||||
|
hs.get_pusherpool().start()
|
||||||
|
hs.get_datastore().start_profiling()
|
||||||
|
hs.get_datastore().start_doing_background_updates()
|
||||||
|
except Exception as e:
|
||||||
|
# If a DeferredList failed (like in listening on the ACME listener),
|
||||||
|
# we need to print the subfailure explicitly.
|
||||||
|
if isinstance(e, defer.FirstError):
|
||||||
|
e.subFailure.printTraceback(sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Something else went wrong when starting. Print it and bail out.
|
||||||
|
traceback.print_exc(file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
reactor.callWhenRunning(start)
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
|
|
|
@ -367,7 +367,7 @@ class Config(object):
|
||||||
if not keys_directory:
|
if not keys_directory:
|
||||||
keys_directory = os.path.dirname(config_files[-1])
|
keys_directory = os.path.dirname(config_files[-1])
|
||||||
|
|
||||||
config_dir_path = os.path.abspath(keys_directory)
|
self.config_dir_path = os.path.abspath(keys_directory)
|
||||||
|
|
||||||
specified_config = {}
|
specified_config = {}
|
||||||
for config_file in config_files:
|
for config_file in config_files:
|
||||||
|
@ -379,7 +379,7 @@ class Config(object):
|
||||||
|
|
||||||
server_name = specified_config["server_name"]
|
server_name = specified_config["server_name"]
|
||||||
config_string = self.generate_config(
|
config_string = self.generate_config(
|
||||||
config_dir_path=config_dir_path,
|
config_dir_path=self.config_dir_path,
|
||||||
data_dir_path=os.getcwd(),
|
data_dir_path=os.getcwd(),
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
generate_secrets=False,
|
generate_secrets=False,
|
||||||
|
|
|
@ -50,6 +50,10 @@ class RegistrationConfig(Config):
|
||||||
raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
|
raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
|
||||||
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
|
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
|
||||||
|
|
||||||
|
self.disable_msisdn_registration = (
|
||||||
|
config.get("disable_msisdn_registration", False)
|
||||||
|
)
|
||||||
|
|
||||||
def default_config(self, generate_secrets=False, **kwargs):
|
def default_config(self, generate_secrets=False, **kwargs):
|
||||||
if generate_secrets:
|
if generate_secrets:
|
||||||
registration_shared_secret = 'registration_shared_secret: "%s"' % (
|
registration_shared_secret = 'registration_shared_secret: "%s"' % (
|
||||||
|
@ -70,6 +74,11 @@ class RegistrationConfig(Config):
|
||||||
# - email
|
# - email
|
||||||
# - msisdn
|
# - msisdn
|
||||||
|
|
||||||
|
# Explicitly disable asking for MSISDNs from the registration
|
||||||
|
# flow (overrides registrations_require_3pid if MSISDNs are set as required)
|
||||||
|
#
|
||||||
|
# disable_msisdn_registration = True
|
||||||
|
|
||||||
# Mandate that users are only allowed to associate certain formats of
|
# Mandate that users are only allowed to associate certain formats of
|
||||||
# 3PIDs with accounts on this server.
|
# 3PIDs with accounts on this server.
|
||||||
#
|
#
|
||||||
|
|
|
@ -256,7 +256,11 @@ class ServerConfig(Config):
|
||||||
#
|
#
|
||||||
# web_client_location: "/path/to/web/root"
|
# web_client_location: "/path/to/web/root"
|
||||||
|
|
||||||
# The public-facing base URL for the client API (not including _matrix/...)
|
# The public-facing base URL that clients use to access this HS
|
||||||
|
# (not including _matrix/...). This is the same URL a user would
|
||||||
|
# enter into the 'custom HS URL' field on their client. If you
|
||||||
|
# use synapse with a reverse proxy, this should be the URL to reach
|
||||||
|
# synapse via the proxy.
|
||||||
# public_baseurl: https://example.com:8448/
|
# public_baseurl: https://example.com:8448/
|
||||||
|
|
||||||
# Set the soft limit on the number of file descriptors synapse can use
|
# Set the soft limit on the number of file descriptors synapse can use
|
||||||
|
@ -420,19 +424,18 @@ class ServerConfig(Config):
|
||||||
" service on the given port.")
|
" service on the given port.")
|
||||||
|
|
||||||
|
|
||||||
def is_threepid_reserved(config, threepid):
|
def is_threepid_reserved(reserved_threepids, threepid):
|
||||||
"""Check the threepid against the reserved threepid config
|
"""Check the threepid against the reserved threepid config
|
||||||
Args:
|
Args:
|
||||||
config(ServerConfig) - to access server config attributes
|
reserved_threepids([dict]) - list of reserved threepids
|
||||||
threepid(dict) - The threepid to test for
|
threepid(dict) - The threepid to test for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
boolean Is the threepid undertest reserved_user
|
boolean Is the threepid undertest reserved_user
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for tp in config.mau_limits_reserved_threepids:
|
for tp in reserved_threepids:
|
||||||
if (threepid['medium'] == tp['medium']
|
if (threepid['medium'] == tp['medium'] and threepid['address'] == tp['address']):
|
||||||
and threepid['address'] == tp['address']):
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -13,45 +13,38 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
|
||||||
from unpaddedbase64 import encode_base64
|
from unpaddedbase64 import encode_base64
|
||||||
|
|
||||||
from OpenSSL import crypto
|
from OpenSSL import crypto
|
||||||
|
|
||||||
from ._base import Config
|
from synapse.config._base import Config
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
class TlsConfig(Config):
|
class TlsConfig(Config):
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
self.tls_certificate = self.read_tls_certificate(
|
|
||||||
config.get("tls_certificate_path")
|
|
||||||
)
|
|
||||||
self.tls_certificate_file = config.get("tls_certificate_path")
|
|
||||||
|
|
||||||
|
acme_config = config.get("acme", {})
|
||||||
|
self.acme_enabled = acme_config.get("enabled", False)
|
||||||
|
self.acme_url = acme_config.get(
|
||||||
|
"url", "https://acme-v01.api.letsencrypt.org/directory"
|
||||||
|
)
|
||||||
|
self.acme_port = acme_config.get("port", 8449)
|
||||||
|
self.acme_bind_addresses = acme_config.get("bind_addresses", ["127.0.0.1"])
|
||||||
|
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
|
||||||
|
|
||||||
|
self.tls_certificate_file = os.path.abspath(config.get("tls_certificate_path"))
|
||||||
|
self.tls_private_key_file = os.path.abspath(config.get("tls_private_key_path"))
|
||||||
|
self._original_tls_fingerprints = config["tls_fingerprints"]
|
||||||
|
self.tls_fingerprints = list(self._original_tls_fingerprints)
|
||||||
self.no_tls = config.get("no_tls", False)
|
self.no_tls = config.get("no_tls", False)
|
||||||
|
|
||||||
if self.no_tls:
|
|
||||||
self.tls_private_key = None
|
|
||||||
else:
|
|
||||||
self.tls_private_key = self.read_tls_private_key(
|
|
||||||
config.get("tls_private_key_path")
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tls_fingerprints = config["tls_fingerprints"]
|
|
||||||
|
|
||||||
# Check that our own certificate is included in the list of fingerprints
|
|
||||||
# and include it if it is not.
|
|
||||||
x509_certificate_bytes = crypto.dump_certificate(
|
|
||||||
crypto.FILETYPE_ASN1,
|
|
||||||
self.tls_certificate
|
|
||||||
)
|
|
||||||
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
|
|
||||||
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
|
|
||||||
if sha256_fingerprint not in sha256_fingerprints:
|
|
||||||
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
|
|
||||||
|
|
||||||
# This config option applies to non-federation HTTP clients
|
# This config option applies to non-federation HTTP clients
|
||||||
# (e.g. for talking to recaptcha, identity servers, and such)
|
# (e.g. for talking to recaptcha, identity servers, and such)
|
||||||
# It should never be used in production, and is intended for
|
# It should never be used in production, and is intended for
|
||||||
|
@ -60,13 +53,70 @@ class TlsConfig(Config):
|
||||||
"use_insecure_ssl_client_just_for_testing_do_not_use"
|
"use_insecure_ssl_client_just_for_testing_do_not_use"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.tls_certificate = None
|
||||||
|
self.tls_private_key = None
|
||||||
|
|
||||||
|
def is_disk_cert_valid(self):
|
||||||
|
"""
|
||||||
|
Is the certificate we have on disk valid, and if so, for how long?
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Days remaining of certificate validity.
|
||||||
|
None: No certificate exists.
|
||||||
|
"""
|
||||||
|
if not os.path.exists(self.tls_certificate_file):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.tls_certificate_file, 'rb') as f:
|
||||||
|
cert_pem = f.read()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to read existing certificate off disk!")
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to parse existing certificate off disk!")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# YYYYMMDDhhmmssZ -- in UTC
|
||||||
|
expires_on = datetime.strptime(
|
||||||
|
tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
|
||||||
|
)
|
||||||
|
now = datetime.utcnow()
|
||||||
|
days_remaining = (expires_on - now).days
|
||||||
|
return days_remaining
|
||||||
|
|
||||||
|
def read_certificate_from_disk(self):
|
||||||
|
"""
|
||||||
|
Read the certificates from disk.
|
||||||
|
"""
|
||||||
|
self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file)
|
||||||
|
|
||||||
|
if not self.no_tls:
|
||||||
|
self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file)
|
||||||
|
|
||||||
|
self.tls_fingerprints = list(self._original_tls_fingerprints)
|
||||||
|
|
||||||
|
# Check that our own certificate is included in the list of fingerprints
|
||||||
|
# and include it if it is not.
|
||||||
|
x509_certificate_bytes = crypto.dump_certificate(
|
||||||
|
crypto.FILETYPE_ASN1, self.tls_certificate
|
||||||
|
)
|
||||||
|
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
|
||||||
|
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
|
||||||
|
if sha256_fingerprint not in sha256_fingerprints:
|
||||||
|
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
|
||||||
|
|
||||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||||
base_key_name = os.path.join(config_dir_path, server_name)
|
base_key_name = os.path.join(config_dir_path, server_name)
|
||||||
|
|
||||||
tls_certificate_path = base_key_name + ".tls.crt"
|
tls_certificate_path = base_key_name + ".tls.crt"
|
||||||
tls_private_key_path = base_key_name + ".tls.key"
|
tls_private_key_path = base_key_name + ".tls.key"
|
||||||
|
|
||||||
return """\
|
return (
|
||||||
|
"""\
|
||||||
# PEM encoded X509 certificate for TLS.
|
# PEM encoded X509 certificate for TLS.
|
||||||
# You can replace the self-signed certificate that synapse
|
# You can replace the self-signed certificate that synapse
|
||||||
# autogenerates on launch with your own SSL certificate + key pair
|
# autogenerates on launch with your own SSL certificate + key pair
|
||||||
|
@ -107,7 +157,24 @@ class TlsConfig(Config):
|
||||||
#
|
#
|
||||||
tls_fingerprints: []
|
tls_fingerprints: []
|
||||||
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
|
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
|
||||||
""" % locals()
|
|
||||||
|
## Support for ACME certificate auto-provisioning.
|
||||||
|
# acme:
|
||||||
|
# enabled: false
|
||||||
|
## ACME path.
|
||||||
|
## If you only want to test, use the staging url:
|
||||||
|
## https://acme-staging.api.letsencrypt.org/directory
|
||||||
|
# url: 'https://acme-v01.api.letsencrypt.org/directory'
|
||||||
|
## Port number (to listen for the HTTP-01 challenge).
|
||||||
|
## Using port 80 requires utilising something like authbind, or proxying to it.
|
||||||
|
# port: 8449
|
||||||
|
## Hosts to bind to.
|
||||||
|
# bind_addresses: ['127.0.0.1']
|
||||||
|
## How many days remaining on a certificate before it is renewed.
|
||||||
|
# reprovision_threshold: 30
|
||||||
|
"""
|
||||||
|
% locals()
|
||||||
|
)
|
||||||
|
|
||||||
def read_tls_certificate(self, cert_path):
|
def read_tls_certificate(self, cert_path):
|
||||||
cert_pem = self.read_file(cert_path, "tls_certificate")
|
cert_pem = self.read_file(cert_path, "tls_certificate")
|
||||||
|
|
|
@ -17,6 +17,7 @@ from zope.interface import implementer
|
||||||
|
|
||||||
from OpenSSL import SSL, crypto
|
from OpenSSL import SSL, crypto
|
||||||
from twisted.internet._sslverify import _defaultCurveName
|
from twisted.internet._sslverify import _defaultCurveName
|
||||||
|
from twisted.internet.abstract import isIPAddress, isIPv6Address
|
||||||
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
|
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
|
||||||
from twisted.internet.ssl import CertificateOptions, ContextFactory
|
from twisted.internet.ssl import CertificateOptions, ContextFactory
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
@ -98,8 +99,14 @@ class ClientTLSOptions(object):
|
||||||
|
|
||||||
def __init__(self, hostname, ctx):
|
def __init__(self, hostname, ctx):
|
||||||
self._ctx = ctx
|
self._ctx = ctx
|
||||||
self._hostname = hostname
|
|
||||||
self._hostnameBytes = _idnaBytes(hostname)
|
if isIPAddress(hostname) or isIPv6Address(hostname):
|
||||||
|
self._hostnameBytes = hostname.encode('ascii')
|
||||||
|
self._sendSNI = False
|
||||||
|
else:
|
||||||
|
self._hostnameBytes = _idnaBytes(hostname)
|
||||||
|
self._sendSNI = True
|
||||||
|
|
||||||
ctx.set_info_callback(
|
ctx.set_info_callback(
|
||||||
_tolerateErrors(self._identityVerifyingInfoCallback)
|
_tolerateErrors(self._identityVerifyingInfoCallback)
|
||||||
)
|
)
|
||||||
|
@ -111,7 +118,9 @@ class ClientTLSOptions(object):
|
||||||
return connection
|
return connection
|
||||||
|
|
||||||
def _identityVerifyingInfoCallback(self, connection, where, ret):
|
def _identityVerifyingInfoCallback(self, connection, where, ret):
|
||||||
if where & SSL.SSL_CB_HANDSHAKE_START:
|
# Literal IPv4 and IPv6 addresses are not permitted
|
||||||
|
# as host names according to the RFCs
|
||||||
|
if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI:
|
||||||
connection.set_tlsext_host_name(self._hostnameBytes)
|
connection.set_tlsext_host_name(self._hostnameBytes)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ from distutils.util import strtobool
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventFormatVersions
|
||||||
from synapse.util.caches import intern_dict
|
from synapse.util.caches import intern_dict
|
||||||
from synapse.util.frozenutils import freeze
|
from synapse.util.frozenutils import freeze
|
||||||
|
|
||||||
|
@ -41,8 +42,13 @@ class _EventInternalMetadata(object):
|
||||||
def is_outlier(self):
|
def is_outlier(self):
|
||||||
return getattr(self, "outlier", False)
|
return getattr(self, "outlier", False)
|
||||||
|
|
||||||
def is_invite_from_remote(self):
|
def is_out_of_band_membership(self):
|
||||||
return getattr(self, "invite_from_remote", False)
|
"""Whether this is an out of band membership, like an invite or an invite
|
||||||
|
rejection. This is needed as those events are marked as outliers, but
|
||||||
|
they still need to be processed as if they're new events (e.g. updating
|
||||||
|
invite state in the database, relaying to clients, etc).
|
||||||
|
"""
|
||||||
|
return getattr(self, "out_of_band_membership", False)
|
||||||
|
|
||||||
def get_send_on_behalf_of(self):
|
def get_send_on_behalf_of(self):
|
||||||
"""Whether this server should send the event on behalf of another server.
|
"""Whether this server should send the event on behalf of another server.
|
||||||
|
@ -179,6 +185,8 @@ class EventBase(object):
|
||||||
|
|
||||||
|
|
||||||
class FrozenEvent(EventBase):
|
class FrozenEvent(EventBase):
|
||||||
|
format_version = EventFormatVersions.V1 # All events of this type are V1
|
||||||
|
|
||||||
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
|
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
|
||||||
event_dict = dict(event_dict)
|
event_dict = dict(event_dict)
|
||||||
|
|
||||||
|
@ -232,3 +240,19 @@ class FrozenEvent(EventBase):
|
||||||
self.get("type", None),
|
self.get("type", None),
|
||||||
self.get("state_key", None),
|
self.get("state_key", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def room_version_to_event_format(room_version):
|
||||||
|
"""Converts a room version string to the event format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_version (str)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int
|
||||||
|
"""
|
||||||
|
if room_version not in KNOWN_ROOM_VERSIONS:
|
||||||
|
# We should have already checked version, so this should not happen
|
||||||
|
raise RuntimeError("Unrecognized room version %s" % (room_version,))
|
||||||
|
|
||||||
|
return EventFormatVersions.V1
|
||||||
|
|
|
@ -43,8 +43,8 @@ class FederationBase(object):
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version,
|
||||||
include_none=False):
|
outlier=False, include_none=False):
|
||||||
"""Takes a list of PDUs and checks the signatures and hashs of each
|
"""Takes a list of PDUs and checks the signatures and hashs of each
|
||||||
one. If a PDU fails its signature check then we check if we have it in
|
one. If a PDU fails its signature check then we check if we have it in
|
||||||
the database and if not then request if from the originating server of
|
the database and if not then request if from the originating server of
|
||||||
|
@ -56,8 +56,12 @@ class FederationBase(object):
|
||||||
a new list.
|
a new list.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
origin (str)
|
||||||
pdu (list)
|
pdu (list)
|
||||||
outlier (bool)
|
room_version (str)
|
||||||
|
outlier (bool): Whether the events are outliers or not
|
||||||
|
include_none (str): Whether to include None in the returned list
|
||||||
|
for events that have failed their checks
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred : A list of PDUs that have valid signatures and hashes.
|
Deferred : A list of PDUs that have valid signatures and hashes.
|
||||||
|
@ -84,6 +88,7 @@ class FederationBase(object):
|
||||||
res = yield self.get_pdu(
|
res = yield self.get_pdu(
|
||||||
destinations=[pdu.origin],
|
destinations=[pdu.origin],
|
||||||
event_id=pdu.event_id,
|
event_id=pdu.event_id,
|
||||||
|
room_version=room_version,
|
||||||
outlier=outlier,
|
outlier=outlier,
|
||||||
timeout=10000,
|
timeout=10000,
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,14 +25,20 @@ from prometheus_client import Counter
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership
|
from synapse.api.constants import (
|
||||||
|
KNOWN_ROOM_VERSIONS,
|
||||||
|
EventTypes,
|
||||||
|
Membership,
|
||||||
|
RoomVersions,
|
||||||
|
)
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException,
|
CodeMessageException,
|
||||||
FederationDeniedError,
|
FederationDeniedError,
|
||||||
HttpResponseException,
|
HttpResponseException,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.events import builder
|
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||||
|
from synapse.events import room_version_to_event_format
|
||||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||||
from synapse.util import logcontext, unwrapFirstError
|
from synapse.util import logcontext, unwrapFirstError
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
@ -66,6 +72,8 @@ class FederationClient(FederationBase):
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.transport_layer = hs.get_federation_transport_client()
|
self.transport_layer = hs.get_federation_transport_client()
|
||||||
|
|
||||||
|
self.event_builder_factory = hs.get_event_builder_factory()
|
||||||
|
|
||||||
self._get_pdu_cache = ExpiringCache(
|
self._get_pdu_cache = ExpiringCache(
|
||||||
cache_name="get_pdu_cache",
|
cache_name="get_pdu_cache",
|
||||||
clock=self._clock,
|
clock=self._clock,
|
||||||
|
@ -202,7 +210,8 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
|
def get_pdu(self, destinations, event_id, room_version, outlier=False,
|
||||||
|
timeout=None):
|
||||||
"""Requests the PDU with given origin and ID from the remote home
|
"""Requests the PDU with given origin and ID from the remote home
|
||||||
servers.
|
servers.
|
||||||
|
|
||||||
|
@ -212,6 +221,7 @@ class FederationClient(FederationBase):
|
||||||
Args:
|
Args:
|
||||||
destinations (list): Which home servers to query
|
destinations (list): Which home servers to query
|
||||||
event_id (str): event to fetch
|
event_id (str): event to fetch
|
||||||
|
room_version (str): version of the room
|
||||||
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
|
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
|
||||||
it's from an arbitary point in the context as opposed to part
|
it's from an arbitary point in the context as opposed to part
|
||||||
of the current block of PDUs. Defaults to `False`
|
of the current block of PDUs. Defaults to `False`
|
||||||
|
@ -352,10 +362,13 @@ class FederationClient(FederationBase):
|
||||||
ev.event_id for ev in itertools.chain(pdus, auth_chain)
|
ev.event_id for ev in itertools.chain(pdus, auth_chain)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
|
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination,
|
destination,
|
||||||
[p for p in pdus if p.event_id not in seen_events],
|
[p for p in pdus if p.event_id not in seen_events],
|
||||||
outlier=True
|
outlier=True,
|
||||||
|
room_version=room_version,
|
||||||
)
|
)
|
||||||
signed_pdus.extend(
|
signed_pdus.extend(
|
||||||
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
|
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
|
||||||
|
@ -364,7 +377,8 @@ class FederationClient(FederationBase):
|
||||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination,
|
destination,
|
||||||
[p for p in auth_chain if p.event_id not in seen_events],
|
[p for p in auth_chain if p.event_id not in seen_events],
|
||||||
outlier=True
|
outlier=True,
|
||||||
|
room_version=room_version,
|
||||||
)
|
)
|
||||||
signed_auth.extend(
|
signed_auth.extend(
|
||||||
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
|
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
|
||||||
|
@ -411,6 +425,8 @@ class FederationClient(FederationBase):
|
||||||
random.shuffle(srvs)
|
random.shuffle(srvs)
|
||||||
return srvs
|
return srvs
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
batch_size = 20
|
batch_size = 20
|
||||||
missing_events = list(missing_events)
|
missing_events = list(missing_events)
|
||||||
for i in range(0, len(missing_events), batch_size):
|
for i in range(0, len(missing_events), batch_size):
|
||||||
|
@ -421,6 +437,7 @@ class FederationClient(FederationBase):
|
||||||
self.get_pdu,
|
self.get_pdu,
|
||||||
destinations=random_server_list(),
|
destinations=random_server_list(),
|
||||||
event_id=e_id,
|
event_id=e_id,
|
||||||
|
room_version=room_version,
|
||||||
)
|
)
|
||||||
for e_id in batch
|
for e_id in batch
|
||||||
]
|
]
|
||||||
|
@ -450,8 +467,11 @@ class FederationClient(FederationBase):
|
||||||
for p in res["auth_chain"]
|
for p in res["auth_chain"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination, auth_chain, outlier=True
|
destination, auth_chain,
|
||||||
|
outlier=True, room_version=room_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
signed_auth.sort(key=lambda e: e.depth)
|
signed_auth.sort(key=lambda e: e.depth)
|
||||||
|
@ -522,6 +542,8 @@ class FederationClient(FederationBase):
|
||||||
Does so by asking one of the already participating servers to create an
|
Does so by asking one of the already participating servers to create an
|
||||||
event with proper context.
|
event with proper context.
|
||||||
|
|
||||||
|
Returns a fully signed and hashed event.
|
||||||
|
|
||||||
Note that this does not append any events to any graphs.
|
Note that this does not append any events to any graphs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -536,8 +558,10 @@ class FederationClient(FederationBase):
|
||||||
params (dict[str, str|Iterable[str]]): Query parameters to include in the
|
params (dict[str, str|Iterable[str]]): Query parameters to include in the
|
||||||
request.
|
request.
|
||||||
Return:
|
Return:
|
||||||
Deferred: resolves to a tuple of (origin (str), event (object))
|
Deferred[tuple[str, FrozenEvent, int]]: resolves to a tuple of
|
||||||
where origin is the remote homeserver which generated the event.
|
`(origin, event, event_format)` where origin is the remote
|
||||||
|
homeserver which generated the event, and event_format is one of
|
||||||
|
`synapse.api.constants.EventFormatVersions`.
|
||||||
|
|
||||||
Fails with a ``SynapseError`` if the chosen remote server
|
Fails with a ``SynapseError`` if the chosen remote server
|
||||||
returns a 300/400 code.
|
returns a 300/400 code.
|
||||||
|
@ -557,6 +581,11 @@ class FederationClient(FederationBase):
|
||||||
destination, room_id, user_id, membership, params,
|
destination, room_id, user_id, membership, params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Note: If not supplied, the room version may be either v1 or v2,
|
||||||
|
# however either way the event format version will be v1.
|
||||||
|
room_version = ret.get("room_version", RoomVersions.V1)
|
||||||
|
event_format = room_version_to_event_format(room_version)
|
||||||
|
|
||||||
pdu_dict = ret.get("event", None)
|
pdu_dict = ret.get("event", None)
|
||||||
if not isinstance(pdu_dict, dict):
|
if not isinstance(pdu_dict, dict):
|
||||||
raise InvalidResponseError("Bad 'event' field in response")
|
raise InvalidResponseError("Bad 'event' field in response")
|
||||||
|
@ -571,10 +600,21 @@ class FederationClient(FederationBase):
|
||||||
if "prev_state" not in pdu_dict:
|
if "prev_state" not in pdu_dict:
|
||||||
pdu_dict["prev_state"] = []
|
pdu_dict["prev_state"] = []
|
||||||
|
|
||||||
ev = builder.EventBuilder(pdu_dict)
|
# Strip off the fields that we want to clobber.
|
||||||
|
pdu_dict.pop("origin", None)
|
||||||
|
pdu_dict.pop("origin_server_ts", None)
|
||||||
|
pdu_dict.pop("unsigned", None)
|
||||||
|
|
||||||
|
builder = self.event_builder_factory.new(pdu_dict)
|
||||||
|
add_hashes_and_signatures(
|
||||||
|
builder,
|
||||||
|
self.hs.hostname,
|
||||||
|
self.hs.config.signing_key[0]
|
||||||
|
)
|
||||||
|
ev = builder.build()
|
||||||
|
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
(destination, ev)
|
(destination, ev, event_format)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._try_destination_list(
|
return self._try_destination_list(
|
||||||
|
@ -650,9 +690,21 @@ class FederationClient(FederationBase):
|
||||||
for p in itertools.chain(state, auth_chain)
|
for p in itertools.chain(state, auth_chain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
room_version = None
|
||||||
|
for e in state:
|
||||||
|
if (e.type, e.state_key) == (EventTypes.Create, ""):
|
||||||
|
room_version = e.content.get("room_version", RoomVersions.V1)
|
||||||
|
break
|
||||||
|
|
||||||
|
if room_version is None:
|
||||||
|
# If the state doesn't have a create event then the room is
|
||||||
|
# invalid, and it would fail auth checks anyway.
|
||||||
|
raise SynapseError(400, "No create event in state")
|
||||||
|
|
||||||
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination, list(pdus.values()),
|
destination, list(pdus.values()),
|
||||||
outlier=True,
|
outlier=True,
|
||||||
|
room_version=room_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_pdus_map = {
|
valid_pdus_map = {
|
||||||
|
@ -790,8 +842,10 @@ class FederationClient(FederationBase):
|
||||||
for e in content["auth_chain"]
|
for e in content["auth_chain"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination, auth_chain, outlier=True
|
destination, auth_chain, outlier=True, room_version=room_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
signed_auth.sort(key=lambda e: e.depth)
|
signed_auth.sort(key=lambda e: e.depth)
|
||||||
|
@ -838,8 +892,10 @@ class FederationClient(FederationBase):
|
||||||
for e in content.get("events", [])
|
for e in content.get("events", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
signed_events = yield self._check_sigs_and_hash_and_fetch(
|
signed_events = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination, events, outlier=False
|
destination, events, outlier=False, room_version=room_version,
|
||||||
)
|
)
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
if not e.code == 400:
|
if not e.code == 400:
|
||||||
|
|
|
@ -400,8 +400,14 @@ class FederationServer(FederationBase):
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
yield self.check_server_matches_acl(origin_host, room_id)
|
yield self.check_server_matches_acl(origin_host, room_id)
|
||||||
pdu = yield self.handler.on_make_leave_request(room_id, user_id)
|
pdu = yield self.handler.on_make_leave_request(room_id, user_id)
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
|
defer.returnValue({
|
||||||
|
"event": pdu.get_pdu_json(time_now),
|
||||||
|
"room_version": room_version,
|
||||||
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_send_leave_request(self, origin, content):
|
def on_send_leave_request(self, origin, content):
|
||||||
|
@ -457,8 +463,10 @@ class FederationServer(FederationBase):
|
||||||
for e in content["auth_chain"]
|
for e in content["auth_chain"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||||
origin, auth_chain, outlier=True
|
origin, auth_chain, outlier=True, room_version=room_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = yield self.handler.on_query_auth(
|
ret = yield self.handler.on_query_auth(
|
||||||
|
|
|
@ -0,0 +1,147 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.endpoints import serverFromString
|
||||||
|
from twisted.python.filepath import FilePath
|
||||||
|
from twisted.python.url import URL
|
||||||
|
from twisted.web import server, static
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from txacme.interfaces import ICertificateStore
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
@implementer(ICertificateStore)
|
||||||
|
class ErsatzStore(object):
|
||||||
|
"""
|
||||||
|
A store that only stores in memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
certs = attr.ib(default=attr.Factory(dict))
|
||||||
|
|
||||||
|
def store(self, server_name, pem_objects):
|
||||||
|
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
|
||||||
|
return defer.succeed(None)
|
||||||
|
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# txacme is missing
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AcmeHandler(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
|
self.reactor = hs.get_reactor()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def start_listening(self):
|
||||||
|
|
||||||
|
# Configure logging for txacme, if you need to debug
|
||||||
|
# from eliot import add_destinations
|
||||||
|
# from eliot.twisted import TwistedDestination
|
||||||
|
#
|
||||||
|
# add_destinations(TwistedDestination())
|
||||||
|
|
||||||
|
from txacme.challenges import HTTP01Responder
|
||||||
|
from txacme.service import AcmeIssuingService
|
||||||
|
from txacme.endpoint import load_or_create_client_key
|
||||||
|
from txacme.client import Client
|
||||||
|
from josepy.jwa import RS256
|
||||||
|
|
||||||
|
self._store = ErsatzStore()
|
||||||
|
responder = HTTP01Responder()
|
||||||
|
|
||||||
|
self._issuer = AcmeIssuingService(
|
||||||
|
cert_store=self._store,
|
||||||
|
client_creator=(
|
||||||
|
lambda: Client.from_url(
|
||||||
|
reactor=self.reactor,
|
||||||
|
url=URL.from_text(self.hs.config.acme_url),
|
||||||
|
key=load_or_create_client_key(
|
||||||
|
FilePath(self.hs.config.config_dir_path)
|
||||||
|
),
|
||||||
|
alg=RS256,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
clock=self.reactor,
|
||||||
|
responders=[responder],
|
||||||
|
)
|
||||||
|
|
||||||
|
well_known = Resource()
|
||||||
|
well_known.putChild(b'acme-challenge', responder.resource)
|
||||||
|
responder_resource = Resource()
|
||||||
|
responder_resource.putChild(b'.well-known', well_known)
|
||||||
|
responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
|
||||||
|
|
||||||
|
srv = server.Site(responder_resource)
|
||||||
|
|
||||||
|
listeners = []
|
||||||
|
|
||||||
|
for host in self.hs.config.acme_bind_addresses:
|
||||||
|
logger.info(
|
||||||
|
"Listening for ACME requests on %s:%s", host, self.hs.config.acme_port
|
||||||
|
)
|
||||||
|
endpoint = serverFromString(
|
||||||
|
self.reactor, "tcp:%s:interface=%s" % (self.hs.config.acme_port, host)
|
||||||
|
)
|
||||||
|
listeners.append(endpoint.listen(srv))
|
||||||
|
|
||||||
|
# Make sure we are registered to the ACME server. There's no public API
|
||||||
|
# for this, it is usually triggered by startService, but since we don't
|
||||||
|
# want it to control where we save the certificates, we have to reach in
|
||||||
|
# and trigger the registration machinery ourselves.
|
||||||
|
self._issuer._registered = False
|
||||||
|
yield self._issuer._ensure_registered()
|
||||||
|
|
||||||
|
# Return a Deferred that will fire when all the servers have started up.
|
||||||
|
yield defer.DeferredList(listeners, fireOnOneErrback=True, consumeErrors=True)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def provision_certificate(self):
|
||||||
|
|
||||||
|
logger.warning("Reprovisioning %s", self.hs.hostname)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self._issuer.issue_cert(self.hs.hostname)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Fail!")
|
||||||
|
raise
|
||||||
|
logger.warning("Reprovisioned %s, saving.", self.hs.hostname)
|
||||||
|
cert_chain = self._store.certs[self.hs.hostname]
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
|
||||||
|
for x in cert_chain:
|
||||||
|
if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"):
|
||||||
|
private_key_file.write(x)
|
||||||
|
|
||||||
|
with open(self.hs.config.tls_certificate_file, "wb") as certificate_file:
|
||||||
|
for x in cert_chain:
|
||||||
|
if x.startswith(b"-----BEGIN CERTIFICATE-----"):
|
||||||
|
certificate_file.write(x)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed saving!")
|
||||||
|
raise
|
||||||
|
|
||||||
|
defer.returnValue(True)
|
|
@ -34,6 +34,7 @@ from synapse.api.constants import (
|
||||||
EventTypes,
|
EventTypes,
|
||||||
Membership,
|
Membership,
|
||||||
RejectedReason,
|
RejectedReason,
|
||||||
|
RoomVersions,
|
||||||
)
|
)
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
|
@ -43,10 +44,7 @@ from synapse.api.errors import (
|
||||||
StoreError,
|
StoreError,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.crypto.event_signing import (
|
from synapse.crypto.event_signing import compute_event_signature
|
||||||
add_hashes_and_signatures,
|
|
||||||
compute_event_signature,
|
|
||||||
)
|
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.replication.http.federation import (
|
from synapse.replication.http.federation import (
|
||||||
ReplicationCleanRoomRestServlet,
|
ReplicationCleanRoomRestServlet,
|
||||||
|
@ -58,7 +56,6 @@ from synapse.types import UserID, get_domain_from_id
|
||||||
from synapse.util import logcontext, unwrapFirstError
|
from synapse.util import logcontext, unwrapFirstError
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.distributor import user_joined_room
|
from synapse.util.distributor import user_joined_room
|
||||||
from synapse.util.frozenutils import unfreeze
|
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
from synapse.visibility import filter_events_for_server
|
from synapse.visibility import filter_events_for_server
|
||||||
|
@ -342,6 +339,8 @@ class FederationHandler(BaseHandler):
|
||||||
room_id, event_id, p,
|
room_id, event_id, p,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
with logcontext.nested_logging_context(p):
|
with logcontext.nested_logging_context(p):
|
||||||
# note that if any of the missing prevs share missing state or
|
# note that if any of the missing prevs share missing state or
|
||||||
# auth events, the requests to fetch those events are deduped
|
# auth events, the requests to fetch those events are deduped
|
||||||
|
@ -355,7 +354,7 @@ class FederationHandler(BaseHandler):
|
||||||
# we want the state *after* p; get_state_for_room returns the
|
# we want the state *after* p; get_state_for_room returns the
|
||||||
# state *before* p.
|
# state *before* p.
|
||||||
remote_event = yield self.federation_client.get_pdu(
|
remote_event = yield self.federation_client.get_pdu(
|
||||||
[origin], p, outlier=True,
|
[origin], p, room_version, outlier=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if remote_event is None:
|
if remote_event is None:
|
||||||
|
@ -379,7 +378,6 @@ class FederationHandler(BaseHandler):
|
||||||
for x in remote_state:
|
for x in remote_state:
|
||||||
event_map[x.event_id] = x
|
event_map[x.event_id] = x
|
||||||
|
|
||||||
room_version = yield self.store.get_room_version(room_id)
|
|
||||||
state_map = yield resolve_events_with_store(
|
state_map = yield resolve_events_with_store(
|
||||||
room_version, state_maps, event_map,
|
room_version, state_maps, event_map,
|
||||||
state_res_store=StateResolutionStore(self.store),
|
state_res_store=StateResolutionStore(self.store),
|
||||||
|
@ -655,6 +653,8 @@ class FederationHandler(BaseHandler):
|
||||||
if dest == self.server_name:
|
if dest == self.server_name:
|
||||||
raise SynapseError(400, "Can't backfill from self.")
|
raise SynapseError(400, "Can't backfill from self.")
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
events = yield self.federation_client.backfill(
|
events = yield self.federation_client.backfill(
|
||||||
dest,
|
dest,
|
||||||
room_id,
|
room_id,
|
||||||
|
@ -748,6 +748,7 @@ class FederationHandler(BaseHandler):
|
||||||
self.federation_client.get_pdu,
|
self.federation_client.get_pdu,
|
||||||
[dest],
|
[dest],
|
||||||
event_id,
|
event_id,
|
||||||
|
room_version=room_version,
|
||||||
outlier=True,
|
outlier=True,
|
||||||
timeout=10000,
|
timeout=10000,
|
||||||
)
|
)
|
||||||
|
@ -1083,7 +1084,6 @@ class FederationHandler(BaseHandler):
|
||||||
handled_events = set()
|
handled_events = set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
event = self._sign_event(event)
|
|
||||||
# Try the host we successfully got a response to /make_join/
|
# Try the host we successfully got a response to /make_join/
|
||||||
# request first.
|
# request first.
|
||||||
try:
|
try:
|
||||||
|
@ -1287,7 +1287,7 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
event.internal_metadata.outlier = True
|
event.internal_metadata.outlier = True
|
||||||
event.internal_metadata.invite_from_remote = True
|
event.internal_metadata.out_of_band_membership = True
|
||||||
|
|
||||||
event.signatures.update(
|
event.signatures.update(
|
||||||
compute_event_signature(
|
compute_event_signature(
|
||||||
|
@ -1313,7 +1313,7 @@ class FederationHandler(BaseHandler):
|
||||||
# Mark as outlier as we don't have any state for this event; we're not
|
# Mark as outlier as we don't have any state for this event; we're not
|
||||||
# even in the room.
|
# even in the room.
|
||||||
event.internal_metadata.outlier = True
|
event.internal_metadata.outlier = True
|
||||||
event = self._sign_event(event)
|
event.internal_metadata.out_of_band_membership = True
|
||||||
|
|
||||||
# Try the host that we succesfully called /make_leave/ on first for
|
# Try the host that we succesfully called /make_leave/ on first for
|
||||||
# the /send_leave/ request.
|
# the /send_leave/ request.
|
||||||
|
@ -1336,7 +1336,7 @@ class FederationHandler(BaseHandler):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
|
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
|
||||||
content={}, params=None):
|
content={}, params=None):
|
||||||
origin, pdu = yield self.federation_client.make_membership_event(
|
origin, pdu, _ = yield self.federation_client.make_membership_event(
|
||||||
target_hosts,
|
target_hosts,
|
||||||
room_id,
|
room_id,
|
||||||
user_id,
|
user_id,
|
||||||
|
@ -1357,27 +1357,6 @@ class FederationHandler(BaseHandler):
|
||||||
assert(event.room_id == room_id)
|
assert(event.room_id == room_id)
|
||||||
defer.returnValue((origin, event))
|
defer.returnValue((origin, event))
|
||||||
|
|
||||||
def _sign_event(self, event):
|
|
||||||
event.internal_metadata.outlier = False
|
|
||||||
|
|
||||||
builder = self.event_builder_factory.new(
|
|
||||||
unfreeze(event.get_pdu_json())
|
|
||||||
)
|
|
||||||
|
|
||||||
builder.event_id = self.event_builder_factory.create_event_id()
|
|
||||||
builder.origin = self.hs.hostname
|
|
||||||
|
|
||||||
if not hasattr(event, "signatures"):
|
|
||||||
builder.signatures = {}
|
|
||||||
|
|
||||||
add_hashes_and_signatures(
|
|
||||||
builder,
|
|
||||||
self.hs.hostname,
|
|
||||||
self.hs.config.signing_key[0],
|
|
||||||
)
|
|
||||||
|
|
||||||
return builder.build()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def on_make_leave_request(self, room_id, user_id):
|
def on_make_leave_request(self, room_id, user_id):
|
||||||
|
@ -1659,6 +1638,13 @@ class FederationHandler(BaseHandler):
|
||||||
create_event = e
|
create_event = e
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if create_event is None:
|
||||||
|
# If the state doesn't have a create event then the room is
|
||||||
|
# invalid, and it would fail auth checks anyway.
|
||||||
|
raise SynapseError(400, "No create event in state")
|
||||||
|
|
||||||
|
room_version = create_event.content.get("room_version", RoomVersions.V1)
|
||||||
|
|
||||||
missing_auth_events = set()
|
missing_auth_events = set()
|
||||||
for e in itertools.chain(auth_events, state, [event]):
|
for e in itertools.chain(auth_events, state, [event]):
|
||||||
for e_id in e.auth_event_ids():
|
for e_id in e.auth_event_ids():
|
||||||
|
@ -1669,6 +1655,7 @@ class FederationHandler(BaseHandler):
|
||||||
m_ev = yield self.federation_client.get_pdu(
|
m_ev = yield self.federation_client.get_pdu(
|
||||||
[origin],
|
[origin],
|
||||||
e_id,
|
e_id,
|
||||||
|
room_version=room_version,
|
||||||
outlier=True,
|
outlier=True,
|
||||||
timeout=10000,
|
timeout=10000,
|
||||||
)
|
)
|
||||||
|
|
|
@ -73,8 +73,14 @@ class RoomListHandler(BaseHandler):
|
||||||
# We explicitly don't bother caching searches or requests for
|
# We explicitly don't bother caching searches or requests for
|
||||||
# appservice specific lists.
|
# appservice specific lists.
|
||||||
logger.info("Bypassing cache as search request.")
|
logger.info("Bypassing cache as search request.")
|
||||||
|
|
||||||
|
# XXX: Quick hack to stop room directory queries taking too long.
|
||||||
|
# Timeout request after 60s. Probably want a more fundamental
|
||||||
|
# solution at some point
|
||||||
|
timeout = self.clock.time() + 60
|
||||||
return self._get_public_room_list(
|
return self._get_public_room_list(
|
||||||
limit, since_token, search_filter, network_tuple=network_tuple,
|
limit, since_token, search_filter,
|
||||||
|
network_tuple=network_tuple, timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
key = (limit, since_token, network_tuple)
|
key = (limit, since_token, network_tuple)
|
||||||
|
@ -87,7 +93,8 @@ class RoomListHandler(BaseHandler):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_public_room_list(self, limit=None, since_token=None,
|
def _get_public_room_list(self, limit=None, since_token=None,
|
||||||
search_filter=None,
|
search_filter=None,
|
||||||
network_tuple=EMPTY_THIRD_PARTY_ID,):
|
network_tuple=EMPTY_THIRD_PARTY_ID,
|
||||||
|
timeout=None,):
|
||||||
if since_token and since_token != "END":
|
if since_token and since_token != "END":
|
||||||
since_token = RoomListNextBatch.from_token(since_token)
|
since_token = RoomListNextBatch.from_token(since_token)
|
||||||
else:
|
else:
|
||||||
|
@ -202,6 +209,9 @@ class RoomListHandler(BaseHandler):
|
||||||
|
|
||||||
chunk = []
|
chunk = []
|
||||||
for i in range(0, len(rooms_to_scan), step):
|
for i in range(0, len(rooms_to_scan), step):
|
||||||
|
if timeout and self.clock.time() > timeout:
|
||||||
|
raise Exception("Timed out searching room directory")
|
||||||
|
|
||||||
batch = rooms_to_scan[i:i + step]
|
batch = rooms_to_scan[i:i + step]
|
||||||
logger.info("Processing %i rooms for result", len(batch))
|
logger.info("Processing %i rooms for result", len(batch))
|
||||||
yield concurrently_execute(
|
yield concurrently_execute(
|
||||||
|
|
|
@ -37,6 +37,41 @@ class SearchHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(SearchHandler, self).__init__(hs)
|
super(SearchHandler, self).__init__(hs)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_old_rooms_from_upgraded_room(self, room_id):
|
||||||
|
"""Retrieves room IDs of old rooms in the history of an upgraded room.
|
||||||
|
|
||||||
|
We do so by checking the m.room.create event of the room for a
|
||||||
|
`predecessor` key. If it exists, we add the room ID to our return
|
||||||
|
list and then check that room for a m.room.create event and so on
|
||||||
|
until we can no longer find any more previous rooms.
|
||||||
|
|
||||||
|
The full list of all found rooms in then returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str): id of the room to search through.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[iterable[unicode]]: predecessor room ids
|
||||||
|
"""
|
||||||
|
|
||||||
|
historical_room_ids = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
predecessor = yield self.store.get_room_predecessor(room_id)
|
||||||
|
|
||||||
|
# If no predecessor, assume we've hit a dead end
|
||||||
|
if not predecessor:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add predecessor's room ID
|
||||||
|
historical_room_ids.append(predecessor["room_id"])
|
||||||
|
|
||||||
|
# Scan through the old room for further predecessors
|
||||||
|
room_id = predecessor["room_id"]
|
||||||
|
|
||||||
|
defer.returnValue(historical_room_ids)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def search(self, user, content, batch=None):
|
def search(self, user, content, batch=None):
|
||||||
"""Performs a full text search for a user.
|
"""Performs a full text search for a user.
|
||||||
|
@ -137,6 +172,18 @@ class SearchHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
room_ids = set(r.room_id for r in rooms)
|
room_ids = set(r.room_id for r in rooms)
|
||||||
|
|
||||||
|
# If doing a subset of all rooms seearch, check if any of the rooms
|
||||||
|
# are from an upgraded room, and search their contents as well
|
||||||
|
if search_filter.rooms:
|
||||||
|
historical_room_ids = []
|
||||||
|
for room_id in search_filter.rooms:
|
||||||
|
# Add any previous rooms to the search if they exist
|
||||||
|
ids = yield self.get_old_rooms_from_upgraded_room(room_id)
|
||||||
|
historical_room_ids += ids
|
||||||
|
|
||||||
|
# Prevent any historical events from being filtered
|
||||||
|
search_filter = search_filter.with_room_ids(historical_room_ids)
|
||||||
|
|
||||||
room_ids = search_filter.filter_rooms(room_ids)
|
room_ids = search_filter.filter_rooms(room_ids)
|
||||||
|
|
||||||
if batch_group == "room_id":
|
if batch_group == "room_id":
|
||||||
|
|
|
@ -19,6 +19,7 @@ from six import iteritems
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.metrics
|
||||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.roommember import ProfileInfo
|
from synapse.storage.roommember import ProfileInfo
|
||||||
|
@ -163,6 +164,11 @@ class UserDirectoryHandler(object):
|
||||||
yield self._handle_deltas(deltas)
|
yield self._handle_deltas(deltas)
|
||||||
|
|
||||||
self.pos = deltas[-1]["stream_id"]
|
self.pos = deltas[-1]["stream_id"]
|
||||||
|
|
||||||
|
# Expose current event processing position to prometheus
|
||||||
|
synapse.metrics.event_processing_positions.labels(
|
||||||
|
"user_dir").set(self.pos)
|
||||||
|
|
||||||
yield self.store.update_user_directory_stream_pos(self.pos)
|
yield self.store.update_user_directory_stream_pos(self.pos)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -333,9 +333,10 @@ class SimpleHttpClient(object):
|
||||||
"POST", uri, headers=Headers(actual_headers), data=query_bytes
|
"POST", uri, headers=Headers(actual_headers), data=query_bytes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
body = yield make_deferred_yieldable(readBody(response))
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
body = yield make_deferred_yieldable(treq.json_content(response))
|
defer.returnValue(json.loads(body))
|
||||||
defer.returnValue(body)
|
|
||||||
else:
|
else:
|
||||||
raise HttpResponseException(response.code, response.phrase, body)
|
raise HttpResponseException(response.code, response.phrase, body)
|
||||||
|
|
||||||
|
|
|
@ -13,15 +13,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import random
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
|
||||||
from twisted.internet.error import ConnectError
|
|
||||||
|
|
||||||
from synapse.http.federation.srv_resolver import Server, resolve_service
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,140 +81,3 @@ def parse_and_validate_server_name(server_name):
|
||||||
))
|
))
|
||||||
|
|
||||||
return host, port
|
return host, port
|
||||||
|
|
||||||
|
|
||||||
def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
|
|
||||||
timeout=None):
|
|
||||||
"""Construct an endpoint for the given matrix destination.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
reactor: Twisted reactor.
|
|
||||||
destination (unicode): The name of the server to connect to.
|
|
||||||
tls_client_options_factory
|
|
||||||
(synapse.crypto.context_factory.ClientTLSOptionsFactory):
|
|
||||||
Factory which generates TLS options for client connections.
|
|
||||||
timeout (int): connection timeout in seconds
|
|
||||||
"""
|
|
||||||
|
|
||||||
domain, port = parse_server_name(destination)
|
|
||||||
|
|
||||||
endpoint_kw_args = {}
|
|
||||||
|
|
||||||
if timeout is not None:
|
|
||||||
endpoint_kw_args.update(timeout=timeout)
|
|
||||||
|
|
||||||
if tls_client_options_factory is None:
|
|
||||||
transport_endpoint = HostnameEndpoint
|
|
||||||
default_port = 8008
|
|
||||||
else:
|
|
||||||
# the SNI string should be the same as the Host header, minus the port.
|
|
||||||
# as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777,
|
|
||||||
# the Host header and SNI should therefore be the server_name of the remote
|
|
||||||
# server.
|
|
||||||
tls_options = tls_client_options_factory.get_options(domain)
|
|
||||||
|
|
||||||
def transport_endpoint(reactor, host, port, timeout):
|
|
||||||
return wrapClientTLS(
|
|
||||||
tls_options,
|
|
||||||
HostnameEndpoint(reactor, host, port, timeout=timeout),
|
|
||||||
)
|
|
||||||
default_port = 8448
|
|
||||||
|
|
||||||
if port is None:
|
|
||||||
return SRVClientEndpoint(
|
|
||||||
reactor, "matrix", domain, protocol="tcp",
|
|
||||||
default_port=default_port, endpoint=transport_endpoint,
|
|
||||||
endpoint_kw_args=endpoint_kw_args
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return transport_endpoint(
|
|
||||||
reactor, domain, port, **endpoint_kw_args
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SRVClientEndpoint(object):
|
|
||||||
"""An endpoint which looks up SRV records for a service.
|
|
||||||
Cycles through the list of servers starting with each call to connect
|
|
||||||
picking the next server.
|
|
||||||
Implements twisted.internet.interfaces.IStreamClientEndpoint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, reactor, service, domain, protocol="tcp",
|
|
||||||
default_port=None, endpoint=HostnameEndpoint,
|
|
||||||
endpoint_kw_args={}):
|
|
||||||
self.reactor = reactor
|
|
||||||
self.service_name = "_%s._%s.%s" % (service, protocol, domain)
|
|
||||||
|
|
||||||
if default_port is not None:
|
|
||||||
self.default_server = Server(
|
|
||||||
host=domain,
|
|
||||||
port=default_port,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.default_server = None
|
|
||||||
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.endpoint_kw_args = endpoint_kw_args
|
|
||||||
|
|
||||||
self.servers = None
|
|
||||||
self.used_servers = None
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def fetch_servers(self):
|
|
||||||
self.used_servers = []
|
|
||||||
self.servers = yield resolve_service(self.service_name)
|
|
||||||
|
|
||||||
def pick_server(self):
|
|
||||||
if not self.servers:
|
|
||||||
if self.used_servers:
|
|
||||||
self.servers = self.used_servers
|
|
||||||
self.used_servers = []
|
|
||||||
self.servers.sort()
|
|
||||||
elif self.default_server:
|
|
||||||
return self.default_server
|
|
||||||
else:
|
|
||||||
raise ConnectError(
|
|
||||||
"No server available for %s" % self.service_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# look for all servers with the same priority
|
|
||||||
min_priority = self.servers[0].priority
|
|
||||||
weight_indexes = list(
|
|
||||||
(index, server.weight + 1)
|
|
||||||
for index, server in enumerate(self.servers)
|
|
||||||
if server.priority == min_priority
|
|
||||||
)
|
|
||||||
|
|
||||||
total_weight = sum(weight for index, weight in weight_indexes)
|
|
||||||
target_weight = random.randint(0, total_weight)
|
|
||||||
for index, weight in weight_indexes:
|
|
||||||
target_weight -= weight
|
|
||||||
if target_weight <= 0:
|
|
||||||
server = self.servers[index]
|
|
||||||
# XXX: this looks totally dubious:
|
|
||||||
#
|
|
||||||
# (a) we never reuse a server until we have been through
|
|
||||||
# all of the servers at the same priority, so if the
|
|
||||||
# weights are A: 100, B:1, we always do ABABAB instead of
|
|
||||||
# AAAA...AAAB (approximately).
|
|
||||||
#
|
|
||||||
# (b) After using all the servers at the lowest priority,
|
|
||||||
# we move onto the next priority. We should only use the
|
|
||||||
# second priority if servers at the top priority are
|
|
||||||
# unreachable.
|
|
||||||
#
|
|
||||||
del self.servers[index]
|
|
||||||
self.used_servers.append(server)
|
|
||||||
return server
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def connect(self, protocolFactory):
|
|
||||||
if self.servers is None:
|
|
||||||
yield self.fetch_servers()
|
|
||||||
server = self.pick_server()
|
|
||||||
logger.info("Connecting to %s:%s", server.host, server.port)
|
|
||||||
endpoint = self.endpoint(
|
|
||||||
self.reactor, server.host, server.port, **self.endpoint_kw_args
|
|
||||||
)
|
|
||||||
connection = yield endpoint.connect(protocolFactory)
|
|
||||||
defer.returnValue(connection)
|
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||||
|
from twisted.web.client import URI, Agent, HTTPConnectionPool
|
||||||
|
from twisted.web.iweb import IAgent
|
||||||
|
|
||||||
|
from synapse.http.endpoint import parse_server_name
|
||||||
|
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
|
||||||
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IAgent)
|
||||||
|
class MatrixFederationAgent(object):
|
||||||
|
"""An Agent-like thing which provides a `request` method which will look up a matrix
|
||||||
|
server and send an HTTP request to it.
|
||||||
|
|
||||||
|
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reactor (IReactor): twisted reactor to use for underlying requests
|
||||||
|
|
||||||
|
tls_client_options_factory (ClientTLSOptionsFactory|None):
|
||||||
|
factory to use for fetching client tls options, or none to disable TLS.
|
||||||
|
|
||||||
|
srv_resolver (SrvResolver|None):
|
||||||
|
SRVResolver impl to use for looking up SRV records. None to use a default
|
||||||
|
implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, reactor, tls_client_options_factory, _srv_resolver=None,
|
||||||
|
):
|
||||||
|
self._reactor = reactor
|
||||||
|
self._tls_client_options_factory = tls_client_options_factory
|
||||||
|
if _srv_resolver is None:
|
||||||
|
_srv_resolver = SrvResolver()
|
||||||
|
self._srv_resolver = _srv_resolver
|
||||||
|
|
||||||
|
self._pool = HTTPConnectionPool(reactor)
|
||||||
|
self._pool.retryAutomatically = False
|
||||||
|
self._pool.maxPersistentPerHost = 5
|
||||||
|
self._pool.cachedConnectionTimeout = 2 * 60
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def request(self, method, uri, headers=None, bodyProducer=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
method (bytes): HTTP method: GET/POST/etc
|
||||||
|
|
||||||
|
uri (bytes): Absolute URI to be retrieved
|
||||||
|
|
||||||
|
headers (twisted.web.http_headers.Headers|None):
|
||||||
|
HTTP headers to send with the request, or None to
|
||||||
|
send no extra headers.
|
||||||
|
|
||||||
|
bodyProducer (twisted.web.iweb.IBodyProducer|None):
|
||||||
|
An object which can generate bytes to make up the
|
||||||
|
body of this request (for example, the properly encoded contents of
|
||||||
|
a file for a file upload). Or None if the request is to have
|
||||||
|
no body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[twisted.web.iweb.IResponse]:
|
||||||
|
fires when the header of the response has been received (regardless of the
|
||||||
|
response status code). Fails if there is any problem which prevents that
|
||||||
|
response from being received (including problems that prevent the request
|
||||||
|
from being sent).
|
||||||
|
"""
|
||||||
|
|
||||||
|
parsed_uri = URI.fromBytes(uri)
|
||||||
|
server_name_bytes = parsed_uri.netloc
|
||||||
|
host, port = parse_server_name(server_name_bytes.decode("ascii"))
|
||||||
|
|
||||||
|
# XXX disabling TLS is really only supported here for the benefit of the
|
||||||
|
# unit tests. We should make the UTs cope with TLS rather than having to make
|
||||||
|
# the code support the unit tests.
|
||||||
|
if self._tls_client_options_factory is None:
|
||||||
|
tls_options = None
|
||||||
|
else:
|
||||||
|
tls_options = self._tls_client_options_factory.get_options(host)
|
||||||
|
|
||||||
|
if port is not None:
|
||||||
|
target = (host, port)
|
||||||
|
else:
|
||||||
|
service_name = b"_matrix._tcp.%s" % (server_name_bytes, )
|
||||||
|
server_list = yield self._srv_resolver.resolve_service(service_name)
|
||||||
|
if not server_list:
|
||||||
|
target = (host, 8448)
|
||||||
|
logger.debug("No SRV record for %s, using %s", host, target)
|
||||||
|
else:
|
||||||
|
target = pick_server_from_list(server_list)
|
||||||
|
|
||||||
|
class EndpointFactory(object):
|
||||||
|
@staticmethod
|
||||||
|
def endpointForURI(_uri):
|
||||||
|
logger.info("Connecting to %s:%s", target[0], target[1])
|
||||||
|
ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1])
|
||||||
|
if tls_options is not None:
|
||||||
|
ep = wrapClientTLS(tls_options, ep)
|
||||||
|
return ep
|
||||||
|
|
||||||
|
agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
|
||||||
|
res = yield make_deferred_yieldable(
|
||||||
|
agent.request(method, uri, headers, bodyProducer)
|
||||||
|
)
|
||||||
|
defer.returnValue(res)
|
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -51,74 +52,118 @@ class Server(object):
|
||||||
expires = attr.ib(default=0)
|
expires = attr.ib(default=0)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def pick_server_from_list(server_list):
|
||||||
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
|
"""Randomly choose a server from the server list
|
||||||
"""Look up a SRV record, with caching
|
|
||||||
|
Args:
|
||||||
|
server_list (list[Server]): list of candidate servers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bytes, int]: (host, port) pair for the chosen server
|
||||||
|
"""
|
||||||
|
if not server_list:
|
||||||
|
raise RuntimeError("pick_server_from_list called with empty list")
|
||||||
|
|
||||||
|
# TODO: currently we only use the lowest-priority servers. We should maintain a
|
||||||
|
# cache of servers known to be "down" and filter them out
|
||||||
|
|
||||||
|
min_priority = min(s.priority for s in server_list)
|
||||||
|
eligible_servers = list(s for s in server_list if s.priority == min_priority)
|
||||||
|
total_weight = sum(s.weight for s in eligible_servers)
|
||||||
|
target_weight = random.randint(0, total_weight)
|
||||||
|
|
||||||
|
for s in eligible_servers:
|
||||||
|
target_weight -= s.weight
|
||||||
|
|
||||||
|
if target_weight <= 0:
|
||||||
|
return s.host, s.port
|
||||||
|
|
||||||
|
# this should be impossible.
|
||||||
|
raise RuntimeError(
|
||||||
|
"pick_server_from_list got to end of eligible server list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SrvResolver(object):
|
||||||
|
"""Interface to the dns client to do SRV lookups, with result caching.
|
||||||
|
|
||||||
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
|
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
|
||||||
but the cache never gets populated), so we add our own caching layer here.
|
but the cache never gets populated), so we add our own caching layer here.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_name (unicode|bytes): record to look up
|
|
||||||
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
|
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
|
||||||
cache (dict): cache object
|
cache (dict): cache object
|
||||||
clock (object): clock implementation. must provide a time() method.
|
get_time (callable): clock implementation. Should return seconds since the epoch
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
|
|
||||||
"""
|
"""
|
||||||
# TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded
|
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
|
||||||
# byteses; however they will obviously end up as separate entries in the cache. We
|
self._dns_client = dns_client
|
||||||
# should pick one form and stick with it.
|
self._cache = cache
|
||||||
cache_entry = cache.get(service_name, None)
|
self._get_time = get_time
|
||||||
if cache_entry:
|
|
||||||
if all(s.expires > int(clock.time()) for s in cache_entry):
|
|
||||||
servers = list(cache_entry)
|
|
||||||
defer.returnValue(servers)
|
|
||||||
|
|
||||||
try:
|
@defer.inlineCallbacks
|
||||||
answers, _, _ = yield make_deferred_yieldable(
|
def resolve_service(self, service_name):
|
||||||
dns_client.lookupService(service_name),
|
"""Look up a SRV record
|
||||||
)
|
|
||||||
except DNSNameError:
|
Args:
|
||||||
# TODO: cache this. We can get the SOA out of the exception, and use
|
service_name (bytes): record to look up
|
||||||
# the negative-TTL value.
|
|
||||||
defer.returnValue([])
|
Returns:
|
||||||
except DomainError as e:
|
Deferred[list[Server]]:
|
||||||
# We failed to resolve the name (other than a NameError)
|
a list of the SRV records, or an empty list if none found
|
||||||
# Try something in the cache, else rereaise
|
"""
|
||||||
cache_entry = cache.get(service_name, None)
|
now = int(self._get_time())
|
||||||
|
|
||||||
|
if not isinstance(service_name, bytes):
|
||||||
|
raise TypeError("%r is not a byte string" % (service_name,))
|
||||||
|
|
||||||
|
cache_entry = self._cache.get(service_name, None)
|
||||||
if cache_entry:
|
if cache_entry:
|
||||||
logger.warn(
|
if all(s.expires > now for s in cache_entry):
|
||||||
"Failed to resolve %r, falling back to cache. %r",
|
servers = list(cache_entry)
|
||||||
service_name, e
|
defer.returnValue(servers)
|
||||||
|
|
||||||
|
try:
|
||||||
|
answers, _, _ = yield make_deferred_yieldable(
|
||||||
|
self._dns_client.lookupService(service_name),
|
||||||
)
|
)
|
||||||
defer.returnValue(list(cache_entry))
|
except DNSNameError:
|
||||||
else:
|
# TODO: cache this. We can get the SOA out of the exception, and use
|
||||||
raise e
|
# the negative-TTL value.
|
||||||
|
defer.returnValue([])
|
||||||
|
except DomainError as e:
|
||||||
|
# We failed to resolve the name (other than a NameError)
|
||||||
|
# Try something in the cache, else rereaise
|
||||||
|
cache_entry = self._cache.get(service_name, None)
|
||||||
|
if cache_entry:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to resolve %r, falling back to cache. %r",
|
||||||
|
service_name, e
|
||||||
|
)
|
||||||
|
defer.returnValue(list(cache_entry))
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
if (len(answers) == 1
|
if (len(answers) == 1
|
||||||
and answers[0].type == dns.SRV
|
and answers[0].type == dns.SRV
|
||||||
and answers[0].payload
|
and answers[0].payload
|
||||||
and answers[0].payload.target == dns.Name(b'.')):
|
and answers[0].payload.target == dns.Name(b'.')):
|
||||||
raise ConnectError("Service %s unavailable" % service_name)
|
raise ConnectError("Service %s unavailable" % service_name)
|
||||||
|
|
||||||
servers = []
|
servers = []
|
||||||
|
|
||||||
for answer in answers:
|
for answer in answers:
|
||||||
if answer.type != dns.SRV or not answer.payload:
|
if answer.type != dns.SRV or not answer.payload:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
payload = answer.payload
|
payload = answer.payload
|
||||||
|
|
||||||
servers.append(Server(
|
servers.append(Server(
|
||||||
host=payload.target.name,
|
host=payload.target.name,
|
||||||
port=payload.port,
|
port=payload.port,
|
||||||
priority=payload.priority,
|
priority=payload.priority,
|
||||||
weight=payload.weight,
|
weight=payload.weight,
|
||||||
expires=int(clock.time()) + answer.ttl,
|
expires=now + answer.ttl,
|
||||||
))
|
))
|
||||||
|
|
||||||
servers.sort() # FIXME: get rid of this (it's broken by the attrs change)
|
self._cache[service_name] = list(servers)
|
||||||
cache[service_name] = list(servers)
|
defer.returnValue(servers)
|
||||||
defer.returnValue(servers)
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ from twisted.internet import defer, protocol
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.internet.task import _EPSILON, Cooperator
|
from twisted.internet.task import _EPSILON, Cooperator
|
||||||
from twisted.web._newclient import ResponseDone
|
from twisted.web._newclient import ResponseDone
|
||||||
from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool
|
from twisted.web.client import FileBodyProducer
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
|
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
@ -44,7 +44,7 @@ from synapse.api.errors import (
|
||||||
RequestSendFailed,
|
RequestSendFailed,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.http.endpoint import matrix_federation_endpoint
|
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||||
from synapse.util.async_helpers import timeout_deferred
|
from synapse.util.async_helpers import timeout_deferred
|
||||||
from synapse.util.logcontext import make_deferred_yieldable
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
@ -66,20 +66,6 @@ else:
|
||||||
MAXINT = sys.maxint
|
MAXINT = sys.maxint
|
||||||
|
|
||||||
|
|
||||||
class MatrixFederationEndpointFactory(object):
|
|
||||||
def __init__(self, hs):
|
|
||||||
self.reactor = hs.get_reactor()
|
|
||||||
self.tls_client_options_factory = hs.tls_client_options_factory
|
|
||||||
|
|
||||||
def endpointForURI(self, uri):
|
|
||||||
destination = uri.netloc.decode('ascii')
|
|
||||||
|
|
||||||
return matrix_federation_endpoint(
|
|
||||||
self.reactor, destination, timeout=10,
|
|
||||||
tls_client_options_factory=self.tls_client_options_factory
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_next_id = 1
|
_next_id = 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -187,12 +173,10 @@ class MatrixFederationHttpClient(object):
|
||||||
self.signing_key = hs.config.signing_key[0]
|
self.signing_key = hs.config.signing_key[0]
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
reactor = hs.get_reactor()
|
reactor = hs.get_reactor()
|
||||||
pool = HTTPConnectionPool(reactor)
|
|
||||||
pool.retryAutomatically = False
|
self.agent = MatrixFederationAgent(
|
||||||
pool.maxPersistentPerHost = 5
|
hs.get_reactor(),
|
||||||
pool.cachedConnectionTimeout = 2 * 60
|
hs.tls_client_options_factory,
|
||||||
self.agent = Agent.usingEndpointFactory(
|
|
||||||
reactor, MatrixFederationEndpointFactory(hs), pool=pool
|
|
||||||
)
|
)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
|
@ -316,9 +300,9 @@ class MatrixFederationHttpClient(object):
|
||||||
headers_dict[b"Authorization"] = auth_headers
|
headers_dict[b"Authorization"] = auth_headers
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"{%s} [%s] Sending request: %s %s",
|
"{%s} [%s] Sending request: %s %s; timeout %fs",
|
||||||
request.txn_id, request.destination, request.method,
|
request.txn_id, request.destination, request.method,
|
||||||
url_str,
|
url_str, _sec_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -338,12 +322,11 @@ class MatrixFederationHttpClient(object):
|
||||||
reactor=self.hs.get_reactor(),
|
reactor=self.hs.get_reactor(),
|
||||||
)
|
)
|
||||||
|
|
||||||
response = yield make_deferred_yieldable(
|
response = yield request_deferred
|
||||||
request_deferred,
|
|
||||||
)
|
|
||||||
except DNSLookupError as e:
|
except DNSLookupError as e:
|
||||||
raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
|
raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.info("Failed to send request: %s", e)
|
||||||
raise_from(RequestSendFailed(e, can_retry=True), e)
|
raise_from(RequestSendFailed(e, can_retry=True), e)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
@ -79,6 +79,10 @@ CONDITIONAL_REQUIREMENTS = {
|
||||||
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
|
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
|
||||||
"resources.consent": ["Jinja2>=2.9"],
|
"resources.consent": ["Jinja2>=2.9"],
|
||||||
|
|
||||||
|
# ACME support is required to provision TLS certificates from authorities
|
||||||
|
# that use the protocol, such as Let's Encrypt.
|
||||||
|
"acme": ["txacme>=0.9.2"],
|
||||||
|
|
||||||
"saml2": ["pysaml2>=4.5.0"],
|
"saml2": ["pysaml2>=4.5.0"],
|
||||||
"url_preview": ["lxml>=3.5.0"],
|
"url_preview": ["lxml>=3.5.0"],
|
||||||
"test": ["mock>=2.0"],
|
"test": ["mock>=2.0"],
|
||||||
|
|
|
@ -309,22 +309,16 @@ class RegisterRestServlet(RestServlet):
|
||||||
assigned_user_id=registered_user_id,
|
assigned_user_id=registered_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only give msisdn flows if the x_show_msisdn flag is given:
|
|
||||||
# this is a hack to work around the fact that clients were shipped
|
|
||||||
# that use fallback registration if they see any flows that they don't
|
|
||||||
# recognise, which means we break registration for these clients if we
|
|
||||||
# advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot
|
|
||||||
# Android <=0.6.9 have fallen below an acceptable threshold, this
|
|
||||||
# parameter should go away and we should always advertise msisdn flows.
|
|
||||||
show_msisdn = False
|
|
||||||
if 'x_show_msisdn' in body and body['x_show_msisdn']:
|
|
||||||
show_msisdn = True
|
|
||||||
|
|
||||||
# FIXME: need a better error than "no auth flow found" for scenarios
|
# FIXME: need a better error than "no auth flow found" for scenarios
|
||||||
# where we required 3PID for registration but the user didn't give one
|
# where we required 3PID for registration but the user didn't give one
|
||||||
require_email = 'email' in self.hs.config.registrations_require_3pid
|
require_email = 'email' in self.hs.config.registrations_require_3pid
|
||||||
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
||||||
|
|
||||||
|
show_msisdn = True
|
||||||
|
if self.hs.config.disable_msisdn_registration:
|
||||||
|
show_msisdn = False
|
||||||
|
require_msisdn = False
|
||||||
|
|
||||||
flows = []
|
flows = []
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
# only support 3PIDless registration if no 3PIDs are required
|
# only support 3PIDless registration if no 3PIDs are required
|
||||||
|
@ -422,8 +416,11 @@ class RegisterRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
# Necessary due to auth checks prior to the threepid being
|
# Necessary due to auth checks prior to the threepid being
|
||||||
# written to the db
|
# written to the db
|
||||||
if is_threepid_reserved(self.hs.config, threepid):
|
if threepid:
|
||||||
yield self.store.upsert_monthly_active_user(registered_user_id)
|
if is_threepid_reserved(
|
||||||
|
self.hs.config.mau_limits_reserved_threepids, threepid
|
||||||
|
):
|
||||||
|
yield self.store.upsert_monthly_active_user(registered_user_id)
|
||||||
|
|
||||||
# remember that we've now registered that user account, and with
|
# remember that we've now registered that user account, and with
|
||||||
# what user ID (since the user may not have specified)
|
# what user ID (since the user may not have specified)
|
||||||
|
|
|
@ -46,6 +46,7 @@ from synapse.federation.transport.client import TransportLayerClient
|
||||||
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
|
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
|
||||||
from synapse.groups.groups_server import GroupsServerHandler
|
from synapse.groups.groups_server import GroupsServerHandler
|
||||||
from synapse.handlers import Handlers
|
from synapse.handlers import Handlers
|
||||||
|
from synapse.handlers.acme import AcmeHandler
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
|
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
|
||||||
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||||
|
@ -129,6 +130,7 @@ class HomeServer(object):
|
||||||
'sync_handler',
|
'sync_handler',
|
||||||
'typing_handler',
|
'typing_handler',
|
||||||
'room_list_handler',
|
'room_list_handler',
|
||||||
|
'acme_handler',
|
||||||
'auth_handler',
|
'auth_handler',
|
||||||
'device_handler',
|
'device_handler',
|
||||||
'e2e_keys_handler',
|
'e2e_keys_handler',
|
||||||
|
@ -310,6 +312,9 @@ class HomeServer(object):
|
||||||
def build_e2e_room_keys_handler(self):
|
def build_e2e_room_keys_handler(self):
|
||||||
return E2eRoomKeysHandler(self)
|
return E2eRoomKeysHandler(self)
|
||||||
|
|
||||||
|
def build_acme_handler(self):
|
||||||
|
return AcmeHandler(self)
|
||||||
|
|
||||||
def build_application_service_api(self):
|
def build_application_service_api(self):
|
||||||
return ApplicationServiceApi(self)
|
return ApplicationServiceApi(self)
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ from prometheus_client import Histogram
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.util.caches.descriptors import Cache
|
from synapse.util.caches.descriptors import Cache
|
||||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
|
@ -192,6 +193,51 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
self.database_engine = hs.database_engine
|
self.database_engine = hs.database_engine
|
||||||
|
|
||||||
|
# A set of tables that are not safe to use native upserts in.
|
||||||
|
self._unsafe_to_upsert_tables = {"user_ips"}
|
||||||
|
|
||||||
|
if self.database_engine.can_native_upsert:
|
||||||
|
# Check ASAP (and then later, every 1s) to see if we have finished
|
||||||
|
# background updates of tables that aren't safe to update.
|
||||||
|
self._clock.call_later(
|
||||||
|
0.0,
|
||||||
|
run_as_background_process,
|
||||||
|
"upsert_safety_check",
|
||||||
|
self._check_safe_to_upsert
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _check_safe_to_upsert(self):
|
||||||
|
"""
|
||||||
|
Is it safe to use native UPSERT?
|
||||||
|
|
||||||
|
If there are background updates, we will need to wait, as they may be
|
||||||
|
the addition of indexes that set the UNIQUE constraint that we require.
|
||||||
|
|
||||||
|
If the background updates have not completed, wait 15 sec and check again.
|
||||||
|
"""
|
||||||
|
updates = yield self._simple_select_list(
|
||||||
|
"background_updates",
|
||||||
|
keyvalues=None,
|
||||||
|
retcols=["update_name"],
|
||||||
|
desc="check_background_updates",
|
||||||
|
)
|
||||||
|
updates = [x["update_name"] for x in updates]
|
||||||
|
|
||||||
|
# The User IPs table in schema #53 was missing a unique index, which we
|
||||||
|
# run as a background update.
|
||||||
|
if "user_ips_device_unique_index" not in updates:
|
||||||
|
self._unsafe_to_upsert_tables.discard("user_ips")
|
||||||
|
|
||||||
|
# If there's any tables left to check, reschedule to run.
|
||||||
|
if self._unsafe_to_upsert_tables:
|
||||||
|
self._clock.call_later(
|
||||||
|
15.0,
|
||||||
|
run_as_background_process,
|
||||||
|
"upsert_safety_check",
|
||||||
|
self._check_safe_to_upsert
|
||||||
|
)
|
||||||
|
|
||||||
def start_profiling(self):
|
def start_profiling(self):
|
||||||
self._previous_loop_ts = self._clock.time_msec()
|
self._previous_loop_ts = self._clock.time_msec()
|
||||||
|
|
||||||
|
@ -494,8 +540,15 @@ class SQLBaseStore(object):
|
||||||
txn.executemany(sql, vals)
|
txn.executemany(sql, vals)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _simple_upsert(self, table, keyvalues, values,
|
def _simple_upsert(
|
||||||
insertion_values={}, desc="_simple_upsert", lock=True):
|
self,
|
||||||
|
table,
|
||||||
|
keyvalues,
|
||||||
|
values,
|
||||||
|
insertion_values={},
|
||||||
|
desc="_simple_upsert",
|
||||||
|
lock=True
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
`lock` should generally be set to True (the default), but can be set
|
`lock` should generally be set to True (the default), but can be set
|
||||||
|
@ -516,16 +569,21 @@ class SQLBaseStore(object):
|
||||||
inserting
|
inserting
|
||||||
lock (bool): True to lock the table when doing the upsert.
|
lock (bool): True to lock the table when doing the upsert.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred(bool): True if a new entry was created, False if an
|
Deferred(None or bool): Native upserts always return None. Emulated
|
||||||
existing one was updated.
|
upserts return True if a new entry was created, False if an existing
|
||||||
|
one was updated.
|
||||||
"""
|
"""
|
||||||
attempts = 0
|
attempts = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
result = yield self.runInteraction(
|
result = yield self.runInteraction(
|
||||||
desc,
|
desc,
|
||||||
self._simple_upsert_txn, table, keyvalues, values, insertion_values,
|
self._simple_upsert_txn,
|
||||||
lock=lock
|
table,
|
||||||
|
keyvalues,
|
||||||
|
values,
|
||||||
|
insertion_values,
|
||||||
|
lock=lock,
|
||||||
)
|
)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
except self.database_engine.module.IntegrityError as e:
|
except self.database_engine.module.IntegrityError as e:
|
||||||
|
@ -537,12 +595,71 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
# presumably we raced with another transaction: let's retry.
|
# presumably we raced with another transaction: let's retry.
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"IntegrityError when upserting into %s; retrying: %s",
|
"%s when upserting into %s; retrying: %s", e.__name__, table, e
|
||||||
table, e
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
|
def _simple_upsert_txn(
|
||||||
lock=True):
|
self,
|
||||||
|
txn,
|
||||||
|
table,
|
||||||
|
keyvalues,
|
||||||
|
values,
|
||||||
|
insertion_values={},
|
||||||
|
lock=True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Pick the UPSERT method which works best on the platform. Either the
|
||||||
|
native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txn: The transaction to use.
|
||||||
|
table (str): The table to upsert into
|
||||||
|
keyvalues (dict): The unique key tables and their new values
|
||||||
|
values (dict): The nonunique columns and their new values
|
||||||
|
insertion_values (dict): additional key/values to use only when
|
||||||
|
inserting
|
||||||
|
lock (bool): True to lock the table when doing the upsert.
|
||||||
|
Returns:
|
||||||
|
None or bool: Native upserts always return None. Emulated
|
||||||
|
upserts return True if a new entry was created, False if an existing
|
||||||
|
one was updated.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
self.database_engine.can_native_upsert
|
||||||
|
and table not in self._unsafe_to_upsert_tables
|
||||||
|
):
|
||||||
|
return self._simple_upsert_txn_native_upsert(
|
||||||
|
txn,
|
||||||
|
table,
|
||||||
|
keyvalues,
|
||||||
|
values,
|
||||||
|
insertion_values=insertion_values,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._simple_upsert_txn_emulated(
|
||||||
|
txn,
|
||||||
|
table,
|
||||||
|
keyvalues,
|
||||||
|
values,
|
||||||
|
insertion_values=insertion_values,
|
||||||
|
lock=lock,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _simple_upsert_txn_emulated(
|
||||||
|
self, txn, table, keyvalues, values, insertion_values={}, lock=True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
table (str): The table to upsert into
|
||||||
|
keyvalues (dict): The unique key tables and their new values
|
||||||
|
values (dict): The nonunique columns and their new values
|
||||||
|
insertion_values (dict): additional key/values to use only when
|
||||||
|
inserting
|
||||||
|
lock (bool): True to lock the table when doing the upsert.
|
||||||
|
Returns:
|
||||||
|
bool: Return True if a new entry was created, False if an existing
|
||||||
|
one was updated.
|
||||||
|
"""
|
||||||
# We need to lock the table :(, unless we're *really* careful
|
# We need to lock the table :(, unless we're *really* careful
|
||||||
if lock:
|
if lock:
|
||||||
self.database_engine.lock_table(txn, table)
|
self.database_engine.lock_table(txn, table)
|
||||||
|
@ -577,12 +694,44 @@ class SQLBaseStore(object):
|
||||||
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
||||||
table,
|
table,
|
||||||
", ".join(k for k in allvalues),
|
", ".join(k for k in allvalues),
|
||||||
", ".join("?" for _ in allvalues)
|
", ".join("?" for _ in allvalues),
|
||||||
)
|
)
|
||||||
txn.execute(sql, list(allvalues.values()))
|
txn.execute(sql, list(allvalues.values()))
|
||||||
# successfully inserted
|
# successfully inserted
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _simple_upsert_txn_native_upsert(
|
||||||
|
self, txn, table, keyvalues, values, insertion_values={}
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Use the native UPSERT functionality in recent PostgreSQL versions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table (str): The table to upsert into
|
||||||
|
keyvalues (dict): The unique key tables and their new values
|
||||||
|
values (dict): The nonunique columns and their new values
|
||||||
|
insertion_values (dict): additional key/values to use only when
|
||||||
|
inserting
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
allvalues = {}
|
||||||
|
allvalues.update(keyvalues)
|
||||||
|
allvalues.update(values)
|
||||||
|
allvalues.update(insertion_values)
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"INSERT INTO %s (%s) VALUES (%s) "
|
||||||
|
"ON CONFLICT (%s) DO UPDATE SET %s"
|
||||||
|
) % (
|
||||||
|
table,
|
||||||
|
", ".join(k for k in allvalues),
|
||||||
|
", ".join("?" for _ in allvalues),
|
||||||
|
", ".join(k for k in keyvalues),
|
||||||
|
", ".join(k + "=EXCLUDED." + k for k in values),
|
||||||
|
)
|
||||||
|
txn.execute(sql, list(allvalues.values()))
|
||||||
|
|
||||||
def _simple_select_one(self, table, keyvalues, retcols,
|
def _simple_select_one(self, table, keyvalues, retcols,
|
||||||
allow_none=False, desc="_simple_select_one"):
|
allow_none=False, desc="_simple_select_one"):
|
||||||
"""Executes a SELECT query on the named table, which is expected to
|
"""Executes a SELECT query on the named table, which is expected to
|
||||||
|
|
|
@ -110,8 +110,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _remove_user_ip_dupes(self, progress, batch_size):
|
def _remove_user_ip_dupes(self, progress, batch_size):
|
||||||
|
# This works function works by scanning the user_ips table in batches
|
||||||
|
# based on `last_seen`. For each row in a batch it searches the rest of
|
||||||
|
# the table to see if there are any duplicates, if there are then they
|
||||||
|
# are removed and replaced with a suitable row.
|
||||||
|
|
||||||
last_seen_progress = progress.get("last_seen", 0)
|
# Fetch the start of the batch
|
||||||
|
begin_last_seen = progress.get("last_seen", 0)
|
||||||
|
|
||||||
def get_last_seen(txn):
|
def get_last_seen(txn):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
|
@ -122,29 +127,28 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
OFFSET ?
|
OFFSET ?
|
||||||
""",
|
""",
|
||||||
(last_seen_progress, batch_size)
|
(begin_last_seen, batch_size)
|
||||||
)
|
)
|
||||||
results = txn.fetchone()
|
row = txn.fetchone()
|
||||||
return results
|
if row:
|
||||||
|
return row[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
# Get a last seen that's sufficiently far away enough from the last one
|
# Get a last seen that has roughly `batch_size` since `begin_last_seen`
|
||||||
last_seen = yield self.runInteraction(
|
end_last_seen = yield self.runInteraction(
|
||||||
"user_ips_dups_get_last_seen", get_last_seen
|
"user_ips_dups_get_last_seen", get_last_seen
|
||||||
)
|
)
|
||||||
|
|
||||||
if not last_seen:
|
# If it returns None, then we're processing the last batch
|
||||||
# If we get a None then we're reaching the end and just need to
|
last = end_last_seen is None
|
||||||
# delete the last batch.
|
|
||||||
last = True
|
|
||||||
|
|
||||||
# We fake not having an upper bound by using a future date, by
|
logger.info(
|
||||||
# just multiplying the current time by two....
|
"Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s",
|
||||||
last_seen = int(self.clock.time_msec()) * 2
|
begin_last_seen, end_last_seen,
|
||||||
else:
|
)
|
||||||
last = False
|
|
||||||
last_seen = last_seen[0]
|
|
||||||
|
|
||||||
def remove(txn, last_seen_progress, last_seen):
|
def remove(txn):
|
||||||
# This works by looking at all entries in the given time span, and
|
# This works by looking at all entries in the given time span, and
|
||||||
# then for each (user_id, access_token, ip) tuple in that range
|
# then for each (user_id, access_token, ip) tuple in that range
|
||||||
# checking for any duplicates in the rest of the table (via a join).
|
# checking for any duplicates in the rest of the table (via a join).
|
||||||
|
@ -153,6 +157,16 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
# all other duplicates.
|
# all other duplicates.
|
||||||
# It is efficient due to the existence of (user_id, access_token,
|
# It is efficient due to the existence of (user_id, access_token,
|
||||||
# ip) and (last_seen) indices.
|
# ip) and (last_seen) indices.
|
||||||
|
|
||||||
|
# Define the search space, which requires handling the last batch in
|
||||||
|
# a different way
|
||||||
|
if last:
|
||||||
|
clause = "? <= last_seen"
|
||||||
|
args = (begin_last_seen,)
|
||||||
|
else:
|
||||||
|
clause = "? <= last_seen AND last_seen < ?"
|
||||||
|
args = (begin_last_seen, end_last_seen)
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT user_id, access_token, ip,
|
SELECT user_id, access_token, ip,
|
||||||
|
@ -160,13 +174,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
FROM (
|
FROM (
|
||||||
SELECT user_id, access_token, ip
|
SELECT user_id, access_token, ip
|
||||||
FROM user_ips
|
FROM user_ips
|
||||||
WHERE ? <= last_seen AND last_seen < ?
|
WHERE {}
|
||||||
ORDER BY last_seen
|
|
||||||
) c
|
) c
|
||||||
INNER JOIN user_ips USING (user_id, access_token, ip)
|
INNER JOIN user_ips USING (user_id, access_token, ip)
|
||||||
GROUP BY user_id, access_token, ip
|
GROUP BY user_id, access_token, ip
|
||||||
HAVING count(*) > 1""",
|
HAVING count(*) > 1
|
||||||
(last_seen_progress, last_seen)
|
""".format(clause),
|
||||||
|
args
|
||||||
)
|
)
|
||||||
res = txn.fetchall()
|
res = txn.fetchall()
|
||||||
|
|
||||||
|
@ -194,12 +208,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
self._background_update_progress_txn(
|
self._background_update_progress_txn(
|
||||||
txn, "user_ips_remove_dupes", {"last_seen": last_seen}
|
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.runInteraction(
|
yield self.runInteraction("user_ips_dups_remove", remove)
|
||||||
"user_ips_dups_remove", remove, last_seen_progress, last_seen
|
|
||||||
)
|
|
||||||
if last:
|
if last:
|
||||||
yield self._end_background_update("user_ips_remove_dupes")
|
yield self._end_background_update("user_ips_remove_dupes")
|
||||||
|
|
||||||
|
@ -244,7 +257,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _update_client_ips_batch_txn(self, txn, to_update):
|
def _update_client_ips_batch_txn(self, txn, to_update):
|
||||||
self.database_engine.lock_table(txn, "user_ips")
|
if "user_ips" in self._unsafe_to_upsert_tables or (
|
||||||
|
not self.database_engine.can_native_upsert
|
||||||
|
):
|
||||||
|
self.database_engine.lock_table(txn, "user_ips")
|
||||||
|
|
||||||
for entry in iteritems(to_update):
|
for entry in iteritems(to_update):
|
||||||
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
|
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
|
||||||
|
|
|
@ -18,7 +18,7 @@ import platform
|
||||||
|
|
||||||
from ._base import IncorrectDatabaseSetup
|
from ._base import IncorrectDatabaseSetup
|
||||||
from .postgres import PostgresEngine
|
from .postgres import PostgresEngine
|
||||||
from .sqlite3 import Sqlite3Engine
|
from .sqlite import Sqlite3Engine
|
||||||
|
|
||||||
SUPPORTED_MODULE = {
|
SUPPORTED_MODULE = {
|
||||||
"sqlite3": Sqlite3Engine,
|
"sqlite3": Sqlite3Engine,
|
||||||
|
|
|
@ -38,6 +38,13 @@ class PostgresEngine(object):
|
||||||
return sql.replace("?", "%s")
|
return sql.replace("?", "%s")
|
||||||
|
|
||||||
def on_new_connection(self, db_conn):
|
def on_new_connection(self, db_conn):
|
||||||
|
|
||||||
|
# Get the version of PostgreSQL that we're using. As per the psycopg2
|
||||||
|
# docs: The number is formed by converting the major, minor, and
|
||||||
|
# revision numbers into two-decimal-digit numbers and appending them
|
||||||
|
# together. For example, version 8.1.5 will be returned as 80105
|
||||||
|
self._version = db_conn.server_version
|
||||||
|
|
||||||
db_conn.set_isolation_level(
|
db_conn.set_isolation_level(
|
||||||
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||||
)
|
)
|
||||||
|
@ -54,6 +61,13 @@ class PostgresEngine(object):
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_native_upsert(self):
|
||||||
|
"""
|
||||||
|
Can we use native UPSERTs? This requires PostgreSQL 9.5+.
|
||||||
|
"""
|
||||||
|
return self._version >= 90500
|
||||||
|
|
||||||
def is_deadlock(self, error):
|
def is_deadlock(self, error):
|
||||||
if isinstance(error, self.module.DatabaseError):
|
if isinstance(error, self.module.DatabaseError):
|
||||||
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
|
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
import struct
|
import struct
|
||||||
import threading
|
import threading
|
||||||
|
from sqlite3 import sqlite_version_info
|
||||||
|
|
||||||
from synapse.storage.prepare_database import prepare_database
|
from synapse.storage.prepare_database import prepare_database
|
||||||
|
|
||||||
|
@ -30,6 +31,14 @@ class Sqlite3Engine(object):
|
||||||
self._current_state_group_id = None
|
self._current_state_group_id = None
|
||||||
self._current_state_group_id_lock = threading.Lock()
|
self._current_state_group_id_lock = threading.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_native_upsert(self):
|
||||||
|
"""
|
||||||
|
Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
|
||||||
|
more work we haven't done yet to tell what was inserted vs updated.
|
||||||
|
"""
|
||||||
|
return sqlite_version_info >= (3, 24, 0)
|
||||||
|
|
||||||
def check_database(self, txn):
|
def check_database(self, txn):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1268,6 +1268,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
|
||||||
event.internal_metadata.get_dict()
|
event.internal_metadata.get_dict()
|
||||||
),
|
),
|
||||||
"json": encode_json(event_dict(event)),
|
"json": encode_json(event_dict(event)),
|
||||||
|
"format_version": event.format_version,
|
||||||
}
|
}
|
||||||
for event, _ in events_and_contexts
|
for event, _ in events_and_contexts
|
||||||
],
|
],
|
||||||
|
|
|
@ -21,10 +21,10 @@ from canonicaljson import json
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.constants import EventFormatVersions
|
||||||
from synapse.api.errors import NotFoundError
|
from synapse.api.errors import NotFoundError
|
||||||
# these are only included to make the type annotations work
|
|
||||||
from synapse.events import EventBase # noqa: F401
|
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
|
# these are only included to make the type annotations work
|
||||||
from synapse.events.snapshot import EventContext # noqa: F401
|
from synapse.events.snapshot import EventContext # noqa: F401
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
@ -353,6 +353,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
self._get_event_from_row,
|
self._get_event_from_row,
|
||||||
row["internal_metadata"], row["json"], row["redacts"],
|
row["internal_metadata"], row["json"], row["redacts"],
|
||||||
rejected_reason=row["rejects"],
|
rejected_reason=row["rejects"],
|
||||||
|
format_version=row["format_version"],
|
||||||
)
|
)
|
||||||
for row in rows
|
for row in rows
|
||||||
],
|
],
|
||||||
|
@ -377,6 +378,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
" e.event_id as event_id, "
|
" e.event_id as event_id, "
|
||||||
" e.internal_metadata,"
|
" e.internal_metadata,"
|
||||||
" e.json,"
|
" e.json,"
|
||||||
|
" e.format_version, "
|
||||||
" r.redacts as redacts,"
|
" r.redacts as redacts,"
|
||||||
" rej.event_id as rejects "
|
" rej.event_id as rejects "
|
||||||
" FROM event_json as e"
|
" FROM event_json as e"
|
||||||
|
@ -392,7 +394,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_event_from_row(self, internal_metadata, js, redacted,
|
def _get_event_from_row(self, internal_metadata, js, redacted,
|
||||||
rejected_reason=None):
|
format_version, rejected_reason=None):
|
||||||
with Measure(self._clock, "_get_event_from_row"):
|
with Measure(self._clock, "_get_event_from_row"):
|
||||||
d = json.loads(js)
|
d = json.loads(js)
|
||||||
internal_metadata = json.loads(internal_metadata)
|
internal_metadata = json.loads(internal_metadata)
|
||||||
|
@ -405,8 +407,17 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
desc="_get_event_from_row_rejected_reason",
|
desc="_get_event_from_row_rejected_reason",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if format_version is None:
|
||||||
|
# This means that we stored the event before we had the concept
|
||||||
|
# of a event format version, so it must be a V1 event.
|
||||||
|
format_version = EventFormatVersions.V1
|
||||||
|
|
||||||
|
# TODO: When we implement new event formats we'll need to use a
|
||||||
|
# different event python type
|
||||||
|
assert format_version == EventFormatVersions.V1
|
||||||
|
|
||||||
original_ev = FrozenEvent(
|
original_ev = FrozenEvent(
|
||||||
d,
|
event_dict=d,
|
||||||
internal_metadata_dict=internal_metadata,
|
internal_metadata_dict=internal_metadata,
|
||||||
rejected_reason=rejected_reason,
|
rejected_reason=rejected_reason,
|
||||||
)
|
)
|
||||||
|
|
|
@ -215,7 +215,7 @@ class PusherStore(PusherWorkerStore):
|
||||||
with self._pushers_id_gen.get_next() as stream_id:
|
with self._pushers_id_gen.get_next() as stream_id:
|
||||||
# no need to lock because `pushers` has a unique key on
|
# no need to lock because `pushers` has a unique key on
|
||||||
# (app_id, pushkey, user_name) so _simple_upsert will retry
|
# (app_id, pushkey, user_name) so _simple_upsert will retry
|
||||||
newly_inserted = yield self._simple_upsert(
|
yield self._simple_upsert(
|
||||||
table="pushers",
|
table="pushers",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"app_id": app_id,
|
"app_id": app_id,
|
||||||
|
@ -238,7 +238,12 @@ class PusherStore(PusherWorkerStore):
|
||||||
lock=False,
|
lock=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if newly_inserted:
|
user_has_pusher = self.get_if_user_has_pusher.cache.get(
|
||||||
|
(user_id,), None, update_metrics=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_has_pusher is not True:
|
||||||
|
# invalidate, since we the user might not have had a pusher before
|
||||||
yield self.runInteraction(
|
yield self.runInteraction(
|
||||||
"add_pusher",
|
"add_pusher",
|
||||||
self._invalidate_cache_and_stream,
|
self._invalidate_cache_and_stream,
|
||||||
|
|
|
@ -588,12 +588,12 @@ class RoomMemberStore(RoomMemberWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
# We update the local_invites table only if the event is "current",
|
# We update the local_invites table only if the event is "current",
|
||||||
# i.e., its something that has just happened.
|
# i.e., its something that has just happened. If the event is an
|
||||||
# The only current event that can also be an outlier is if its an
|
# outlier it is only current if its an "out of band membership",
|
||||||
# invite that has come in across federation.
|
# like a remote invite or a rejection of a remote invite.
|
||||||
is_new_state = not backfilled and (
|
is_new_state = not backfilled and (
|
||||||
not event.internal_metadata.is_outlier()
|
not event.internal_metadata.is_outlier()
|
||||||
or event.internal_metadata.is_invite_from_remote()
|
or event.internal_metadata.is_out_of_band_membership()
|
||||||
)
|
)
|
||||||
is_mine = self.hs.is_mine_id(event.state_key)
|
is_mine = self.hs.is_mine_id(event.state_key)
|
||||||
if is_new_state and is_mine:
|
if is_new_state and is_mine:
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
/* Copyright 2019 New Vector Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
ALTER TABLE event_json ADD COLUMN format_version INTEGER;
|
|
@ -437,6 +437,30 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
create_event = yield self.get_event(create_id)
|
create_event = yield self.get_event(create_id)
|
||||||
defer.returnValue(create_event.content.get("room_version", "1"))
|
defer.returnValue(create_event.content.get("room_version", "1"))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_room_predecessor(self, room_id):
|
||||||
|
"""Get the predecessor room of an upgraded room if one exists.
|
||||||
|
Otherwise return None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[unicode|None]: predecessor room id
|
||||||
|
"""
|
||||||
|
state_ids = yield self.get_current_state_ids(room_id)
|
||||||
|
create_id = state_ids.get((EventTypes.Create, ""))
|
||||||
|
|
||||||
|
# If we can't find the create event, assume we've hit a dead end
|
||||||
|
if not create_id:
|
||||||
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
# Retrieve the room's create event
|
||||||
|
create_event = yield self.get_event(create_id)
|
||||||
|
|
||||||
|
# Return predecessor if present
|
||||||
|
defer.returnValue(create_event.content.get("predecessor", None))
|
||||||
|
|
||||||
@cached(max_entries=100000, iterable=True)
|
@cached(max_entries=100000, iterable=True)
|
||||||
def get_current_state_ids(self, room_id):
|
def get_current_state_ids(self, room_id):
|
||||||
"""Get the current state event ids for a room based on the
|
"""Get the current state event ids for a room based on the
|
||||||
|
|
|
@ -168,14 +168,14 @@ class UserDirectoryStore(SQLBaseStore):
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
# We weight the localpart most highly, then display name and finally
|
# We weight the localpart most highly, then display name and finally
|
||||||
# server name
|
# server name
|
||||||
if new_entry:
|
if self.database_engine.can_native_upsert:
|
||||||
sql = """
|
sql = """
|
||||||
INSERT INTO user_directory_search(user_id, vector)
|
INSERT INTO user_directory_search(user_id, vector)
|
||||||
VALUES (?,
|
VALUES (?,
|
||||||
setweight(to_tsvector('english', ?), 'A')
|
setweight(to_tsvector('english', ?), 'A')
|
||||||
|| setweight(to_tsvector('english', ?), 'D')
|
|| setweight(to_tsvector('english', ?), 'D')
|
||||||
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
|
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
|
||||||
)
|
) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
|
||||||
"""
|
"""
|
||||||
txn.execute(
|
txn.execute(
|
||||||
sql,
|
sql,
|
||||||
|
@ -185,20 +185,45 @@ class UserDirectoryStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sql = """
|
# TODO: Remove this code after we've bumped the minimum version
|
||||||
UPDATE user_directory_search
|
# of postgres to always support upserts, so we can get rid of
|
||||||
SET vector = setweight(to_tsvector('english', ?), 'A')
|
# `new_entry` usage
|
||||||
|| setweight(to_tsvector('english', ?), 'D')
|
if new_entry is True:
|
||||||
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
|
sql = """
|
||||||
WHERE user_id = ?
|
INSERT INTO user_directory_search(user_id, vector)
|
||||||
"""
|
VALUES (?,
|
||||||
txn.execute(
|
setweight(to_tsvector('english', ?), 'A')
|
||||||
sql,
|
|| setweight(to_tsvector('english', ?), 'D')
|
||||||
(
|
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
|
||||||
get_localpart_from_id(user_id), get_domain_from_id(user_id),
|
)
|
||||||
display_name, user_id,
|
"""
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(
|
||||||
|
user_id, get_localpart_from_id(user_id),
|
||||||
|
get_domain_from_id(user_id), display_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif new_entry is False:
|
||||||
|
sql = """
|
||||||
|
UPDATE user_directory_search
|
||||||
|
SET vector = setweight(to_tsvector('english', ?), 'A')
|
||||||
|
|| setweight(to_tsvector('english', ?), 'D')
|
||||||
|
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
|
||||||
|
WHERE user_id = ?
|
||||||
|
"""
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(
|
||||||
|
get_localpart_from_id(user_id),
|
||||||
|
get_domain_from_id(user_id),
|
||||||
|
display_name, user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"upsert returned None when 'can_native_upsert' is False"
|
||||||
)
|
)
|
||||||
)
|
|
||||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||||
value = "%s %s" % (user_id, display_name,) if display_name else user_id
|
value = "%s %s" % (user_id, display_name,) if display_name else user_id
|
||||||
self._simple_upsert_txn(
|
self._simple_upsert_txn(
|
||||||
|
|
|
@ -0,0 +1,315 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
|
import treq
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.protocol import Factory
|
||||||
|
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||||
|
from twisted.test.ssl_helpers import ServerTLSContext
|
||||||
|
from twisted.web.http import HTTPChannel
|
||||||
|
|
||||||
|
from synapse.crypto.context_factory import ClientTLSOptionsFactory
|
||||||
|
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||||
|
from synapse.http.federation.srv_resolver import Server
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
|
||||||
|
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||||
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixFederationAgentTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.reactor = ThreadedMemoryReactorClock()
|
||||||
|
|
||||||
|
self.mock_resolver = Mock()
|
||||||
|
|
||||||
|
self.agent = MatrixFederationAgent(
|
||||||
|
reactor=self.reactor,
|
||||||
|
tls_client_options_factory=ClientTLSOptionsFactory(None),
|
||||||
|
_srv_resolver=self.mock_resolver,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_connection(self, client_factory, expected_sni):
|
||||||
|
"""Builds a test server, and completes the outgoing client connection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTTPChannel: the test server
|
||||||
|
"""
|
||||||
|
|
||||||
|
# build the test server
|
||||||
|
server_tls_protocol = _build_test_server()
|
||||||
|
|
||||||
|
# now, tell the client protocol factory to build the client protocol (it will be a
|
||||||
|
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
|
||||||
|
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
|
||||||
|
# a FakeTransport.
|
||||||
|
#
|
||||||
|
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||||
|
# stubbing that out here.
|
||||||
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
|
client_protocol.makeConnection(FakeTransport(server_tls_protocol, self.reactor))
|
||||||
|
|
||||||
|
# tell the server tls protocol to send its stuff back to the client, too
|
||||||
|
server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
|
||||||
|
|
||||||
|
# give the reactor a pump to get the TLS juices flowing.
|
||||||
|
self.reactor.pump((0.1,))
|
||||||
|
|
||||||
|
# check the SNI
|
||||||
|
server_name = server_tls_protocol._tlsConnection.get_servername()
|
||||||
|
self.assertEqual(
|
||||||
|
server_name,
|
||||||
|
expected_sni,
|
||||||
|
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
# fish the test server back out of the server-side TLS protocol.
|
||||||
|
return server_tls_protocol.wrappedProtocol
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _make_get_request(self, uri):
|
||||||
|
"""
|
||||||
|
Sends a simple GET request via the agent, and checks its logcontext management
|
||||||
|
"""
|
||||||
|
with LoggingContext("one") as context:
|
||||||
|
fetch_d = self.agent.request(b'GET', uri)
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(fetch_d)
|
||||||
|
|
||||||
|
# should have reset logcontext to the sentinel
|
||||||
|
_check_logcontext(LoggingContext.sentinel)
|
||||||
|
|
||||||
|
try:
|
||||||
|
fetch_res = yield fetch_d
|
||||||
|
defer.returnValue(fetch_res)
|
||||||
|
finally:
|
||||||
|
_check_logcontext(context)
|
||||||
|
|
||||||
|
def test_get(self):
|
||||||
|
"""
|
||||||
|
happy-path test of a GET request with an explicit port
|
||||||
|
"""
|
||||||
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
|
test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
# Make sure treq is trying to connect
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, '1.2.3.4')
|
||||||
|
self.assertEqual(port, 8448)
|
||||||
|
|
||||||
|
# make a test server, and wire up the client
|
||||||
|
http_server = self._make_connection(
|
||||||
|
client_factory,
|
||||||
|
expected_sni=b"testserv",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b'GET')
|
||||||
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
|
content = request.content.read()
|
||||||
|
self.assertEqual(content, b'')
|
||||||
|
|
||||||
|
# Deferred is still without a result
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
# send the headers
|
||||||
|
request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json'])
|
||||||
|
request.write('')
|
||||||
|
|
||||||
|
self.reactor.pump((0.1,))
|
||||||
|
|
||||||
|
response = self.successResultOf(test_d)
|
||||||
|
|
||||||
|
# that should give us a Response object
|
||||||
|
self.assertEqual(response.code, 200)
|
||||||
|
|
||||||
|
# Send the body
|
||||||
|
request.write('{ "a": 1 }'.encode('ascii'))
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
self.reactor.pump((0.1,))
|
||||||
|
|
||||||
|
# check it can be read
|
||||||
|
json = self.successResultOf(treq.json_content(response))
|
||||||
|
self.assertEqual(json, {"a": 1})
|
||||||
|
|
||||||
|
def test_get_ip_address(self):
|
||||||
|
"""
|
||||||
|
Test the behaviour when the server name contains an explicit IP (with no port)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
|
||||||
|
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
||||||
|
|
||||||
|
# then there will be a getaddrinfo on the IP
|
||||||
|
self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
|
||||||
|
|
||||||
|
test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||||
|
b"_matrix._tcp.1.2.3.4",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure treq is trying to connect
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, '1.2.3.4')
|
||||||
|
self.assertEqual(port, 8448)
|
||||||
|
|
||||||
|
# make a test server, and wire up the client
|
||||||
|
http_server = self._make_connection(
|
||||||
|
client_factory,
|
||||||
|
expected_sni=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b'GET')
|
||||||
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
|
|
||||||
|
# finish the request
|
||||||
|
request.finish()
|
||||||
|
self.reactor.pump((0.1,))
|
||||||
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
|
def test_get_hostname_no_srv(self):
|
||||||
|
"""
|
||||||
|
Test the behaviour when the server name has no port, and no SRV record
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
||||||
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
|
|
||||||
|
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||||
|
b"_matrix._tcp.testserv",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure treq is trying to connect
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, '1.2.3.4')
|
||||||
|
self.assertEqual(port, 8448)
|
||||||
|
|
||||||
|
# make a test server, and wire up the client
|
||||||
|
http_server = self._make_connection(
|
||||||
|
client_factory,
|
||||||
|
expected_sni=b'testserv',
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b'GET')
|
||||||
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
|
|
||||||
|
# finish the request
|
||||||
|
request.finish()
|
||||||
|
self.reactor.pump((0.1,))
|
||||||
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
|
def test_get_hostname_srv(self):
|
||||||
|
"""
|
||||||
|
Test the behaviour when there is a single SRV record
|
||||||
|
"""
|
||||||
|
self.mock_resolver.resolve_service.side_effect = lambda _: [
|
||||||
|
Server(host="srvtarget", port=8443)
|
||||||
|
]
|
||||||
|
self.reactor.lookups["srvtarget"] = "1.2.3.4"
|
||||||
|
|
||||||
|
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||||
|
b"_matrix._tcp.testserv",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure treq is trying to connect
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, '1.2.3.4')
|
||||||
|
self.assertEqual(port, 8443)
|
||||||
|
|
||||||
|
# make a test server, and wire up the client
|
||||||
|
http_server = self._make_connection(
|
||||||
|
client_factory,
|
||||||
|
expected_sni=b'testserv',
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b'GET')
|
||||||
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
|
|
||||||
|
# finish the request
|
||||||
|
request.finish()
|
||||||
|
self.reactor.pump((0.1,))
|
||||||
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_logcontext(context):
|
||||||
|
current = LoggingContext.current_context()
|
||||||
|
if current is not context:
|
||||||
|
raise AssertionError(
|
||||||
|
"Expected logcontext %s but was %s" % (context, current),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_test_server():
|
||||||
|
"""Construct a test server
|
||||||
|
|
||||||
|
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TLSMemoryBIOProtocol
|
||||||
|
"""
|
||||||
|
server_factory = Factory.forProtocol(HTTPChannel)
|
||||||
|
# Request.finish expects the factory to have a 'log' method.
|
||||||
|
server_factory.log = _log_request
|
||||||
|
|
||||||
|
server_tls_factory = TLSMemoryBIOFactory(
|
||||||
|
ServerTLSContext(), isClient=False, wrappedFactory=server_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
return server_tls_factory.buildProtocol(None)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_request(request):
|
||||||
|
"""Implements Factory.log, which is expected by Request.finish"""
|
||||||
|
logger.info("Completed request %s", request)
|
|
@ -21,7 +21,7 @@ from twisted.internet.defer import Deferred
|
||||||
from twisted.internet.error import ConnectError
|
from twisted.internet.error import ConnectError
|
||||||
from twisted.names import dns, error
|
from twisted.names import dns, error
|
||||||
|
|
||||||
from synapse.http.federation.srv_resolver import resolve_service
|
from synapse.http.federation.srv_resolver import SrvResolver
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -43,13 +43,13 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
dns_client_mock.lookupService.return_value = result_deferred
|
dns_client_mock.lookupService.return_value = result_deferred
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_lookup():
|
def do_lookup():
|
||||||
|
|
||||||
with LoggingContext("one") as ctx:
|
with LoggingContext("one") as ctx:
|
||||||
resolve_d = resolve_service(
|
resolve_d = resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertNoResult(resolve_d)
|
self.assertNoResult(resolve_d)
|
||||||
|
|
||||||
|
@ -83,16 +83,15 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||||
|
|
||||||
service_name = "test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
entry = Mock(spec_set=["expires"])
|
entry = Mock(spec_set=["expires"])
|
||||||
entry.expires = 0
|
entry.expires = 0
|
||||||
|
|
||||||
cache = {service_name: [entry]}
|
cache = {service_name: [entry]}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
servers = yield resolve_service(
|
servers = yield resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
|
|
||||||
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
||||||
|
|
||||||
|
@ -106,17 +105,18 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
dns_client_mock = Mock(spec_set=['lookupService'])
|
dns_client_mock = Mock(spec_set=['lookupService'])
|
||||||
dns_client_mock.lookupService = Mock(spec_set=[])
|
dns_client_mock.lookupService = Mock(spec_set=[])
|
||||||
|
|
||||||
service_name = "test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
entry = Mock(spec_set=["expires"])
|
entry = Mock(spec_set=["expires"])
|
||||||
entry.expires = 999999999
|
entry.expires = 999999999
|
||||||
|
|
||||||
cache = {service_name: [entry]}
|
cache = {service_name: [entry]}
|
||||||
|
resolver = SrvResolver(
|
||||||
servers = yield resolve_service(
|
dns_client=dns_client_mock, cache=cache, get_time=clock.time,
|
||||||
service_name, dns_client=dns_client_mock, cache=cache, clock=clock
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
servers = yield resolver.resolve_service(service_name)
|
||||||
|
|
||||||
self.assertFalse(dns_client_mock.lookupService.called)
|
self.assertFalse(dns_client_mock.lookupService.called)
|
||||||
|
|
||||||
self.assertEquals(len(servers), 1)
|
self.assertEquals(len(servers), 1)
|
||||||
|
@ -128,12 +128,13 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
|
|
||||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||||
|
|
||||||
service_name = "test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
with self.assertRaises(error.DNSServerError):
|
with self.assertRaises(error.DNSServerError):
|
||||||
yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache)
|
yield resolver.resolve_service(service_name)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_name_error(self):
|
def test_name_error(self):
|
||||||
|
@ -141,13 +142,12 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
|
|
||||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
|
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
|
||||||
|
|
||||||
service_name = "test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
servers = yield resolve_service(
|
servers = yield resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEquals(len(servers), 0)
|
self.assertEquals(len(servers), 0)
|
||||||
self.assertEquals(len(cache), 0)
|
self.assertEquals(len(cache), 0)
|
||||||
|
@ -162,10 +162,9 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
dns_client_mock.lookupService.return_value = lookup_deferred
|
dns_client_mock.lookupService.return_value = lookup_deferred
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
resolve_d = resolve_service(
|
resolve_d = resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
self.assertNoResult(resolve_d)
|
self.assertNoResult(resolve_d)
|
||||||
|
|
||||||
# returning a single "." should make the lookup fail with a ConenctError
|
# returning a single "." should make the lookup fail with a ConenctError
|
||||||
|
@ -187,10 +186,9 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
dns_client_mock.lookupService.return_value = lookup_deferred
|
dns_client_mock.lookupService.return_value = lookup_deferred
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
resolve_d = resolve_service(
|
resolve_d = resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
self.assertNoResult(resolve_d)
|
self.assertNoResult(resolve_d)
|
||||||
|
|
||||||
lookup_deferred.callback((
|
lookup_deferred.callback((
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
from twisted.internet.defer import TimeoutError
|
from twisted.internet.defer import TimeoutError
|
||||||
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
|
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
|
||||||
from twisted.test.proto_helpers import StringTransport
|
from twisted.test.proto_helpers import StringTransport
|
||||||
|
@ -26,11 +27,20 @@ from synapse.http.matrixfederationclient import (
|
||||||
MatrixFederationHttpClient,
|
MatrixFederationHttpClient,
|
||||||
MatrixFederationRequest,
|
MatrixFederationRequest,
|
||||||
)
|
)
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
|
||||||
from tests.server import FakeTransport
|
from tests.server import FakeTransport
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
|
def check_logcontext(context):
|
||||||
|
current = LoggingContext.current_context()
|
||||||
|
if current is not context:
|
||||||
|
raise AssertionError(
|
||||||
|
"Expected logcontext %s but was %s" % (context, current),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FederationClientTests(HomeserverTestCase):
|
class FederationClientTests(HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
||||||
|
@ -43,6 +53,70 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
self.cl = MatrixFederationHttpClient(self.hs)
|
self.cl = MatrixFederationHttpClient(self.hs)
|
||||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
|
|
||||||
|
def test_client_get(self):
|
||||||
|
"""
|
||||||
|
happy-path test of a GET request
|
||||||
|
"""
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def do_request():
|
||||||
|
with LoggingContext("one") as context:
|
||||||
|
fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(fetch_d)
|
||||||
|
|
||||||
|
# should have reset logcontext to the sentinel
|
||||||
|
check_logcontext(LoggingContext.sentinel)
|
||||||
|
|
||||||
|
try:
|
||||||
|
fetch_res = yield fetch_d
|
||||||
|
defer.returnValue(fetch_res)
|
||||||
|
finally:
|
||||||
|
check_logcontext(context)
|
||||||
|
|
||||||
|
test_d = do_request()
|
||||||
|
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
# Make sure treq is trying to connect
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, '1.2.3.4')
|
||||||
|
self.assertEqual(port, 8008)
|
||||||
|
|
||||||
|
# complete the connection and wire it up to a fake transport
|
||||||
|
protocol = factory.buildProtocol(None)
|
||||||
|
transport = StringTransport()
|
||||||
|
protocol.makeConnection(transport)
|
||||||
|
|
||||||
|
# that should have made it send the request to the transport
|
||||||
|
self.assertRegex(transport.value(), b"^GET /foo/bar")
|
||||||
|
|
||||||
|
# Deferred is still without a result
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
# Send it the HTTP response
|
||||||
|
res_json = '{ "a": 1 }'.encode('ascii')
|
||||||
|
protocol.dataReceived(
|
||||||
|
b"HTTP/1.1 200 OK\r\n"
|
||||||
|
b"Server: Fake\r\n"
|
||||||
|
b"Content-Type: application/json\r\n"
|
||||||
|
b"Content-Length: %i\r\n"
|
||||||
|
b"\r\n"
|
||||||
|
b"%s" % (len(res_json), res_json)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
res = self.successResultOf(test_d)
|
||||||
|
|
||||||
|
# check the response is as expected
|
||||||
|
self.assertEqual(res, {"a": 1})
|
||||||
|
|
||||||
def test_dns_error(self):
|
def test_dns_error(self):
|
||||||
"""
|
"""
|
||||||
If the DNS lookup returns an error, it will bubble up.
|
If the DNS lookup returns an error, it will bubble up.
|
||||||
|
@ -54,6 +128,28 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
self.assertIsInstance(f.value, RequestSendFailed)
|
self.assertIsInstance(f.value, RequestSendFailed)
|
||||||
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
|
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
|
||||||
|
|
||||||
|
def test_client_connection_refused(self):
|
||||||
|
d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
|
||||||
|
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(d)
|
||||||
|
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, '1.2.3.4')
|
||||||
|
self.assertEqual(port, 8008)
|
||||||
|
e = Exception("go away")
|
||||||
|
factory.clientConnectionFailed(None, e)
|
||||||
|
self.pump(0.5)
|
||||||
|
|
||||||
|
f = self.failureResultOf(d)
|
||||||
|
|
||||||
|
self.assertIsInstance(f.value, RequestSendFailed)
|
||||||
|
self.assertIs(f.value.inner_exception, e)
|
||||||
|
|
||||||
def test_client_never_connect(self):
|
def test_client_never_connect(self):
|
||||||
"""
|
"""
|
||||||
If the HTTP request is not connected and is timed out, it'll give a
|
If the HTTP request is not connected and is timed out, it'll give a
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from six import text_type
|
from six import text_type
|
||||||
|
@ -22,6 +23,8 @@ from synapse.util import Clock
|
||||||
|
|
||||||
from tests.utils import setup_test_homeserver as _sth
|
from tests.utils import setup_test_homeserver as _sth
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TimedOutException(Exception):
|
class TimedOutException(Exception):
|
||||||
"""
|
"""
|
||||||
|
@ -339,7 +342,7 @@ def get_clock():
|
||||||
return (clock, hs_clock)
|
return (clock, hs_clock)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s(cmp=False)
|
||||||
class FakeTransport(object):
|
class FakeTransport(object):
|
||||||
"""
|
"""
|
||||||
A twisted.internet.interfaces.ITransport implementation which sends all its data
|
A twisted.internet.interfaces.ITransport implementation which sends all its data
|
||||||
|
@ -414,6 +417,11 @@ class FakeTransport(object):
|
||||||
self.buffer = self.buffer + byt
|
self.buffer = self.buffer + byt
|
||||||
|
|
||||||
def _write():
|
def _write():
|
||||||
|
if not self.buffer:
|
||||||
|
# nothing to do. Don't write empty buffers: it upsets the
|
||||||
|
# TLSMemoryBIOProtocol
|
||||||
|
return
|
||||||
|
|
||||||
if getattr(self.other, "transport") is not None:
|
if getattr(self.other, "transport") is not None:
|
||||||
self.other.dataReceived(self.buffer)
|
self.other.dataReceived(self.buffer)
|
||||||
self.buffer = b""
|
self.buffer = b""
|
||||||
|
@ -421,7 +429,10 @@ class FakeTransport(object):
|
||||||
|
|
||||||
self._reactor.callLater(0.0, _write)
|
self._reactor.callLater(0.0, _write)
|
||||||
|
|
||||||
_write()
|
# always actually do the write asynchronously. Some protocols (notably the
|
||||||
|
# TLSMemoryBIOProtocol) get very confused if a read comes back while they are
|
||||||
|
# still doing a write. Doing a callLater here breaks the cycle.
|
||||||
|
self._reactor.callLater(0.0, _write)
|
||||||
|
|
||||||
def writeSequence(self, seq):
|
def writeSequence(self, seq):
|
||||||
for x in seq:
|
for x in seq:
|
||||||
|
|
|
@ -49,6 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
self.db_pool.runWithConnection = runWithConnection
|
self.db_pool.runWithConnection = runWithConnection
|
||||||
|
|
||||||
config = Mock()
|
config = Mock()
|
||||||
|
config._enable_native_upserts = False
|
||||||
config.event_cache_size = 1
|
config.event_cache_size = 1
|
||||||
config.database_config = {"name": "sqlite3"}
|
config.database_config = {"name": "sqlite3"}
|
||||||
hs = TestHomeServer(
|
hs = TestHomeServer(
|
||||||
|
|
|
@ -19,7 +19,7 @@ from six import StringIO
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
|
from twisted.test.proto_helpers import AccumulatingProtocol
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.web.server import NOT_DONE_YET
|
from twisted.web.server import NOT_DONE_YET
|
||||||
|
|
||||||
|
@ -30,12 +30,18 @@ from synapse.util import Clock
|
||||||
from synapse.util.logcontext import make_deferred_yieldable
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import FakeTransport, make_request, render, setup_test_homeserver
|
from tests.server import (
|
||||||
|
FakeTransport,
|
||||||
|
ThreadedMemoryReactorClock,
|
||||||
|
make_request,
|
||||||
|
render,
|
||||||
|
setup_test_homeserver,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JsonResourceTests(unittest.TestCase):
|
class JsonResourceTests(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.reactor = MemoryReactorClock()
|
self.reactor = ThreadedMemoryReactorClock()
|
||||||
self.hs_clock = Clock(self.reactor)
|
self.hs_clock = Clock(self.reactor)
|
||||||
self.homeserver = setup_test_homeserver(
|
self.homeserver = setup_test_homeserver(
|
||||||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
|
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
|
||||||
|
|
|
@ -96,7 +96,7 @@ class TestCase(unittest.TestCase):
|
||||||
|
|
||||||
method = getattr(self, methodName)
|
method = getattr(self, methodName)
|
||||||
|
|
||||||
level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR))
|
level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING))
|
||||||
|
|
||||||
@around(self)
|
@around(self)
|
||||||
def setUp(orig):
|
def setUp(orig):
|
||||||
|
@ -333,7 +333,15 @@ class HomeserverTestCase(TestCase):
|
||||||
"""
|
"""
|
||||||
kwargs = dict(kwargs)
|
kwargs = dict(kwargs)
|
||||||
kwargs.update(self._hs_args)
|
kwargs.update(self._hs_args)
|
||||||
return setup_test_homeserver(self.addCleanup, *args, **kwargs)
|
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
|
||||||
|
stor = hs.get_datastore()
|
||||||
|
|
||||||
|
# Run the database background updates.
|
||||||
|
if hasattr(stor, "do_next_background_update"):
|
||||||
|
while not self.get_success(stor.has_completed_background_updates()):
|
||||||
|
self.get_success(stor.do_next_background_update(1))
|
||||||
|
|
||||||
|
return hs
|
||||||
|
|
||||||
def pump(self, by=0.0):
|
def pump(self, by=0.0):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -154,7 +154,9 @@ def default_config(name):
|
||||||
config.update_user_directory = False
|
config.update_user_directory = False
|
||||||
|
|
||||||
def is_threepid_reserved(threepid):
|
def is_threepid_reserved(threepid):
|
||||||
return ServerConfig.is_threepid_reserved(config, threepid)
|
return ServerConfig.is_threepid_reserved(
|
||||||
|
config.mau_limits_reserved_threepids, threepid
|
||||||
|
)
|
||||||
|
|
||||||
config.is_threepid_reserved.side_effect = is_threepid_reserved
|
config.is_threepid_reserved.side_effect = is_threepid_reserved
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue