Merge branch 'develop' into pushrules2

This commit is contained in:
David Baker 2015-03-04 14:56:41 +00:00
commit 92b3dc3219
33 changed files with 1043 additions and 201 deletions

View File

@ -4,6 +4,8 @@ Changes in synapse vx.x.x (x-x-x)
* Add support for registration fallback. This is a page hosted on the server * Add support for registration fallback. This is a page hosted on the server
which allows a user to register for an account, regardless of what client which allows a user to register for an account, regardless of what client
they are using (e.g. mobile devices). they are using (e.g. mobile devices).
* Application services can now poll on the CS API ``/events`` for their events,
by providing their application service ``access_token``.
Changes in synapse v0.7.1 (2015-02-19) Changes in synapse v0.7.1 (2015-02-19)
====================================== ======================================

View File

@ -12,6 +12,10 @@ Servers which use captchas will need to add their public key to::
This is required in order to support registration fallback (typically used on This is required in order to support registration fallback (typically used on
mobile devices). mobile devices).
The format of stored application services has changed in Synapse. You will need
to run ``PYTHONPATH=. python scripts/upgrade_appservice_db.py <database file path>``
to convert to the new format.
Upgrading to v0.7.0 Upgrading to v0.7.0
=================== ===================

View File

@ -0,0 +1,54 @@
from synapse.storage import read_schema
import argparse
import json
import sqlite3
def do_other_deltas(cursor):
cursor.execute("PRAGMA user_version")
row = cursor.fetchone()
if row and row[0]:
user_version = row[0]
# Run every version since after the current version.
for v in range(user_version + 1, 10):
print "Running delta: %d" % (v,)
sql_script = read_schema("delta/v%d" % (v,))
cursor.executescript(sql_script)
def update_app_service_table(cur):
cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall():
try:
print "checking %s..." % row[0]
json.loads(row[1])
except ValueError:
# row isn't in json, make it so.
string_regex = row[1]
new_regex = json.dumps({
"regex": string_regex,
"exclusive": True
})
cur.execute(
"UPDATE application_services_regex SET regex=? WHERE id=?",
(new_regex, row[0])
)
def main(dbname):
con = sqlite3.connect(dbname)
cur = con.cursor()
do_other_deltas(cur)
update_app_service_table(cur)
cur.execute("PRAGMA user_version = 14")
cur.close()
con.commit()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("database")
args = parser.parse_args()
main(args.database)

View File

@ -53,6 +53,7 @@ import synapse
import logging import logging
import os import os
import re import re
import resource
import subprocess import subprocess
import sqlite3 import sqlite3
import syweb import syweb
@ -146,8 +147,8 @@ class SynapseHomeServer(HomeServer):
# instead, we'll store a copy of this mapping so we can actually add # instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key. # extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {} resource_mappings = {}
for (full_path, resource) in desired_tree: for full_path, res in desired_tree:
logger.info("Attaching %s to path %s", resource, full_path) logger.info("Attaching %s to path %s", res, full_path)
last_resource = self.root_resource last_resource = self.root_resource
for path_seg in full_path.split('/')[1:-1]: for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames(): if path_seg not in last_resource.listNames():
@ -178,12 +179,12 @@ class SynapseHomeServer(HomeServer):
child_name) child_name)
child_resource = resource_mappings[child_res_id] child_resource = resource_mappings[child_res_id]
# steal the children # steal the children
resource.putChild(child_name, child_resource) res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place # finally, insert the desired resource in the right place
last_resource.putChild(last_path_seg, resource) last_resource.putChild(last_path_seg, res)
res_id = self._resource_id(last_resource, last_path_seg) res_id = self._resource_id(last_resource, last_path_seg)
resource_mappings[res_id] = resource resource_mappings[res_id] = res
return self.root_resource return self.root_resource
@ -275,6 +276,20 @@ def get_version_string():
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii") return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
def change_resource_limit(soft_file_no):
try:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
if not soft_file_no:
soft_file_no = hard
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
logger.info("Set file limit to: %d", soft_file_no)
except (ValueError, resource.error) as e:
logger.warn("Failed to set file limit: %s", e)
def setup(): def setup():
config = HomeServerConfig.load_config( config = HomeServerConfig.load_config(
"Synapse Homeserver", "Synapse Homeserver",
@ -351,10 +366,11 @@ def setup():
if config.daemonize: if config.daemonize:
print config.pid_file print config.pid_file
daemon = Daemonize( daemon = Daemonize(
app="synapse-homeserver", app="synapse-homeserver",
pid=config.pid_file, pid=config.pid_file,
action=run, action=lambda: run(config),
auto_close_fds=False, auto_close_fds=False,
verbose=True, verbose=True,
logger=logger, logger=logger,
@ -362,11 +378,13 @@ def setup():
daemon.start() daemon.start()
else: else:
reactor.run() run(config)
def run(): def run(config):
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(config.soft_file_limit)
reactor.run() reactor.run()

View File

@ -46,22 +46,34 @@ class ApplicationService(object):
def _check_namespaces(self, namespaces): def _check_namespaces(self, namespaces):
# Sanity check that it is of the form: # Sanity check that it is of the form:
# { # {
# users: ["regex",...], # users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# aliases: ["regex",...], # aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# rooms: ["regex",...], # rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# } # }
if not namespaces: if not namespaces:
return None return None
for ns in ApplicationService.NS_LIST: for ns in ApplicationService.NS_LIST:
if ns not in namespaces:
namespaces[ns] = []
continue
if type(namespaces[ns]) != list: if type(namespaces[ns]) != list:
raise ValueError("Bad namespace value for '%s'", ns) raise ValueError("Bad namespace value for '%s'" % ns)
for regex in namespaces[ns]: for regex_obj in namespaces[ns]:
if not isinstance(regex, basestring): if not isinstance(regex_obj, dict):
raise ValueError("Expected string regex for ns '%s'", ns) raise ValueError("Expected dict regex for ns '%s'" % ns)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Expected bool for 'exclusive' in ns '%s'" % ns
)
if not isinstance(regex_obj.get("regex"), basestring):
raise ValueError(
"Expected string for 'regex' in ns '%s'" % ns
)
return namespaces return namespaces
def _matches_regex(self, test_string, namespace_key): def _matches_regex(self, test_string, namespace_key, return_obj=False):
if not isinstance(test_string, basestring): if not isinstance(test_string, basestring):
logger.error( logger.error(
"Expected a string to test regex against, but got %s", "Expected a string to test regex against, but got %s",
@ -69,11 +81,19 @@ class ApplicationService(object):
) )
return False return False
for regex in self.namespaces[namespace_key]: for regex_obj in self.namespaces[namespace_key]:
if re.match(regex, test_string): if re.match(regex_obj["regex"], test_string):
if return_obj:
return regex_obj
return True return True
return False return False
def _is_exclusive(self, ns_key, test_string):
regex_obj = self._matches_regex(test_string, ns_key, return_obj=True)
if regex_obj:
return regex_obj["exclusive"]
return False
def _matches_user(self, event, member_list): def _matches_user(self, event, member_list):
if (hasattr(event, "sender") and if (hasattr(event, "sender") and
self.is_interested_in_user(event.sender)): self.is_interested_in_user(event.sender)):
@ -143,5 +163,14 @@ class ApplicationService(object):
def is_interested_in_room(self, room_id): def is_interested_in_room(self, room_id):
return self._matches_regex(room_id, ApplicationService.NS_ROOMS) return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
def is_exclusive_user(self, user_id):
return self._is_exclusive(ApplicationService.NS_USERS, user_id)
def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def __str__(self): def __str__(self):
return "ApplicationService: %s" % (self.__dict__,) return "ApplicationService: %s" % (self.__dict__,)

View File

@ -22,6 +22,12 @@ class RatelimitConfig(Config):
self.rc_messages_per_second = args.rc_messages_per_second self.rc_messages_per_second = args.rc_messages_per_second
self.rc_message_burst_count = args.rc_message_burst_count self.rc_message_burst_count = args.rc_message_burst_count
self.federation_rc_window_size = args.federation_rc_window_size
self.federation_rc_sleep_limit = args.federation_rc_sleep_limit
self.federation_rc_sleep_delay = args.federation_rc_sleep_delay
self.federation_rc_reject_limit = args.federation_rc_reject_limit
self.federation_rc_concurrent = args.federation_rc_concurrent
@classmethod @classmethod
def add_arguments(cls, parser): def add_arguments(cls, parser):
super(RatelimitConfig, cls).add_arguments(parser) super(RatelimitConfig, cls).add_arguments(parser)
@ -34,3 +40,33 @@ class RatelimitConfig(Config):
"--rc-message-burst-count", type=float, default=10, "--rc-message-burst-count", type=float, default=10,
help="number of message a client can send before being throttled" help="number of message a client can send before being throttled"
) )
rc_group.add_argument(
"--federation-rc-window-size", type=int, default=10000,
help="The federation window size in milliseconds",
)
rc_group.add_argument(
"--federation-rc-sleep-limit", type=int, default=10,
help="The number of federation requests from a single server"
" in a window before the server will delay processing the"
" request.",
)
rc_group.add_argument(
"--federation-rc-sleep-delay", type=int, default=500,
help="The duration in milliseconds to delay processing events from"
" remote servers by if they go over the sleep limit.",
)
rc_group.add_argument(
"--federation-rc-reject-limit", type=int, default=50,
help="The maximum number of concurrent federation requests allowed"
" from a single server",
)
rc_group.add_argument(
"--federation-rc-concurrent", type=int, default=3,
help="The number of federation requests to concurrently process"
" from a single server",
)

View File

@ -31,6 +31,7 @@ class ServerConfig(Config):
self.webclient = True self.webclient = True
self.manhole = args.manhole self.manhole = args.manhole
self.no_tls = args.no_tls self.no_tls = args.no_tls
self.soft_file_limit = args.soft_file_limit
if not args.content_addr: if not args.content_addr:
host = args.server_name host = args.server_name
@ -77,6 +78,12 @@ class ServerConfig(Config):
"content repository") "content repository")
server_group.add_argument("--no-tls", action='store_true', server_group.add_argument("--no-tls", action='store_true',
help="Don't bind to the https port.") help="Don't bind to the https port.")
server_group.add_argument("--soft-file-limit", type=int, default=0,
help="Set the soft limit on the number of "
"file descriptors synapse can use. "
"Zero is used to indicate synapse "
"should set the soft limit to the hard"
"limit.")
def read_signing_key(self, signing_key_path): def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key") signing_keys = self.read_file(signing_key_path, "signing_key")

