Merge branch 'develop' into travis/admin-list-media
This commit is contained in:
commit
6e87b34f7b
|
@ -0,0 +1,133 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Moves a list of remote media from one media store to another.
|
||||||
|
|
||||||
|
The input should be a list of media files to be moved, one per line. Each line
|
||||||
|
should be formatted::
|
||||||
|
|
||||||
|
<origin server>|<file id>
|
||||||
|
|
||||||
|
This can be extracted from postgres with::
|
||||||
|
|
||||||
|
psql --tuples-only -A -c "select media_origin, filesystem_id from
|
||||||
|
matrix.remote_media_cache where ..."
|
||||||
|
|
||||||
|
To use, pipe the above into::
|
||||||
|
|
||||||
|
PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from synapse.rest.media.v1.filepath import MediaFilePaths
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def main(src_repo, dest_repo):
|
||||||
|
src_paths = MediaFilePaths(src_repo)
|
||||||
|
dest_paths = MediaFilePaths(dest_repo)
|
||||||
|
for line in sys.stdin:
|
||||||
|
line = line.strip()
|
||||||
|
parts = line.split('|')
|
||||||
|
if len(parts) != 2:
|
||||||
|
print("Unable to parse input line %s" % line, file=sys.stderr)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
move_media(parts[0], parts[1], src_paths, dest_paths)
|
||||||
|
|
||||||
|
|
||||||
|
def move_media(origin_server, file_id, src_paths, dest_paths):
|
||||||
|
"""Move the given file, and any thumbnails, to the dest repo
|
||||||
|
|
||||||
|
Args:
|
||||||
|
origin_server (str):
|
||||||
|
file_id (str):
|
||||||
|
src_paths (MediaFilePaths):
|
||||||
|
dest_paths (MediaFilePaths):
|
||||||
|
"""
|
||||||
|
logger.info("%s/%s", origin_server, file_id)
|
||||||
|
|
||||||
|
# check that the original exists
|
||||||
|
original_file = src_paths.remote_media_filepath(origin_server, file_id)
|
||||||
|
if not os.path.exists(original_file):
|
||||||
|
logger.warn(
|
||||||
|
"Original for %s/%s (%s) does not exist",
|
||||||
|
origin_server, file_id, original_file,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mkdir_and_move(
|
||||||
|
original_file,
|
||||||
|
dest_paths.remote_media_filepath(origin_server, file_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# now look for thumbnails
|
||||||
|
original_thumb_dir = src_paths.remote_media_thumbnail_dir(
|
||||||
|
origin_server, file_id,
|
||||||
|
)
|
||||||
|
if not os.path.exists(original_thumb_dir):
|
||||||
|
return
|
||||||
|
|
||||||
|
mkdir_and_move(
|
||||||
|
original_thumb_dir,
|
||||||
|
dest_paths.remote_media_thumbnail_dir(origin_server, file_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir_and_move(original_file, dest_file):
|
||||||
|
dirname = os.path.dirname(dest_file)
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
logger.debug("mkdir %s", dirname)
|
||||||
|
os.makedirs(dirname)
|
||||||
|
logger.debug("mv %s %s", original_file, dest_file)
|
||||||
|
shutil.move(original_file, dest_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=__doc__,
|
||||||
|
formatter_class = argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-v", action='store_true', help='enable debug logging')
|
||||||
|
parser.add_argument(
|
||||||
|
"src_repo",
|
||||||
|
help="Path to source content repo",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"dest_repo",
|
||||||
|
help="Path to source content repo",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging_config = {
|
||||||
|
"level": logging.DEBUG if args.v else logging.INFO,
|
||||||
|
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
|
||||||
|
}
|
||||||
|
logging.basicConfig(**logging_config)
|
||||||
|
|
||||||
|
main(args.src_repo, args.dest_repo)
|
|
@ -46,6 +46,7 @@ class Codes(object):
|
||||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||||
THREEPID_IN_USE = "M_THREEPID_IN_USE"
|
THREEPID_IN_USE = "M_THREEPID_IN_USE"
|
||||||
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
|
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
|
||||||
|
THREEPID_DENIED = "M_THREEPID_DENIED"
|
||||||
INVALID_USERNAME = "M_INVALID_USERNAME"
|
INVALID_USERNAME = "M_INVALID_USERNAME"
|
||||||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||||
|
|
||||||
|
@ -140,6 +141,32 @@ class RegistrationError(SynapseError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FederationDeniedError(SynapseError):
|
||||||
|
"""An error raised when the server tries to federate with a server which
|
||||||
|
is not on its federation whitelist.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
destination (str): The destination which has been denied
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, destination):
|
||||||
|
"""Raised by federation client or server to indicate that we are
|
||||||
|
are deliberately not attempting to contact a given server because it is
|
||||||
|
not on our federation whitelist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str): the domain in question
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.destination = destination
|
||||||
|
|
||||||
|
super(FederationDeniedError, self).__init__(
|
||||||
|
code=403,
|
||||||
|
msg="Federation denied with %s." % (self.destination,),
|
||||||
|
errcode=Codes.FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InteractiveAuthIncompleteError(Exception):
|
class InteractiveAuthIncompleteError(Exception):
|
||||||
"""An error raised when UI auth is not yet complete
|
"""An error raised when UI auth is not yet complete
|
||||||
|
|
||||||
|
|
|
@ -49,19 +49,6 @@ class AppserviceSlaveStore(
|
||||||
|
|
||||||
|
|
||||||
class AppserviceServer(HomeServer):
|
class AppserviceServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
|
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
|
||||||
|
|
|
@ -64,19 +64,6 @@ class ClientReaderSlavedStore(
|
||||||
|
|
||||||
|
|
||||||
class ClientReaderServer(HomeServer):
|
class ClientReaderServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
|
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
|
||||||
|
|
|
@ -58,19 +58,6 @@ class FederationReaderSlavedStore(
|
||||||
|
|
||||||
|
|
||||||
class FederationReaderServer(HomeServer):
|
class FederationReaderServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
|
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
|
||||||
|
|
|
@ -76,19 +76,6 @@ class FederationSenderSlaveStore(
|
||||||
|
|
||||||
|
|
||||||
class FederationSenderServer(HomeServer):
|
class FederationSenderServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
|
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
|
||||||
|
|
|
@ -118,19 +118,6 @@ class FrontendProxySlavedStore(
|
||||||
|
|
||||||
|
|
||||||
class FrontendProxyServer(HomeServer):
|
class FrontendProxyServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
|
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
|
||||||
|
|
|
@ -266,19 +266,6 @@ class SynapseHomeServer(HomeServer):
|
||||||
except IncorrectDatabaseSetup as e:
|
except IncorrectDatabaseSetup as e:
|
||||||
quit_with_error(e.message)
|
quit_with_error(e.message)
|
||||||
|
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
|
|
||||||
def setup(config_options):
|
def setup(config_options):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -60,19 +60,6 @@ class MediaRepositorySlavedStore(
|
||||||
|
|
||||||
|
|
||||||
class MediaRepositoryServer(HomeServer):
|
class MediaRepositoryServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
|
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
|
||||||
|
|
|
@ -81,19 +81,6 @@ class PusherSlaveStore(
|
||||||
|
|
||||||
|
|
||||||
class PusherServer(HomeServer):
|
class PusherServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
|
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
|
||||||
|
|
|
@ -246,19 +246,6 @@ class SynchrotronApplicationService(object):
|
||||||
|
|
||||||
|
|
||||||
class SynchrotronServer(HomeServer):
|
class SynchrotronServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
|
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
|
||||||
|
|
|
@ -92,19 +92,6 @@ class UserDirectorySlaveStore(
|
||||||
|
|
||||||
|
|
||||||
class UserDirectoryServer(HomeServer):
|
class UserDirectoryServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
|
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
|
||||||
|
|
|
@ -31,6 +31,8 @@ class RegistrationConfig(Config):
|
||||||
strtobool(str(config["disable_registration"]))
|
strtobool(str(config["disable_registration"]))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
||||||
|
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
||||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||||
|
|
||||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||||
|
@ -52,6 +54,23 @@ class RegistrationConfig(Config):
|
||||||
# Enable registration for new users.
|
# Enable registration for new users.
|
||||||
enable_registration: False
|
enable_registration: False
|
||||||
|
|
||||||
|
# The user must provide all of the below types of 3PID when registering.
|
||||||
|
#
|
||||||
|
# registrations_require_3pid:
|
||||||
|
# - email
|
||||||
|
# - msisdn
|
||||||
|
|
||||||
|
# Mandate that users are only allowed to associate certain formats of
|
||||||
|
# 3PIDs with accounts on this server.
|
||||||
|
#
|
||||||
|
# allowed_local_3pids:
|
||||||
|
# - medium: email
|
||||||
|
# pattern: ".*@matrix\\.org"
|
||||||
|
# - medium: email
|
||||||
|
# pattern: ".*@vector\\.im"
|
||||||
|
# - medium: msisdn
|
||||||
|
# pattern: "\\+44"
|
||||||
|
|
||||||
# If set, allows registration by anyone who also has the shared
|
# If set, allows registration by anyone who also has the shared
|
||||||
# secret, even if registration is otherwise disabled.
|
# secret, even if registration is otherwise disabled.
|
||||||
registration_shared_secret: "%(registration_shared_secret)s"
|
registration_shared_secret: "%(registration_shared_secret)s"
|
||||||
|
|
|
@ -55,6 +55,17 @@ class ServerConfig(Config):
|
||||||
"block_non_admin_invites", False,
|
"block_non_admin_invites", False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FIXME: federation_domain_whitelist needs sytests
|
||||||
|
self.federation_domain_whitelist = None
|
||||||
|
federation_domain_whitelist = config.get(
|
||||||
|
"federation_domain_whitelist", None
|
||||||
|
)
|
||||||
|
# turn the whitelist into a hash for speed of lookup
|
||||||
|
if federation_domain_whitelist is not None:
|
||||||
|
self.federation_domain_whitelist = {}
|
||||||
|
for domain in federation_domain_whitelist:
|
||||||
|
self.federation_domain_whitelist[domain] = True
|
||||||
|
|
||||||
if self.public_baseurl is not None:
|
if self.public_baseurl is not None:
|
||||||
if self.public_baseurl[-1] != '/':
|
if self.public_baseurl[-1] != '/':
|
||||||
self.public_baseurl += '/'
|
self.public_baseurl += '/'
|
||||||
|
@ -210,6 +221,17 @@ class ServerConfig(Config):
|
||||||
# (except those sent by local server admins). The default is False.
|
# (except those sent by local server admins). The default is False.
|
||||||
# block_non_admin_invites: True
|
# block_non_admin_invites: True
|
||||||
|
|
||||||
|
# Restrict federation to the following whitelist of domains.
|
||||||
|
# N.B. we recommend also firewalling your federation listener to limit
|
||||||
|
# inbound federation traffic as early as possible, rather than relying
|
||||||
|
# purely on this application-layer restriction. If not specified, the
|
||||||
|
# default is to whitelist everything.
|
||||||
|
#
|
||||||
|
# federation_domain_whitelist:
|
||||||
|
# - lon.example.com
|
||||||
|
# - nyc.example.com
|
||||||
|
# - syd.example.com
|
||||||
|
|
||||||
# List of ports that Synapse should listen on, their purpose and their
|
# List of ports that Synapse should listen on, their purpose and their
|
||||||
# configuration.
|
# configuration.
|
||||||
listeners:
|
listeners:
|
||||||
|
|
|
@ -23,7 +23,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException, HttpResponseException, SynapseError,
|
CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
|
||||||
)
|
)
|
||||||
from synapse.events import builder
|
from synapse.events import builder
|
||||||
from synapse.federation.federation_base import (
|
from synapse.federation.federation_base import (
|
||||||
|
@ -266,6 +266,9 @@ class FederationClient(FederationBase):
|
||||||
except NotRetryingDestination as e:
|
except NotRetryingDestination as e:
|
||||||
logger.info(e.message)
|
logger.info(e.message)
|
||||||
continue
|
continue
|
||||||
|
except FederationDeniedError as e:
|
||||||
|
logger.info(e.message)
|
||||||
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pdu_attempts[destination] = now
|
pdu_attempts[destination] = now
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ from twisted.internet import defer
|
||||||
from .persistence import TransactionActions
|
from .persistence import TransactionActions
|
||||||
from .units import Transaction, Edu
|
from .units import Transaction, Edu
|
||||||
|
|
||||||
from synapse.api.errors import HttpResponseException
|
from synapse.api.errors import HttpResponseException, FederationDeniedError
|
||||||
from synapse.util import logcontext, PreserveLoggingContext
|
from synapse.util import logcontext, PreserveLoggingContext
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||||
|
@ -490,6 +490,8 @@ class TransactionQueue(object):
|
||||||
(e.retry_last_ts + e.retry_interval) / 1000.0
|
(e.retry_last_ts + e.retry_interval) / 1000.0
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
except FederationDeniedError as e:
|
||||||
|
logger.info(e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"TX [%s] Failed to send transaction: %s",
|
"TX [%s] Failed to send transaction: %s",
|
||||||
|
|
|
@ -212,6 +212,9 @@ class TransportLayerClient(object):
|
||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if the remote destination
|
||||||
|
is not in our federation whitelist
|
||||||
"""
|
"""
|
||||||
valid_memberships = {Membership.JOIN, Membership.LEAVE}
|
valid_memberships = {Membership.JOIN, Membership.LEAVE}
|
||||||
if membership not in valid_memberships:
|
if membership not in valid_memberships:
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError, FederationDeniedError
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
|
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
|
||||||
|
@ -81,6 +81,7 @@ class Authenticator(object):
|
||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||||
|
|
||||||
# A method just so we can pass 'self' as the authenticator to the Servlets
|
# A method just so we can pass 'self' as the authenticator to the Servlets
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -92,6 +93,12 @@ class Authenticator(object):
|
||||||
"signatures": {},
|
"signatures": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.federation_domain_whitelist is not None and
|
||||||
|
self.server_name not in self.federation_domain_whitelist
|
||||||
|
):
|
||||||
|
raise FederationDeniedError(self.server_name)
|
||||||
|
|
||||||
if content is not None:
|
if content is not None:
|
||||||
json_request["content"] = content
|
json_request["content"] = content
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from synapse.api import errors
|
from synapse.api import errors
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
from synapse.api.errors import FederationDeniedError
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
@ -513,6 +514,9 @@ class DeviceListEduUpdater(object):
|
||||||
# This makes it more likely that the device lists will
|
# This makes it more likely that the device lists will
|
||||||
# eventually become consistent.
|
# eventually become consistent.
|
||||||
return
|
return
|
||||||
|
except FederationDeniedError as e:
|
||||||
|
logger.info(e)
|
||||||
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
# TODO: Remember that we are now out of sync and try again
|
# TODO: Remember that we are now out of sync and try again
|
||||||
# later
|
# later
|
||||||
|
|
|
@ -19,7 +19,9 @@ import logging
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, CodeMessageException
|
from synapse.api.errors import (
|
||||||
|
SynapseError, CodeMessageException, FederationDeniedError,
|
||||||
|
)
|
||||||
from synapse.types import get_domain_from_id, UserID
|
from synapse.types import get_domain_from_id, UserID
|
||||||
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
|
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
@ -140,6 +142,10 @@ class E2eKeysHandler(object):
|
||||||
failures[destination] = {
|
failures[destination] = {
|
||||||
"status": 503, "message": "Not ready for retry",
|
"status": 503, "message": "Not ready for retry",
|
||||||
}
|
}
|
||||||
|
except FederationDeniedError as e:
|
||||||
|
failures[destination] = {
|
||||||
|
"status": 403, "message": "Federation Denied",
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# include ConnectionRefused and other errors
|
# include ConnectionRefused and other errors
|
||||||
failures[destination] = {
|
failures[destination] = {
|
||||||
|
|
|
@ -22,6 +22,7 @@ from ._base import BaseHandler
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
|
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
|
||||||
|
FederationDeniedError,
|
||||||
)
|
)
|
||||||
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
|
@ -782,6 +783,9 @@ class FederationHandler(BaseHandler):
|
||||||
except NotRetryingDestination as e:
|
except NotRetryingDestination as e:
|
||||||
logger.info(e.message)
|
logger.info(e.message)
|
||||||
continue
|
continue
|
||||||
|
except FederationDeniedError as e:
|
||||||
|
logger.info(e)
|
||||||
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to backfill from %s because %s",
|
"Failed to backfill from %s because %s",
|
||||||
|
|
|
@ -25,6 +25,7 @@ from synapse.http.client import CaptchaServerHttpClient
|
||||||
from synapse import types
|
from synapse import types
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
from synapse.util.threepids import check_3pid_allowed
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -293,7 +294,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for c in threepidCreds:
|
for c in threepidCreds:
|
||||||
logger.info("validating theeepidcred sid %s on id server %s",
|
logger.info("validating threepidcred sid %s on id server %s",
|
||||||
c['sid'], c['idServer'])
|
c['sid'], c['idServer'])
|
||||||
try:
|
try:
|
||||||
identity_handler = self.hs.get_handlers().identity_handler
|
identity_handler = self.hs.get_handlers().identity_handler
|
||||||
|
@ -307,6 +308,11 @@ class RegistrationHandler(BaseHandler):
|
||||||
logger.info("got threepid with medium '%s' and address '%s'",
|
logger.info("got threepid with medium '%s' and address '%s'",
|
||||||
threepid['medium'], threepid['address'])
|
threepid['medium'], threepid['address'])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
|
||||||
|
raise RegistrationError(
|
||||||
|
403, "Third party identifier is not allowed"
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def bind_emails(self, user_id, threepidCreds):
|
def bind_emails(self, user_id, threepidCreds):
|
||||||
"""Links emails with a user ID and informs an identity server.
|
"""Links emails with a user ID and informs an identity server.
|
||||||
|
|
|
@ -203,7 +203,8 @@ class RoomListHandler(BaseHandler):
|
||||||
if limit:
|
if limit:
|
||||||
step = limit + 1
|
step = limit + 1
|
||||||
else:
|
else:
|
||||||
step = len(rooms_to_scan)
|
# step cannot be zero
|
||||||
|
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
|
||||||
|
|
||||||
chunk = []
|
chunk = []
|
||||||
for i in xrange(0, len(rooms_to_scan), step):
|
for i in xrange(0, len(rooms_to_scan), step):
|
||||||
|
|
|
@ -18,6 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
|
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
|
||||||
)
|
)
|
||||||
|
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||||
from synapse.util.logcontext import make_deferred_yieldable
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
from synapse.util import logcontext
|
from synapse.util import logcontext
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
@ -30,6 +31,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||||
from twisted.web.client import (
|
from twisted.web.client import (
|
||||||
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
|
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
|
||||||
readBody, PartialDownloadError,
|
readBody, PartialDownloadError,
|
||||||
|
HTTPConnectionPool,
|
||||||
)
|
)
|
||||||
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
|
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
|
||||||
from twisted.web.http import PotentialDataLoss
|
from twisted.web.http import PotentialDataLoss
|
||||||
|
@ -64,13 +66,23 @@ class SimpleHttpClient(object):
|
||||||
"""
|
"""
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
|
pool = HTTPConnectionPool(reactor)
|
||||||
|
|
||||||
|
# the pusher makes lots of concurrent SSL connections to sygnal, and
|
||||||
|
# tends to do so in batches, so we need to allow the pool to keep lots
|
||||||
|
# of idle connections around.
|
||||||
|
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
|
||||||
|
pool.cachedConnectionTimeout = 2 * 60
|
||||||
|
|
||||||
# The default context factory in Twisted 14.0.0 (which we require) is
|
# The default context factory in Twisted 14.0.0 (which we require) is
|
||||||
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
||||||
# 'like a browser'
|
# 'like a browser'
|
||||||
self.agent = Agent(
|
self.agent = Agent(
|
||||||
reactor,
|
reactor,
|
||||||
connectTimeout=15,
|
connectTimeout=15,
|
||||||
contextFactory=hs.get_http_client_context_factory()
|
contextFactory=hs.get_http_client_context_factory(),
|
||||||
|
pool=pool,
|
||||||
)
|
)
|
||||||
self.user_agent = hs.version_string
|
self.user_agent = hs.version_string
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
|
@ -357,8 +357,7 @@ def _get_hosts_for_srv_record(dns_client, host):
|
||||||
def eb(res, record_type):
|
def eb(res, record_type):
|
||||||
if res.check(DNSNameError):
|
if res.check(DNSNameError):
|
||||||
return []
|
return []
|
||||||
logger.warn("Error looking up %s for %s: %s",
|
logger.warn("Error looking up %s for %s: %s", record_type, host, res)
|
||||||
record_type, host, res, res.value)
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
# no logcontexts here, so we can safely fire these off and gatherResults
|
# no logcontexts here, so we can safely fire these off and gatherResults
|
||||||
|
|
|
@ -27,7 +27,7 @@ import synapse.metrics
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
SynapseError, Codes, HttpResponseException,
|
SynapseError, Codes, HttpResponseException, FederationDeniedError,
|
||||||
)
|
)
|
||||||
|
|
||||||
from signedjson.sign import sign_json
|
from signedjson.sign import sign_json
|
||||||
|
@ -123,11 +123,22 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
Fails with ``HTTPRequestException``: if we get an HTTP response
|
Fails with ``HTTPRequestException``: if we get an HTTP response
|
||||||
code >= 300.
|
code >= 300.
|
||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if this destination
|
||||||
|
is not on our federation whitelist
|
||||||
|
|
||||||
(May also fail with plenty of other Exceptions for things like DNS
|
(May also fail with plenty of other Exceptions for things like DNS
|
||||||
failures, connection failures, SSL failures.)
|
failures, connection failures, SSL failures.)
|
||||||
"""
|
"""
|
||||||
|
if (
|
||||||
|
self.hs.config.federation_domain_whitelist and
|
||||||
|
destination not in self.hs.config.federation_domain_whitelist
|
||||||
|
):
|
||||||
|
raise FederationDeniedError(destination)
|
||||||
|
|
||||||
limiter = yield synapse.util.retryutils.get_retry_limiter(
|
limiter = yield synapse.util.retryutils.get_retry_limiter(
|
||||||
destination,
|
destination,
|
||||||
self.clock,
|
self.clock,
|
||||||
|
@ -308,6 +319,9 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if this destination
|
||||||
|
is not on our federation whitelist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not json_data_callback:
|
if not json_data_callback:
|
||||||
|
@ -368,6 +382,9 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if this destination
|
||||||
|
is not on our federation whitelist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def body_callback(method, url_bytes, headers_dict):
|
def body_callback(method, url_bytes, headers_dict):
|
||||||
|
@ -422,6 +439,9 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if this destination
|
||||||
|
is not on our federation whitelist
|
||||||
"""
|
"""
|
||||||
logger.debug("get_json args: %s", args)
|
logger.debug("get_json args: %s", args)
|
||||||
|
|
||||||
|
@ -475,6 +495,9 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if this destination
|
||||||
|
is not on our federation whitelist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response = yield self._request(
|
response = yield self._request(
|
||||||
|
@ -518,6 +541,9 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if this destination
|
||||||
|
is not on our federation whitelist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
encoded_args = {}
|
encoded_args = {}
|
||||||
|
|
|
@ -146,10 +146,15 @@ def runUntilCurrentTimer(func):
|
||||||
num_pending += 1
|
num_pending += 1
|
||||||
|
|
||||||
num_pending += len(reactor.threadCallQueue)
|
num_pending += len(reactor.threadCallQueue)
|
||||||
|
|
||||||
start = time.time() * 1000
|
start = time.time() * 1000
|
||||||
ret = func(*args, **kwargs)
|
ret = func(*args, **kwargs)
|
||||||
end = time.time() * 1000
|
end = time.time() * 1000
|
||||||
|
|
||||||
|
# record the amount of wallclock time spent running pending calls.
|
||||||
|
# This is a proxy for the actual amount of time between reactor polls,
|
||||||
|
# since about 25% of time is actually spent running things triggered by
|
||||||
|
# I/O events, but that is harder to capture without rewriting half the
|
||||||
|
# reactor.
|
||||||
tick_time.inc_by(end - start)
|
tick_time.inc_by(end - start)
|
||||||
pending_calls_metric.inc_by(num_pending)
|
pending_calls_metric.inc_by(num_pending)
|
||||||
|
|
||||||
|
|
|
@ -13,21 +13,30 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
from synapse.push import PusherConfigException
|
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||||
|
|
||||||
import logging
|
|
||||||
import push_rule_evaluator
|
import push_rule_evaluator
|
||||||
import push_tools
|
import push_tools
|
||||||
|
import synapse
|
||||||
|
from synapse.push import PusherConfigException
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||||
|
|
||||||
|
http_push_processed_counter = metrics.register_counter(
|
||||||
|
"http_pushes_processed",
|
||||||
|
)
|
||||||
|
|
||||||
|
http_push_failed_counter = metrics.register_counter(
|
||||||
|
"http_pushes_failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HttpPusher(object):
|
class HttpPusher(object):
|
||||||
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
|
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
|
||||||
|
@ -152,9 +161,16 @@ class HttpPusher(object):
|
||||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Processing %i unprocessed push actions for %s starting at "
|
||||||
|
"stream_ordering %s",
|
||||||
|
len(unprocessed), self.name, self.last_stream_ordering,
|
||||||
|
)
|
||||||
|
|
||||||
for push_action in unprocessed:
|
for push_action in unprocessed:
|
||||||
processed = yield self._process_one(push_action)
|
processed = yield self._process_one(push_action)
|
||||||
if processed:
|
if processed:
|
||||||
|
http_push_processed_counter.inc()
|
||||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||||
self.last_stream_ordering = push_action['stream_ordering']
|
self.last_stream_ordering = push_action['stream_ordering']
|
||||||
yield self.store.update_pusher_last_stream_ordering_and_success(
|
yield self.store.update_pusher_last_stream_ordering_and_success(
|
||||||
|
@ -169,6 +185,7 @@ class HttpPusher(object):
|
||||||
self.failing_since
|
self.failing_since
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
http_push_failed_counter.inc()
|
||||||
if not self.failing_since:
|
if not self.failing_since:
|
||||||
self.failing_since = self.clock.time_msec()
|
self.failing_since = self.clock.time_msec()
|
||||||
yield self.store.update_pusher_failing_since(
|
yield self.store.update_pusher_failing_since(
|
||||||
|
@ -316,7 +333,10 @@ class HttpPusher(object):
|
||||||
try:
|
try:
|
||||||
resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
|
resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warn("Failed to push %s ", self.url)
|
logger.warn(
|
||||||
|
"Failed to push event %s to %s",
|
||||||
|
event.event_id, self.name, exc_info=True,
|
||||||
|
)
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
rejected = []
|
rejected = []
|
||||||
if 'rejected' in resp:
|
if 'rejected' in resp:
|
||||||
|
@ -325,7 +345,7 @@ class HttpPusher(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _send_badge(self, badge):
|
def _send_badge(self, badge):
|
||||||
logger.info("Sending updated badge count %d to %r", badge, self.user_id)
|
logger.info("Sending updated badge count %d to %s", badge, self.name)
|
||||||
d = {
|
d = {
|
||||||
'notification': {
|
'notification': {
|
||||||
'id': '',
|
'id': '',
|
||||||
|
@ -347,7 +367,10 @@ class HttpPusher(object):
|
||||||
try:
|
try:
|
||||||
resp = yield self.http_client.post_json_get_json(self.url, d)
|
resp = yield self.http_client.post_json_get_json(self.url, d)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to push %s ", self.url)
|
logger.warn(
|
||||||
|
"Failed to send badge count to %s",
|
||||||
|
self.name, exc_info=True,
|
||||||
|
)
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
rejected = []
|
rejected = []
|
||||||
if 'rejected' in resp:
|
if 'rejected' in resp:
|
||||||
|
|
|
@ -70,10 +70,15 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
|
|
||||||
|
require_email = 'email' in self.hs.config.registrations_require_3pid
|
||||||
|
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
||||||
|
|
||||||
|
flows = []
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
return (
|
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||||
200,
|
if not require_msisdn:
|
||||||
{"flows": [
|
flows.extend([
|
||||||
{
|
{
|
||||||
"type": LoginType.RECAPTCHA,
|
"type": LoginType.RECAPTCHA,
|
||||||
"stages": [
|
"stages": [
|
||||||
|
@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
LoginType.PASSWORD
|
LoginType.PASSWORD
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
])
|
||||||
|
# only support 3PIDless registration if no 3PIDs are required
|
||||||
|
if not require_email and not require_msisdn:
|
||||||
|
flows.extend([
|
||||||
{
|
{
|
||||||
"type": LoginType.RECAPTCHA,
|
"type": LoginType.RECAPTCHA,
|
||||||
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
|
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
|
||||||
}
|
}
|
||||||
]}
|
])
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return (
|
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||||
200,
|
if require_email or not require_msisdn:
|
||||||
{"flows": [
|
flows.extend([
|
||||||
{
|
{
|
||||||
"type": LoginType.EMAIL_IDENTITY,
|
"type": LoginType.EMAIL_IDENTITY,
|
||||||
"stages": [
|
"stages": [
|
||||||
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
|
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
|
||||||
]
|
]
|
||||||
},
|
}
|
||||||
|
])
|
||||||
|
# only support 3PIDless registration if no 3PIDs are required
|
||||||
|
if not require_email and not require_msisdn:
|
||||||
|
flows.extend([
|
||||||
{
|
{
|
||||||
"type": LoginType.PASSWORD
|
"type": LoginType.PASSWORD
|
||||||
}
|
}
|
||||||
]}
|
])
|
||||||
)
|
return (200, {"flows": flows})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
|
|
|
@ -195,15 +195,20 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
||||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
msg_handler = self.handlers.message_handler
|
event_dict = {
|
||||||
event = yield msg_handler.create_and_send_nonmember_event(
|
|
||||||
requester,
|
|
||||||
{
|
|
||||||
"type": event_type,
|
"type": event_type,
|
||||||
"content": content,
|
"content": content,
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"sender": requester.user.to_string(),
|
"sender": requester.user.to_string(),
|
||||||
},
|
}
|
||||||
|
|
||||||
|
if 'ts' in request.args and requester.app_service:
|
||||||
|
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
|
||||||
|
|
||||||
|
msg_handler = self.handlers.message_handler
|
||||||
|
event = yield msg_handler.create_and_send_nonmember_event(
|
||||||
|
requester,
|
||||||
|
event_dict,
|
||||||
txn_id=txn_id,
|
txn_id=txn_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ from synapse.http.servlet import (
|
||||||
)
|
)
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
from synapse.util.threepids import check_3pid_allowed
|
||||||
from ._base import client_v2_patterns, interactive_auth_handler
|
from ._base import client_v2_patterns, interactive_auth_handler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -47,6 +48,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||||
])
|
])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
'email', body['email']
|
'email', body['email']
|
||||||
)
|
)
|
||||||
|
@ -78,6 +84,11 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
|
||||||
|
|
||||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||||
'msisdn', msisdn
|
'msisdn', msisdn
|
||||||
)
|
)
|
||||||
|
@ -217,6 +228,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||||
if absent:
|
if absent:
|
||||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||||
'email', body['email']
|
'email', body['email']
|
||||||
)
|
)
|
||||||
|
@ -255,6 +271,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||||
|
|
||||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||||
'msisdn', msisdn
|
'msisdn', msisdn
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,6 +26,7 @@ from synapse.http.servlet import (
|
||||||
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
|
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
|
||||||
)
|
)
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
from synapse.util.threepids import check_3pid_allowed
|
||||||
|
|
||||||
from ._base import client_v2_patterns, interactive_auth_handler
|
from ._base import client_v2_patterns, interactive_auth_handler
|
||||||
|
|
||||||
|
@ -70,6 +71,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||||
])
|
])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
'email', body['email']
|
'email', body['email']
|
||||||
)
|
)
|
||||||
|
@ -105,6 +111,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||||
|
|
||||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
'msisdn', msisdn
|
'msisdn', msisdn
|
||||||
)
|
)
|
||||||
|
@ -305,31 +316,67 @@ class RegisterRestServlet(RestServlet):
|
||||||
if 'x_show_msisdn' in body and body['x_show_msisdn']:
|
if 'x_show_msisdn' in body and body['x_show_msisdn']:
|
||||||
show_msisdn = True
|
show_msisdn = True
|
||||||
|
|
||||||
|
# FIXME: need a better error than "no auth flow found" for scenarios
|
||||||
|
# where we required 3PID for registration but the user didn't give one
|
||||||
|
require_email = 'email' in self.hs.config.registrations_require_3pid
|
||||||
|
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
||||||
|
|
||||||
|
flows = []
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
flows = [
|
# only support 3PIDless registration if no 3PIDs are required
|
||||||
[LoginType.RECAPTCHA],
|
if not require_email and not require_msisdn:
|
||||||
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
flows.extend([[LoginType.RECAPTCHA]])
|
||||||
]
|
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||||
|
if not require_msisdn:
|
||||||
|
flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
|
||||||
|
|
||||||
if show_msisdn:
|
if show_msisdn:
|
||||||
|
# only support the MSISDN-only flow if we don't require email 3PIDs
|
||||||
|
if not require_email:
|
||||||
|
flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
|
||||||
|
# always let users provide both MSISDN & email
|
||||||
flows.extend([
|
flows.extend([
|
||||||
[LoginType.MSISDN, LoginType.RECAPTCHA],
|
|
||||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
flows = [
|
# only support 3PIDless registration if no 3PIDs are required
|
||||||
[LoginType.DUMMY],
|
if not require_email and not require_msisdn:
|
||||||
[LoginType.EMAIL_IDENTITY],
|
flows.extend([[LoginType.DUMMY]])
|
||||||
]
|
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||||
|
if not require_msisdn:
|
||||||
|
flows.extend([[LoginType.EMAIL_IDENTITY]])
|
||||||
|
|
||||||
if show_msisdn:
|
if show_msisdn:
|
||||||
|
# only support the MSISDN-only flow if we don't require email 3PIDs
|
||||||
|
if not require_email or require_msisdn:
|
||||||
|
flows.extend([[LoginType.MSISDN]])
|
||||||
|
# always let users provide both MSISDN & email
|
||||||
flows.extend([
|
flows.extend([
|
||||||
[LoginType.MSISDN],
|
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
|
||||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
|
|
||||||
])
|
])
|
||||||
|
|
||||||
auth_result, params, session_id = yield self.auth_handler.check_auth(
|
auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||||
flows, body, self.hs.get_ip_from_request(request)
|
flows, body, self.hs.get_ip_from_request(request)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check that we're not trying to register a denied 3pid.
|
||||||
|
#
|
||||||
|
# the user-facing checks will probably already have happened in
|
||||||
|
# /register/email/requestToken when we requested a 3pid, but that's not
|
||||||
|
# guaranteed.
|
||||||
|
|
||||||
|
if auth_result:
|
||||||
|
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
|
||||||
|
if login_type in auth_result:
|
||||||
|
medium = auth_result[login_type]['medium']
|
||||||
|
address = auth_result[login_type]['address']
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, medium, address):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed",
|
||||||
|
Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
if registered_user_id is not None:
|
if registered_user_id is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Already registered user ID %r for this session",
|
"Already registered user ID %r for this session",
|
||||||
|
|
|
@ -93,6 +93,7 @@ class RemoteKey(Resource):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
self.async_render_GET(request)
|
self.async_render_GET(request)
|
||||||
|
@ -137,6 +138,13 @@ class RemoteKey(Resource):
|
||||||
logger.info("Handling query for keys %r", query)
|
logger.info("Handling query for keys %r", query)
|
||||||
store_queries = []
|
store_queries = []
|
||||||
for server_name, key_ids in query.items():
|
for server_name, key_ids in query.items():
|
||||||
|
if (
|
||||||
|
self.federation_domain_whitelist is not None and
|
||||||
|
server_name not in self.federation_domain_whitelist
|
||||||
|
):
|
||||||
|
logger.debug("Federation denied with %s", server_name)
|
||||||
|
continue
|
||||||
|
|
||||||
if not key_ids:
|
if not key_ids:
|
||||||
key_ids = (None,)
|
key_ids = (None,)
|
||||||
for key_id in key_ids:
|
for key_id in key_ids:
|
||||||
|
|
|
@ -32,8 +32,9 @@ from .media_storage import MediaStorage
|
||||||
|
|
||||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
from synapse.api.errors import SynapseError, HttpResponseException, \
|
from synapse.api.errors import (
|
||||||
NotFoundError
|
SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
|
||||||
|
)
|
||||||
|
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.stringutils import is_ascii
|
from synapse.util.stringutils import is_ascii
|
||||||
|
@ -75,6 +76,8 @@ class MediaRepository(object):
|
||||||
self.recently_accessed_remotes = set()
|
self.recently_accessed_remotes = set()
|
||||||
self.recently_accessed_locals = set()
|
self.recently_accessed_locals = set()
|
||||||
|
|
||||||
|
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||||
|
|
||||||
# List of StorageProviders where we should search for media and
|
# List of StorageProviders where we should search for media and
|
||||||
# potentially upload to.
|
# potentially upload to.
|
||||||
storage_providers = []
|
storage_providers = []
|
||||||
|
@ -216,6 +219,12 @@ class MediaRepository(object):
|
||||||
Deferred: Resolves once a response has successfully been written
|
Deferred: Resolves once a response has successfully been written
|
||||||
to request
|
to request
|
||||||
"""
|
"""
|
||||||
|
if (
|
||||||
|
self.federation_domain_whitelist is not None and
|
||||||
|
server_name not in self.federation_domain_whitelist
|
||||||
|
):
|
||||||
|
raise FederationDeniedError(server_name)
|
||||||
|
|
||||||
self.mark_recently_accessed(server_name, media_id)
|
self.mark_recently_accessed(server_name, media_id)
|
||||||
|
|
||||||
# We linearize here to ensure that we don't try and download remote
|
# We linearize here to ensure that we don't try and download remote
|
||||||
|
@ -250,6 +259,12 @@ class MediaRepository(object):
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict]: The media_info of the file
|
Deferred[dict]: The media_info of the file
|
||||||
"""
|
"""
|
||||||
|
if (
|
||||||
|
self.federation_domain_whitelist is not None and
|
||||||
|
server_name not in self.federation_domain_whitelist
|
||||||
|
):
|
||||||
|
raise FederationDeniedError(server_name)
|
||||||
|
|
||||||
# We linearize here to ensure that we don't try and download remote
|
# We linearize here to ensure that we don't try and download remote
|
||||||
# media multiple times concurrently
|
# media multiple times concurrently
|
||||||
key = (server_name, media_id)
|
key = (server_name, media_id)
|
||||||
|
|
|
@ -307,6 +307,23 @@ class HomeServer(object):
|
||||||
**self.db_config.get("args", {})
|
**self.db_config.get("args", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_db_conn(self, run_new_connection=True):
|
||||||
|
"""Makes a new connection to the database, skipping the db pool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Connection: a connection object implementing the PEP-249 spec
|
||||||
|
"""
|
||||||
|
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||||
|
# not be passed to the database engine.
|
||||||
|
db_params = {
|
||||||
|
k: v for k, v in self.db_config.get("args", {}).items()
|
||||||
|
if not k.startswith("cp_")
|
||||||
|
}
|
||||||
|
db_conn = self.database_engine.module.connect(**db_params)
|
||||||
|
if run_new_connection:
|
||||||
|
self.database_engine.on_new_connection(db_conn)
|
||||||
|
return db_conn
|
||||||
|
|
||||||
def build_media_repository_resource(self):
|
def build_media_repository_resource(self):
|
||||||
# build the media repo resource. This indirects through the HomeServer
|
# build the media repo resource. This indirects through the HomeServer
|
||||||
# to ensure that we only have a single instance of
|
# to ensure that we only have a single instance of
|
||||||
|
|
|
@ -146,8 +146,20 @@ class StateHandler(object):
|
||||||
defer.returnValue(state)
|
defer.returnValue(state)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_current_state_ids(self, room_id, event_type=None, state_key="",
|
def get_current_state_ids(self, room_id, latest_event_ids=None):
|
||||||
latest_event_ids=None):
|
"""Get the current state, or the state at a set of events, for a room
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str):
|
||||||
|
|
||||||
|
latest_event_ids (iterable[str]|None): if given, the forward
|
||||||
|
extremities to resolve. If None, we look them up from the
|
||||||
|
database (via a cache)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[dict[(str, str), str)]]: the state dict, mapping from
|
||||||
|
(event_type, state_key) -> event_id
|
||||||
|
"""
|
||||||
if not latest_event_ids:
|
if not latest_event_ids:
|
||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
|
|
||||||
|
@ -155,10 +167,6 @@ class StateHandler(object):
|
||||||
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
|
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||||
state = ret.state
|
state = ret.state
|
||||||
|
|
||||||
if event_type:
|
|
||||||
defer.returnValue(state.get((event_type, state_key)))
|
|
||||||
return
|
|
||||||
|
|
||||||
defer.returnValue(state)
|
defer.returnValue(state)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -110,7 +110,7 @@ class _EventPeristenceQueue(object):
|
||||||
end_item.events_and_contexts.extend(events_and_contexts)
|
end_item.events_and_contexts.extend(events_and_contexts)
|
||||||
return end_item.deferred.observe()
|
return end_item.deferred.observe()
|
||||||
|
|
||||||
deferred = ObservableDeferred(defer.Deferred())
|
deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
|
||||||
|
|
||||||
queue.append(self._EventPersistQueueItem(
|
queue.append(self._EventPersistQueueItem(
|
||||||
events_and_contexts=events_and_contexts,
|
events_and_contexts=events_and_contexts,
|
||||||
|
@ -152,8 +152,8 @@ class _EventPeristenceQueue(object):
|
||||||
try:
|
try:
|
||||||
ret = yield per_item_callback(item)
|
ret = yield per_item_callback(item)
|
||||||
item.deferred.callback(ret)
|
item.deferred.callback(ret)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
item.deferred.errback(e)
|
item.deferred.errback()
|
||||||
finally:
|
finally:
|
||||||
queue = self._event_persist_queues.pop(room_id, None)
|
queue = self._event_persist_queues.pop(room_id, None)
|
||||||
if queue:
|
if queue:
|
||||||
|
|
|
@ -577,7 +577,7 @@ class RoomStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
UPDATE remote_media_cache
|
UPDATE remote_media_cache
|
||||||
SET quarantined_by = ?
|
SET quarantined_by = ?
|
||||||
WHERE media_origin AND media_id = ?
|
WHERE media_origin = ? AND media_id = ?
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
(quarantined_by, origin, media_id)
|
(quarantined_by, origin, media_id)
|
||||||
|
|
|
@ -641,13 +641,12 @@ class UserDirectoryStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.hs.config.user_directory_search_all_users:
|
if self.hs.config.user_directory_search_all_users:
|
||||||
# dummy to keep the number of binds & aliases the same
|
# make s.user_id null to keep the ordering algorithm happy
|
||||||
join_clause = """
|
join_clause = """
|
||||||
LEFT JOIN (
|
CROSS JOIN (SELECT NULL as user_id) AS s
|
||||||
SELECT NULL as user_id WHERE NULL = ?
|
|
||||||
) AS s USING (user_id)"
|
|
||||||
"""
|
"""
|
||||||
where_clause = ""
|
join_args = ()
|
||||||
|
where_clause = "1=1"
|
||||||
else:
|
else:
|
||||||
join_clause = """
|
join_clause = """
|
||||||
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
||||||
|
@ -656,6 +655,7 @@ class UserDirectoryStore(SQLBaseStore):
|
||||||
WHERE user_id = ? AND share_private
|
WHERE user_id = ? AND share_private
|
||||||
) AS s USING (user_id)
|
) AS s USING (user_id)
|
||||||
"""
|
"""
|
||||||
|
join_args = (user_id,)
|
||||||
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
|
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
@ -697,7 +697,7 @@ class UserDirectoryStore(SQLBaseStore):
|
||||||
avatar_url IS NULL
|
avatar_url IS NULL
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
""" % (join_clause, where_clause)
|
""" % (join_clause, where_clause)
|
||||||
args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
|
args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
|
||||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||||
search_query = _parse_query_sqlite(search_term)
|
search_query = _parse_query_sqlite(search_term)
|
||||||
|
|
||||||
|
@ -715,7 +715,7 @@ class UserDirectoryStore(SQLBaseStore):
|
||||||
avatar_url IS NULL
|
avatar_url IS NULL
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
""" % (join_clause, where_clause)
|
""" % (join_clause, where_clause)
|
||||||
args = (user_id, search_query, limit + 1)
|
args = join_args + (search_query, limit + 1)
|
||||||
else:
|
else:
|
||||||
# This should be unreachable.
|
# This should be unreachable.
|
||||||
raise Exception("Unrecognized database engine")
|
raise Exception("Unrecognized database engine")
|
||||||
|
|
|
@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class NotRetryingDestination(Exception):
|
class NotRetryingDestination(Exception):
|
||||||
def __init__(self, retry_last_ts, retry_interval, destination):
|
def __init__(self, retry_last_ts, retry_interval, destination):
|
||||||
|
"""Raised by the limiter (and federation client) to indicate that we are
|
||||||
|
are deliberately not attempting to contact a given server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retry_last_ts (int): the unix ts in milliseconds of our last attempt
|
||||||
|
to contact the server. 0 indicates that the last attempt was
|
||||||
|
successful or that we've never actually attempted to connect.
|
||||||
|
retry_interval (int): the time in milliseconds to wait until the next
|
||||||
|
attempt.
|
||||||
|
destination (str): the domain in question
|
||||||
|
"""
|
||||||
|
|
||||||
msg = "Not retrying server %s." % (destination,)
|
msg = "Not retrying server %s." % (destination,)
|
||||||
super(NotRetryingDestination, self).__init__(msg)
|
super(NotRetryingDestination, self).__init__(msg)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_3pid_allowed(hs, medium, address):
|
||||||
|
"""Checks whether a given format of 3PID is allowed to be used on this HS
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
medium (str): 3pid medium - e.g. email, msisdn
|
||||||
|
address (str): address within that medium (e.g. "wotan@matrix.org")
|
||||||
|
msisdns need to first have been canonicalised
|
||||||
|
Returns:
|
||||||
|
bool: whether the 3PID medium/address is allowed to be added to this HS
|
||||||
|
"""
|
||||||
|
|
||||||
|
if hs.config.allowed_local_3pids:
|
||||||
|
for constraint in hs.config.allowed_local_3pids:
|
||||||
|
logger.debug(
|
||||||
|
"Checking 3PID %s (%s) against %s (%s)",
|
||||||
|
address, medium, constraint['pattern'], constraint['medium'],
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
medium == constraint['medium'] and
|
||||||
|
re.match(constraint['pattern'], address)
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
|
@ -167,7 +167,7 @@ class KeyringTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# wait a tick for it to send the request to the perspectives server
|
# wait a tick for it to send the request to the perspectives server
|
||||||
# (it first tries the datastore)
|
# (it first tries the datastore)
|
||||||
yield async.sleep(0.005)
|
yield async.sleep(1) # XXX find out why this takes so long!
|
||||||
self.http_client.post_json.assert_called_once()
|
self.http_client.post_json.assert_called_once()
|
||||||
|
|
||||||
self.assertIs(LoggingContext.current_context(), context_11)
|
self.assertIs(LoggingContext.current_context(), context_11)
|
||||||
|
@ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase):
|
||||||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1)],
|
[("server10", json1)],
|
||||||
)
|
)
|
||||||
yield async.sleep(0.005)
|
yield async.sleep(01)
|
||||||
self.http_client.post_json.assert_not_called()
|
self.http_client.post_json.assert_not_called()
|
||||||
res_deferreds_2[0].addBoth(self.check_context, None)
|
res_deferreds_2[0].addBoth(self.check_context, None)
|
||||||
|
|
||||||
|
|
|
@ -143,7 +143,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||||
except errors.SynapseError:
|
except errors.SynapseError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.DEBUG
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_claim_one_time_key(self):
|
def test_claim_one_time_key(self):
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock, NonCallableMock
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||||
|
@ -41,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
|
||||||
self.event_id = 0
|
self.event_id = 0
|
||||||
|
|
||||||
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
||||||
listener = reactor.listenUNIX("\0xxx", server_factory)
|
# XXX: mktemp is unsafe and should never be used. but we're just a test.
|
||||||
|
path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
|
||||||
|
listener = reactor.listenUNIX(path, server_factory)
|
||||||
self.addCleanup(listener.stopListening)
|
self.addCleanup(listener.stopListening)
|
||||||
self.streamer = server_factory.streamer
|
self.streamer = server_factory.streamer
|
||||||
|
|
||||||
|
@ -49,7 +53,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
|
||||||
client_factory = ReplicationClientFactory(
|
client_factory = ReplicationClientFactory(
|
||||||
self.hs, "client_name", self.replication_handler
|
self.hs, "client_name", self.replication_handler
|
||||||
)
|
)
|
||||||
client_connector = reactor.connectUNIX("\0xxx", client_factory)
|
client_connector = reactor.connectUNIX(path, client_factory)
|
||||||
self.addCleanup(client_factory.stopTrying)
|
self.addCleanup(client_factory.stopTrying)
|
||||||
self.addCleanup(client_connector.disconnect)
|
self.addCleanup(client_connector.disconnect)
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||||
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
||||||
self.hs.config.enable_registration = True
|
self.hs.config.enable_registration = True
|
||||||
|
self.hs.config.registrations_require_3pid = []
|
||||||
self.hs.config.auto_join_rooms = []
|
self.hs.config.auto_join_rooms = []
|
||||||
|
|
||||||
# init the thing we're testing
|
# init the thing we're testing
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.storage import UserDirectoryStore
|
||||||
|
from synapse.storage.roommember import ProfileInfo
|
||||||
|
from tests import unittest
|
||||||
|
from tests.utils import setup_test_homeserver
|
||||||
|
|
||||||
|
ALICE = "@alice:a"
|
||||||
|
BOB = "@bob:b"
|
||||||
|
BOBBY = "@bobby:a"
|
||||||
|
|
||||||
|
|
||||||
|
class UserDirectoryStoreTestCase(unittest.TestCase):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def setUp(self):
|
||||||
|
self.hs = yield setup_test_homeserver()
|
||||||
|
self.store = UserDirectoryStore(None, self.hs)
|
||||||
|
|
||||||
|
# alice and bob are both in !room_id. bobby is not but shares
|
||||||
|
# a homeserver with alice.
|
||||||
|
yield self.store.add_profiles_to_user_dir(
|
||||||
|
"!room:id",
|
||||||
|
{
|
||||||
|
ALICE: ProfileInfo(None, "alice"),
|
||||||
|
BOB: ProfileInfo(None, "bob"),
|
||||||
|
BOBBY: ProfileInfo(None, "bobby")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield self.store.add_users_to_public_room(
|
||||||
|
"!room:id",
|
||||||
|
[ALICE, BOB],
|
||||||
|
)
|
||||||
|
yield self.store.add_users_who_share_room(
|
||||||
|
"!room:id",
|
||||||
|
False,
|
||||||
|
(
|
||||||
|
(ALICE, BOB),
|
||||||
|
(BOB, ALICE),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_search_user_dir(self):
|
||||||
|
# normally when alice searches the directory she should just find
|
||||||
|
# bob because bobby doesn't share a room with her.
|
||||||
|
r = yield self.store.search_user_dir(ALICE, "bob", 10)
|
||||||
|
self.assertFalse(r["limited"])
|
||||||
|
self.assertEqual(1, len(r["results"]))
|
||||||
|
self.assertDictEqual(r["results"][0], {
|
||||||
|
"user_id": BOB,
|
||||||
|
"display_name": "bob",
|
||||||
|
"avatar_url": None,
|
||||||
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_search_user_dir_all_users(self):
|
||||||
|
self.hs.config.user_directory_search_all_users = True
|
||||||
|
try:
|
||||||
|
r = yield self.store.search_user_dir(ALICE, "bob", 10)
|
||||||
|
self.assertFalse(r["limited"])
|
||||||
|
self.assertEqual(2, len(r["results"]))
|
||||||
|
self.assertDictEqual(r["results"][0], {
|
||||||
|
"user_id": BOB,
|
||||||
|
"display_name": "bob",
|
||||||
|
"avatar_url": None,
|
||||||
|
})
|
||||||
|
self.assertDictEqual(r["results"][1], {
|
||||||
|
"user_id": BOBBY,
|
||||||
|
"display_name": "bobby",
|
||||||
|
"avatar_url": None,
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
self.hs.config.user_directory_search_all_users = False
|
245
tests/utils.py
245
tests/utils.py
|
@ -13,27 +13,28 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.http.server import HttpServer
|
|
||||||
from synapse.api.errors import cs_error, CodeMessageException, StoreError
|
|
||||||
from synapse.api.constants import EventTypes
|
|
||||||
from synapse.storage.prepare_database import prepare_database
|
|
||||||
from synapse.storage.engines import create_engine
|
|
||||||
from synapse.server import HomeServer
|
|
||||||
from synapse.federation.transport import server
|
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
|
||||||
|
|
||||||
from synapse.util.logcontext import LoggingContext
|
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
|
||||||
from twisted.enterprise.adbapi import ConnectionPool
|
|
||||||
|
|
||||||
from collections import namedtuple
|
|
||||||
from mock import patch, Mock
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from inspect import getcallargs
|
||||||
import urllib
|
import urllib
|
||||||
import urlparse
|
import urlparse
|
||||||
|
|
||||||
from inspect import getcallargs
|
from mock import Mock, patch
|
||||||
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
|
from synapse.api.errors import CodeMessageException, cs_error
|
||||||
|
from synapse.federation.transport import server
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage import PostgresEngine
|
||||||
|
from synapse.storage.engines import create_engine
|
||||||
|
from synapse.storage.prepare_database import prepare_database
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
|
|
||||||
|
# set this to True to run the tests against postgres instead of sqlite.
|
||||||
|
# It requires you to have a local postgres database called synapse_test, within
|
||||||
|
# which ALL TABLES WILL BE DROPPED
|
||||||
|
USE_POSTGRES_FOR_TESTS = False
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -57,36 +58,70 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||||
config.worker_app = None
|
config.worker_app = None
|
||||||
config.email_enable_notifs = False
|
config.email_enable_notifs = False
|
||||||
config.block_non_admin_invites = False
|
config.block_non_admin_invites = False
|
||||||
|
config.federation_domain_whitelist = None
|
||||||
|
config.user_directory_search_all_users = False
|
||||||
|
|
||||||
# disable user directory updates, because they get done in the
|
# disable user directory updates, because they get done in the
|
||||||
# background, which upsets the test runner.
|
# background, which upsets the test runner.
|
||||||
config.update_user_directory = False
|
config.update_user_directory = False
|
||||||
|
|
||||||
config.use_frozen_dicts = True
|
config.use_frozen_dicts = True
|
||||||
config.database_config = {"name": "sqlite3"}
|
|
||||||
config.ldap_enabled = False
|
config.ldap_enabled = False
|
||||||
|
|
||||||
if "clock" not in kargs:
|
if "clock" not in kargs:
|
||||||
kargs["clock"] = MockClock()
|
kargs["clock"] = MockClock()
|
||||||
|
|
||||||
|
if USE_POSTGRES_FOR_TESTS:
|
||||||
|
config.database_config = {
|
||||||
|
"name": "psycopg2",
|
||||||
|
"args": {
|
||||||
|
"database": "synapse_test",
|
||||||
|
"cp_min": 1,
|
||||||
|
"cp_max": 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config.database_config = {
|
||||||
|
"name": "sqlite3",
|
||||||
|
"args": {
|
||||||
|
"database": ":memory:",
|
||||||
|
"cp_min": 1,
|
||||||
|
"cp_max": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
db_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
|
# we need to configure the connection pool to run the on_new_connection
|
||||||
|
# function, so that we can test code that uses custom sqlite functions
|
||||||
|
# (like rank).
|
||||||
|
config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
|
||||||
|
|
||||||
if datastore is None:
|
if datastore is None:
|
||||||
db_pool = SQLiteMemoryDbPool()
|
|
||||||
yield db_pool.prepare()
|
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
name, db_pool=db_pool, config=config,
|
name, config=config,
|
||||||
|
db_config=config.database_config,
|
||||||
version_string="Synapse/tests",
|
version_string="Synapse/tests",
|
||||||
database_engine=create_engine(config.database_config),
|
database_engine=db_engine,
|
||||||
get_db_conn=db_pool.get_db_conn,
|
|
||||||
room_list_handler=object(),
|
room_list_handler=object(),
|
||||||
tls_server_context_factory=Mock(),
|
tls_server_context_factory=Mock(),
|
||||||
**kargs
|
**kargs
|
||||||
)
|
)
|
||||||
|
db_conn = hs.get_db_conn()
|
||||||
|
# make sure that the database is empty
|
||||||
|
if isinstance(db_engine, PostgresEngine):
|
||||||
|
cur = db_conn.cursor()
|
||||||
|
cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
|
||||||
|
rows = cur.fetchall()
|
||||||
|
for r in rows:
|
||||||
|
cur.execute("DROP TABLE %s CASCADE" % r[0])
|
||||||
|
yield prepare_database(db_conn, db_engine, config)
|
||||||
hs.setup()
|
hs.setup()
|
||||||
else:
|
else:
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
name, db_pool=None, datastore=datastore, config=config,
|
name, db_pool=None, datastore=datastore, config=config,
|
||||||
version_string="Synapse/tests",
|
version_string="Synapse/tests",
|
||||||
database_engine=create_engine(config.database_config),
|
database_engine=db_engine,
|
||||||
room_list_handler=object(),
|
room_list_handler=object(),
|
||||||
tls_server_context_factory=Mock(),
|
tls_server_context_factory=Mock(),
|
||||||
**kargs
|
**kargs
|
||||||
|
@ -305,168 +340,6 @@ class MockClock(object):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
class SQLiteMemoryDbPool(ConnectionPool, object):
|
|
||||||
def __init__(self):
|
|
||||||
super(SQLiteMemoryDbPool, self).__init__(
|
|
||||||
"sqlite3", ":memory:",
|
|
||||||
cp_min=1,
|
|
||||||
cp_max=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.config = Mock()
|
|
||||||
self.config.password_providers = []
|
|
||||||
self.config.database_config = {"name": "sqlite3"}
|
|
||||||
|
|
||||||
def prepare(self):
|
|
||||||
engine = self.create_engine()
|
|
||||||
return self.runWithConnection(
|
|
||||||
lambda conn: prepare_database(conn, engine, self.config)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_db_conn(self):
|
|
||||||
conn = self.connect()
|
|
||||||
engine = self.create_engine()
|
|
||||||
prepare_database(conn, engine, self.config)
|
|
||||||
return conn
|
|
||||||
|
|
||||||
def create_engine(self):
|
|
||||||
return create_engine(self.config.database_config)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryDataStore(object):
|
|
||||||
|
|
||||||
Room = namedtuple(
|
|
||||||
"Room",
|
|
||||||
["room_id", "is_public", "creator"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.tokens_to_users = {}
|
|
||||||
self.paths_to_content = {}
|
|
||||||
|
|
||||||
self.members = {}
|
|
||||||
self.rooms = {}
|
|
||||||
|
|
||||||
self.current_state = {}
|
|
||||||
self.events = []
|
|
||||||
|
|
||||||
class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")):
|
|
||||||
def fill_out_prev_events(self, event):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
|
|
||||||
return self.Snapshot(
|
|
||||||
room_id, user_id, self.get_room_member(user_id, room_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
def register(self, user_id, token, password_hash):
|
|
||||||
if user_id in self.tokens_to_users.values():
|
|
||||||
raise StoreError(400, "User in use.")
|
|
||||||
self.tokens_to_users[token] = user_id
|
|
||||||
|
|
||||||
def get_user_by_access_token(self, token):
|
|
||||||
try:
|
|
||||||
return {
|
|
||||||
"name": self.tokens_to_users[token],
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
raise StoreError(400, "User does not exist.")
|
|
||||||
|
|
||||||
def get_room(self, room_id):
|
|
||||||
try:
|
|
||||||
return self.rooms[room_id]
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def store_room(self, room_id, room_creator_user_id, is_public):
|
|
||||||
if room_id in self.rooms:
|
|
||||||
raise StoreError(409, "Conflicting room!")
|
|
||||||
|
|
||||||
room = MemoryDataStore.Room(
|
|
||||||
room_id=room_id,
|
|
||||||
is_public=is_public,
|
|
||||||
creator=room_creator_user_id
|
|
||||||
)
|
|
||||||
self.rooms[room_id] = room
|
|
||||||
|
|
||||||
def get_room_member(self, user_id, room_id):
|
|
||||||
return self.members.get(room_id, {}).get(user_id)
|
|
||||||
|
|
||||||
def get_room_members(self, room_id, membership=None):
|
|
||||||
if membership:
|
|
||||||
return [
|
|
||||||
v for k, v in self.members.get(room_id, {}).items()
|
|
||||||
if v.membership == membership
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
return self.members.get(room_id, {}).values()
|
|
||||||
|
|
||||||
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
|
|
||||||
return [
|
|
||||||
m[user_id] for m in self.members.values()
|
|
||||||
if user_id in m and m[user_id].membership in membership_list
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
|
|
||||||
limit=0, with_feedback=False):
|
|
||||||
return ([], from_key) # TODO
|
|
||||||
|
|
||||||
def get_joined_hosts_for_room(self, room_id):
|
|
||||||
return defer.succeed([])
|
|
||||||
|
|
||||||
def persist_event(self, event):
|
|
||||||
if event.type == EventTypes.Member:
|
|
||||||
room_id = event.room_id
|
|
||||||
user = event.state_key
|
|
||||||
self.members.setdefault(room_id, {})[user] = event
|
|
||||||
|
|
||||||
if hasattr(event, "state_key"):
|
|
||||||
key = (event.room_id, event.type, event.state_key)
|
|
||||||
self.current_state[key] = event
|
|
||||||
|
|
||||||
self.events.append(event)
|
|
||||||
|
|
||||||
def get_current_state(self, room_id, event_type=None, state_key=""):
|
|
||||||
if event_type:
|
|
||||||
key = (room_id, event_type, state_key)
|
|
||||||
if self.current_state.get(key):
|
|
||||||
return [self.current_state.get(key)]
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return [
|
|
||||||
e for e in self.current_state
|
|
||||||
if e[0] == room_id
|
|
||||||
]
|
|
||||||
|
|
||||||
def set_presence_state(self, user_localpart, state):
|
|
||||||
return defer.succeed({"state": 0})
|
|
||||||
|
|
||||||
def get_presence_list(self, user_localpart, accepted):
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_room_events_max_id(self):
|
|
||||||
return "s0" # TODO (erikj)
|
|
||||||
|
|
||||||
def get_send_event_level(self, room_id):
|
|
||||||
return defer.succeed(0)
|
|
||||||
|
|
||||||
def get_power_level(self, room_id, user_id):
|
|
||||||
return defer.succeed(0)
|
|
||||||
|
|
||||||
def get_add_state_level(self, room_id):
|
|
||||||
return defer.succeed(0)
|
|
||||||
|
|
||||||
def get_room_join_rule(self, room_id):
|
|
||||||
# TODO (erikj): This should be configurable
|
|
||||||
return defer.succeed("invite")
|
|
||||||
|
|
||||||
def get_ops_levels(self, room_id):
|
|
||||||
return defer.succeed((5, 5, 5))
|
|
||||||
|
|
||||||
def insert_client_ip(self, user, access_token, ip, user_agent):
|
|
||||||
return defer.succeed(None)
|
|
||||||
|
|
||||||
|
|
||||||
def _format_call(args, kwargs):
|
def _format_call(args, kwargs):
|
||||||
return ", ".join(
|
return ", ".join(
|
||||||
["%r" % (a) for a in args] +
|
["%r" % (a) for a in args] +
|
||||||
|
|
Loading…
Reference in New Issue