Merge branch 'develop' of github.com:matrix-org/synapse into batched_get_pdu
This commit is contained in:
commit
0a036944bd
|
@ -41,3 +41,4 @@ media_store/
|
|||
build/
|
||||
|
||||
localhost-800*/
|
||||
static/client/register/register_config.js
|
||||
|
|
|
@ -1,3 +1,12 @@
|
|||
Changes in synapse vx.x.x (x-x-x)
|
||||
=================================
|
||||
|
||||
* Add support for registration fallback. This is a page hosted on the server
|
||||
which allows a user to register for an account, regardless of what client
|
||||
they are using (e.g. mobile devices).
|
||||
* Application services can now poll on the CS API ``/events`` for their events,
|
||||
by providing their application service ``access_token``.
|
||||
|
||||
Changes in synapse v0.7.1 (2015-02-19)
|
||||
======================================
|
||||
|
||||
|
|
14
UPGRADE.rst
14
UPGRADE.rst
|
@ -1,3 +1,17 @@
|
|||
Upgrading to vx.xx
|
||||
==================
|
||||
|
||||
Servers which use captchas will need to add their public key to::
|
||||
|
||||
static/client/register/register_config.js
|
||||
|
||||
window.matrixRegistrationConfig = {
|
||||
recaptcha_public_key: "YOUR_PUBLIC_KEY"
|
||||
};
|
||||
|
||||
This is required in order to support registration fallback (typically used on
|
||||
mobile devices).
|
||||
|
||||
Upgrading to v0.7.0
|
||||
===================
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ Your home server configuration file needs the following extra keys:
|
|||
As an example, here is the relevant section of the config file for
|
||||
matrix.org::
|
||||
|
||||
turn_uris: turn:turn.matrix.org:3478?transport=udp,turn:turn.matrix.org:3478?transport=tcp
|
||||
turn_uris: [ "turn:turn.matrix.org:3478?transport=udp", "turn:turn.matrix.org:3478?transport=tcp" ]
|
||||
turn_shared_secret: n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons
|
||||
turn_user_lifetime: 86400000
|
||||
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
<html>
|
||||
<head>
|
||||
<title> Registration </title>
|
||||
<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
|
||||
<link rel="stylesheet" href="style.css">
|
||||
<script src="js/jquery-2.1.3.min.js"></script>
|
||||
<script src="js/recaptcha_ajax.js"></script>
|
||||
<script src="register_config.js"></script>
|
||||
<script src="js/register.js"></script>
|
||||
</head>
|
||||
<body onload="matrixRegistration.onLoad()">
|
||||
<form id="registrationForm" onsubmit="matrixRegistration.signUp(); return false;">
|
||||
<div>
|
||||
Create account:<br/>
|
||||
|
||||
<div style="text-align: center">
|
||||
<input id="desired_user_id" size="32" type="text" placeholder="Matrix ID (e.g. bob)" autocapitalize="off" autocorrect="off" />
|
||||
<br/>
|
||||
<input id="pwd1" size="32" type="password" placeholder="Type a password"/>
|
||||
<br/>
|
||||
<input id="pwd2" size="32" type="password" placeholder="Confirm your password"/>
|
||||
<br/>
|
||||
<span id="feedback" style="color: #f00"></span>
|
||||
<br/>
|
||||
<div id="regcaptcha"></div>
|
||||
|
||||
<button type="submit" style="margin: 10px">Sign up</button>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,117 @@
|
|||
window.matrixRegistration = {
|
||||
endpoint: location.origin + "/_matrix/client/api/v1/register"
|
||||
};
|
||||
|
||||
var setupCaptcha = function() {
|
||||
if (!window.matrixRegistrationConfig) {
|
||||
return;
|
||||
}
|
||||
$.get(matrixRegistration.endpoint, function(response) {
|
||||
var serverExpectsCaptcha = false;
|
||||
for (var i=0; i<response.flows.length; i++) {
|
||||
var flow = response.flows[i];
|
||||
if ("m.login.recaptcha" === flow.type) {
|
||||
serverExpectsCaptcha = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!serverExpectsCaptcha) {
|
||||
console.log("This server does not require a captcha.");
|
||||
return;
|
||||
}
|
||||
console.log("Setting up ReCaptcha for "+matrixRegistration.endpoint);
|
||||
var public_key = window.matrixRegistrationConfig.recaptcha_public_key;
|
||||
if (public_key === undefined) {
|
||||
console.error("No public key defined for captcha!");
|
||||
setFeedbackString("Misconfigured captcha for server. Contact server admin.");
|
||||
return;
|
||||
}
|
||||
Recaptcha.create(public_key,
|
||||
"regcaptcha",
|
||||
{
|
||||
theme: "red",
|
||||
callback: Recaptcha.focus_response_field
|
||||
});
|
||||
window.matrixRegistration.isUsingRecaptcha = true;
|
||||
}).error(errorFunc);
|
||||
|
||||
};
|
||||
|
||||
var submitCaptcha = function(user, pwd) {
|
||||
var challengeToken = Recaptcha.get_challenge();
|
||||
var captchaEntry = Recaptcha.get_response();
|
||||
var data = {
|
||||
type: "m.login.recaptcha",
|
||||
challenge: challengeToken,
|
||||
response: captchaEntry
|
||||
};
|
||||
console.log("Submitting captcha");
|
||||
$.post(matrixRegistration.endpoint, JSON.stringify(data), function(response) {
|
||||
console.log("Success -> "+JSON.stringify(response));
|
||||
submitPassword(user, pwd, response.session);
|
||||
}).error(function(err) {
|
||||
Recaptcha.reload();
|
||||
errorFunc(err);
|
||||
});
|
||||
};
|
||||
|
||||
var submitPassword = function(user, pwd, session) {
|
||||
console.log("Registering...");
|
||||
var data = {
|
||||
type: "m.login.password",
|
||||
user: user,
|
||||
password: pwd,
|
||||
session: session
|
||||
};
|
||||
$.post(matrixRegistration.endpoint, JSON.stringify(data), function(response) {
|
||||
matrixRegistration.onRegistered(
|
||||
response.home_server, response.user_id, response.access_token
|
||||
);
|
||||
}).error(errorFunc);
|
||||
};
|
||||
|
||||
var errorFunc = function(err) {
|
||||
if (err.responseJSON && err.responseJSON.error) {
|
||||
setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")");
|
||||
}
|
||||
else {
|
||||
setFeedbackString("Request failed: " + err.status);
|
||||
}
|
||||
};
|
||||
|
||||
var setFeedbackString = function(text) {
|
||||
$("#feedback").text(text);
|
||||
};
|
||||
|
||||
matrixRegistration.onLoad = function() {
|
||||
setupCaptcha();
|
||||
};
|
||||
|
||||
matrixRegistration.signUp = function() {
|
||||
var user = $("#desired_user_id").val();
|
||||
if (user.length == 0) {
|
||||
setFeedbackString("Must specify a username.");
|
||||
return;
|
||||
}
|
||||
var pwd1 = $("#pwd1").val();
|
||||
var pwd2 = $("#pwd2").val();
|
||||
if (pwd1.length < 6) {
|
||||
setFeedbackString("Password: min. 6 characters.");
|
||||
return;
|
||||
}
|
||||
if (pwd1 != pwd2) {
|
||||
setFeedbackString("Passwords do not match.");
|
||||
return;
|
||||
}
|
||||
if (window.matrixRegistration.isUsingRecaptcha) {
|
||||
submitCaptcha(user, pwd1);
|
||||
}
|
||||
else {
|
||||
submitPassword(user, pwd1);
|
||||
}
|
||||
};
|
||||
|
||||
matrixRegistration.onRegistered = function(hs_url, user_id, access_token) {
|
||||
// clobber this function
|
||||
console.log("onRegistered - This function should be replaced to proceed.");
|
||||
};
|
|
@ -0,0 +1,3 @@
|
|||
window.matrixRegistrationConfig = {
|
||||
recaptcha_public_key: "YOUR_PUBLIC_KEY"
|
||||
};
|
|
@ -0,0 +1,56 @@
|
|||
html {
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
body {
|
||||
height: 100%;
|
||||
font-family: "Myriad Pro", "Myriad", Helvetica, Arial, sans-serif;
|
||||
font-size: 12pt;
|
||||
margin: 0px;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 20pt;
|
||||
}
|
||||
|
||||
a:link { color: #666; }
|
||||
a:visited { color: #666; }
|
||||
a:hover { color: #000; }
|
||||
a:active { color: #000; }
|
||||
|
||||
input {
|
||||
width: 100%
|
||||
}
|
||||
|
||||
textarea, input {
|
||||
font-family: inherit;
|
||||
font-size: inherit;
|
||||
}
|
||||
|
||||
.smallPrint {
|
||||
color: #888;
|
||||
font-size: 9pt ! important;
|
||||
font-style: italic ! important;
|
||||
}
|
||||
|
||||
#recaptcha_area {
|
||||
margin: auto
|
||||
}
|
||||
|
||||
#registrationForm {
|
||||
text-align: left;
|
||||
padding: 1em;
|
||||
margin-bottom: 40px;
|
||||
display: inline-block;
|
||||
|
||||
-webkit-border-radius: 10px;
|
||||
-moz-border-radius: 10px;
|
||||
border-radius: 10px;
|
||||
|
||||
-webkit-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
|
||||
-moz-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
|
||||
box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
|
||||
|
||||
background-color: #f8f8f8;
|
||||
border: 1px #ccc solid;
|
||||
}
|
|
@ -18,6 +18,7 @@
|
|||
CLIENT_PREFIX = "/_matrix/client/api/v1"
|
||||
CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
|
||||
FEDERATION_PREFIX = "/_matrix/federation/v1"
|
||||
STATIC_PREFIX = "/_matrix/static"
|
||||
WEB_CLIENT_PREFIX = "/_matrix/client"
|
||||
CONTENT_REPO_PREFIX = "/_matrix/content"
|
||||
SERVER_KEY_PREFIX = "/_matrix/key/v1"
|
||||
|
|
|
@ -36,7 +36,8 @@ from synapse.http.server_key_resource import LocalKey
|
|||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.api.urls import (
|
||||
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
|
||||
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX
|
||||
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX,
|
||||
STATIC_PREFIX
|
||||
)
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
|
@ -52,6 +53,7 @@ import synapse
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import resource
|
||||
import subprocess
|
||||
import sqlite3
|
||||
import syweb
|
||||
|
@ -81,6 +83,9 @@ class SynapseHomeServer(HomeServer):
|
|||
webclient_path = os.path.join(syweb_path, "webclient")
|
||||
return File(webclient_path) # TODO configurable?
|
||||
|
||||
def build_resource_for_static_content(self):
|
||||
return File("static")
|
||||
|
||||
def build_resource_for_content_repo(self):
|
||||
return ContentRepoResource(
|
||||
self, self.upload_dir, self.auth, self.content_addr
|
||||
|
@ -124,7 +129,9 @@ class SynapseHomeServer(HomeServer):
|
|||
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
|
||||
(MEDIA_PREFIX, self.get_resource_for_media_repository()),
|
||||
(APP_SERVICE_PREFIX, self.get_resource_for_app_services()),
|
||||
(STATIC_PREFIX, self.get_resource_for_static_content()),
|
||||
]
|
||||
|
||||
if web_client:
|
||||
logger.info("Adding the web client.")
|
||||
desired_tree.append((WEB_CLIENT_PREFIX,
|
||||
|
@ -140,8 +147,8 @@ class SynapseHomeServer(HomeServer):
|
|||
# instead, we'll store a copy of this mapping so we can actually add
|
||||
# extra resources to existing nodes. See self._resource_id for the key.
|
||||
resource_mappings = {}
|
||||
for (full_path, resource) in desired_tree:
|
||||
logger.info("Attaching %s to path %s", resource, full_path)
|
||||
for full_path, res in desired_tree:
|
||||
logger.info("Attaching %s to path %s", res, full_path)
|
||||
last_resource = self.root_resource
|
||||
for path_seg in full_path.split('/')[1:-1]:
|
||||
if path_seg not in last_resource.listNames():
|
||||
|
@ -172,12 +179,12 @@ class SynapseHomeServer(HomeServer):
|
|||
child_name)
|
||||
child_resource = resource_mappings[child_res_id]
|
||||
# steal the children
|
||||
resource.putChild(child_name, child_resource)
|
||||
res.putChild(child_name, child_resource)
|
||||
|
||||
# finally, insert the desired resource in the right place
|
||||
last_resource.putChild(last_path_seg, resource)
|
||||
last_resource.putChild(last_path_seg, res)
|
||||
res_id = self._resource_id(last_resource, last_path_seg)
|
||||
resource_mappings[res_id] = resource
|
||||
resource_mappings[res_id] = res
|
||||
|
||||
return self.root_resource
|
||||
|
||||
|
@ -269,6 +276,20 @@ def get_version_string():
|
|||
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
|
||||
|
||||
|
||||
def change_resource_limit(soft_file_no):
|
||||
try:
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
|
||||
if not soft_file_no:
|
||||
soft_file_no = hard
|
||||
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
|
||||
|
||||
logger.info("Set file limit to: %d", soft_file_no)
|
||||
except (ValueError, resource.error) as e:
|
||||
logger.warn("Failed to set file limit: %s", e)
|
||||
|
||||
|
||||
def setup():
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse Homeserver",
|
||||
|
@ -345,10 +366,11 @@ def setup():
|
|||
|
||||
if config.daemonize:
|
||||
print config.pid_file
|
||||
|
||||
daemon = Daemonize(
|
||||
app="synapse-homeserver",
|
||||
pid=config.pid_file,
|
||||
action=run,
|
||||
action=lambda: run(config),
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
|
@ -356,11 +378,13 @@ def setup():
|
|||
|
||||
daemon.start()
|
||||
else:
|
||||
reactor.run()
|
||||
run(config)
|
||||
|
||||
|
||||
def run():
|
||||
def run(config):
|
||||
with LoggingContext("run"):
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
|
||||
reactor.run()
|
||||
|
||||
|
||||
|
|
|
@ -22,6 +22,12 @@ class RatelimitConfig(Config):
|
|||
self.rc_messages_per_second = args.rc_messages_per_second
|
||||
self.rc_message_burst_count = args.rc_message_burst_count
|
||||
|
||||
self.federation_rc_window_size = args.federation_rc_window_size
|
||||
self.federation_rc_sleep_limit = args.federation_rc_sleep_limit
|
||||
self.federation_rc_sleep_delay = args.federation_rc_sleep_delay
|
||||
self.federation_rc_reject_limit = args.federation_rc_reject_limit
|
||||
self.federation_rc_concurrent = args.federation_rc_concurrent
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(RatelimitConfig, cls).add_arguments(parser)
|
||||
|
@ -34,3 +40,33 @@ class RatelimitConfig(Config):
|
|||
"--rc-message-burst-count", type=float, default=10,
|
||||
help="number of message a client can send before being throttled"
|
||||
)
|
||||
|
||||
rc_group.add_argument(
|
||||
"--federation-rc-window-size", type=int, default=10000,
|
||||
help="The federation window size in milliseconds",
|
||||
)
|
||||
|
||||
rc_group.add_argument(
|
||||
"--federation-rc-sleep-limit", type=int, default=10,
|
||||
help="The number of federation requests from a single server"
|
||||
" in a window before the server will delay processing the"
|
||||
" request.",
|
||||
)
|
||||
|
||||
rc_group.add_argument(
|
||||
"--federation-rc-sleep-delay", type=int, default=500,
|
||||
help="The duration in milliseconds to delay processing events from"
|
||||
" remote servers by if they go over the sleep limit.",
|
||||
)
|
||||
|
||||
rc_group.add_argument(
|
||||
"--federation-rc-reject-limit", type=int, default=50,
|
||||
help="The maximum number of concurrent federation requests allowed"
|
||||
" from a single server",
|
||||
)
|
||||
|
||||
rc_group.add_argument(
|
||||
"--federation-rc-concurrent", type=int, default=3,
|
||||
help="The number of federation requests to concurrently process"
|
||||
" from a single server",
|
||||
)
|
||||
|
|
|
@ -31,6 +31,7 @@ class ServerConfig(Config):
|
|||
self.webclient = True
|
||||
self.manhole = args.manhole
|
||||
self.no_tls = args.no_tls
|
||||
self.soft_file_limit = args.soft_file_limit
|
||||
|
||||
if not args.content_addr:
|
||||
host = args.server_name
|
||||
|
@ -77,6 +78,12 @@ class ServerConfig(Config):
|
|||
"content repository")
|
||||
server_group.add_argument("--no-tls", action='store_true',
|
||||
help="Don't bind to the https port.")
|
||||
server_group.add_argument("--soft-file-limit", type=int, default=0,
|
||||
help="Set the soft limit on the number of "
|
||||
"file descriptors synapse can use. "
|
||||
"Zero is used to indicate synapse "
|
||||
"should set the soft limit to the hard"
|
||||
"limit.")
|
||||
|
||||
def read_signing_key(self, signing_key_path):
|
||||
signing_keys = self.read_file(signing_key_path, "signing_key")
|
||||
|
|
|
@ -28,7 +28,7 @@ class VoipConfig(Config):
|
|||
super(VoipConfig, cls).add_arguments(parser)
|
||||
group = parser.add_argument_group("voip")
|
||||
group.add_argument(
|
||||
"--turn-uris", type=str, default=None,
|
||||
"--turn-uris", type=str, default=None, action='append',
|
||||
help="The public URIs of the TURN server to give to clients"
|
||||
)
|
||||
group.add_argument(
|
||||
|
|
|
@ -112,17 +112,20 @@ class FederationServer(FederationBase):
|
|||
logger.debug("[%s] Transaction is new", transaction.transaction_id)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
dl = []
|
||||
results = []
|
||||
|
||||
for pdu in pdu_list:
|
||||
d = self._handle_new_pdu(transaction.origin, pdu)
|
||||
|
||||
def handle_failure(failure):
|
||||
failure.trap(FederationError)
|
||||
self.send_failure(failure.value, transaction.origin)
|
||||
|
||||
d.addErrback(handle_failure)
|
||||
|
||||
dl.append(d)
|
||||
try:
|
||||
yield d
|
||||
results.append({})
|
||||
except FederationError as e:
|
||||
self.send_failure(e, transaction.origin)
|
||||
results.append({"error": str(e)})
|
||||
except Exception as e:
|
||||
results.append({"error": str(e)})
|
||||
logger.exception("Failed to handle PDU")
|
||||
|
||||
if hasattr(transaction, "edus"):
|
||||
for edu in [Edu(**x) for x in transaction.edus]:
|
||||
|
@ -135,29 +138,11 @@ class FederationServer(FederationBase):
|
|||
for failure in getattr(transaction, "pdu_failures", []):
|
||||
logger.info("Got failure %r", failure)
|
||||
|
||||
results = yield defer.DeferredList(dl, consumeErrors=True)
|
||||
|
||||
ret = []
|
||||
for r in results:
|
||||
if r[0]:
|
||||
ret.append({})
|
||||
else:
|
||||
failure = r[1]
|
||||
logger.error(
|
||||
"Failed to handle PDU",
|
||||
exc_info=(
|
||||
failure.type,
|
||||
failure.value,
|
||||
failure.getTracebackObject()
|
||||
)
|
||||
)
|
||||
ret.append({"error": str(r[1].value)})
|
||||
|
||||
logger.debug("Returning: %s", str(ret))
|
||||
logger.debug("Returning: %s", str(results))
|
||||
|
||||
response = {
|
||||
"pdus": dict(zip(
|
||||
(p.event_id for p in pdu_list), ret
|
||||
(p.event_id for p in pdu_list), results
|
||||
)),
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,8 @@ communicate over a different (albeit still reliable) protocol.
|
|||
from .server import TransportLayerServer
|
||||
from .client import TransportLayerClient
|
||||
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
|
||||
|
||||
class TransportLayer(TransportLayerServer, TransportLayerClient):
|
||||
"""This is a basic implementation of the transport layer that translates
|
||||
|
@ -55,8 +57,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient):
|
|||
send requests
|
||||
"""
|
||||
self.keyring = homeserver.get_keyring()
|
||||
self.clock = homeserver.get_clock()
|
||||
self.server_name = server_name
|
||||
self.server = server
|
||||
self.client = client
|
||||
self.request_handler = None
|
||||
self.received_handler = None
|
||||
|
||||
self.ratelimiter = FederationRateLimiter(
|
||||
self.clock,
|
||||
window_size=homeserver.config.federation_rc_window_size,
|
||||
sleep_limit=homeserver.config.federation_rc_sleep_limit,
|
||||
sleep_msec=homeserver.config.federation_rc_sleep_delay,
|
||||
reject_limit=homeserver.config.federation_rc_reject_limit,
|
||||
concurrent_requests=homeserver.config.federation_rc_concurrent,
|
||||
)
|
||||
|
|
|
@ -98,15 +98,23 @@ class TransportLayerServer(object):
|
|||
def new_handler(request, *args, **kwargs):
|
||||
try:
|
||||
(origin, content) = yield self._authenticate_request(request)
|
||||
response = yield handler(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
with self.ratelimiter.ratelimit(origin) as d:
|
||||
yield d
|
||||
response = yield handler(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
except:
|
||||
logger.exception("_authenticate_request failed")
|
||||
raise
|
||||
defer.returnValue(response)
|
||||
return new_handler
|
||||
|
||||
def rate_limit_origin(self, handler):
|
||||
def new_handler(origin, *args, **kwargs):
|
||||
response = yield handler(origin, *args, **kwargs)
|
||||
defer.returnValue(response)
|
||||
return new_handler()
|
||||
|
||||
@log_function
|
||||
def register_received_handler(self, handler):
|
||||
""" Register a handler that will be fired when we receive data.
|
||||
|
|
|
@ -160,7 +160,7 @@ class DirectoryHandler(BaseHandler):
|
|||
if not room_id:
|
||||
raise SynapseError(
|
||||
404,
|
||||
"Room alias %r not found" % (room_alias.to_string(),),
|
||||
"Room alias %s not found" % (room_alias.to_string(),),
|
||||
Codes.NOT_FOUND
|
||||
)
|
||||
|
||||
|
|
|
@ -69,9 +69,6 @@ class EventStreamHandler(BaseHandler):
|
|||
)
|
||||
self._streams_per_user[auth_user] += 1
|
||||
|
||||
if pagin_config.from_token is None:
|
||||
pagin_config.from_token = None
|
||||
|
||||
rm_handler = self.hs.get_handlers().room_member_handler
|
||||
room_ids = yield rm_handler.get_rooms_for_user(auth_user)
|
||||
|
||||
|
|
|
@ -510,9 +510,16 @@ class RoomMemberHandler(BaseHandler):
|
|||
def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]):
|
||||
"""Returns a list of roomids that the user has any of the given
|
||||
membership states in."""
|
||||
rooms = yield self.store.get_rooms_for_user_where_membership_is(
|
||||
user_id=user.to_string(), membership_list=membership_list
|
||||
|
||||
app_service = yield self.store.get_app_service_by_user_id(
|
||||
user.to_string()
|
||||
)
|
||||
if app_service:
|
||||
rooms = yield self.store.get_app_service_rooms(app_service)
|
||||
else:
|
||||
rooms = yield self.store.get_rooms_for_user_where_membership_is(
|
||||
user_id=user.to_string(), membership_list=membership_list
|
||||
)
|
||||
|
||||
# For some reason the list of events contains duplicates
|
||||
# TODO(paul): work out why because I really don't think it should
|
||||
|
@ -559,13 +566,24 @@ class RoomEventSource(object):
|
|||
|
||||
to_key = yield self.get_current_key()
|
||||
|
||||
events, end_key = yield self.store.get_room_events_stream(
|
||||
user_id=user.to_string(),
|
||||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
room_id=None,
|
||||
limit=limit,
|
||||
app_service = yield self.store.get_app_service_by_user_id(
|
||||
user.to_string()
|
||||
)
|
||||
if app_service:
|
||||
events, end_key = yield self.store.get_appservice_room_stream(
|
||||
service=app_service,
|
||||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
limit=limit,
|
||||
)
|
||||
else:
|
||||
events, end_key = yield self.store.get_room_events_stream(
|
||||
user_id=user.to_string(),
|
||||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
room_id=None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
defer.returnValue((events, end_key))
|
||||
|
||||
|
|
|
@ -36,8 +36,10 @@ class _NotificationListener(object):
|
|||
so that it can remove itself from the indexes in the Notifier class.
|
||||
"""
|
||||
|
||||
def __init__(self, user, rooms, from_token, limit, timeout, deferred):
|
||||
def __init__(self, user, rooms, from_token, limit, timeout, deferred,
|
||||
appservice=None):
|
||||
self.user = user
|
||||
self.appservice = appservice
|
||||
self.from_token = from_token
|
||||
self.limit = limit
|
||||
self.timeout = timeout
|
||||
|
@ -65,6 +67,10 @@ class _NotificationListener(object):
|
|||
lst.discard(self)
|
||||
|
||||
notifier.user_to_listeners.get(self.user, set()).discard(self)
|
||||
if self.appservice:
|
||||
notifier.appservice_to_listeners.get(
|
||||
self.appservice, set()
|
||||
).discard(self)
|
||||
|
||||
|
||||
class Notifier(object):
|
||||
|
@ -79,6 +85,7 @@ class Notifier(object):
|
|||
|
||||
self.rooms_to_listeners = {}
|
||||
self.user_to_listeners = {}
|
||||
self.appservice_to_listeners = {}
|
||||
|
||||
self.event_sources = hs.get_event_sources()
|
||||
|
||||
|
@ -114,6 +121,17 @@ class Notifier(object):
|
|||
for user in extra_users:
|
||||
listeners |= self.user_to_listeners.get(user, set()).copy()
|
||||
|
||||
for appservice in self.appservice_to_listeners:
|
||||
# TODO (kegan): Redundant appservice listener checks?
|
||||
# App services will already be in the rooms_to_listeners set, but
|
||||
# that isn't enough. They need to be checked here in order to
|
||||
# receive *invites* for users they are interested in. Does this
|
||||
# make the rooms_to_listeners check somewhat obselete?
|
||||
if appservice.is_interested(event):
|
||||
listeners |= self.appservice_to_listeners.get(
|
||||
appservice, set()
|
||||
).copy()
|
||||
|
||||
logger.debug("on_new_room_event listeners %s", listeners)
|
||||
|
||||
# TODO (erikj): Can we make this more efficient by hitting the
|
||||
|
@ -280,6 +298,10 @@ class Notifier(object):
|
|||
if not from_token:
|
||||
from_token = yield self.event_sources.get_current_token()
|
||||
|
||||
appservice = yield self.hs.get_datastore().get_app_service_by_user_id(
|
||||
user.to_string()
|
||||
)
|
||||
|
||||
listener = _NotificationListener(
|
||||
user,
|
||||
rooms,
|
||||
|
@ -287,6 +309,7 @@ class Notifier(object):
|
|||
limit,
|
||||
timeout,
|
||||
deferred,
|
||||
appservice=appservice
|
||||
)
|
||||
|
||||
def _timeout_listener():
|
||||
|
@ -319,6 +342,11 @@ class Notifier(object):
|
|||
|
||||
self.user_to_listeners.setdefault(listener.user, set()).add(listener)
|
||||
|
||||
if listener.appservice:
|
||||
self.appservice_to_listeners.setdefault(
|
||||
listener.appservice, set()
|
||||
).add(listener)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _check_for_updates(self, listener):
|
||||
|
|
|
@ -214,7 +214,7 @@ def _rule_spec_from_path(path):
|
|||
template = path[0]
|
||||
path = path[1:]
|
||||
|
||||
if len(path) == 0:
|
||||
if len(path) == 0 or len(path[0]) == 0:
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
rule_id = path[0]
|
||||
|
|
|
@ -73,6 +73,7 @@ class BaseHomeServer(object):
|
|||
'resource_for_client',
|
||||
'resource_for_client_v2_alpha',
|
||||
'resource_for_federation',
|
||||
'resource_for_static_content',
|
||||
'resource_for_web_client',
|
||||
'resource_for_content_repo',
|
||||
'resource_for_server_key',
|
||||
|
|
|
@ -23,7 +23,7 @@ from synapse.util.lrucache import LruCache
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
import collections
|
||||
from collections import namedtuple, OrderedDict
|
||||
import simplejson as json
|
||||
import sys
|
||||
import time
|
||||
|
@ -35,6 +35,52 @@ sql_logger = logging.getLogger("synapse.storage.SQL")
|
|||
transaction_logger = logging.getLogger("synapse.storage.txn")
|
||||
|
||||
|
||||
# TODO(paul):
|
||||
# * more generic key management
|
||||
# * export monitoring stats
|
||||
# * consider other eviction strategies - LRU?
|
||||
def cached(max_entries=1000):
|
||||
""" A method decorator that applies a memoizing cache around the function.
|
||||
|
||||
The function is presumed to take one additional argument, which is used as
|
||||
the key for the cache. Cache hits are served directly from the cache;
|
||||
misses use the function body to generate the value.
|
||||
|
||||
The wrapped function has an additional member, a callable called
|
||||
"invalidate". This can be used to remove individual entries from the cache.
|
||||
|
||||
The wrapped function has another additional callable, called "prefill",
|
||||
which can be used to insert values into the cache specifically, without
|
||||
calling the calculation function.
|
||||
"""
|
||||
def wrap(orig):
|
||||
cache = OrderedDict()
|
||||
|
||||
def prefill(key, value):
|
||||
while len(cache) > max_entries:
|
||||
cache.popitem(last=False)
|
||||
|
||||
cache[key] = value
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wrapped(self, key):
|
||||
if key in cache:
|
||||
defer.returnValue(cache[key])
|
||||
|
||||
ret = yield orig(self, key)
|
||||
prefill(key, ret)
|
||||
defer.returnValue(ret)
|
||||
|
||||
def invalidate(key):
|
||||
cache.pop(key, None)
|
||||
|
||||
wrapped.invalidate = invalidate
|
||||
wrapped.prefill = prefill
|
||||
return wrapped
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
class LoggingTransaction(object):
|
||||
"""An object that almost-transparently proxies for the 'txn' object
|
||||
passed to the constructor. Adds logging to the .execute() method."""
|
||||
|
@ -404,7 +450,8 @@ class SQLBaseStore(object):
|
|||
|
||||
Args:
|
||||
table : string giving the table name
|
||||
keyvalues : dict of column names and values to select the rows with
|
||||
keyvalues : dict of column names and values to select the rows with,
|
||||
or None to not apply a WHERE clause.
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
"""
|
||||
return self.runInteraction(
|
||||
|
@ -423,13 +470,20 @@ class SQLBaseStore(object):
|
|||
keyvalues : dict of column names and values to select the rows with
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
"""
|
||||
sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
)
|
||||
if keyvalues:
|
||||
sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
)
|
||||
txn.execute(sql, keyvalues.values())
|
||||
else:
|
||||
sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
|
||||
", ".join(retcols),
|
||||
table
|
||||
)
|
||||
txn.execute(sql)
|
||||
|
||||
txn.execute(sql, keyvalues.values())
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||
|
@ -586,8 +640,9 @@ class SQLBaseStore(object):
|
|||
start_time = time.time() * 1000
|
||||
update_counter = self._get_event_counters.update
|
||||
|
||||
cache = self._get_event_cache.setdefault(event_id, {})
|
||||
|
||||
try:
|
||||
cache = self._get_event_cache.setdefault(event_id, {})
|
||||
# Separate cache entries for each way to invoke _get_event_txn
|
||||
return cache[(check_redacted, get_prev_content, allow_rejected)]
|
||||
except KeyError:
|
||||
|
@ -786,7 +841,7 @@ class JoinHelper(object):
|
|||
for table in self.tables:
|
||||
res += [f for f in table.fields if f not in res]
|
||||
|
||||
self.EntryType = collections.namedtuple("JoinHelperEntry", res)
|
||||
self.EntryType = namedtuple("JoinHelperEntry", res)
|
||||
|
||||
def get_fields(self, **prefixes):
|
||||
"""Get a string representing a list of fields for use in SELECT
|
||||
|
|
|
@ -15,31 +15,21 @@
|
|||
import logging
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.storage.roommember import RoomsForUser
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApplicationServiceCache(object):
|
||||
"""Caches ApplicationServices and provides utility functions on top.
|
||||
|
||||
This class is designed to be invoked on incoming events in order to avoid
|
||||
hammering the database every time to extract a list of application service
|
||||
regexes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.services = []
|
||||
|
||||
|
||||
class ApplicationServiceStore(SQLBaseStore):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ApplicationServiceStore, self).__init__(hs)
|
||||
self.cache = ApplicationServiceCache()
|
||||
self.services_cache = []
|
||||
self.cache_defer = self._populate_cache()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -56,7 +46,7 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||
token,
|
||||
)
|
||||
# update cache TODO: Should this be in the txn?
|
||||
for service in self.cache.services:
|
||||
for service in self.services_cache:
|
||||
if service.token == token:
|
||||
service.url = None
|
||||
service.namespaces = None
|
||||
|
@ -110,13 +100,13 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
# update cache TODO: Should this be in the txn?
|
||||
for (index, cache_service) in enumerate(self.cache.services):
|
||||
for (index, cache_service) in enumerate(self.services_cache):
|
||||
if service.token == cache_service.token:
|
||||
self.cache.services[index] = service
|
||||
self.services_cache[index] = service
|
||||
logger.info("Updated: %s", service)
|
||||
return
|
||||
# new entry
|
||||
self.cache.services.append(service)
|
||||
self.services_cache.append(service)
|
||||
logger.info("Updated(new): %s", service)
|
||||
|
||||
def _update_app_service_txn(self, txn, service):
|
||||
|
@ -160,11 +150,34 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||
@defer.inlineCallbacks
|
||||
def get_app_services(self):
|
||||
yield self.cache_defer # make sure the cache is ready
|
||||
defer.returnValue(self.cache.services)
|
||||
defer.returnValue(self.services_cache)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_app_service_by_user_id(self, user_id):
|
||||
"""Retrieve an application service from their user ID.
|
||||
|
||||
All application services have associated with them a particular user ID.
|
||||
There is no distinguishing feature on the user ID which indicates it
|
||||
represents an application service. This function allows you to map from
|
||||
a user ID to an application service.
|
||||
|
||||
Args:
|
||||
user_id(str): The user ID to see if it is an application service.
|
||||
Returns:
|
||||
synapse.appservice.ApplicationService or None.
|
||||
"""
|
||||
|
||||
yield self.cache_defer # make sure the cache is ready
|
||||
|
||||
for service in self.services_cache:
|
||||
if service.sender == user_id:
|
||||
defer.returnValue(service)
|
||||
return
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_app_service_by_token(self, token, from_cache=True):
|
||||
"""Get the application service with the given token.
|
||||
"""Get the application service with the given appservice token.
|
||||
|
||||
Args:
|
||||
token (str): The application service token.
|
||||
|
@ -176,7 +189,7 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||
yield self.cache_defer # make sure the cache is ready
|
||||
|
||||
if from_cache:
|
||||
for service in self.cache.services:
|
||||
for service in self.services_cache:
|
||||
if service.token == token:
|
||||
defer.returnValue(service)
|
||||
return
|
||||
|
@ -185,6 +198,77 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||
# TODO: The from_cache=False impl
|
||||
# TODO: This should be JOINed with the application_services_regex table.
|
||||
|
||||
def get_app_service_rooms(self, service):
|
||||
"""Get a list of RoomsForUser for this application service.
|
||||
|
||||
Application services may be "interested" in lots of rooms depending on
|
||||
the room ID, the room aliases, or the members in the room. This function
|
||||
takes all of these into account and returns a list of RoomsForUser which
|
||||
represent the entire list of room IDs that this application service
|
||||
wants to know about.
|
||||
|
||||
Args:
|
||||
service: The application service to get a room list for.
|
||||
Returns:
|
||||
A list of RoomsForUser.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_app_service_rooms",
|
||||
self._get_app_service_rooms_txn,
|
||||
service,
|
||||
)
|
||||
|
||||
def _get_app_service_rooms_txn(self, txn, service):
|
||||
# get all rooms matching the room ID regex.
|
||||
room_entries = self._simple_select_list_txn(
|
||||
txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
|
||||
)
|
||||
matching_room_list = set([
|
||||
r["room_id"] for r in room_entries if
|
||||
service.is_interested_in_room(r["room_id"])
|
||||
])
|
||||
|
||||
# resolve room IDs for matching room alias regex.
|
||||
room_alias_mappings = self._simple_select_list_txn(
|
||||
txn=txn, table="room_aliases", keyvalues=None,
|
||||
retcols=["room_id", "room_alias"]
|
||||
)
|
||||
matching_room_list |= set([
|
||||
r["room_id"] for r in room_alias_mappings if
|
||||
service.is_interested_in_alias(r["room_alias"])
|
||||
])
|
||||
|
||||
# get all rooms for every user for this AS. This is scoped to users on
|
||||
# this HS only.
|
||||
user_list = self._simple_select_list_txn(
|
||||
txn=txn, table="users", keyvalues=None, retcols=["name"]
|
||||
)
|
||||
user_list = [
|
||||
u["name"] for u in user_list if
|
||||
service.is_interested_in_user(u["name"])
|
||||
]
|
||||
rooms_for_user_matching_user_id = set() # RoomsForUser list
|
||||
for user_id in user_list:
|
||||
# FIXME: This assumes this store is linked with RoomMemberStore :(
|
||||
rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
|
||||
txn=txn,
|
||||
user_id=user_id,
|
||||
membership_list=[Membership.JOIN]
|
||||
)
|
||||
rooms_for_user_matching_user_id |= set(rooms_for_user)
|
||||
|
||||
# make RoomsForUser tuples for room ids and aliases which are not in the
|
||||
# main rooms_for_user_list - e.g. they are rooms which do not have AS
|
||||
# registered users in it.
|
||||
known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
|
||||
missing_rooms_for_user = [
|
||||
RoomsForUser(r, service.sender, "join") for r in
|
||||
matching_room_list if r not in known_room_ids
|
||||
]
|
||||
rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
|
||||
|
||||
return rooms_for_user_matching_user_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _populate_cache(self):
|
||||
"""Populates the ApplicationServiceCache from the database."""
|
||||
|
@ -235,7 +319,7 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||
# TODO get last successful txn id f.e. service
|
||||
for service in services.values():
|
||||
logger.info("Found application service: %s", service)
|
||||
self.cache.services.append(ApplicationService(
|
||||
self.services_cache.append(ApplicationService(
|
||||
token=service["token"],
|
||||
url=service["url"],
|
||||
namespaces=service["namespaces"],
|
||||
|
|
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
|||
|
||||
from collections import namedtuple
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from ._base import SQLBaseStore, cached
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.types import UserID
|
||||
|
@ -35,11 +35,6 @@ RoomsForUser = namedtuple(
|
|||
|
||||
class RoomMemberStore(SQLBaseStore):
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
super(RoomMemberStore, self).__init__(*args, **kw)
|
||||
|
||||
self._user_rooms_cache = {}
|
||||
|
||||
def _store_room_member_txn(self, txn, event):
|
||||
"""Store a room member in the database.
|
||||
"""
|
||||
|
@ -103,7 +98,7 @@ class RoomMemberStore(SQLBaseStore):
|
|||
|
||||
txn.execute(sql, (event.room_id, domain))
|
||||
|
||||
self.invalidate_rooms_for_user(target_user_id)
|
||||
self.get_rooms_for_user.invalidate(target_user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_member(self, user_id, room_id):
|
||||
|
@ -185,6 +180,14 @@ class RoomMemberStore(SQLBaseStore):
|
|||
if not membership_list:
|
||||
return defer.succeed(None)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_rooms_for_user_where_membership_is",
|
||||
self._get_rooms_for_user_where_membership_is_txn,
|
||||
user_id, membership_list
|
||||
)
|
||||
|
||||
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
|
||||
membership_list):
|
||||
where_clause = "user_id = ? AND (%s)" % (
|
||||
" OR ".join(["membership = ?" for _ in membership_list]),
|
||||
)
|
||||
|
@ -192,24 +195,18 @@ class RoomMemberStore(SQLBaseStore):
|
|||
args = [user_id]
|
||||
args.extend(membership_list)
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT m.room_id, m.sender, m.membership"
|
||||
" FROM room_memberships as m"
|
||||
" INNER JOIN current_state_events as c"
|
||||
" ON m.event_id = c.event_id"
|
||||
" WHERE %s"
|
||||
) % (where_clause,)
|
||||
sql = (
|
||||
"SELECT m.room_id, m.sender, m.membership"
|
||||
" FROM room_memberships as m"
|
||||
" INNER JOIN current_state_events as c"
|
||||
" ON m.event_id = c.event_id"
|
||||
" WHERE %s"
|
||||
) % (where_clause,)
|
||||
|
||||
txn.execute(sql, args)
|
||||
return [
|
||||
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
|
||||
]
|
||||
|
||||
return self.runInteraction(
|
||||
"get_rooms_for_user_where_membership_is",
|
||||
f
|
||||
)
|
||||
txn.execute(sql, args)
|
||||
return [
|
||||
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
|
||||
]
|
||||
|
||||
def get_joined_hosts_for_room(self, room_id):
|
||||
return self._simple_select_onecol(
|
||||
|
@ -247,33 +244,12 @@ class RoomMemberStore(SQLBaseStore):
|
|||
results = self._parse_events_txn(txn, rows)
|
||||
return results
|
||||
|
||||
# TODO(paul): Create a nice @cached decorator to do this
|
||||
# @cached
|
||||
# def get_foo(...)
|
||||
# ...
|
||||
# invalidate_foo = get_foo.invalidator
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@cached()
|
||||
def get_rooms_for_user(self, user_id):
|
||||
# TODO(paul): put some performance counters in here so we can easily
|
||||
# track what impact this cache is having
|
||||
if user_id in self._user_rooms_cache:
|
||||
defer.returnValue(self._user_rooms_cache[user_id])
|
||||
|
||||
rooms = yield self.get_rooms_for_user_where_membership_is(
|
||||
return self.get_rooms_for_user_where_membership_is(
|
||||
user_id, membership_list=[Membership.JOIN],
|
||||
)
|
||||
|
||||
# TODO(paul): Consider applying a maximum size; just evict things at
|
||||
# random, or consider LRU?
|
||||
|
||||
self._user_rooms_cache[user_id] = rooms
|
||||
defer.returnValue(rooms)
|
||||
|
||||
def invalidate_rooms_for_user(self, user_id):
|
||||
if user_id in self._user_rooms_cache:
|
||||
del self._user_rooms_cache[user_id]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_rooms_intersect(self, user_id_list):
|
||||
""" Checks whether all the users whose IDs are given in a list share a
|
||||
|
|
|
@ -36,6 +36,7 @@ what sort order was used:
|
|||
from twisted.internet import defer
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
|
@ -127,6 +128,85 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||
|
||||
|
||||
class StreamStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
|
||||
# NB this lives here instead of appservice.py so we can reuse the
|
||||
# 'private' StreamToken class in this file.
|
||||
if limit:
|
||||
limit = max(limit, MAX_STREAM_SIZE)
|
||||
else:
|
||||
limit = MAX_STREAM_SIZE
|
||||
|
||||
# From and to keys should be integers from ordering.
|
||||
from_id = _StreamToken.parse_stream_token(from_key)
|
||||
to_id = _StreamToken.parse_stream_token(to_key)
|
||||
|
||||
if from_key == to_key:
|
||||
defer.returnValue(([], to_key))
|
||||
return
|
||||
|
||||
# select all the events between from/to with a sensible limit
|
||||
sql = (
|
||||
"SELECT e.event_id, e.room_id, e.type, s.state_key, "
|
||||
"e.stream_ordering FROM events AS e LEFT JOIN state_events as s ON "
|
||||
"e.event_id = s.event_id "
|
||||
"WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
|
||||
"ORDER BY stream_ordering ASC LIMIT %(limit)d "
|
||||
) % {
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
def f(txn):
|
||||
# pull out all the events between the tokens
|
||||
txn.execute(sql, (from_id.stream, to_id.stream,))
|
||||
rows = self.cursor_to_dict(txn)
|
||||
|
||||
# Logic:
|
||||
# - We want ALL events which match the AS room_id regex
|
||||
# - We want ALL events which match the rooms represented by the AS
|
||||
# room_alias regex
|
||||
# - We want ALL events for rooms that AS users have joined.
|
||||
# This is currently supported via get_app_service_rooms (which is
|
||||
# used for the Notifier listener rooms). We can't reasonably make a
|
||||
# SQL query for these room IDs, so we'll pull all the events between
|
||||
# from/to and filter in python.
|
||||
rooms_for_as = self._get_app_service_rooms_txn(txn, service)
|
||||
room_ids_for_as = [r.room_id for r in rooms_for_as]
|
||||
|
||||
def app_service_interested(row):
|
||||
if row["room_id"] in room_ids_for_as:
|
||||
return True
|
||||
|
||||
if row["type"] == EventTypes.Member:
|
||||
if service.is_interested_in_user(row.get("state_key")):
|
||||
return True
|
||||
return False
|
||||
|
||||
ret = self._get_events_txn(
|
||||
txn,
|
||||
# apply the filter on the room id list
|
||||
[
|
||||
r["event_id"] for r in rows
|
||||
if app_service_interested(r)
|
||||
],
|
||||
get_prev_content=True
|
||||
)
|
||||
|
||||
self._set_before_and_after(ret, rows)
|
||||
|
||||
if rows:
|
||||
key = "s%d" % max(r["stream_ordering"] for r in rows)
|
||||
else:
|
||||
# Assume we didn't get anything because there was nothing to
|
||||
# get.
|
||||
key = to_key
|
||||
|
||||
return ret, key
|
||||
|
||||
results = yield self.runInteraction("get_appservice_room_stream", f)
|
||||
defer.returnValue(results)
|
||||
|
||||
@log_function
|
||||
def get_room_events_stream(self, user_id, from_key, to_key, room_id,
|
||||
limit=0, with_feedback=False):
|
||||
|
@ -184,8 +264,7 @@ class StreamStore(SQLBaseStore):
|
|||
self._set_before_and_after(ret, rows)
|
||||
|
||||
if rows:
|
||||
key = "s%d" % max([r["stream_ordering"] for r in rows])
|
||||
|
||||
key = "s%d" % max(r["stream_ordering"] for r in rows)
|
||||
else:
|
||||
# Assume we didn't get anything because there was nothing to
|
||||
# get.
|
||||
|
|
|
@ -13,12 +13,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore, Table
|
||||
from ._base import SQLBaseStore, Table, cached
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -28,10 +26,6 @@ class TransactionStore(SQLBaseStore):
|
|||
"""A collection of queries for handling PDUs.
|
||||
"""
|
||||
|
||||
# a write-through cache of DestinationsTable.EntryType indexed by
|
||||
# destination string
|
||||
destination_retry_cache = {}
|
||||
|
||||
def get_received_txn_response(self, transaction_id, origin):
|
||||
"""For an incoming transaction from a given origin, check if we have
|
||||
already responded to it. If so, return the response code and response
|
||||
|
@ -211,6 +205,7 @@ class TransactionStore(SQLBaseStore):
|
|||
|
||||
return ReceivedTransactionsTable.decode_results(txn.fetchall())
|
||||
|
||||
@cached()
|
||||
def get_destination_retry_timings(self, destination):
|
||||
"""Gets the current retry timings (if any) for a given destination.
|
||||
|
||||
|
@ -221,9 +216,6 @@ class TransactionStore(SQLBaseStore):
|
|||
None if not retrying
|
||||
Otherwise a DestinationsTable.EntryType for the retry scheme
|
||||
"""
|
||||
if destination in self.destination_retry_cache:
|
||||
return defer.succeed(self.destination_retry_cache[destination])
|
||||
|
||||
return self.runInteraction(
|
||||
"get_destination_retry_timings",
|
||||
self._get_destination_retry_timings, destination)
|
||||
|
@ -250,7 +242,9 @@ class TransactionStore(SQLBaseStore):
|
|||
retry_interval (int) - how long until next retry in ms
|
||||
"""
|
||||
|
||||
self.destination_retry_cache[destination] = (
|
||||
# As this is the new value, we might as well prefill the cache
|
||||
self.get_destination_retry_timings.prefill(
|
||||
destination,
|
||||
DestinationsTable.EntryType(
|
||||
destination,
|
||||
retry_last_ts,
|
||||
|
|
|
@ -0,0 +1,216 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import LimitExceededError
|
||||
|
||||
from synapse.util.async import sleep
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FederationRateLimiter(object):
|
||||
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
|
||||
reject_limit, concurrent_requests):
|
||||
"""
|
||||
Args:
|
||||
clock (Clock)
|
||||
window_size (int): The window size in milliseconds.
|
||||
sleep_limit (int): The number of requests received in the last
|
||||
`window_size` milliseconds before we artificially start
|
||||
delaying processing of requests.
|
||||
sleep_msec (int): The number of milliseconds to delay processing
|
||||
of incoming requests by.
|
||||
reject_limit (int): The maximum number of requests that are can be
|
||||
queued for processing before we start rejecting requests with
|
||||
a 429 Too Many Requests response.
|
||||
concurrent_requests (int): The number of concurrent requests to
|
||||
process.
|
||||
"""
|
||||
self.clock = clock
|
||||
|
||||
self.window_size = window_size
|
||||
self.sleep_limit = sleep_limit
|
||||
self.sleep_msec = sleep_msec
|
||||
self.reject_limit = reject_limit
|
||||
self.concurrent_requests = concurrent_requests
|
||||
|
||||
self.ratelimiters = {}
|
||||
|
||||
def ratelimit(self, host):
|
||||
"""Used to ratelimit an incoming request from given host
|
||||
|
||||
Example usage:
|
||||
|
||||
with rate_limiter.ratelimit(origin) as wait_deferred:
|
||||
yield wait_deferred
|
||||
# Handle request ...
|
||||
|
||||
Args:
|
||||
host (str): Origin of incoming request.
|
||||
|
||||
Returns:
|
||||
_PerHostRatelimiter
|
||||
"""
|
||||
return self.ratelimiters.setdefault(
|
||||
host,
|
||||
_PerHostRatelimiter(
|
||||
clock=self.clock,
|
||||
window_size=self.window_size,
|
||||
sleep_limit=self.sleep_limit,
|
||||
sleep_msec=self.sleep_msec,
|
||||
reject_limit=self.reject_limit,
|
||||
concurrent_requests=self.concurrent_requests,
|
||||
)
|
||||
).ratelimit()
|
||||
|
||||
|
||||
class _PerHostRatelimiter(object):
|
||||
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
|
||||
reject_limit, concurrent_requests):
|
||||
self.clock = clock
|
||||
|
||||
self.window_size = window_size
|
||||
self.sleep_limit = sleep_limit
|
||||
self.sleep_msec = sleep_msec
|
||||
self.reject_limit = reject_limit
|
||||
self.concurrent_requests = concurrent_requests
|
||||
|
||||
self.sleeping_requests = set()
|
||||
self.ready_request_queue = collections.OrderedDict()
|
||||
self.current_processing = set()
|
||||
self.request_times = []
|
||||
|
||||
def is_empty(self):
|
||||
time_now = self.clock.time_msec()
|
||||
self.request_times[:] = [
|
||||
r for r in self.request_times
|
||||
if time_now - r < self.window_size
|
||||
]
|
||||
|
||||
return not (
|
||||
self.ready_request_queue
|
||||
or self.sleeping_requests
|
||||
or self.current_processing
|
||||
or self.request_times
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ratelimit(self):
|
||||
# `contextlib.contextmanager` takes a generator and turns it into a
|
||||
# context manager. The generator should only yield once with a value
|
||||
# to be returned by manager.
|
||||
# Exceptions will be reraised at the yield.
|
||||
|
||||
request_id = object()
|
||||
ret = self._on_enter(request_id)
|
||||
try:
|
||||
yield ret
|
||||
finally:
|
||||
self._on_exit(request_id)
|
||||
|
||||
def _on_enter(self, request_id):
|
||||
time_now = self.clock.time_msec()
|
||||
self.request_times[:] = [
|
||||
r for r in self.request_times
|
||||
if time_now - r < self.window_size
|
||||
]
|
||||
|
||||
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
|
||||
if queue_size > self.reject_limit:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(
|
||||
self.window_size / self.sleep_limit
|
||||
),
|
||||
)
|
||||
|
||||
self.request_times.append(time_now)
|
||||
|
||||
def queue_request():
|
||||
if len(self.current_processing) > self.concurrent_requests:
|
||||
logger.debug("Ratelimit [%s]: Queue req", id(request_id))
|
||||
queue_defer = defer.Deferred()
|
||||
self.ready_request_queue[request_id] = queue_defer
|
||||
return queue_defer
|
||||
else:
|
||||
return defer.succeed(None)
|
||||
|
||||
logger.debug(
|
||||
"Ratelimit [%s]: len(self.request_times)=%d",
|
||||
id(request_id), len(self.request_times),
|
||||
)
|
||||
|
||||
if len(self.request_times) > self.sleep_limit:
|
||||
logger.debug(
|
||||
"Ratelimit [%s]: sleeping req",
|
||||
id(request_id),
|
||||
)
|
||||
ret_defer = sleep(self.sleep_msec/1000.0)
|
||||
|
||||
self.sleeping_requests.add(request_id)
|
||||
|
||||
def on_wait_finished(_):
|
||||
logger.debug(
|
||||
"Ratelimit [%s]: Finished sleeping",
|
||||
id(request_id),
|
||||
)
|
||||
self.sleeping_requests.discard(request_id)
|
||||
queue_defer = queue_request()
|
||||
return queue_defer
|
||||
|
||||
ret_defer.addBoth(on_wait_finished)
|
||||
else:
|
||||
ret_defer = queue_request()
|
||||
|
||||
def on_start(r):
|
||||
logger.debug(
|
||||
"Ratelimit [%s]: Processing req",
|
||||
id(request_id),
|
||||
)
|
||||
self.current_processing.add(request_id)
|
||||
return r
|
||||
|
||||
def on_err(r):
|
||||
self.current_processing.discard(request_id)
|
||||
return r
|
||||
|
||||
def on_both(r):
|
||||
# Ensure that we've properly cleaned up.
|
||||
self.sleeping_requests.discard(request_id)
|
||||
self.ready_request_queue.pop(request_id, None)
|
||||
return r
|
||||
|
||||
ret_defer.addCallbacks(on_start, on_err)
|
||||
ret_defer.addBoth(on_both)
|
||||
return ret_defer
|
||||
|
||||
def _on_exit(self, request_id):
|
||||
logger.debug(
|
||||
"Ratelimit [%s]: Processed req",
|
||||
id(request_id),
|
||||
)
|
||||
self.current_processing.discard(request_id)
|
||||
try:
|
||||
request_id, deferred = self.ready_request_queue.popitem()
|
||||
self.current_processing.add(request_id)
|
||||
deferred.callback(None)
|
||||
except KeyError:
|
||||
pass
|
|
@ -295,6 +295,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
|||
|
||||
self.mock_datastore = hs.get_datastore()
|
||||
self.mock_datastore.get_app_service_by_token = Mock(return_value=None)
|
||||
self.mock_datastore.get_app_service_by_user_id = Mock(
|
||||
return_value=defer.succeed(None)
|
||||
)
|
||||
|
||||
def get_profile_displayname(user_id):
|
||||
return defer.succeed("Frank")
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage._base import cached
|
||||
|
||||
|
||||
class CacheDecoratorTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_passthrough(self):
|
||||
@cached()
|
||||
def func(self, key):
|
||||
return key
|
||||
|
||||
self.assertEquals((yield func(self, "foo")), "foo")
|
||||
self.assertEquals((yield func(self, "bar")), "bar")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_hit(self):
|
||||
callcount = [0]
|
||||
|
||||
@cached()
|
||||
def func(self, key):
|
||||
callcount[0] += 1
|
||||
return key
|
||||
|
||||
yield func(self, "foo")
|
||||
|
||||
self.assertEquals(callcount[0], 1)
|
||||
|
||||
self.assertEquals((yield func(self, "foo")), "foo")
|
||||
self.assertEquals(callcount[0], 1)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_invalidate(self):
|
||||
callcount = [0]
|
||||
|
||||
@cached()
|
||||
def func(self, key):
|
||||
callcount[0] += 1
|
||||
return key
|
||||
|
||||
yield func(self, "foo")
|
||||
|
||||
self.assertEquals(callcount[0], 1)
|
||||
|
||||
func.invalidate("foo")
|
||||
|
||||
yield func(self, "foo")
|
||||
|
||||
self.assertEquals(callcount[0], 2)
|
||||
|
||||
def test_invalidate_missing(self):
|
||||
@cached()
|
||||
def func(self, key):
|
||||
return key
|
||||
|
||||
func.invalidate("what")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_max_entries(self):
|
||||
callcount = [0]
|
||||
|
||||
@cached(max_entries=10)
|
||||
def func(self, key):
|
||||
callcount[0] += 1
|
||||
return key
|
||||
|
||||
for k in range(0,12):
|
||||
yield func(self, k)
|
||||
|
||||
self.assertEquals(callcount[0], 12)
|
||||
|
||||
# There must have been at least 2 evictions, meaning if we calculate
|
||||
# all 12 values again, we must get called at least 2 more times
|
||||
for k in range(0,12):
|
||||
yield func(self, k)
|
||||
|
||||
self.assertTrue(callcount[0] >= 14,
|
||||
msg="Expected callcount >= 14, got %d" % (callcount[0]))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_prefill(self):
|
||||
callcount = [0]
|
||||
|
||||
@cached()
|
||||
def func(self, key):
|
||||
callcount[0] += 1
|
||||
return key
|
||||
|
||||
func.prefill("foo", 123)
|
||||
|
||||
self.assertEquals((yield func(self, "foo")), 123)
|
||||
self.assertEquals(callcount[0], 0)
|
Loading…
Reference in New Issue