View File

@ -439,6 +439,25 @@ class FederationClient(FederationBase):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth):
content = yield self.transport_layer.get_missing_events(
destination, room_id, earliest_events, latest_events, limit,
min_depth,
)
events = [
self.event_from_pdu_json(e)
for e in content.get("events", [])
]
signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=True
)
defer.returnValue(signed_events)
def event_from_pdu_json(self, pdu_json, outlier=False): def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent( event = FrozenEvent(
pdu_json pdu_json

View File

@ -112,17 +112,20 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext(): with PreserveLoggingContext():
dl = [] results = []
for pdu in pdu_list: for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu) d = self._handle_new_pdu(transaction.origin, pdu)
def handle_failure(failure): try:
failure.trap(FederationError) yield d
self.send_failure(failure.value, transaction.origin) results.append({})
except FederationError as e:
d.addErrback(handle_failure) self.send_failure(e, transaction.origin)
results.append({"error": str(e)})
dl.append(d) except Exception as e:
results.append({"error": str(e)})
logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]: for edu in [Edu(**x) for x in transaction.edus]:
@ -135,21 +138,11 @@ class FederationServer(FederationBase):
for failure in getattr(transaction, "pdu_failures", []): for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure) logger.info("Got failure %r", failure)
results = yield defer.DeferredList(dl, consumeErrors=True) logger.debug("Returning: %s", str(results))
ret = []
for r in results:
if r[0]:
ret.append({})
else:
logger.exception(r[1])
ret.append({"error": str(r[1].value)})
logger.debug("Returning: %s", str(ret))
response = { response = {
"pdus": dict(zip( "pdus": dict(zip(
(p.event_id for p in pdu_list), ret (p.event_id for p in pdu_list), results
)), )),
} }
@ -305,6 +298,20 @@ class FederationServer(FederationBase):
(200, send_content) (200, send_content)
) )
@defer.inlineCallbacks
@log_function
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
missing_events = yield self.handler.on_get_missing_events(
origin, room_id, earliest_events, latest_events, limit, min_depth
)
time_now = self._clock.time_msec()
defer.returnValue({
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
})
@log_function @log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True): def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id. """ Get a PDU from the database with given origin and id.
@ -331,7 +338,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _handle_new_pdu(self, origin, pdu, max_recursion=10): def _handle_new_pdu(self, origin, pdu, get_missing=True):
# We reprocess pdus when we have seen them only as outliers # We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu( existing = yield self._get_persisted_pdu(
origin, pdu.event_id, do_auth=False origin, pdu.event_id, do_auth=False
@ -383,48 +390,50 @@ class FederationServer(FederationBase):
pdu.room_id, min_depth pdu.room_id, min_depth
) )
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if min_depth and pdu.depth < min_depth: if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this # This is so that we don't notify the user about this
# message, to work around the fact that some events will # message, to work around the fact that some events will
# reference really really old events we really don't want to # reference really really old events we really don't want to
# send to the clients. # send to the clients.
pdu.internal_metadata.outlier = True pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth and max_recursion > 0: elif min_depth and pdu.depth > min_depth:
for event_id, hashes in pdu.prev_events: if get_missing and prevs - seen:
if event_id not in have_seen: latest_tuples = yield self.store.get_latest_events_in_room(
logger.debug( pdu.room_id
"_handle_new_pdu requesting pdu %s",
event_id
) )
try: # We add the prev events that we have seen to the latest
new_pdu = yield self.federation_client.get_pdu( # list to ensure the remote server doesn't give them to us
[origin, pdu.origin], latest = set(e_id for e_id, _, _ in latest_tuples)
event_id=event_id, latest |= seen
missing_events = yield self.get_missing_events(
origin,
pdu.room_id,
earliest_events=list(latest),
latest_events=[pdu.event_id],
limit=10,
min_depth=min_depth,
) )
if new_pdu: for e in missing_events:
yield self._handle_new_pdu( yield self._handle_new_pdu(
origin, origin,
new_pdu, e,
max_recursion=max_recursion-1 get_missing=False
)
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
) )
logger.debug("Processed pdu %s", event_id)
else:
logger.warn("Failed to get PDU %s", event_id)
fetch_state = True
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
fetch_state = True
else:
prevs = {e_id for e_id, _ in pdu.prev_events} prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys()) seen = set(have_seen.keys())
if prevs - seen: if prevs - seen:
fetch_state = True fetch_state = True
else:
fetch_state = True
if fetch_state: if fetch_state:
# We need to get the state at this event, since we haven't # We need to get the state at this event, since we haven't

View File

@ -224,6 +224,8 @@ class TransactionQueue(object):
] ]
try: try:
self.pending_transactions[destination] = 1
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, destination,
self._clock, self._clock,
@ -239,8 +241,6 @@ class TransactionQueue(object):
len(pending_failures) len(pending_failures)
) )
self.pending_transactions[destination] = 1
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new( transaction = Transaction.create_new(
@ -287,7 +287,7 @@ class TransactionQueue(object):
code = 200 code = 200
if response: if response:
for e_id, r in getattr(response, "pdus", {}).items(): for e_id, r in response.get("pdus", {}).items():
if "error" in r: if "error" in r:
logger.warn( logger.warn(
"Transaction returned error for %s: %s", "Transaction returned error for %s: %s",

View File

@ -24,6 +24,8 @@ communicate over a different (albeit still reliable) protocol.
from .server import TransportLayerServer from .server import TransportLayerServer
from .client import TransportLayerClient from .client import TransportLayerClient
from synapse.util.ratelimitutils import FederationRateLimiter
class TransportLayer(TransportLayerServer, TransportLayerClient): class TransportLayer(TransportLayerServer, TransportLayerClient):
"""This is a basic implementation of the transport layer that translates """This is a basic implementation of the transport layer that translates
@ -55,8 +57,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient):
send requests send requests
""" """
self.keyring = homeserver.get_keyring() self.keyring = homeserver.get_keyring()
self.clock = homeserver.get_clock()
self.server_name = server_name self.server_name = server_name
self.server = server self.server = server
self.client = client self.client = client
self.request_handler = None self.request_handler = None
self.received_handler = None self.received_handler = None
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=homeserver.config.federation_rc_window_size,
sleep_limit=homeserver.config.federation_rc_sleep_limit,
sleep_msec=homeserver.config.federation_rc_sleep_delay,
reject_limit=homeserver.config.federation_rc_reject_limit,
concurrent_requests=homeserver.config.federation_rc_concurrent,
)

View File

@ -219,3 +219,22 @@ class TransportLayerClient(object):
) )
defer.returnValue(content) defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth):
path = PREFIX + "/get_missing_events/%s" % (room_id,)
content = yield self.client.post_json(
destination=destination,
path=path,
data={
"limit": int(limit),
"min_depth": int(min_depth),
"earliest_events": earliest_events,
"latest_events": latest_events,
}
)
defer.returnValue(content)

View File

@ -98,6 +98,8 @@ class TransportLayerServer(object):
def new_handler(request, *args, **kwargs): def new_handler(request, *args, **kwargs):
try: try:
(origin, content) = yield self._authenticate_request(request) (origin, content) = yield self._authenticate_request(request)
with self.ratelimiter.ratelimit(origin) as d:
yield d
response = yield handler( response = yield handler(
origin, content, request.args, *args, **kwargs origin, content, request.args, *args, **kwargs
) )
@ -107,6 +109,12 @@ class TransportLayerServer(object):
defer.returnValue(response) defer.returnValue(response)
return new_handler return new_handler
def rate_limit_origin(self, handler):
def new_handler(origin, *args, **kwargs):
response = yield handler(origin, *args, **kwargs)
defer.returnValue(response)
return new_handler()
@log_function @log_function
def register_received_handler(self, handler): def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data. """ Register a handler that will be fired when we receive data.
@ -234,6 +242,7 @@ class TransportLayerServer(object):
) )
) )
) )
self.server.register_path( self.server.register_path(
"POST", "POST",
re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"), re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
@ -245,6 +254,17 @@ class TransportLayerServer(object):
) )
) )
self.server.register_path(
"POST",
re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"),
self._with_authentication(
lambda origin, content, query, room_id:
self._get_missing_events(
origin, content, room_id,
)
)
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _on_send_request(self, origin, content, query, transaction_id): def _on_send_request(self, origin, content, query, transaction_id):
@ -344,3 +364,22 @@ class TransportLayerServer(object):
) )
defer.returnValue((200, new_content)) defer.returnValue((200, new_content))
@defer.inlineCallbacks
@log_function
def _get_missing_events(self, origin, content, room_id):
limit = int(content.get("limit", 10))
min_depth = int(content.get("min_depth", 0))
earliest_events = content.get("earliest_events", [])
latest_events = content.get("latest_events", [])
content = yield self.request_handler.on_get_missing_events(
origin,
room_id=room_id,
earliest_events=earliest_events,
latest_events=latest_events,
min_depth=min_depth,
limit=limit,
)
defer.returnValue((200, content))

View File

@ -232,13 +232,23 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def can_modify_alias(self, alias, user_id=None): def can_modify_alias(self, alias, user_id=None):
# Any application service "interested" in an alias they are regexing on
# can modify the alias.
# Users can only modify the alias if ALL the interested services have
# non-exclusive locks on the alias (or there are no interested services)
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
interested_services = [ interested_services = [
s for s in services if s.is_interested_in_alias(alias.to_string()) s for s in services if s.is_interested_in_alias(alias.to_string())
] ]
for service in interested_services: for service in interested_services:
if user_id == service.sender: if user_id == service.sender:
# this user IS the app service # this user IS the app service so they can do whatever they like
defer.returnValue(True) defer.returnValue(True)
return return
defer.returnValue(len(interested_services) == 0) elif service.is_exclusive_alias(alias.to_string()):
# another service has an exclusive lock on this alias.
defer.returnValue(False)
return
# either no interested services, or no service with an exclusive lock
defer.returnValue(True)

View File

@ -69,9 +69,6 @@ class EventStreamHandler(BaseHandler):
) )
self._streams_per_user[auth_user] += 1 self._streams_per_user[auth_user] += 1
if pagin_config.from_token is None:
pagin_config.from_token = None
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(auth_user) room_ids = yield rm_handler.get_rooms_for_user(auth_user)

View File

@ -581,9 +581,10 @@ class FederationHandler(BaseHandler):
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, origin, room_id, event_id): def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
yield run_on_reactor() yield run_on_reactor()
if do_auth:
in_room = yield self.auth.check_host_in_room(room_id, origin) in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
@ -788,6 +789,29 @@ class FederationHandler(BaseHandler):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
in_room = yield self.auth.check_host_in_room(
room_id,
origin
)
if not in_room:
raise AuthError(403, "Host not in room.")
limit = min(limit, 20)
min_depth = max(min_depth, 0)
missing_events = yield self.store.get_missing_events(
room_id=room_id,
earliest_events=earliest_events,
latest_events=latest_events,
limit=limit,
min_depth=min_depth,
)
defer.returnValue(missing_events)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def do_auth(self, origin, event, context, auth_events): def do_auth(self, origin, event, context, auth_events):

View File

@ -201,7 +201,8 @@ class RegistrationHandler(BaseHandler):
interested_services = [ interested_services = [
s for s in services if s.is_interested_in_user(user_id) s for s in services if s.is_interested_in_user(user_id)
] ]
if len(interested_services) > 0: for service in interested_services:
if service.is_exclusive_user(user_id):
raise SynapseError( raise SynapseError(
400, "This user ID is reserved by an application service.", 400, "This user ID is reserved by an application service.",
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE

View File

@ -510,6 +510,13 @@ class RoomMemberHandler(BaseHandler):
def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]): def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]):
"""Returns a list of roomids that the user has any of the given """Returns a list of roomids that the user has any of the given
membership states in.""" membership states in."""
app_service = yield self.store.get_app_service_by_user_id(
user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
else:
rooms = yield self.store.get_rooms_for_user_where_membership_is( rooms = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user.to_string(), membership_list=membership_list user_id=user.to_string(), membership_list=membership_list
) )
@ -559,6 +566,17 @@ class RoomEventSource(object):
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
app_service = yield self.store.get_app_service_by_user_id(
user.to_string()
)
if app_service:
events, end_key = yield self.store.get_appservice_room_stream(
service=app_service,
from_key=from_key,
to_key=to_key,
limit=limit,
)
else:
events, end_key = yield self.store.get_room_events_stream( events, end_key = yield self.store.get_room_events_stream(
user_id=user.to_string(), user_id=user.to_string(),
from_key=from_key, from_key=from_key,

View File

@ -143,7 +143,7 @@ class SimpleHttpClient(object):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes) uri = "%s?%s" % (uri, query_bytes)
json_str = json.dumps(json_body) json_str = encode_canonical_json(json_body)
response = yield self.agent.request( response = yield self.agent.request(
"PUT", "PUT",

View File

@ -36,8 +36,10 @@ class _NotificationListener(object):
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user, rooms, from_token, limit, timeout, deferred): def __init__(self, user, rooms, from_token, limit, timeout, deferred,
appservice=None):
self.user = user self.user = user
self.appservice = appservice
self.from_token = from_token self.from_token = from_token
self.limit = limit self.limit = limit
self.timeout = timeout self.timeout = timeout
@ -65,6 +67,10 @@ class _NotificationListener(object):
lst.discard(self) lst.discard(self)
notifier.user_to_listeners.get(self.user, set()).discard(self) notifier.user_to_listeners.get(self.user, set()).discard(self)
if self.appservice:
notifier.appservice_to_listeners.get(
self.appservice, set()
).discard(self)
class Notifier(object): class Notifier(object):
@ -79,6 +85,7 @@ class Notifier(object):
self.rooms_to_listeners = {} self.rooms_to_listeners = {}
self.user_to_listeners = {} self.user_to_listeners = {}
self.appservice_to_listeners = {}
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
@ -114,6 +121,17 @@ class Notifier(object):
for user in extra_users: for user in extra_users:
listeners |= self.user_to_listeners.get(user, set()).copy() listeners |= self.user_to_listeners.get(user, set()).copy()
for appservice in self.appservice_to_listeners:
# TODO (kegan): Redundant appservice listener checks?
# App services will already be in the rooms_to_listeners set, but
# that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this
# make the rooms_to_listeners check somewhat obselete?
if appservice.is_interested(event):
listeners |= self.appservice_to_listeners.get(
appservice, set()
).copy()
logger.debug("on_new_room_event listeners %s", listeners) logger.debug("on_new_room_event listeners %s", listeners)
# TODO (erikj): Can we make this more efficient by hitting the # TODO (erikj): Can we make this more efficient by hitting the
@ -280,6 +298,10 @@ class Notifier(object):
if not from_token: if not from_token:
from_token = yield self.event_sources.get_current_token() from_token = yield self.event_sources.get_current_token()
appservice = yield self.hs.get_datastore().get_app_service_by_user_id(
user.to_string()
)
listener = _NotificationListener( listener = _NotificationListener(
user, user,
rooms, rooms,
@ -287,6 +309,7 @@ class Notifier(object):
limit, limit,
timeout, timeout,
deferred, deferred,
appservice=appservice
) )
def _timeout_listener(): def _timeout_listener():
@ -319,6 +342,11 @@ class Notifier(object):
self.user_to_listeners.setdefault(listener.user, set()).add(listener) self.user_to_listeners.setdefault(listener.user, set()).add(listener)
if listener.appservice:
self.appservice_to_listeners.setdefault(
listener.appservice, set()
).add(listener)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _check_for_updates(self, listener): def _check_for_updates(self, listener):

View File

@ -5,7 +5,7 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"syutil>=0.0.3": ["syutil"], "syutil>=0.0.3": ["syutil"],
"matrix_angular_sdk>=0.6.3": ["syweb>=0.6.3"], "matrix_angular_sdk>=0.6.4": ["syweb>=0.6.4"],
"Twisted==14.0.2": ["twisted==14.0.2"], "Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
@ -36,8 +36,8 @@ DEPENDENCY_LINKS = [
), ),
github_link( github_link(
project="matrix-org/matrix-angular-sdk", project="matrix-org/matrix-angular-sdk",
version="v0.6.3", version="v0.6.4",
egg="matrix_angular_sdk-0.6.3", egg="matrix_angular_sdk-0.6.4",
), ),
] ]

View File

@ -48,18 +48,12 @@ class RegisterRestServlet(AppServiceRestServlet):
400, "Missed required keys: as_token(str) / url(str)." 400, "Missed required keys: as_token(str) / url(str)."
) )
namespaces = { try:
"users": [], app_service = ApplicationService(
"rooms": [], as_token, as_url, params["namespaces"]
"aliases": [] )
} except ValueError as e:
raise SynapseError(400, e.message)
if "namespaces" in params:
self._parse_namespace(namespaces, params["namespaces"], "users")
self._parse_namespace(namespaces, params["namespaces"], "rooms")
self._parse_namespace(namespaces, params["namespaces"], "aliases")
app_service = ApplicationService(as_token, as_url, namespaces)
app_service = yield self.handler.register(app_service) app_service = yield self.handler.register(app_service)
hs_token = app_service.hs_token hs_token = app_service.hs_token
@ -68,23 +62,6 @@ class RegisterRestServlet(AppServiceRestServlet):
"hs_token": hs_token "hs_token": hs_token
})) }))
def _parse_namespace(self, target_ns, origin_ns, ns):
if ns not in target_ns or ns not in origin_ns:
return # nothing to parse / map through to.
possible_regex_list = origin_ns[ns]
if not type(possible_regex_list) == list:
raise SynapseError(400, "Namespace %s isn't an array." % ns)
for regex in possible_regex_list:
if not isinstance(regex, basestring):
raise SynapseError(
400, "Regex '%s' isn't a string in namespace %s" %
(regex, ns)
)
target_ns[ns] = origin_ns[ns]
class UnregisterRestServlet(AppServiceRestServlet): class UnregisterRestServlet(AppServiceRestServlet):
"""Handles AS registration with the home server. """Handles AS registration with the home server.

View File

@ -74,7 +74,7 @@ SCHEMAS = [
# Remember to update this number every time an incompatible change is made to # Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts. # database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 13 SCHEMA_VERSION = 14
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -633,11 +633,11 @@ def prepare_database(db_conn):
# Run every version since after the current version. # Run every version since after the current version.
for v in range(user_version + 1, SCHEMA_VERSION + 1): for v in range(user_version + 1, SCHEMA_VERSION + 1):
if v == 10: if v in (10, 14,):
raise UpgradeDatabaseException( raise UpgradeDatabaseException(
"No delta for version 10" "No delta for version 10"
) )
sql_script = read_schema("delta/v%d" % (v)) sql_script = read_schema("delta/v%d" % (v,))
c.executescript(sql_script) c.executescript(sql_script)
db_conn.commit() db_conn.commit()

View File

@ -450,7 +450,8 @@ class SQLBaseStore(object):
Args: Args:
table : string giving the table name table : string giving the table name
keyvalues : dict of column names and values to select the rows with keyvalues : dict of column names and values to select the rows with,
or None to not apply a WHERE clause.
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
return self.runInteraction( return self.runInteraction(
@ -469,13 +470,20 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the rows with keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
) )
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
else:
sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
", ".join(retcols),
table
)
txn.execute(sql)
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues, def _simple_update_one(self, table, keyvalues, updatevalues,

View File

@ -13,22 +13,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import simplejson
from simplejson import JSONDecodeError
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.storage.roommember import RoomsForUser
from ._base import SQLBaseStore from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def log_failure(failure):
logger.error("Failed to detect application services: %s", failure.value)
logger.error(failure.getTraceback())
class ApplicationServiceStore(SQLBaseStore): class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs) super(ApplicationServiceStore, self).__init__(hs)
self.services_cache = [] self.services_cache = []
self.cache_defer = self._populate_cache() self.cache_defer = self._populate_cache()
self.cache_defer.addErrback(log_failure)
@defer.inlineCallbacks @defer.inlineCallbacks
def unregister_app_service(self, token): def unregister_app_service(self, token):
@ -128,11 +138,11 @@ class ApplicationServiceStore(SQLBaseStore):
) )
for (ns_int, ns_str) in enumerate(ApplicationService.NS_LIST): for (ns_int, ns_str) in enumerate(ApplicationService.NS_LIST):
if ns_str in service.namespaces: if ns_str in service.namespaces:
for regex in service.namespaces[ns_str]: for regex_obj in service.namespaces[ns_str]:
txn.execute( txn.execute(
"INSERT INTO application_services_regex(" "INSERT INTO application_services_regex("
"as_id, namespace, regex) values(?,?,?)", "as_id, namespace, regex) values(?,?,?)",
(as_id, ns_int, regex) (as_id, ns_int, simplejson.dumps(regex_obj))
) )
return True return True
@ -150,9 +160,32 @@ class ApplicationServiceStore(SQLBaseStore):
yield self.cache_defer # make sure the cache is ready yield self.cache_defer # make sure the cache is ready
defer.returnValue(self.services_cache) defer.returnValue(self.services_cache)
@defer.inlineCallbacks
def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID.
All application services have associated with them a particular user ID.
There is no distinguishing feature on the user ID which indicates it
represents an application service. This function allows you to map from
a user ID to an application service.
Args:
user_id(str): The user ID to see if it is an application service.
Returns:
synapse.appservice.ApplicationService or None.
"""
yield self.cache_defer # make sure the cache is ready
for service in self.services_cache:
if service.sender == user_id:
defer.returnValue(service)
return
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_app_service_by_token(self, token, from_cache=True): def get_app_service_by_token(self, token, from_cache=True):
"""Get the application service with the given token. """Get the application service with the given appservice token.
Args: Args:
token (str): The application service token. token (str): The application service token.
@ -173,6 +206,77 @@ class ApplicationServiceStore(SQLBaseStore):
# TODO: The from_cache=False impl # TODO: The from_cache=False impl
# TODO: This should be JOINed with the application_services_regex table. # TODO: This should be JOINed with the application_services_regex table.
def get_app_service_rooms(self, service):
"""Get a list of RoomsForUser for this application service.
Application services may be "interested" in lots of rooms depending on
the room ID, the room aliases, or the members in the room. This function
takes all of these into account and returns a list of RoomsForUser which
represent the entire list of room IDs that this application service
wants to know about.
Args:
service: The application service to get a room list for.
Returns:
A list of RoomsForUser.
"""
return self.runInteraction(
"get_app_service_rooms",
self._get_app_service_rooms_txn,
service,
)
def _get_app_service_rooms_txn(self, txn, service):
# get all rooms matching the room ID regex.
room_entries = self._simple_select_list_txn(
txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
)
matching_room_list = set([
r["room_id"] for r in room_entries if
service.is_interested_in_room(r["room_id"])
])
# resolve room IDs for matching room alias regex.
room_alias_mappings = self._simple_select_list_txn(
txn=txn, table="room_aliases", keyvalues=None,
retcols=["room_id", "room_alias"]
)
matching_room_list |= set([
r["room_id"] for r in room_alias_mappings if
service.is_interested_in_alias(r["room_alias"])
])
# get all rooms for every user for this AS. This is scoped to users on
# this HS only.
user_list = self._simple_select_list_txn(
txn=txn, table="users", keyvalues=None, retcols=["name"]
)
user_list = [
u["name"] for u in user_list if
service.is_interested_in_user(u["name"])
]
rooms_for_user_matching_user_id = set() # RoomsForUser list
for user_id in user_list:
# FIXME: This assumes this store is linked with RoomMemberStore :(
rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
txn=txn,
user_id=user_id,
membership_list=[Membership.JOIN]
)
rooms_for_user_matching_user_id |= set(rooms_for_user)
# make RoomsForUser tuples for room ids and aliases which are not in the
# main rooms_for_user_list - e.g. they are rooms which do not have AS
# registered users in it.
known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
missing_rooms_for_user = [
RoomsForUser(r, service.sender, "join") for r in
matching_room_list if r not in known_room_ids
]
rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
return rooms_for_user_matching_user_id
@defer.inlineCallbacks @defer.inlineCallbacks
def _populate_cache(self): def _populate_cache(self):
"""Populates the ApplicationServiceCache from the database.""" """Populates the ApplicationServiceCache from the database."""
@ -215,10 +319,12 @@ class ApplicationServiceStore(SQLBaseStore):
try: try:
services[as_token]["namespaces"][ services[as_token]["namespaces"][
ApplicationService.NS_LIST[ns_int]].append( ApplicationService.NS_LIST[ns_int]].append(
res["regex"] simplejson.loads(res["regex"])
) )
except IndexError: except IndexError:
logger.error("Bad namespace enum '%s'. %s", ns_int, res) logger.error("Bad namespace enum '%s'. %s", ns_int, res)
except JSONDecodeError:
logger.error("Bad regex object '%s'", res["regex"])
# TODO get last successful txn id f.e. service # TODO get last successful txn id f.e. service
for service in services.values(): for service in services.values():

View File

@ -64,6 +64,9 @@ class EventFederationStore(SQLBaseStore):
for f in front: for f in front:
txn.execute(base_sql, (f,)) txn.execute(base_sql, (f,))
new_front.update([r[0] for r in txn.fetchall()]) new_front.update([r[0] for r in txn.fetchall()])
new_front -= results
front = new_front front = new_front
results.update(front) results.update(front)
@ -378,3 +381,51 @@ class EventFederationStore(SQLBaseStore):
event_results += new_front event_results += new_front
return self._get_events_txn(txn, event_results) return self._get_events_txn(txn, event_results)
def get_missing_events(self, room_id, earliest_events, latest_events,
limit, min_depth):
return self.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id, earliest_events, latest_events, limit, min_depth
)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
limit, min_depth):
earliest_events = set(earliest_events)
front = set(latest_events) - earliest_events
event_results = set()
query = (
"SELECT prev_event_id FROM event_edges "
"WHERE room_id = ? AND event_id = ? AND is_state = 0 "
"LIMIT ?"
)
while front and len(event_results) < limit:
new_front = set()
for event_id in front:
txn.execute(
query,
(room_id, event_id, limit - len(event_results))
)
for e_id, in txn.fetchall():
new_front.add(e_id)
new_front -= earliest_events
new_front -= event_results
front = new_front
event_results |= new_front
events = self._get_events_txn(txn, event_results)
events = sorted(
[ev for ev in events if ev.depth >= min_depth],
key=lambda e: e.depth,
)
return events[:limit]

View File

@ -180,6 +180,14 @@ class RoomMemberStore(SQLBaseStore):
if not membership_list: if not membership_list:
return defer.succeed(None) return defer.succeed(None)
return self.runInteraction(
"get_rooms_for_user_where_membership_is",
self._get_rooms_for_user_where_membership_is_txn,
user_id, membership_list
)
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
membership_list):
where_clause = "user_id = ? AND (%s)" % ( where_clause = "user_id = ? AND (%s)" % (
" OR ".join(["membership = ?" for _ in membership_list]), " OR ".join(["membership = ?" for _ in membership_list]),
) )
@ -187,7 +195,6 @@ class RoomMemberStore(SQLBaseStore):
args = [user_id] args = [user_id]
args.extend(membership_list) args.extend(membership_list)
def f(txn):
sql = ( sql = (
"SELECT m.room_id, m.sender, m.membership" "SELECT m.room_id, m.sender, m.membership"
" FROM room_memberships as m" " FROM room_memberships as m"
@ -201,11 +208,6 @@ class RoomMemberStore(SQLBaseStore):
RoomsForUser(**r) for r in self.cursor_to_dict(txn) RoomsForUser(**r) for r in self.cursor_to_dict(txn)
] ]
return self.runInteraction(
"get_rooms_for_user_where_membership_is",
f
)
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):
return self._simple_select_onecol( return self._simple_select_onecol(
"room_hosts", "room_hosts",

View File

@ -36,6 +36,7 @@ what sort order was used:
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -127,6 +128,85 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
class StreamStore(SQLBaseStore): class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
# NB this lives here instead of appservice.py so we can reuse the
# 'private' StreamToken class in this file.
if limit:
limit = max(limit, MAX_STREAM_SIZE)
else:
limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering.
from_id = _StreamToken.parse_stream_token(from_key)
to_id = _StreamToken.parse_stream_token(to_key)
if from_key == to_key:
defer.returnValue(([], to_key))
return
# select all the events between from/to with a sensible limit
sql = (
"SELECT e.event_id, e.room_id, e.type, s.state_key, "
"e.stream_ordering FROM events AS e LEFT JOIN state_events as s ON "
"e.event_id = s.event_id "
"WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
"ORDER BY stream_ordering ASC LIMIT %(limit)d "
) % {
"limit": limit
}
def f(txn):
# pull out all the events between the tokens
txn.execute(sql, (from_id.stream, to_id.stream,))
rows = self.cursor_to_dict(txn)
# Logic:
# - We want ALL events which match the AS room_id regex
# - We want ALL events which match the rooms represented by the AS
# room_alias regex
# - We want ALL events for rooms that AS users have joined.
# This is currently supported via get_app_service_rooms (which is
# used for the Notifier listener rooms). We can't reasonably make a
# SQL query for these room IDs, so we'll pull all the events between
# from/to and filter in python.
rooms_for_as = self._get_app_service_rooms_txn(txn, service)
room_ids_for_as = [r.room_id for r in rooms_for_as]
def app_service_interested(row):
if row["room_id"] in room_ids_for_as:
return True
if row["type"] == EventTypes.Member:
if service.is_interested_in_user(row.get("state_key")):
return True
return False
ret = self._get_events_txn(
txn,
# apply the filter on the room id list
[
r["event_id"] for r in rows
if app_service_interested(r)
],
get_prev_content=True
)
self._set_before_and_after(ret, rows)
if rows:
key = "s%d" % max(r["stream_ordering"] for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
key = to_key
return ret, key
results = yield self.runInteraction("get_appservice_room_stream", f)
defer.returnValue(results)
@log_function @log_function
def get_room_events_stream(self, user_id, from_key, to_key, room_id, def get_room_events_stream(self, user_id, from_key, to_key, room_id,
limit=0, with_feedback=False): limit=0, with_feedback=False):
@ -184,8 +264,7 @@ class StreamStore(SQLBaseStore):
self._set_before_and_after(ret, rows) self._set_before_and_after(ret, rows)
if rows: if rows:
key = "s%d" % max([r["stream_ordering"] for r in rows]) key = "s%d" % max(r["stream_ordering"] for r in rows)
else: else:
# Assume we didn't get anything because there was nothing to # Assume we didn't get anything because there was nothing to
# get. # get.

View File

@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
import collections
import contextlib
import logging
logger = logging.getLogger(__name__)
class FederationRateLimiter(object):
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
reject_limit, concurrent_requests):
"""
Args:
clock (Clock)
window_size (int): The window size in milliseconds.
sleep_limit (int): The number of requests received in the last
`window_size` milliseconds before we artificially start
delaying processing of requests.
sleep_msec (int): The number of milliseconds to delay processing
of incoming requests by.
reject_limit (int): The maximum number of requests that are can be
queued for processing before we start rejecting requests with
a 429 Too Many Requests response.
concurrent_requests (int): The number of concurrent requests to
process.
"""
self.clock = clock
self.window_size = window_size
self.sleep_limit = sleep_limit
self.sleep_msec = sleep_msec
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
self.ratelimiters = {}
def ratelimit(self, host):
"""Used to ratelimit an incoming request from given host
Example usage:
with rate_limiter.ratelimit(origin) as wait_deferred:
yield wait_deferred
# Handle request ...
Args:
host (str): Origin of incoming request.
Returns:
_PerHostRatelimiter
"""
return self.ratelimiters.setdefault(
host,
_PerHostRatelimiter(
clock=self.clock,
window_size=self.window_size,
sleep_limit=self.sleep_limit,
sleep_msec=self.sleep_msec,
reject_limit=self.reject_limit,
concurrent_requests=self.concurrent_requests,
)
).ratelimit()
class _PerHostRatelimiter(object):
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
reject_limit, concurrent_requests):
self.clock = clock
self.window_size = window_size
self.sleep_limit = sleep_limit
self.sleep_msec = sleep_msec
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
self.sleeping_requests = set()
self.ready_request_queue = collections.OrderedDict()
self.current_processing = set()
self.request_times = []
def is_empty(self):
time_now = self.clock.time_msec()
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
return not (
self.ready_request_queue
or self.sleeping_requests
or self.current_processing
or self.request_times
)
@contextlib.contextmanager
def ratelimit(self):
# `contextlib.contextmanager` takes a generator and turns it into a
# context manager. The generator should only yield once with a value
# to be returned by manager.
# Exceptions will be reraised at the yield.
request_id = object()
ret = self._on_enter(request_id)
try:
yield ret
finally:
self._on_exit(request_id)
def _on_enter(self, request_id):
time_now = self.clock.time_msec()
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit:
raise LimitExceededError(
retry_after_ms=int(
self.window_size / self.sleep_limit
),
)
self.request_times.append(time_now)
def queue_request():
if len(self.current_processing) > self.concurrent_requests:
logger.debug("Ratelimit [%s]: Queue req", id(request_id))
queue_defer = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
return queue_defer
else:
return defer.succeed(None)
logger.debug(
"Ratelimit [%s]: len(self.request_times)=%d",
id(request_id), len(self.request_times),
)
if len(self.request_times) > self.sleep_limit:
logger.debug(
"Ratelimit [%s]: sleeping req",
id(request_id),
)
ret_defer = sleep(self.sleep_msec/1000.0)
self.sleeping_requests.add(request_id)
def on_wait_finished(_):
logger.debug(
"Ratelimit [%s]: Finished sleeping",
id(request_id),
)
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
ret_defer.addBoth(on_wait_finished)
else:
ret_defer = queue_request()
def on_start(r):
logger.debug(
"Ratelimit [%s]: Processing req",
id(request_id),
)
self.current_processing.add(request_id)
return r
def on_err(r):
self.current_processing.discard(request_id)
return r
def on_both(r):
# Ensure that we've properly cleaned up.
self.sleeping_requests.discard(request_id)
self.ready_request_queue.pop(request_id, None)
return r
ret_defer.addCallbacks(on_start, on_err)
ret_defer.addBoth(on_both)
return ret_defer
def _on_exit(self, request_id):
logger.debug(
"Ratelimit [%s]: Processed req",
id(request_id),
)
self.current_processing.discard(request_id)
try:
request_id, deferred = self.ready_request_queue.popitem()
self.current_processing.add(request_id)
deferred.callback(None)
except KeyError:
pass

View File

@ -18,6 +18,13 @@ from mock import Mock, PropertyMock
from tests import unittest from tests import unittest
def _regex(regex, exclusive=True):
return {
"regex": regex,
"exclusive": exclusive
}
class ApplicationServiceTestCase(unittest.TestCase): class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
@ -36,21 +43,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_regex_user_id_prefix_match(self): def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(self.event)) self.assertTrue(self.service.is_interested(self.event))
def test_regex_user_id_prefix_no_match(self): def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.assertFalse(self.service.is_interested(self.event)) self.assertFalse(self.service.is_interested(self.event))
def test_regex_room_member_is_checked(self): def test_regex_room_member_is_checked(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member" self.event.type = "m.room.member"
@ -59,30 +66,78 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_regex_room_id_match(self): def test_regex_room_id_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
"!some_prefix.*some_suffix:matrix.org" _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue(self.service.is_interested(self.event)) self.assertTrue(self.service.is_interested(self.event))
def test_regex_room_id_no_match(self): def test_regex_room_id_no_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
"!some_prefix.*some_suffix:matrix.org" _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse(self.service.is_interested(self.event)) self.assertFalse(self.service.is_interested(self.event))
def test_regex_alias_match(self): def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#irc_.*:matrix.org" _regex("#irc_.*:matrix.org")
) )
self.assertTrue(self.service.is_interested( self.assertTrue(self.service.is_interested(
self.event, self.event,
aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"] aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"]
)) ))
def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org", exclusive=False)
)
self.assertFalse(self.service.is_exclusive_alias(
"#irc_foobar:matrix.org"
))
def test_non_exclusive_room(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!irc_.*:matrix.org", exclusive=False)
)
self.assertFalse(self.service.is_exclusive_room(
"!irc_foobar:matrix.org"
))
def test_non_exclusive_user(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*:matrix.org", exclusive=False)
)
self.assertFalse(self.service.is_exclusive_user(
"@irc_foobar:matrix.org"
))
def test_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org", exclusive=True)
)
self.assertTrue(self.service.is_exclusive_alias(
"#irc_foobar:matrix.org"
))
def test_exclusive_user(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*:matrix.org", exclusive=True)
)
self.assertTrue(self.service.is_exclusive_user(
"@irc_foobar:matrix.org"
))
def test_exclusive_room(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!irc_.*:matrix.org", exclusive=True)
)
self.assertTrue(self.service.is_exclusive_room(
"!irc_foobar:matrix.org"
))
def test_regex_alias_no_match(self): def test_regex_alias_no_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#irc_.*:matrix.org" _regex("#irc_.*:matrix.org")
) )
self.assertFalse(self.service.is_interested( self.assertFalse(self.service.is_interested(
self.event, self.event,
@ -91,10 +146,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_regex_multiple_matches(self): def test_regex_multiple_matches(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#irc_.*:matrix.org" _regex("#irc_.*:matrix.org")
) )
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested( self.assertTrue(self.service.is_interested(
@ -104,10 +159,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_restrict_to_rooms(self): def test_restrict_to_rooms(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
"!flibble_.*:matrix.org" _regex("!flibble_.*:matrix.org")
) )
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.event.room_id = "!wibblewoo:matrix.org" self.event.room_id = "!wibblewoo:matrix.org"
@ -118,10 +173,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_restrict_to_aliases(self): def test_restrict_to_aliases(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#xmpp_.*:matrix.org" _regex("#xmpp_.*:matrix.org")
) )
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertFalse(self.service.is_interested( self.assertFalse(self.service.is_interested(
@ -132,10 +187,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_restrict_to_senders(self): def test_restrict_to_senders(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#xmpp_.*:matrix.org" _regex("#xmpp_.*:matrix.org")
) )
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@xmpp_foobar:matrix.org" self.event.sender = "@xmpp_foobar:matrix.org"
self.assertFalse(self.service.is_interested( self.assertFalse(self.service.is_interested(
@ -146,7 +201,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_member_list_match(self): def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
join_list = [ join_list = [
Mock( Mock(

View File

@ -389,14 +389,18 @@ class PresenceInvitesTestCase(PresenceTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invite_remote(self): def test_invite_remote(self):
# Use a different destination, otherwise retry logic might fail the
# request
u_rocket = UserID.from_string("@rocket:there")
put_json = self.mock_http_client.put_json put_json = self.mock_http_client.put_json
put_json.expect_call_and_return( put_json.expect_call_and_return(
call("elsewhere", call("there",
path="/_matrix/federation/v1/send/1000000/", path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu("elsewhere", "m.presence_invite", data=_expect_edu("there", "m.presence_invite",
content={ content={
"observer_user": "@apple:test", "observer_user": "@apple:test",
"observed_user": "@cabbage:elsewhere", "observed_user": "@rocket:there",
} }
), ),
json_data_callback=ANY, json_data_callback=ANY,
@ -405,10 +409,10 @@ class PresenceInvitesTestCase(PresenceTestCase):
) )
yield self.handler.send_invite( yield self.handler.send_invite(
observer_user=self.u_apple, observed_user=self.u_cabbage) observer_user=self.u_apple, observed_user=u_rocket)
self.assertEquals( self.assertEquals(
[{"observed_user_id": "@cabbage:elsewhere", "accepted": 0}], [{"observed_user_id": "@rocket:there", "accepted": 0}],
(yield self.datastore.get_presence_list(self.u_apple.localpart)) (yield self.datastore.get_presence_list(self.u_apple.localpart))
) )
@ -418,13 +422,18 @@ class PresenceInvitesTestCase(PresenceTestCase):
def test_accept_remote(self): def test_accept_remote(self):
# TODO(paul): This test will likely break if/when real auth permissions # TODO(paul): This test will likely break if/when real auth permissions
# are added; for now the HS will always accept any invite # are added; for now the HS will always accept any invite
# Use a different destination, otherwise retry logic might fail the
# request
u_rocket = UserID.from_string("@rocket:moon")
put_json = self.mock_http_client.put_json put_json = self.mock_http_client.put_json
put_json.expect_call_and_return( put_json.expect_call_and_return(
call("elsewhere", call("moon",
path="/_matrix/federation/v1/send/1000000/", path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu("elsewhere", "m.presence_accept", data=_expect_edu("moon", "m.presence_accept",
content={ content={
"observer_user": "@cabbage:elsewhere", "observer_user": "@rocket:moon",
"observed_user": "@apple:test", "observed_user": "@apple:test",
} }
), ),
@ -437,7 +446,7 @@ class PresenceInvitesTestCase(PresenceTestCase):
"/_matrix/federation/v1/send/1000000/", "/_matrix/federation/v1/send/1000000/",
_make_edu_json("elsewhere", "m.presence_invite", _make_edu_json("elsewhere", "m.presence_invite",
content={ content={
"observer_user": "@cabbage:elsewhere", "observer_user": "@rocket:moon",
"observed_user": "@apple:test", "observed_user": "@apple:test",
} }
) )
@ -446,7 +455,7 @@ class PresenceInvitesTestCase(PresenceTestCase):
self.assertTrue( self.assertTrue(
(yield self.datastore.is_presence_visible( (yield self.datastore.is_presence_visible(
observed_localpart=self.u_apple.localpart, observed_localpart=self.u_apple.localpart,
observer_userid=self.u_cabbage.to_string(), observer_userid=u_rocket.to_string(),
)) ))
) )
@ -454,13 +463,17 @@ class PresenceInvitesTestCase(PresenceTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invited_remote_nonexistant(self): def test_invited_remote_nonexistant(self):
# Use a different destination, otherwise retry logic might fail the
# request
u_rocket = UserID.from_string("@rocket:sun")
put_json = self.mock_http_client.put_json put_json = self.mock_http_client.put_json
put_json.expect_call_and_return( put_json.expect_call_and_return(
call("elsewhere", call("sun",
path="/_matrix/federation/v1/send/1000000/", path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu("elsewhere", "m.presence_deny", data=_expect_edu("sun", "m.presence_deny",
content={ content={
"observer_user": "@cabbage:elsewhere", "observer_user": "@rocket:sun",
"observed_user": "@durian:test", "observed_user": "@durian:test",
} }
), ),
@ -471,9 +484,9 @@ class PresenceInvitesTestCase(PresenceTestCase):
yield self.mock_federation_resource.trigger("PUT", yield self.mock_federation_resource.trigger("PUT",
"/_matrix/federation/v1/send/1000000/", "/_matrix/federation/v1/send/1000000/",
_make_edu_json("elsewhere", "m.presence_invite", _make_edu_json("sun", "m.presence_invite",
content={ content={
"observer_user": "@cabbage:elsewhere", "observer_user": "@rocket:sun",
"observed_user": "@durian:test", "observed_user": "@durian:test",
} }
) )

View File

@ -295,6 +295,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
self.mock_datastore = hs.get_datastore() self.mock_datastore = hs.get_datastore()
self.mock_datastore.get_app_service_by_token = Mock(return_value=None) self.mock_datastore.get_app_service_by_token = Mock(return_value=None)
self.mock_datastore.get_app_service_by_user_id = Mock(
return_value=defer.succeed(None)
)
def get_profile_displayname(user_id): def get_profile_displayname(user_id):
return defer.succeed("Frank") return defer.succeed("Frank")

View File

@ -50,9 +50,15 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
def test_update_and_retrieval_of_service(self): def test_update_and_retrieval_of_service(self):
url = "https://matrix.org/appservices/foobar" url = "https://matrix.org/appservices/foobar"
hs_token = "hstok" hs_token = "hstok"
user_regex = ["@foobar_.*:matrix.org"] user_regex = [
alias_regex = ["#foobar_.*:matrix.org"] {"regex": "@foobar_.*:matrix.org", "exclusive": True}
room_regex = [] ]
alias_regex = [
{"regex": "#foobar_.*:matrix.org", "exclusive": False}
]
room_regex = [
]
service = ApplicationService( service = ApplicationService(
url=url, hs_token=hs_token, token=self.as_token, namespaces={ url=url, hs_token=hs_token, token=self.as_token, namespaces={
ApplicationService.NS_USERS: user_regex, ApplicationService.NS_USERS: user_regex,