Merge branch 'develop' into application-services-txn-reliability
Conflicts: synapse/storage/__init__.py
This commit is contained in:
commit
4edcbcee3b
|
@ -1,3 +1,12 @@
|
||||||
|
Changes in synapse v0.8.1 (2015-03-18)
|
||||||
|
======================================
|
||||||
|
|
||||||
|
* Disable registration by default. New users can be added using the command
|
||||||
|
``register_new_matrix_user`` or by enabling registration in the config.
|
||||||
|
* Add metrics to synapse. To enable metrics use config options
|
||||||
|
``enable_metrics`` and ``metrics_port``.
|
||||||
|
* Fix bug where banning only kicked the user.
|
||||||
|
|
||||||
Changes in synapse v0.8.0 (2015-03-06)
|
Changes in synapse v0.8.0 (2015-03-06)
|
||||||
======================================
|
======================================
|
||||||
|
|
||||||
|
|
11
README.rst
11
README.rst
|
@ -128,6 +128,17 @@ To set up your homeserver, run (in your virtualenv, as before)::
|
||||||
|
|
||||||
Substituting your host and domain name as appropriate.
|
Substituting your host and domain name as appropriate.
|
||||||
|
|
||||||
|
By default, registration of new users is disabled. You can either enable
|
||||||
|
registration in the config (it is then recommended to also set up CAPTCHA), or
|
||||||
|
you can use the command line to register new users::
|
||||||
|
|
||||||
|
$ source ~/.synapse/bin/activate
|
||||||
|
$ register_new_matrix_user -c homeserver.yaml https://localhost:8448
|
||||||
|
New user localpart: erikj
|
||||||
|
Password:
|
||||||
|
Confirm password:
|
||||||
|
Success!
|
||||||
|
|
||||||
For reliable VoIP calls to be routed via this homeserver, you MUST configure
|
For reliable VoIP calls to be routed via this homeserver, you MUST configure
|
||||||
a TURN server. See docs/turn-howto.rst for details.
|
a TURN server. See docs/turn-howto.rst for details.
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import getpass
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import urllib2
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def request_registration(user, password, server_location, shared_secret):
|
||||||
|
mac = hmac.new(
|
||||||
|
key=shared_secret,
|
||||||
|
msg=user,
|
||||||
|
digestmod=hashlib.sha1,
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"user": user,
|
||||||
|
"password": password,
|
||||||
|
"mac": mac,
|
||||||
|
"type": "org.matrix.login.shared_secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
server_location = server_location.rstrip("/")
|
||||||
|
|
||||||
|
print "Sending registration request..."
|
||||||
|
|
||||||
|
req = urllib2.Request(
|
||||||
|
"%s/_matrix/client/api/v1/register" % (server_location,),
|
||||||
|
data=json.dumps(data),
|
||||||
|
headers={'Content-Type': 'application/json'}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
f = urllib2.urlopen(req)
|
||||||
|
f.read()
|
||||||
|
f.close()
|
||||||
|
print "Success."
|
||||||
|
except urllib2.HTTPError as e:
|
||||||
|
print "ERROR! Received %d %s" % (e.code, e.reason,)
|
||||||
|
if 400 <= e.code < 500:
|
||||||
|
if e.info().type == "application/json":
|
||||||
|
resp = json.load(e)
|
||||||
|
if "error" in resp:
|
||||||
|
print resp["error"]
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def register_new_user(user, password, server_location, shared_secret):
|
||||||
|
if not user:
|
||||||
|
try:
|
||||||
|
default_user = getpass.getuser()
|
||||||
|
except:
|
||||||
|
default_user = None
|
||||||
|
|
||||||
|
if default_user:
|
||||||
|
user = raw_input("New user localpart [%s]: " % (default_user,))
|
||||||
|
if not user:
|
||||||
|
user = default_user
|
||||||
|
else:
|
||||||
|
user = raw_input("New user localpart: ")
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
print "Invalid user name"
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not password:
|
||||||
|
password = getpass.getpass("Password: ")
|
||||||
|
|
||||||
|
if not password:
|
||||||
|
print "Password cannot be blank."
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
confirm_password = getpass.getpass("Confirm password: ")
|
||||||
|
|
||||||
|
if password != confirm_password:
|
||||||
|
print "Passwords do not match"
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
request_registration(user, password, server_location, shared_secret)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Used to register new users with a given home server when"
|
||||||
|
" registration has been disabled. The home server must be"
|
||||||
|
" configured with the 'registration_shared_secret' option"
|
||||||
|
" set.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-u", "--user",
|
||||||
|
default=None,
|
||||||
|
help="Local part of the new user. Will prompt if omitted.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p", "--password",
|
||||||
|
default=None,
|
||||||
|
help="New password for user. Will prompt if omitted.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group = parser.add_mutually_exclusive_group(required=True)
|
||||||
|
group.add_argument(
|
||||||
|
"-c", "--config",
|
||||||
|
type=argparse.FileType('r'),
|
||||||
|
help="Path to server config file. Used to read in shared secret.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"-k", "--shared-secret",
|
||||||
|
help="Shared secret as defined in server config file.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"server_url",
|
||||||
|
default="https://localhost:8448",
|
||||||
|
nargs='?',
|
||||||
|
help="URL to use to talk to the home server. Defaults to "
|
||||||
|
" 'https://localhost:8448'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if "config" in args and args.config:
|
||||||
|
config = yaml.safe_load(args.config)
|
||||||
|
secret = config.get("registration_shared_secret", None)
|
||||||
|
if not secret:
|
||||||
|
print "No 'registration_shared_secret' defined in config."
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
secret = args.shared_secret
|
||||||
|
|
||||||
|
register_new_user(args.user, args.password, args.server_url, secret)
|
4
setup.py
4
setup.py
|
@ -45,7 +45,7 @@ setup(
|
||||||
version=version,
|
version=version,
|
||||||
packages=find_packages(exclude=["tests", "tests.*"]),
|
packages=find_packages(exclude=["tests", "tests.*"]),
|
||||||
description="Reference Synapse Home Server",
|
description="Reference Synapse Home Server",
|
||||||
install_requires=dependencies["REQUIREMENTS"].keys(),
|
install_requires=dependencies['requirements'](include_conditional=True).keys(),
|
||||||
setup_requires=[
|
setup_requires=[
|
||||||
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
|
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
|
||||||
"setuptools_trial",
|
"setuptools_trial",
|
||||||
|
@ -55,5 +55,5 @@ setup(
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
scripts=["synctl"],
|
scripts=["synctl", "register_new_matrix_user"],
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,4 +16,4 @@
|
||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.8.0"
|
__version__ = "0.8.1-r2"
|
||||||
|
|
|
@ -28,6 +28,12 @@ import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
AuthEventTypes = (
|
||||||
|
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
|
||||||
|
EventTypes.JoinRules,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Auth(object):
|
class Auth(object):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
@ -166,6 +172,7 @@ class Auth(object):
|
||||||
target = auth_events.get(key)
|
target = auth_events.get(key)
|
||||||
|
|
||||||
target_in_room = target and target.membership == Membership.JOIN
|
target_in_room = target and target.membership == Membership.JOIN
|
||||||
|
target_banned = target and target.membership == Membership.BAN
|
||||||
|
|
||||||
key = (EventTypes.JoinRules, "", )
|
key = (EventTypes.JoinRules, "", )
|
||||||
join_rule_event = auth_events.get(key)
|
join_rule_event = auth_events.get(key)
|
||||||
|
@ -194,6 +201,7 @@ class Auth(object):
|
||||||
{
|
{
|
||||||
"caller_in_room": caller_in_room,
|
"caller_in_room": caller_in_room,
|
||||||
"caller_invited": caller_invited,
|
"caller_invited": caller_invited,
|
||||||
|
"target_banned": target_banned,
|
||||||
"target_in_room": target_in_room,
|
"target_in_room": target_in_room,
|
||||||
"membership": membership,
|
"membership": membership,
|
||||||
"join_rule": join_rule,
|
"join_rule": join_rule,
|
||||||
|
@ -202,6 +210,11 @@ class Auth(object):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ban_level:
|
||||||
|
ban_level = int(ban_level)
|
||||||
|
else:
|
||||||
|
ban_level = 50 # FIXME (erikj): What should we do here?
|
||||||
|
|
||||||
if Membership.INVITE == membership:
|
if Membership.INVITE == membership:
|
||||||
# TODO (erikj): We should probably handle this more intelligently
|
# TODO (erikj): We should probably handle this more intelligently
|
||||||
# PRIVATE join rules.
|
# PRIVATE join rules.
|
||||||
|
@ -212,6 +225,10 @@ class Auth(object):
|
||||||
403,
|
403,
|
||||||
"%s not in room %s." % (event.user_id, event.room_id,)
|
"%s not in room %s." % (event.user_id, event.room_id,)
|
||||||
)
|
)
|
||||||
|
elif target_banned:
|
||||||
|
raise AuthError(
|
||||||
|
403, "%s is banned from the room" % (target_user_id,)
|
||||||
|
)
|
||||||
elif target_in_room: # the target is already in the room.
|
elif target_in_room: # the target is already in the room.
|
||||||
raise AuthError(403, "%s is already in the room." %
|
raise AuthError(403, "%s is already in the room." %
|
||||||
target_user_id)
|
target_user_id)
|
||||||
|
@ -221,6 +238,8 @@ class Auth(object):
|
||||||
# joined: It's a NOOP
|
# joined: It's a NOOP
|
||||||
if event.user_id != target_user_id:
|
if event.user_id != target_user_id:
|
||||||
raise AuthError(403, "Cannot force another user to join.")
|
raise AuthError(403, "Cannot force another user to join.")
|
||||||
|
elif target_banned:
|
||||||
|
raise AuthError(403, "You are banned from this room")
|
||||||
elif join_rule == JoinRules.PUBLIC:
|
elif join_rule == JoinRules.PUBLIC:
|
||||||
pass
|
pass
|
||||||
elif join_rule == JoinRules.INVITE:
|
elif join_rule == JoinRules.INVITE:
|
||||||
|
@ -238,6 +257,10 @@ class Auth(object):
|
||||||
403,
|
403,
|
||||||
"%s not in room %s." % (target_user_id, event.room_id,)
|
"%s not in room %s." % (target_user_id, event.room_id,)
|
||||||
)
|
)
|
||||||
|
elif target_banned and user_level < ban_level:
|
||||||
|
raise AuthError(
|
||||||
|
403, "You cannot unban user &s." % (target_user_id,)
|
||||||
|
)
|
||||||
elif target_user_id != event.user_id:
|
elif target_user_id != event.user_id:
|
||||||
if kick_level:
|
if kick_level:
|
||||||
kick_level = int(kick_level)
|
kick_level = int(kick_level)
|
||||||
|
@ -249,11 +272,6 @@ class Auth(object):
|
||||||
403, "You cannot kick user %s." % target_user_id
|
403, "You cannot kick user %s." % target_user_id
|
||||||
)
|
)
|
||||||
elif Membership.BAN == membership:
|
elif Membership.BAN == membership:
|
||||||
if ban_level:
|
|
||||||
ban_level = int(ban_level)
|
|
||||||
else:
|
|
||||||
ban_level = 50 # FIXME (erikj): What should we do here?
|
|
||||||
|
|
||||||
if user_level < ban_level:
|
if user_level < ban_level:
|
||||||
raise AuthError(403, "You don't have permission to ban")
|
raise AuthError(403, "You don't have permission to ban")
|
||||||
else:
|
else:
|
||||||
|
@ -370,7 +388,7 @@ class Auth(object):
|
||||||
AuthError if no user by that token exists or the token is invalid.
|
AuthError if no user by that token exists or the token is invalid.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
ret = yield self.store.get_user_by_token(token=token)
|
ret = yield self.store.get_user_by_token(token)
|
||||||
if not ret:
|
if not ret:
|
||||||
raise StoreError(400, "Unknown token")
|
raise StoreError(400, "Unknown token")
|
||||||
user_info = {
|
user_info = {
|
||||||
|
@ -412,12 +430,6 @@ class Auth(object):
|
||||||
|
|
||||||
builder.auth_events = auth_events_entries
|
builder.auth_events = auth_events_entries
|
||||||
|
|
||||||
context.auth_events = {
|
|
||||||
k: v
|
|
||||||
for k, v in context.current_state.items()
|
|
||||||
if v.event_id in auth_ids
|
|
||||||
}
|
|
||||||
|
|
||||||
def compute_auth_events(self, event, current_state):
|
def compute_auth_events(self, event, current_state):
|
||||||
if event.type == EventTypes.Create:
|
if event.type == EventTypes.Create:
|
||||||
return []
|
return []
|
||||||
|
|
|
@ -60,6 +60,7 @@ class LoginType(object):
|
||||||
EMAIL_IDENTITY = u"m.login.email.identity"
|
EMAIL_IDENTITY = u"m.login.email.identity"
|
||||||
RECAPTCHA = u"m.login.recaptcha"
|
RECAPTCHA = u"m.login.recaptcha"
|
||||||
APPLICATION_SERVICE = u"m.login.application_service"
|
APPLICATION_SERVICE = u"m.login.application_service"
|
||||||
|
SHARED_SECRET = u"org.matrix.login.shared_secret"
|
||||||
|
|
||||||
|
|
||||||
class EventTypes(object):
|
class EventTypes(object):
|
||||||
|
|
|
@ -60,7 +60,6 @@ import re
|
||||||
import resource
|
import resource
|
||||||
import subprocess
|
import subprocess
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import syweb
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -83,6 +82,7 @@ class SynapseHomeServer(HomeServer):
|
||||||
return AppServiceRestResource(self)
|
return AppServiceRestResource(self)
|
||||||
|
|
||||||
def build_resource_for_web_client(self):
|
def build_resource_for_web_client(self):
|
||||||
|
import syweb
|
||||||
syweb_path = os.path.dirname(syweb.__file__)
|
syweb_path = os.path.dirname(syweb.__file__)
|
||||||
webclient_path = os.path.join(syweb_path, "webclient")
|
webclient_path = os.path.join(syweb_path, "webclient")
|
||||||
return File(webclient_path) # TODO configurable?
|
return File(webclient_path) # TODO configurable?
|
||||||
|
@ -130,7 +130,7 @@ class SynapseHomeServer(HomeServer):
|
||||||
True.
|
True.
|
||||||
"""
|
"""
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
web_client = config.webclient
|
web_client = config.web_client
|
||||||
|
|
||||||
# list containing (path_str, Resource) e.g:
|
# list containing (path_str, Resource) e.g:
|
||||||
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
|
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
|
||||||
|
@ -343,7 +343,8 @@ def setup(config_options):
|
||||||
|
|
||||||
config.setup_logging()
|
config.setup_logging()
|
||||||
|
|
||||||
check_requirements()
|
# check any extra requirements we have now we have a config
|
||||||
|
check_requirements(config)
|
||||||
|
|
||||||
version_string = get_version_string()
|
version_string = get_version_string()
|
||||||
|
|
||||||
|
@ -450,6 +451,7 @@ def run(hs):
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
with LoggingContext("main"):
|
with LoggingContext("main"):
|
||||||
|
# check base requirements
|
||||||
check_requirements()
|
check_requirements()
|
||||||
hs = setup(sys.argv[1:])
|
hs = setup(sys.argv[1:])
|
||||||
run(hs)
|
run(hs)
|
||||||
|
|
|
@ -15,19 +15,46 @@
|
||||||
|
|
||||||
from ._base import Config
|
from ._base import Config
|
||||||
|
|
||||||
|
from synapse.util.stringutils import random_string_with_symbols
|
||||||
|
|
||||||
|
import distutils.util
|
||||||
|
|
||||||
|
|
||||||
class RegistrationConfig(Config):
|
class RegistrationConfig(Config):
|
||||||
|
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super(RegistrationConfig, self).__init__(args)
|
super(RegistrationConfig, self).__init__(args)
|
||||||
self.disable_registration = args.disable_registration
|
|
||||||
|
# `args.disable_registration` may either be a bool or a string depending
|
||||||
|
# on if the option was given a value (e.g. --disable-registration=false
|
||||||
|
# would set `args.disable_registration` to "false" not False.)
|
||||||
|
self.disable_registration = bool(
|
||||||
|
distutils.util.strtobool(str(args.disable_registration))
|
||||||
|
)
|
||||||
|
self.registration_shared_secret = args.registration_shared_secret
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_arguments(cls, parser):
|
def add_arguments(cls, parser):
|
||||||
super(RegistrationConfig, cls).add_arguments(parser)
|
super(RegistrationConfig, cls).add_arguments(parser)
|
||||||
reg_group = parser.add_argument_group("registration")
|
reg_group = parser.add_argument_group("registration")
|
||||||
|
|
||||||
reg_group.add_argument(
|
reg_group.add_argument(
|
||||||
"--disable-registration",
|
"--disable-registration",
|
||||||
action='store_true',
|
const=True,
|
||||||
help="Disable registration of new users."
|
default=True,
|
||||||
|
nargs='?',
|
||||||
|
help="Disable registration of new users.",
|
||||||
)
|
)
|
||||||
|
reg_group.add_argument(
|
||||||
|
"--registration-shared-secret", type=str,
|
||||||
|
help="If set, allows registration by anyone who also has the shared"
|
||||||
|
" secret, even if registration is otherwise disabled.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_config(cls, args, config_dir_path):
|
||||||
|
if args.disable_registration is None:
|
||||||
|
args.disable_registration = True
|
||||||
|
|
||||||
|
if args.registration_shared_secret is None:
|
||||||
|
args.registration_shared_secret = random_string_with_symbols(50)
|
||||||
|
|
|
@ -28,7 +28,7 @@ class ServerConfig(Config):
|
||||||
self.unsecure_port = args.unsecure_port
|
self.unsecure_port = args.unsecure_port
|
||||||
self.daemonize = args.daemonize
|
self.daemonize = args.daemonize
|
||||||
self.pid_file = self.abspath(args.pid_file)
|
self.pid_file = self.abspath(args.pid_file)
|
||||||
self.webclient = True
|
self.web_client = args.web_client
|
||||||
self.manhole = args.manhole
|
self.manhole = args.manhole
|
||||||
self.soft_file_limit = args.soft_file_limit
|
self.soft_file_limit = args.soft_file_limit
|
||||||
|
|
||||||
|
@ -68,6 +68,8 @@ class ServerConfig(Config):
|
||||||
server_group.add_argument('--pid-file', default="homeserver.pid",
|
server_group.add_argument('--pid-file', default="homeserver.pid",
|
||||||
help="When running as a daemon, the file to"
|
help="When running as a daemon, the file to"
|
||||||
" store the pid in")
|
" store the pid in")
|
||||||
|
server_group.add_argument('--web_client', default=True, type=bool,
|
||||||
|
help="Whether or not to serve a web client")
|
||||||
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
||||||
type=int,
|
type=int,
|
||||||
help="Turn on the twisted telnet manhole"
|
help="Turn on the twisted telnet manhole"
|
||||||
|
|
|
@ -16,8 +16,7 @@
|
||||||
|
|
||||||
class EventContext(object):
|
class EventContext(object):
|
||||||
|
|
||||||
def __init__(self, current_state=None, auth_events=None):
|
def __init__(self, current_state=None):
|
||||||
self.current_state = current_state
|
self.current_state = current_state
|
||||||
self.auth_events = auth_events
|
|
||||||
self.state_group = None
|
self.state_group = None
|
||||||
self.rejected = False
|
self.rejected = False
|
||||||
|
|
|
@ -361,4 +361,5 @@ SERVLET_CLASSES = (
|
||||||
FederationInviteServlet,
|
FederationInviteServlet,
|
||||||
FederationQueryAuthServlet,
|
FederationQueryAuthServlet,
|
||||||
FederationGetMissingEventsServlet,
|
FederationGetMissingEventsServlet,
|
||||||
|
FederationEventAuthServlet,
|
||||||
)
|
)
|
||||||
|
|
|
@ -90,8 +90,8 @@ class BaseHandler(object):
|
||||||
event = builder.build()
|
event = builder.build()
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Created event %s with auth_events: %s, current state: %s",
|
"Created event %s with current state: %s",
|
||||||
event.event_id, context.auth_events, context.current_state,
|
event.event_id, context.current_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
|
@ -106,7 +106,7 @@ class BaseHandler(object):
|
||||||
# We now need to go and hit out to wherever we need to hit out to.
|
# We now need to go and hit out to wherever we need to hit out to.
|
||||||
|
|
||||||
if not suppress_auth:
|
if not suppress_auth:
|
||||||
self.auth.check(event, auth_events=context.auth_events)
|
self.auth.check(event, auth_events=context.current_state)
|
||||||
|
|
||||||
yield self.store.persist_event(event, context=context)
|
yield self.store.persist_event(event, context=context)
|
||||||
|
|
||||||
|
@ -142,7 +142,16 @@ class BaseHandler(object):
|
||||||
"Failed to get destination from event %s", s.event_id
|
"Failed to get destination from event %s", s.event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.notifier.on_new_room_event(event, extra_users=extra_users)
|
# Don't block waiting on waking up all the listeners.
|
||||||
|
d = self.notifier.on_new_room_event(event, extra_users=extra_users)
|
||||||
|
|
||||||
|
def log_failure(f):
|
||||||
|
logger.warn(
|
||||||
|
"Failed to notify about %s: %s",
|
||||||
|
event.event_id, f.value
|
||||||
|
)
|
||||||
|
|
||||||
|
d.addErrback(log_failure)
|
||||||
|
|
||||||
yield federation_handler.handle_new_event(
|
yield federation_handler.handle_new_event(
|
||||||
event, destinations=destinations,
|
event, destinations=destinations,
|
||||||
|
|
|
@ -290,6 +290,8 @@ class FederationHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
logger.debug("Joining %s to %s", joinee, room_id)
|
logger.debug("Joining %s to %s", joinee, room_id)
|
||||||
|
|
||||||
|
yield self.store.clean_room_for_join(room_id)
|
||||||
|
|
||||||
origin, pdu = yield self.replication_layer.make_join(
|
origin, pdu = yield self.replication_layer.make_join(
|
||||||
target_hosts,
|
target_hosts,
|
||||||
room_id,
|
room_id,
|
||||||
|
@ -464,11 +466,9 @@ class FederationHandler(BaseHandler):
|
||||||
builder=builder,
|
builder=builder,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.auth.check(event, auth_events=context.auth_events)
|
self.auth.check(event, auth_events=context.current_state)
|
||||||
|
|
||||||
pdu = event
|
defer.returnValue(event)
|
||||||
|
|
||||||
defer.returnValue(pdu)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -705,7 +705,7 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not auth_events:
|
if not auth_events:
|
||||||
auth_events = context.auth_events
|
auth_events = context.current_state
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"_handle_new_event: %s, auth_events: %s",
|
"_handle_new_event: %s, auth_events: %s",
|
||||||
|
|
|
@ -33,6 +33,10 @@ logger = logging.getLogger(__name__)
|
||||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Don't bother bumping "last active" time if it differs by less than 60 seconds
|
||||||
|
LAST_ACTIVE_GRANULARITY = 60*1000
|
||||||
|
|
||||||
|
|
||||||
# TODO(paul): Maybe there's one of these I can steal from somewhere
|
# TODO(paul): Maybe there's one of these I can steal from somewhere
|
||||||
def partition(l, func):
|
def partition(l, func):
|
||||||
"""Partition the list by the result of func applied to each element."""
|
"""Partition the list by the result of func applied to each element."""
|
||||||
|
@ -282,6 +286,10 @@ class PresenceHandler(BaseHandler):
|
||||||
if now is None:
|
if now is None:
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
|
|
||||||
|
prev_state = self._get_or_make_usercache(user)
|
||||||
|
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
|
||||||
|
return
|
||||||
|
|
||||||
self.changed_presencelike_data(user, {"last_active": now})
|
self.changed_presencelike_data(user, {"last_active": now})
|
||||||
|
|
||||||
def changed_presencelike_data(self, user, state):
|
def changed_presencelike_data(self, user, state):
|
||||||
|
|
|
@ -31,6 +31,7 @@ import base64
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import urllib
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -63,6 +64,13 @@ class RegistrationHandler(BaseHandler):
|
||||||
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
|
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
|
||||||
|
|
||||||
if localpart:
|
if localpart:
|
||||||
|
if localpart and urllib.quote(localpart) != localpart:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"User ID must only contain characters which do not"
|
||||||
|
" require URL encoding."
|
||||||
|
)
|
||||||
|
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,11 @@ outgoing_responses_counter = metrics.register_counter(
|
||||||
labels=["method", "code"],
|
labels=["method", "code"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response_timer = metrics.register_distribution(
|
||||||
|
"response_time",
|
||||||
|
labels=["method", "servlet"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HttpServer(object):
|
class HttpServer(object):
|
||||||
""" Interface for registering callbacks on a HTTP server
|
""" Interface for registering callbacks on a HTTP server
|
||||||
|
@ -169,6 +174,10 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
code, response = yield callback(request, *args)
|
code, response = yield callback(request, *args)
|
||||||
|
|
||||||
self._send_response(request, code, response)
|
self._send_response(request, code, response)
|
||||||
|
response_timer.inc_by(
|
||||||
|
self.clock.time_msec() - start, request.method, servlet_classname
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
||||||
|
|
|
@ -51,8 +51,8 @@ class RestServlet(object):
|
||||||
pattern = self.PATTERN
|
pattern = self.PATTERN
|
||||||
|
|
||||||
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
|
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
|
||||||
if hasattr(self, "on_%s" % (method)):
|
if hasattr(self, "on_%s" % (method,)):
|
||||||
method_handler = getattr(self, "on_%s" % (method))
|
method_handler = getattr(self, "on_%s" % (method,))
|
||||||
http_server.register_path(method, pattern, method_handler)
|
http_server.register_path(method, pattern, method_handler)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("RestServlet must register something.")
|
raise NotImplementedError("RestServlet must register something.")
|
||||||
|
|
|
@ -5,7 +5,6 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
REQUIREMENTS = {
|
REQUIREMENTS = {
|
||||||
"syutil>=0.0.3": ["syutil"],
|
"syutil>=0.0.3": ["syutil"],
|
||||||
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
|
|
||||||
"Twisted==14.0.2": ["twisted==14.0.2"],
|
"Twisted==14.0.2": ["twisted==14.0.2"],
|
||||||
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
||||||
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
||||||
|
@ -18,6 +17,19 @@ REQUIREMENTS = {
|
||||||
"pillow": ["PIL"],
|
"pillow": ["PIL"],
|
||||||
"pydenticon": ["pydenticon"],
|
"pydenticon": ["pydenticon"],
|
||||||
}
|
}
|
||||||
|
CONDITIONAL_REQUIREMENTS = {
|
||||||
|
"web_client": {
|
||||||
|
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def requirements(config=None, include_conditional=False):
|
||||||
|
reqs = REQUIREMENTS.copy()
|
||||||
|
for key, req in CONDITIONAL_REQUIREMENTS.items():
|
||||||
|
if (config and getattr(config, key)) or include_conditional:
|
||||||
|
reqs.update(req)
|
||||||
|
return reqs
|
||||||
|
|
||||||
|
|
||||||
def github_link(project, version, egg):
|
def github_link(project, version, egg):
|
||||||
|
@ -46,10 +58,11 @@ class MissingRequirementError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def check_requirements():
|
def check_requirements(config=None):
|
||||||
"""Checks that all the modules needed by synapse have been correctly
|
"""Checks that all the modules needed by synapse have been correctly
|
||||||
installed and are at the correct version"""
|
installed and are at the correct version"""
|
||||||
for dependency, module_requirements in REQUIREMENTS.items():
|
for dependency, module_requirements in (
|
||||||
|
requirements(config, include_conditional=False).items()):
|
||||||
for module_requirement in module_requirements:
|
for module_requirement in module_requirements:
|
||||||
if ">=" in module_requirement:
|
if ">=" in module_requirement:
|
||||||
module_name, required_version = module_requirement.split(">=")
|
module_name, required_version = module_requirement.split(">=")
|
||||||
|
@ -110,7 +123,7 @@ def list_requirements():
|
||||||
egg = link.split("#egg=")[1]
|
egg = link.split("#egg=")[1]
|
||||||
linked.append(egg.split('-')[0])
|
linked.append(egg.split('-')[0])
|
||||||
result.append(link)
|
result.append(link)
|
||||||
for requirement in REQUIREMENTS:
|
for requirement in requirements(include_conditional=True):
|
||||||
is_linked = False
|
is_linked = False
|
||||||
for link in linked:
|
for link in linked:
|
||||||
if requirement.replace('-', '_').startswith(link):
|
if requirement.replace('-', '_').startswith(link):
|
||||||
|
|
|
@ -27,7 +27,6 @@ from hashlib import sha1
|
||||||
import hmac
|
import hmac
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -110,14 +109,22 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
login_type = register_json["type"]
|
login_type = register_json["type"]
|
||||||
|
|
||||||
is_application_server = login_type == LoginType.APPLICATION_SERVICE
|
is_application_server = login_type == LoginType.APPLICATION_SERVICE
|
||||||
if self.disable_registration and not is_application_server:
|
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
|
||||||
|
|
||||||
|
can_register = (
|
||||||
|
not self.disable_registration
|
||||||
|
or is_application_server
|
||||||
|
or is_using_shared_secret
|
||||||
|
)
|
||||||
|
if not can_register:
|
||||||
raise SynapseError(403, "Registration has been disabled")
|
raise SynapseError(403, "Registration has been disabled")
|
||||||
|
|
||||||
stages = {
|
stages = {
|
||||||
LoginType.RECAPTCHA: self._do_recaptcha,
|
LoginType.RECAPTCHA: self._do_recaptcha,
|
||||||
LoginType.PASSWORD: self._do_password,
|
LoginType.PASSWORD: self._do_password,
|
||||||
LoginType.EMAIL_IDENTITY: self._do_email_identity,
|
LoginType.EMAIL_IDENTITY: self._do_email_identity,
|
||||||
LoginType.APPLICATION_SERVICE: self._do_app_service
|
LoginType.APPLICATION_SERVICE: self._do_app_service,
|
||||||
|
LoginType.SHARED_SECRET: self._do_shared_secret,
|
||||||
}
|
}
|
||||||
|
|
||||||
session_info = self._get_session_info(request, session)
|
session_info = self._get_session_info(request, session)
|
||||||
|
@ -255,14 +262,11 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
password = register_json["password"].encode("utf-8")
|
password = register_json["password"].encode("utf-8")
|
||||||
desired_user_id = (register_json["user"].encode("utf-8")
|
desired_user_id = (
|
||||||
if "user" in register_json else None)
|
register_json["user"].encode("utf-8")
|
||||||
if (desired_user_id
|
if "user" in register_json else None
|
||||||
and urllib.quote(desired_user_id) != desired_user_id):
|
)
|
||||||
raise SynapseError(
|
|
||||||
400,
|
|
||||||
"User ID must only contain characters which do not " +
|
|
||||||
"require URL encoding.")
|
|
||||||
handler = self.handlers.registration_handler
|
handler = self.handlers.registration_handler
|
||||||
(user_id, token) = yield handler.register(
|
(user_id, token) = yield handler.register(
|
||||||
localpart=desired_user_id,
|
localpart=desired_user_id,
|
||||||
|
@ -304,6 +308,51 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _do_shared_secret(self, request, register_json, session):
|
||||||
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
if not isinstance(register_json.get("mac", None), basestring):
|
||||||
|
raise SynapseError(400, "Expected mac.")
|
||||||
|
if not isinstance(register_json.get("user", None), basestring):
|
||||||
|
raise SynapseError(400, "Expected 'user' key.")
|
||||||
|
if not isinstance(register_json.get("password", None), basestring):
|
||||||
|
raise SynapseError(400, "Expected 'password' key.")
|
||||||
|
|
||||||
|
if not self.hs.config.registration_shared_secret:
|
||||||
|
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||||
|
|
||||||
|
user = register_json["user"].encode("utf-8")
|
||||||
|
|
||||||
|
# str() because otherwise hmac complains that 'unicode' does not
|
||||||
|
# have the buffer interface
|
||||||
|
got_mac = str(register_json["mac"])
|
||||||
|
|
||||||
|
want_mac = hmac.new(
|
||||||
|
key=self.hs.config.registration_shared_secret,
|
||||||
|
msg=user,
|
||||||
|
digestmod=sha1,
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
password = register_json["password"].encode("utf-8")
|
||||||
|
|
||||||
|
if compare_digest(want_mac, got_mac):
|
||||||
|
handler = self.handlers.registration_handler
|
||||||
|
user_id, token = yield handler.register(
|
||||||
|
localpart=user,
|
||||||
|
password=password,
|
||||||
|
)
|
||||||
|
self._remove_session(session)
|
||||||
|
defer.returnValue({
|
||||||
|
"user_id": user_id,
|
||||||
|
"access_token": token,
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
raise SynapseError(
|
||||||
|
403, "HMAC incorrect",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_json(request):
|
def _parse_json(request):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -21,6 +21,7 @@ from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.expiringcache import ExpiringCache
|
from synapse.util.expiringcache import ExpiringCache
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
|
from synapse.api.auth import AuthEventTypes
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
@ -38,12 +39,6 @@ def _get_state_key_from_event(event):
|
||||||
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
|
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
|
||||||
|
|
||||||
|
|
||||||
AuthEventTypes = (
|
|
||||||
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
|
|
||||||
EventTypes.JoinRules,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
SIZE_OF_CACHE = 1000
|
SIZE_OF_CACHE = 1000
|
||||||
EVICTION_TIMEOUT_SECONDS = 20
|
EVICTION_TIMEOUT_SECONDS = 20
|
||||||
|
|
||||||
|
@ -139,18 +134,6 @@ class StateHandler(object):
|
||||||
}
|
}
|
||||||
context.state_group = None
|
context.state_group = None
|
||||||
|
|
||||||
if hasattr(event, "auth_events") and event.auth_events:
|
|
||||||
auth_ids = self.hs.get_auth().compute_auth_events(
|
|
||||||
event, context.current_state
|
|
||||||
)
|
|
||||||
context.auth_events = {
|
|
||||||
k: v
|
|
||||||
for k, v in context.current_state.items()
|
|
||||||
if v.event_id in auth_ids
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
context.auth_events = {}
|
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in context.current_state:
|
if key in context.current_state:
|
||||||
|
@ -187,18 +170,6 @@ class StateHandler(object):
|
||||||
replaces = context.current_state[key]
|
replaces = context.current_state[key]
|
||||||
event.unsigned["replaces_state"] = replaces.event_id
|
event.unsigned["replaces_state"] = replaces.event_id
|
||||||
|
|
||||||
if hasattr(event, "auth_events") and event.auth_events:
|
|
||||||
auth_ids = self.hs.get_auth().compute_auth_events(
|
|
||||||
event, context.current_state
|
|
||||||
)
|
|
||||||
context.auth_events = {
|
|
||||||
k: v
|
|
||||||
for k, v in context.current_state.items()
|
|
||||||
if v.event_id in auth_ids
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
context.auth_events = {}
|
|
||||||
|
|
||||||
context.prev_state_events = prev_state
|
context.prev_state_events = prev_state
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
|
|
|
@ -14,15 +14,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.logutils import log_function
|
|
||||||
from synapse.api.constants import EventTypes
|
|
||||||
|
|
||||||
from .appservice import (
|
from .appservice import (
|
||||||
ApplicationServiceStore, ApplicationServiceTransactionStore
|
ApplicationServiceStore, ApplicationServiceTransactionStore
|
||||||
)
|
)
|
||||||
|
from ._base import Cache
|
||||||
from .directory import DirectoryStore
|
from .directory import DirectoryStore
|
||||||
from .feedback import FeedbackStore
|
from .events import EventsStore
|
||||||
from .presence import PresenceStore
|
from .presence import PresenceStore
|
||||||
from .profile import ProfileStore
|
from .profile import ProfileStore
|
||||||
from .registration import RegistrationStore
|
from .registration import RegistrationStore
|
||||||
|
@ -41,11 +38,6 @@ from .state import StateStore
|
||||||
from .signatures import SignatureStore
|
from .signatures import SignatureStore
|
||||||
from .filtering import FilteringStore
|
from .filtering import FilteringStore
|
||||||
|
|
||||||
from syutil.base64util import decode_base64
|
|
||||||
from syutil.jsonutil import encode_canonical_json
|
|
||||||
|
|
||||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
|
||||||
|
|
||||||
|
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import imp
|
import imp
|
||||||
|
@ -63,16 +55,14 @@ SCHEMA_VERSION = 15
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
||||||
class _RollbackButIsFineException(Exception):
|
# times give more inserts into the database even for readonly API hits
|
||||||
""" This exception is used to rollback a transaction without implying
|
# 120 seconds == 2 minutes
|
||||||
something went wrong.
|
LAST_SEEN_GRANULARITY = 120*1000
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DataStore(RoomMemberStore, RoomStore,
|
class DataStore(RoomMemberStore, RoomStore,
|
||||||
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
|
RegistrationStore, StreamStore, ProfileStore,
|
||||||
PresenceStore, TransactionStore,
|
PresenceStore, TransactionStore,
|
||||||
DirectoryStore, KeyStore, StateStore, SignatureStore,
|
DirectoryStore, KeyStore, StateStore, SignatureStore,
|
||||||
ApplicationServiceStore,
|
ApplicationServiceStore,
|
||||||
|
@ -83,6 +73,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
PusherStore,
|
PusherStore,
|
||||||
PushRuleStore,
|
PushRuleStore,
|
||||||
ApplicationServiceTransactionStore,
|
ApplicationServiceTransactionStore,
|
||||||
|
EventsStore,
|
||||||
):
|
):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
@ -92,424 +83,28 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
self.min_token_deferred = self._get_min_token()
|
self.min_token_deferred = self._get_min_token()
|
||||||
self.min_token = None
|
self.min_token = None
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
self.client_ip_last_seen = Cache(
|
||||||
@log_function
|
name="client_ip_last_seen",
|
||||||
def persist_event(self, event, context, backfilled=False,
|
keylen=4,
|
||||||
is_new_state=True, current_state=None):
|
|
||||||
stream_ordering = None
|
|
||||||
if backfilled:
|
|
||||||
if not self.min_token_deferred.called:
|
|
||||||
yield self.min_token_deferred
|
|
||||||
self.min_token -= 1
|
|
||||||
stream_ordering = self.min_token
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield self.runInteraction(
|
|
||||||
"persist_event",
|
|
||||||
self._persist_event_txn,
|
|
||||||
event=event,
|
|
||||||
context=context,
|
|
||||||
backfilled=backfilled,
|
|
||||||
stream_ordering=stream_ordering,
|
|
||||||
is_new_state=is_new_state,
|
|
||||||
current_state=current_state,
|
|
||||||
)
|
|
||||||
except _RollbackButIsFineException:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_event(self, event_id, check_redacted=True,
|
|
||||||
get_prev_content=False, allow_rejected=False,
|
|
||||||
allow_none=False):
|
|
||||||
"""Get an event from the database by event_id.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_id (str): The event_id of the event to fetch
|
|
||||||
check_redacted (bool): If True, check if event has been redacted
|
|
||||||
and redact it.
|
|
||||||
get_prev_content (bool): If True and event is a state event,
|
|
||||||
include the previous states content in the unsigned field.
|
|
||||||
allow_rejected (bool): If True return rejected events.
|
|
||||||
allow_none (bool): If True, return None if no event found, if
|
|
||||||
False throw an exception.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred : A FrozenEvent.
|
|
||||||
"""
|
|
||||||
event = yield self.runInteraction(
|
|
||||||
"get_event", self._get_event_txn,
|
|
||||||
event_id,
|
|
||||||
check_redacted=check_redacted,
|
|
||||||
get_prev_content=get_prev_content,
|
|
||||||
allow_rejected=allow_rejected,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not event and not allow_none:
|
|
||||||
raise RuntimeError("Could not find event %s" % (event_id,))
|
|
||||||
|
|
||||||
defer.returnValue(event)
|
|
||||||
|
|
||||||
@log_function
|
|
||||||
def _persist_event_txn(self, txn, event, context, backfilled,
|
|
||||||
stream_ordering=None, is_new_state=True,
|
|
||||||
current_state=None):
|
|
||||||
|
|
||||||
# Remove the any existing cache entries for the event_id
|
|
||||||
self._get_event_cache.pop(event.event_id)
|
|
||||||
|
|
||||||
# We purposefully do this first since if we include a `current_state`
|
|
||||||
# key, we *want* to update the `current_state_events` table
|
|
||||||
if current_state:
|
|
||||||
txn.execute(
|
|
||||||
"DELETE FROM current_state_events WHERE room_id = ?",
|
|
||||||
(event.room_id,)
|
|
||||||
)
|
|
||||||
|
|
||||||
for s in current_state:
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
"current_state_events",
|
|
||||||
{
|
|
||||||
"event_id": s.event_id,
|
|
||||||
"room_id": s.room_id,
|
|
||||||
"type": s.type,
|
|
||||||
"state_key": s.state_key,
|
|
||||||
},
|
|
||||||
or_replace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if event.is_state() and is_new_state:
|
|
||||||
if not backfilled and not context.rejected:
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
table="state_forward_extremities",
|
|
||||||
values={
|
|
||||||
"event_id": event.event_id,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"type": event.type,
|
|
||||||
"state_key": event.state_key,
|
|
||||||
},
|
|
||||||
or_replace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for prev_state_id, _ in event.prev_state:
|
|
||||||
self._simple_delete_txn(
|
|
||||||
txn,
|
|
||||||
table="state_forward_extremities",
|
|
||||||
keyvalues={
|
|
||||||
"event_id": prev_state_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
outlier = event.internal_metadata.is_outlier()
|
|
||||||
|
|
||||||
if not outlier:
|
|
||||||
self._store_state_groups_txn(txn, event, context)
|
|
||||||
|
|
||||||
self._update_min_depth_for_room_txn(
|
|
||||||
txn,
|
|
||||||
event.room_id,
|
|
||||||
event.depth
|
|
||||||
)
|
|
||||||
|
|
||||||
self._handle_prev_events(
|
|
||||||
txn,
|
|
||||||
outlier=outlier,
|
|
||||||
event_id=event.event_id,
|
|
||||||
prev_events=event.prev_events,
|
|
||||||
room_id=event.room_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
have_persisted = self._simple_select_one_onecol_txn(
|
|
||||||
txn,
|
|
||||||
table="event_json",
|
|
||||||
keyvalues={"event_id": event.event_id},
|
|
||||||
retcol="event_id",
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata_json = encode_canonical_json(
|
|
||||||
event.internal_metadata.get_dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
# If we have already persisted this event, we don't need to do any
|
|
||||||
# more processing.
|
|
||||||
# The processing above must be done on every call to persist event,
|
|
||||||
# since they might not have happened on previous calls. For example,
|
|
||||||
# if we are persisting an event that we had persisted as an outlier,
|
|
||||||
# but is no longer one.
|
|
||||||
if have_persisted:
|
|
||||||
if not outlier:
|
|
||||||
sql = (
|
|
||||||
"UPDATE event_json SET internal_metadata = ?"
|
|
||||||
" WHERE event_id = ?"
|
|
||||||
)
|
|
||||||
txn.execute(
|
|
||||||
sql,
|
|
||||||
(metadata_json.decode("UTF-8"), event.event_id,)
|
|
||||||
)
|
|
||||||
|
|
||||||
sql = (
|
|
||||||
"UPDATE events SET outlier = 0"
|
|
||||||
" WHERE event_id = ?"
|
|
||||||
)
|
|
||||||
txn.execute(
|
|
||||||
sql,
|
|
||||||
(event.event_id,)
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
|
||||||
self._store_room_member_txn(txn, event)
|
|
||||||
elif event.type == EventTypes.Feedback:
|
|
||||||
self._store_feedback_txn(txn, event)
|
|
||||||
elif event.type == EventTypes.Name:
|
|
||||||
self._store_room_name_txn(txn, event)
|
|
||||||
elif event.type == EventTypes.Topic:
|
|
||||||
self._store_room_topic_txn(txn, event)
|
|
||||||
elif event.type == EventTypes.Redaction:
|
|
||||||
self._store_redaction(txn, event)
|
|
||||||
|
|
||||||
event_dict = {
|
|
||||||
k: v
|
|
||||||
for k, v in event.get_dict().items()
|
|
||||||
if k not in [
|
|
||||||
"redacted",
|
|
||||||
"redacted_because",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
table="event_json",
|
|
||||||
values={
|
|
||||||
"event_id": event.event_id,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"internal_metadata": metadata_json.decode("UTF-8"),
|
|
||||||
"json": encode_canonical_json(event_dict).decode("UTF-8"),
|
|
||||||
},
|
|
||||||
or_replace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
content = encode_canonical_json(
|
|
||||||
event.content
|
|
||||||
).decode("UTF-8")
|
|
||||||
|
|
||||||
vals = {
|
|
||||||
"topological_ordering": event.depth,
|
|
||||||
"event_id": event.event_id,
|
|
||||||
"type": event.type,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"content": content,
|
|
||||||
"processed": True,
|
|
||||||
"outlier": outlier,
|
|
||||||
"depth": event.depth,
|
|
||||||
}
|
|
||||||
|
|
||||||
if stream_ordering is not None:
|
|
||||||
vals["stream_ordering"] = stream_ordering
|
|
||||||
|
|
||||||
unrec = {
|
|
||||||
k: v
|
|
||||||
for k, v in event.get_dict().items()
|
|
||||||
if k not in vals.keys() and k not in [
|
|
||||||
"redacted",
|
|
||||||
"redacted_because",
|
|
||||||
"signatures",
|
|
||||||
"hashes",
|
|
||||||
"prev_events",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
vals["unrecognized_keys"] = encode_canonical_json(
|
|
||||||
unrec
|
|
||||||
).decode("UTF-8")
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
"events",
|
|
||||||
vals,
|
|
||||||
or_replace=(not outlier),
|
|
||||||
or_ignore=bool(outlier),
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
logger.warn(
|
|
||||||
"Failed to persist, probably duplicate: %s",
|
|
||||||
event.event_id,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
raise _RollbackButIsFineException("_persist_event")
|
|
||||||
|
|
||||||
if context.rejected:
|
|
||||||
self._store_rejections_txn(txn, event.event_id, context.rejected)
|
|
||||||
|
|
||||||
if event.is_state():
|
|
||||||
vals = {
|
|
||||||
"event_id": event.event_id,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"type": event.type,
|
|
||||||
"state_key": event.state_key,
|
|
||||||
}
|
|
||||||
|
|
||||||
# TODO: How does this work with backfilling?
|
|
||||||
if hasattr(event, "replaces_state"):
|
|
||||||
vals["prev_state"] = event.replaces_state
|
|
||||||
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
"state_events",
|
|
||||||
vals,
|
|
||||||
or_replace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_new_state and not context.rejected:
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
"current_state_events",
|
|
||||||
{
|
|
||||||
"event_id": event.event_id,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"type": event.type,
|
|
||||||
"state_key": event.state_key,
|
|
||||||
},
|
|
||||||
or_replace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for e_id, h in event.prev_state:
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
table="event_edges",
|
|
||||||
values={
|
|
||||||
"event_id": event.event_id,
|
|
||||||
"prev_event_id": e_id,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"is_state": 1,
|
|
||||||
},
|
|
||||||
or_ignore=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for hash_alg, hash_base64 in event.hashes.items():
|
|
||||||
hash_bytes = decode_base64(hash_base64)
|
|
||||||
self._store_event_content_hash_txn(
|
|
||||||
txn, event.event_id, hash_alg, hash_bytes,
|
|
||||||
)
|
|
||||||
|
|
||||||
for prev_event_id, prev_hashes in event.prev_events:
|
|
||||||
for alg, hash_base64 in prev_hashes.items():
|
|
||||||
hash_bytes = decode_base64(hash_base64)
|
|
||||||
self._store_prev_event_hash_txn(
|
|
||||||
txn, event.event_id, prev_event_id, alg, hash_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
for auth_id, _ in event.auth_events:
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
table="event_auth",
|
|
||||||
values={
|
|
||||||
"event_id": event.event_id,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"auth_id": auth_id,
|
|
||||||
},
|
|
||||||
or_ignore=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
|
|
||||||
self._store_event_reference_hash_txn(
|
|
||||||
txn, event.event_id, ref_alg, ref_hash_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
def _store_redaction(self, txn, event):
|
|
||||||
# invalidate the cache for the redacted event
|
|
||||||
self._get_event_cache.pop(event.redacts)
|
|
||||||
txn.execute(
|
|
||||||
"INSERT OR IGNORE INTO redactions "
|
|
||||||
"(event_id, redacts) VALUES (?,?)",
|
|
||||||
(event.event_id, event.redacts)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_current_state(self, room_id, event_type=None, state_key=""):
|
|
||||||
del_sql = (
|
|
||||||
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
|
|
||||||
"LIMIT 1"
|
|
||||||
)
|
|
||||||
|
|
||||||
sql = (
|
|
||||||
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
|
|
||||||
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
|
|
||||||
"INNER JOIN state_events as s ON e.event_id = s.event_id "
|
|
||||||
"WHERE c.room_id = ? "
|
|
||||||
) % {
|
|
||||||
"redacted": del_sql,
|
|
||||||
}
|
|
||||||
|
|
||||||
if event_type and state_key is not None:
|
|
||||||
sql += " AND s.type = ? AND s.state_key = ? "
|
|
||||||
args = (room_id, event_type, state_key)
|
|
||||||
elif event_type:
|
|
||||||
sql += " AND s.type = ?"
|
|
||||||
args = (room_id, event_type)
|
|
||||||
else:
|
|
||||||
args = (room_id, )
|
|
||||||
|
|
||||||
results = yield self._execute_and_decode("get_current_state", sql, *args)
|
|
||||||
|
|
||||||
events = yield self._parse_events(results)
|
|
||||||
defer.returnValue(events)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_room_name_and_aliases(self, room_id):
|
|
||||||
del_sql = (
|
|
||||||
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
|
|
||||||
"LIMIT 1"
|
|
||||||
)
|
|
||||||
|
|
||||||
sql = (
|
|
||||||
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
|
|
||||||
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
|
|
||||||
"INNER JOIN state_events as s ON e.event_id = s.event_id "
|
|
||||||
"WHERE c.room_id = ? "
|
|
||||||
) % {
|
|
||||||
"redacted": del_sql,
|
|
||||||
}
|
|
||||||
|
|
||||||
sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
|
|
||||||
sql += " OR s.type = 'm.room.aliases')"
|
|
||||||
args = (room_id,)
|
|
||||||
|
|
||||||
results = yield self._execute_and_decode("get_current_state", sql, *args)
|
|
||||||
|
|
||||||
events = yield self._parse_events(results)
|
|
||||||
|
|
||||||
name = None
|
|
||||||
aliases = []
|
|
||||||
|
|
||||||
for e in events:
|
|
||||||
if e.type == 'm.room.name':
|
|
||||||
if 'name' in e.content:
|
|
||||||
name = e.content['name']
|
|
||||||
elif e.type == 'm.room.aliases':
|
|
||||||
if 'aliases' in e.content:
|
|
||||||
aliases.extend(e.content['aliases'])
|
|
||||||
|
|
||||||
defer.returnValue((name, aliases))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _get_min_token(self):
|
|
||||||
row = yield self._execute(
|
|
||||||
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
|
|
||||||
self.min_token = min(self.min_token, -1)
|
|
||||||
|
|
||||||
logger.debug("min_token is: %s", self.min_token)
|
|
||||||
|
|
||||||
defer.returnValue(self.min_token)
|
|
||||||
|
|
||||||
def insert_client_ip(self, user, access_token, device_id, ip, user_agent):
|
def insert_client_ip(self, user, access_token, device_id, ip, user_agent):
|
||||||
return self._simple_insert(
|
now = int(self._clock.time_msec())
|
||||||
|
key = (user.to_string(), access_token, device_id, ip)
|
||||||
|
|
||||||
|
try:
|
||||||
|
last_seen = self.client_ip_last_seen.get(*key)
|
||||||
|
except KeyError:
|
||||||
|
last_seen = None
|
||||||
|
|
||||||
|
# Rate-limited inserts
|
||||||
|
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||||
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
self.client_ip_last_seen.prefill(*key + (now,))
|
||||||
|
|
||||||
|
yield self._simple_insert(
|
||||||
"user_ips",
|
"user_ips",
|
||||||
{
|
{
|
||||||
"user": user.to_string(),
|
"user": user.to_string(),
|
||||||
|
@ -517,8 +112,9 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
"ip": ip,
|
"ip": ip,
|
||||||
"user_agent": user_agent,
|
"user_agent": user_agent,
|
||||||
"last_seen": int(self._clock.time_msec()),
|
"last_seen": now,
|
||||||
}
|
},
|
||||||
|
desc="insert_client_ip",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_user_ip_and_agents(self, user):
|
def get_user_ip_and_agents(self, user):
|
||||||
|
@ -528,38 +124,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
retcols=[
|
retcols=[
|
||||||
"device_id", "access_token", "ip", "user_agent", "last_seen"
|
"device_id", "access_token", "ip", "user_agent", "last_seen"
|
||||||
],
|
],
|
||||||
)
|
desc="get_user_ip_and_agents",
|
||||||
|
|
||||||
def have_events(self, event_ids):
|
|
||||||
"""Given a list of event ids, check if we have already processed them.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Has an entry for each event id we already have seen. Maps to
|
|
||||||
the rejected reason string if we rejected the event, else maps to
|
|
||||||
None.
|
|
||||||
"""
|
|
||||||
if not event_ids:
|
|
||||||
return defer.succeed({})
|
|
||||||
|
|
||||||
def f(txn):
|
|
||||||
sql = (
|
|
||||||
"SELECT e.event_id, reason FROM events as e "
|
|
||||||
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
|
|
||||||
"WHERE e.event_id = ?"
|
|
||||||
)
|
|
||||||
|
|
||||||
res = {}
|
|
||||||
for event_id in event_ids:
|
|
||||||
txn.execute(sql, (event_id,))
|
|
||||||
row = txn.fetchone()
|
|
||||||
if row:
|
|
||||||
_, rejected = row
|
|
||||||
res[event_id] = rejected
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
return self.runInteraction(
|
|
||||||
"have_events", f,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ import synapse.metrics
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from collections import namedtuple, OrderedDict
|
from collections import namedtuple, OrderedDict
|
||||||
|
import functools
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
@ -38,6 +39,8 @@ transaction_logger = logging.getLogger("synapse.storage.txn")
|
||||||
|
|
||||||
metrics = synapse.metrics.get_metrics_for("synapse.storage")
|
metrics = synapse.metrics.get_metrics_for("synapse.storage")
|
||||||
|
|
||||||
|
sql_scheduling_timer = metrics.register_distribution("schedule_time")
|
||||||
|
|
||||||
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
|
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
|
||||||
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
|
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
|
||||||
sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
|
sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
|
||||||
|
@ -50,14 +53,57 @@ cache_counter = metrics.register_cache(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO(paul):
|
class Cache(object):
|
||||||
# * more generic key management
|
|
||||||
# * consider other eviction strategies - LRU?
|
def __init__(self, name, max_entries=1000, keylen=1, lru=False):
|
||||||
def cached(max_entries=1000):
|
if lru:
|
||||||
|
self.cache = LruCache(max_size=max_entries)
|
||||||
|
self.max_entries = None
|
||||||
|
else:
|
||||||
|
self.cache = OrderedDict()
|
||||||
|
self.max_entries = max_entries
|
||||||
|
|
||||||
|
self.name = name
|
||||||
|
self.keylen = keylen
|
||||||
|
|
||||||
|
caches_by_name[name] = self.cache
|
||||||
|
|
||||||
|
def get(self, *keyargs):
|
||||||
|
if len(keyargs) != self.keylen:
|
||||||
|
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||||
|
|
||||||
|
if keyargs in self.cache:
|
||||||
|
cache_counter.inc_hits(self.name)
|
||||||
|
return self.cache[keyargs]
|
||||||
|
|
||||||
|
cache_counter.inc_misses(self.name)
|
||||||
|
raise KeyError()
|
||||||
|
|
||||||
|
def prefill(self, *args): # because I can't *keyargs, value
|
||||||
|
keyargs = args[:-1]
|
||||||
|
value = args[-1]
|
||||||
|
|
||||||
|
if len(keyargs) != self.keylen:
|
||||||
|
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||||
|
|
||||||
|
if self.max_entries is not None:
|
||||||
|
while len(self.cache) >= self.max_entries:
|
||||||
|
self.cache.popitem(last=False)
|
||||||
|
|
||||||
|
self.cache[keyargs] = value
|
||||||
|
|
||||||
|
def invalidate(self, *keyargs):
|
||||||
|
if len(keyargs) != self.keylen:
|
||||||
|
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||||
|
|
||||||
|
self.cache.pop(keyargs, None)
|
||||||
|
|
||||||
|
|
||||||
|
def cached(max_entries=1000, num_args=1, lru=False):
|
||||||
""" A method decorator that applies a memoizing cache around the function.
|
""" A method decorator that applies a memoizing cache around the function.
|
||||||
|
|
||||||
The function is presumed to take one additional argument, which is used as
|
The function is presumed to take zero or more arguments, which are used in
|
||||||
the key for the cache. Cache hits are served directly from the cache;
|
a tuple as the key for the cache. Hits are served directly from the cache;
|
||||||
misses use the function body to generate the value.
|
misses use the function body to generate the value.
|
||||||
|
|
||||||
The wrapped function has an additional member, a callable called
|
The wrapped function has an additional member, a callable called
|
||||||
|
@ -68,33 +114,27 @@ def cached(max_entries=1000):
|
||||||
calling the calculation function.
|
calling the calculation function.
|
||||||
"""
|
"""
|
||||||
def wrap(orig):
|
def wrap(orig):
|
||||||
cache = OrderedDict()
|
cache = Cache(
|
||||||
name = orig.__name__
|
name=orig.__name__,
|
||||||
|
max_entries=max_entries,
|
||||||
caches_by_name[name] = cache
|
keylen=num_args,
|
||||||
|
lru=lru,
|
||||||
def prefill(key, value):
|
)
|
||||||
while len(cache) > max_entries:
|
|
||||||
cache.popitem(last=False)
|
|
||||||
|
|
||||||
cache[key] = value
|
|
||||||
|
|
||||||
|
@functools.wraps(orig)
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wrapped(self, key):
|
def wrapped(self, *keyargs):
|
||||||
if key in cache:
|
try:
|
||||||
cache_counter.inc_hits(name)
|
defer.returnValue(cache.get(*keyargs))
|
||||||
defer.returnValue(cache[key])
|
except KeyError:
|
||||||
|
ret = yield orig(self, *keyargs)
|
||||||
|
|
||||||
cache_counter.inc_misses(name)
|
cache.prefill(*keyargs + (ret,))
|
||||||
ret = yield orig(self, key)
|
|
||||||
prefill(key, ret)
|
|
||||||
defer.returnValue(ret)
|
|
||||||
|
|
||||||
def invalidate(key):
|
defer.returnValue(ret)
|
||||||
cache.pop(key, None)
|
|
||||||
|
|
||||||
wrapped.invalidate = invalidate
|
wrapped.invalidate = cache.invalidate
|
||||||
wrapped.prefill = prefill
|
wrapped.prefill = cache.prefill
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
return wrap
|
return wrap
|
||||||
|
@ -240,6 +280,8 @@ class SQLBaseStore(object):
|
||||||
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
||||||
current_context = LoggingContext.current_context()
|
current_context = LoggingContext.current_context()
|
||||||
|
|
||||||
|
start_time = time.time() * 1000
|
||||||
|
|
||||||
def inner_func(txn, *args, **kwargs):
|
def inner_func(txn, *args, **kwargs):
|
||||||
with LoggingContext("runInteraction") as context:
|
with LoggingContext("runInteraction") as context:
|
||||||
current_context.copy_to(context)
|
current_context.copy_to(context)
|
||||||
|
@ -252,6 +294,7 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
name = "%s-%x" % (desc, txn_id, )
|
name = "%s-%x" % (desc, txn_id, )
|
||||||
|
|
||||||
|
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
||||||
transaction_logger.debug("[TXN START] {%s}", name)
|
transaction_logger.debug("[TXN START] {%s}", name)
|
||||||
try:
|
try:
|
||||||
return func(LoggingTransaction(txn, name), *args, **kwargs)
|
return func(LoggingTransaction(txn, name), *args, **kwargs)
|
||||||
|
@ -314,7 +357,8 @@ class SQLBaseStore(object):
|
||||||
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
||||||
# no complex WHERE clauses, just a dict of values for columns.
|
# no complex WHERE clauses, just a dict of values for columns.
|
||||||
|
|
||||||
def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
|
def _simple_insert(self, table, values, or_replace=False, or_ignore=False,
|
||||||
|
desc="_simple_insert"):
|
||||||
"""Executes an INSERT query on the named table.
|
"""Executes an INSERT query on the named table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -323,7 +367,7 @@ class SQLBaseStore(object):
|
||||||
or_replace : bool; if True performs an INSERT OR REPLACE
|
or_replace : bool; if True performs an INSERT OR REPLACE
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"_simple_insert",
|
desc,
|
||||||
self._simple_insert_txn, table, values, or_replace=or_replace,
|
self._simple_insert_txn, table, values, or_replace=or_replace,
|
||||||
or_ignore=or_ignore,
|
or_ignore=or_ignore,
|
||||||
)
|
)
|
||||||
|
@ -347,7 +391,7 @@ class SQLBaseStore(object):
|
||||||
txn.execute(sql, values.values())
|
txn.execute(sql, values.values())
|
||||||
return txn.lastrowid
|
return txn.lastrowid
|
||||||
|
|
||||||
def _simple_upsert(self, table, keyvalues, values):
|
def _simple_upsert(self, table, keyvalues, values, desc="_simple_upsert"):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
table (str): The table to upsert into
|
table (str): The table to upsert into
|
||||||
|
@ -356,7 +400,7 @@ class SQLBaseStore(object):
|
||||||
Returns: A deferred
|
Returns: A deferred
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"_simple_upsert",
|
desc,
|
||||||
self._simple_upsert_txn, table, keyvalues, values
|
self._simple_upsert_txn, table, keyvalues, values
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -392,7 +436,7 @@ class SQLBaseStore(object):
|
||||||
txn.execute(sql, allvalues.values())
|
txn.execute(sql, allvalues.values())
|
||||||
|
|
||||||
def _simple_select_one(self, table, keyvalues, retcols,
|
def _simple_select_one(self, table, keyvalues, retcols,
|
||||||
allow_none=False):
|
allow_none=False, desc="_simple_select_one"):
|
||||||
"""Executes a SELECT query on the named table, which is expected to
|
"""Executes a SELECT query on the named table, which is expected to
|
||||||
return a single row, returning a single column from it.
|
return a single row, returning a single column from it.
|
||||||
|
|
||||||
|
@ -404,12 +448,15 @@ class SQLBaseStore(object):
|
||||||
allow_none : If true, return None instead of failing if the SELECT
|
allow_none : If true, return None instead of failing if the SELECT
|
||||||
statement returns no rows
|
statement returns no rows
|
||||||
"""
|
"""
|
||||||
return self._simple_selectupdate_one(
|
return self.runInteraction(
|
||||||
table, keyvalues, retcols=retcols, allow_none=allow_none
|
desc,
|
||||||
|
self._simple_select_one_txn,
|
||||||
|
table, keyvalues, retcols, allow_none,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _simple_select_one_onecol(self, table, keyvalues, retcol,
|
def _simple_select_one_onecol(self, table, keyvalues, retcol,
|
||||||
allow_none=False):
|
allow_none=False,
|
||||||
|
desc="_simple_select_one_onecol"):
|
||||||
"""Executes a SELECT query on the named table, which is expected to
|
"""Executes a SELECT query on the named table, which is expected to
|
||||||
return a single row, returning a single column from it."
|
return a single row, returning a single column from it."
|
||||||
|
|
||||||
|
@ -419,7 +466,7 @@ class SQLBaseStore(object):
|
||||||
retcol : string giving the name of the column to return
|
retcol : string giving the name of the column to return
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"_simple_select_one_onecol",
|
desc,
|
||||||
self._simple_select_one_onecol_txn,
|
self._simple_select_one_onecol_txn,
|
||||||
table, keyvalues, retcol, allow_none=allow_none,
|
table, keyvalues, retcol, allow_none=allow_none,
|
||||||
)
|
)
|
||||||
|
@ -455,7 +502,8 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
return [r[0] for r in txn.fetchall()]
|
return [r[0] for r in txn.fetchall()]
|
||||||
|
|
||||||
def _simple_select_onecol(self, table, keyvalues, retcol):
|
def _simple_select_onecol(self, table, keyvalues, retcol,
|
||||||
|
desc="_simple_select_onecol"):
|
||||||
"""Executes a SELECT query on the named table, which returns a list
|
"""Executes a SELECT query on the named table, which returns a list
|
||||||
comprising of the values of the named column from the selected rows.
|
comprising of the values of the named column from the selected rows.
|
||||||
|
|
||||||
|
@ -468,12 +516,13 @@ class SQLBaseStore(object):
|
||||||
Deferred: Results in a list
|
Deferred: Results in a list
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"_simple_select_onecol",
|
desc,
|
||||||
self._simple_select_onecol_txn,
|
self._simple_select_onecol_txn,
|
||||||
table, keyvalues, retcol
|
table, keyvalues, retcol
|
||||||
)
|
)
|
||||||
|
|
||||||
def _simple_select_list(self, table, keyvalues, retcols):
|
def _simple_select_list(self, table, keyvalues, retcols,
|
||||||
|
desc="_simple_select_list"):
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
more rows, returning the result as a list of dicts.
|
more rows, returning the result as a list of dicts.
|
||||||
|
|
||||||
|
@ -484,7 +533,7 @@ class SQLBaseStore(object):
|
||||||
retcols : list of strings giving the names of the columns to return
|
retcols : list of strings giving the names of the columns to return
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"_simple_select_list",
|
desc,
|
||||||
self._simple_select_list_txn,
|
self._simple_select_list_txn,
|
||||||
table, keyvalues, retcols
|
table, keyvalues, retcols
|
||||||
)
|
)
|
||||||
|
@ -516,7 +565,7 @@ class SQLBaseStore(object):
|
||||||
return self.cursor_to_dict(txn)
|
return self.cursor_to_dict(txn)
|
||||||
|
|
||||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||||
retcols=None):
|
desc="_simple_update_one"):
|
||||||
"""Executes an UPDATE query on the named table, setting new values for
|
"""Executes an UPDATE query on the named table, setting new values for
|
||||||
columns in a row matching the key values.
|
columns in a row matching the key values.
|
||||||
|
|
||||||
|
@ -534,56 +583,76 @@ class SQLBaseStore(object):
|
||||||
get-and-set. This can be used to implement compare-and-set by putting
|
get-and-set. This can be used to implement compare-and-set by putting
|
||||||
the update column in the 'keyvalues' dict as well.
|
the update column in the 'keyvalues' dict as well.
|
||||||
"""
|
"""
|
||||||
return self._simple_selectupdate_one(table, keyvalues, updatevalues,
|
return self.runInteraction(
|
||||||
retcols=retcols)
|
desc,
|
||||||
|
self._simple_update_one_txn,
|
||||||
|
table, keyvalues, updatevalues,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
|
||||||
|
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
||||||
|
table,
|
||||||
|
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||||
|
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
update_sql,
|
||||||
|
updatevalues.values() + keyvalues.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
if txn.rowcount == 0:
|
||||||
|
raise StoreError(404, "No row found")
|
||||||
|
if txn.rowcount > 1:
|
||||||
|
raise StoreError(500, "More than one row matched")
|
||||||
|
|
||||||
|
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
|
||||||
|
allow_none=False):
|
||||||
|
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
||||||
|
", ".join(retcols),
|
||||||
|
table,
|
||||||
|
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(select_sql, keyvalues.values())
|
||||||
|
|
||||||
|
row = txn.fetchone()
|
||||||
|
if not row:
|
||||||
|
if allow_none:
|
||||||
|
return None
|
||||||
|
raise StoreError(404, "No row found")
|
||||||
|
if txn.rowcount > 1:
|
||||||
|
raise StoreError(500, "More than one row matched")
|
||||||
|
|
||||||
|
return dict(zip(retcols, row))
|
||||||
|
|
||||||
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
|
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
|
||||||
retcols=None, allow_none=False):
|
retcols=None, allow_none=False,
|
||||||
|
desc="_simple_selectupdate_one"):
|
||||||
""" Combined SELECT then UPDATE."""
|
""" Combined SELECT then UPDATE."""
|
||||||
if retcols:
|
|
||||||
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
|
||||||
", ".join(retcols),
|
|
||||||
table,
|
|
||||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
|
||||||
)
|
|
||||||
|
|
||||||
if updatevalues:
|
|
||||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
|
||||||
table,
|
|
||||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
|
||||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
|
||||||
)
|
|
||||||
|
|
||||||
def func(txn):
|
def func(txn):
|
||||||
ret = None
|
ret = None
|
||||||
if retcols:
|
if retcols:
|
||||||
txn.execute(select_sql, keyvalues.values())
|
ret = self._simple_select_one_txn(
|
||||||
|
txn,
|
||||||
row = txn.fetchone()
|
table=table,
|
||||||
if not row:
|
keyvalues=keyvalues,
|
||||||
if allow_none:
|
retcols=retcols,
|
||||||
return None
|
allow_none=allow_none,
|
||||||
raise StoreError(404, "No row found")
|
|
||||||
if txn.rowcount > 1:
|
|
||||||
raise StoreError(500, "More than one row matched")
|
|
||||||
|
|
||||||
ret = dict(zip(retcols, row))
|
|
||||||
|
|
||||||
if updatevalues:
|
|
||||||
txn.execute(
|
|
||||||
update_sql,
|
|
||||||
updatevalues.values() + keyvalues.values()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if txn.rowcount == 0:
|
if updatevalues:
|
||||||
raise StoreError(404, "No row found")
|
self._simple_update_one_txn(
|
||||||
if txn.rowcount > 1:
|
txn,
|
||||||
raise StoreError(500, "More than one row matched")
|
table=table,
|
||||||
|
keyvalues=keyvalues,
|
||||||
|
updatevalues=updatevalues,
|
||||||
|
)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
return self.runInteraction("_simple_selectupdate_one", func)
|
return self.runInteraction(desc, func)
|
||||||
|
|
||||||
def _simple_delete_one(self, table, keyvalues):
|
def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
|
||||||
"""Executes a DELETE query on the named table, expecting to delete a
|
"""Executes a DELETE query on the named table, expecting to delete a
|
||||||
single row.
|
single row.
|
||||||
|
|
||||||
|
@ -602,9 +671,9 @@ class SQLBaseStore(object):
|
||||||
raise StoreError(404, "No row found")
|
raise StoreError(404, "No row found")
|
||||||
if txn.rowcount > 1:
|
if txn.rowcount > 1:
|
||||||
raise StoreError(500, "more than one row matched")
|
raise StoreError(500, "more than one row matched")
|
||||||
return self.runInteraction("_simple_delete_one", func)
|
return self.runInteraction(desc, func)
|
||||||
|
|
||||||
def _simple_delete(self, table, keyvalues):
|
def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
|
||||||
"""Executes a DELETE query on the named table.
|
"""Executes a DELETE query on the named table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -612,7 +681,7 @@ class SQLBaseStore(object):
|
||||||
keyvalues : dict of column names and values to select the row with
|
keyvalues : dict of column names and values to select the row with
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.runInteraction("_simple_delete", self._simple_delete_txn)
|
return self.runInteraction(desc, self._simple_delete_txn)
|
||||||
|
|
||||||
def _simple_delete_txn(self, txn, table, keyvalues):
|
def _simple_delete_txn(self, txn, table, keyvalues):
|
||||||
sql = "DELETE FROM %s WHERE %s" % (
|
sql = "DELETE FROM %s WHERE %s" % (
|
||||||
|
@ -782,6 +851,13 @@ class SQLBaseStore(object):
|
||||||
return result[0] if result else None
|
return result[0] if result else None
|
||||||
|
|
||||||
|
|
||||||
|
class _RollbackButIsFineException(Exception):
|
||||||
|
""" This exception is used to rollback a transaction without implying
|
||||||
|
something went wrong.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Table(object):
|
class Table(object):
|
||||||
""" A base class used to store information about a particular table.
|
""" A base class used to store information about a particular table.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore, cached
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
|
@ -48,6 +48,7 @@ class DirectoryStore(SQLBaseStore):
|
||||||
{"room_alias": room_alias.to_string()},
|
{"room_alias": room_alias.to_string()},
|
||||||
"room_id",
|
"room_id",
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="get_association_from_room_alias",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not room_id:
|
if not room_id:
|
||||||
|
@ -58,6 +59,7 @@ class DirectoryStore(SQLBaseStore):
|
||||||
"room_alias_servers",
|
"room_alias_servers",
|
||||||
{"room_alias": room_alias.to_string()},
|
{"room_alias": room_alias.to_string()},
|
||||||
"server",
|
"server",
|
||||||
|
desc="get_association_from_room_alias",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not servers:
|
if not servers:
|
||||||
|
@ -87,6 +89,7 @@ class DirectoryStore(SQLBaseStore):
|
||||||
"room_alias": room_alias.to_string(),
|
"room_alias": room_alias.to_string(),
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
},
|
},
|
||||||
|
desc="create_room_alias_association",
|
||||||
)
|
)
|
||||||
except sqlite3.IntegrityError:
|
except sqlite3.IntegrityError:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
@ -100,16 +103,22 @@ class DirectoryStore(SQLBaseStore):
|
||||||
{
|
{
|
||||||
"room_alias": room_alias.to_string(),
|
"room_alias": room_alias.to_string(),
|
||||||
"server": server,
|
"server": server,
|
||||||
}
|
},
|
||||||
|
desc="create_room_alias_association",
|
||||||
)
|
)
|
||||||
|
self.get_aliases_for_room.invalidate(room_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def delete_room_alias(self, room_alias):
|
def delete_room_alias(self, room_alias):
|
||||||
return self.runInteraction(
|
room_id = yield self.runInteraction(
|
||||||
"delete_room_alias",
|
"delete_room_alias",
|
||||||
self._delete_room_alias_txn,
|
self._delete_room_alias_txn,
|
||||||
room_alias,
|
room_alias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.get_aliases_for_room.invalidate(room_id)
|
||||||
|
defer.returnValue(room_id)
|
||||||
|
|
||||||
def _delete_room_alias_txn(self, txn, room_alias):
|
def _delete_room_alias_txn(self, txn, room_alias):
|
||||||
cursor = txn.execute(
|
cursor = txn.execute(
|
||||||
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
|
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
|
||||||
|
@ -134,9 +143,11 @@ class DirectoryStore(SQLBaseStore):
|
||||||
|
|
||||||
return room_id
|
return room_id
|
||||||
|
|
||||||
|
@cached()
|
||||||
def get_aliases_for_room(self, room_id):
|
def get_aliases_for_room(self, room_id):
|
||||||
return self._simple_select_onecol(
|
return self._simple_select_onecol(
|
||||||
"room_aliases",
|
"room_aliases",
|
||||||
{"room_id": room_id},
|
{"room_id": room_id},
|
||||||
"room_alias",
|
"room_alias",
|
||||||
|
desc="get_aliases_for_room",
|
||||||
)
|
)
|
||||||
|
|
|
@ -429,3 +429,15 @@ class EventFederationStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
return events[:limit]
|
return events[:limit]
|
||||||
|
|
||||||
|
def clean_room_for_join(self, room_id):
|
||||||
|
return self.runInteraction(
|
||||||
|
"clean_room_for_join",
|
||||||
|
self._clean_room_for_join_txn,
|
||||||
|
room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clean_room_for_join_txn(self, txn, room_id):
|
||||||
|
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||||
|
|
||||||
|
txn.execute(query, (room_id,))
|
||||||
|
|
|
@ -0,0 +1,395 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2014, 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from _base import SQLBaseStore, _RollbackButIsFineException
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.util.logutils import log_function
|
||||||
|
from synapse.api.constants import EventTypes
|
||||||
|
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||||
|
|
||||||
|
from syutil.base64util import decode_base64
|
||||||
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EventsStore(SQLBaseStore):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
@log_function
|
||||||
|
def persist_event(self, event, context, backfilled=False,
|
||||||
|
is_new_state=True, current_state=None):
|
||||||
|
stream_ordering = None
|
||||||
|
if backfilled:
|
||||||
|
if not self.min_token_deferred.called:
|
||||||
|
yield self.min_token_deferred
|
||||||
|
self.min_token -= 1
|
||||||
|
stream_ordering = self.min_token
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.runInteraction(
|
||||||
|
"persist_event",
|
||||||
|
self._persist_event_txn,
|
||||||
|
event=event,
|
||||||
|
context=context,
|
||||||
|
backfilled=backfilled,
|
||||||
|
stream_ordering=stream_ordering,
|
||||||
|
is_new_state=is_new_state,
|
||||||
|
current_state=current_state,
|
||||||
|
)
|
||||||
|
self.get_room_events_max_id.invalidate()
|
||||||
|
except _RollbackButIsFineException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_event(self, event_id, check_redacted=True,
|
||||||
|
get_prev_content=False, allow_rejected=False,
|
||||||
|
allow_none=False):
|
||||||
|
"""Get an event from the database by event_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id (str): The event_id of the event to fetch
|
||||||
|
check_redacted (bool): If True, check if event has been redacted
|
||||||
|
and redact it.
|
||||||
|
get_prev_content (bool): If True and event is a state event,
|
||||||
|
include the previous states content in the unsigned field.
|
||||||
|
allow_rejected (bool): If True return rejected events.
|
||||||
|
allow_none (bool): If True, return None if no event found, if
|
||||||
|
False throw an exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred : A FrozenEvent.
|
||||||
|
"""
|
||||||
|
event = yield self.runInteraction(
|
||||||
|
"get_event", self._get_event_txn,
|
||||||
|
event_id,
|
||||||
|
check_redacted=check_redacted,
|
||||||
|
get_prev_content=get_prev_content,
|
||||||
|
allow_rejected=allow_rejected,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not event and not allow_none:
|
||||||
|
raise RuntimeError("Could not find event %s" % (event_id,))
|
||||||
|
|
||||||
|
defer.returnValue(event)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def _persist_event_txn(self, txn, event, context, backfilled,
|
||||||
|
stream_ordering=None, is_new_state=True,
|
||||||
|
current_state=None):
|
||||||
|
|
||||||
|
# Remove the any existing cache entries for the event_id
|
||||||
|
self._get_event_cache.pop(event.event_id)
|
||||||
|
|
||||||
|
# We purposefully do this first since if we include a `current_state`
|
||||||
|
# key, we *want* to update the `current_state_events` table
|
||||||
|
if current_state:
|
||||||
|
txn.execute(
|
||||||
|
"DELETE FROM current_state_events WHERE room_id = ?",
|
||||||
|
(event.room_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
for s in current_state:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
"current_state_events",
|
||||||
|
{
|
||||||
|
"event_id": s.event_id,
|
||||||
|
"room_id": s.room_id,
|
||||||
|
"type": s.type,
|
||||||
|
"state_key": s.state_key,
|
||||||
|
},
|
||||||
|
or_replace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if event.is_state() and is_new_state:
|
||||||
|
if not backfilled and not context.rejected:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="state_forward_extremities",
|
||||||
|
values={
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"type": event.type,
|
||||||
|
"state_key": event.state_key,
|
||||||
|
},
|
||||||
|
or_replace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for prev_state_id, _ in event.prev_state:
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="state_forward_extremities",
|
||||||
|
keyvalues={
|
||||||
|
"event_id": prev_state_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
outlier = event.internal_metadata.is_outlier()
|
||||||
|
|
||||||
|
if not outlier:
|
||||||
|
self._store_state_groups_txn(txn, event, context)
|
||||||
|
|
||||||
|
self._update_min_depth_for_room_txn(
|
||||||
|
txn,
|
||||||
|
event.room_id,
|
||||||
|
event.depth
|
||||||
|
)
|
||||||
|
|
||||||
|
self._handle_prev_events(
|
||||||
|
txn,
|
||||||
|
outlier=outlier,
|
||||||
|
event_id=event.event_id,
|
||||||
|
prev_events=event.prev_events,
|
||||||
|
room_id=event.room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
have_persisted = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="event_json",
|
||||||
|
keyvalues={"event_id": event.event_id},
|
||||||
|
retcol="event_id",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata_json = encode_canonical_json(
|
||||||
|
event.internal_metadata.get_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we have already persisted this event, we don't need to do any
|
||||||
|
# more processing.
|
||||||
|
# The processing above must be done on every call to persist event,
|
||||||
|
# since they might not have happened on previous calls. For example,
|
||||||
|
# if we are persisting an event that we had persisted as an outlier,
|
||||||
|
# but is no longer one.
|
||||||
|
if have_persisted:
|
||||||
|
if not outlier:
|
||||||
|
sql = (
|
||||||
|
"UPDATE event_json SET internal_metadata = ?"
|
||||||
|
" WHERE event_id = ?"
|
||||||
|
)
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(metadata_json.decode("UTF-8"), event.event_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"UPDATE events SET outlier = 0"
|
||||||
|
" WHERE event_id = ?"
|
||||||
|
)
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(event.event_id,)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if event.type == EventTypes.Member:
|
||||||
|
self._store_room_member_txn(txn, event)
|
||||||
|
elif event.type == EventTypes.Feedback:
|
||||||
|
self._store_feedback_txn(txn, event)
|
||||||
|
elif event.type == EventTypes.Name:
|
||||||
|
self._store_room_name_txn(txn, event)
|
||||||
|
elif event.type == EventTypes.Topic:
|
||||||
|
self._store_room_topic_txn(txn, event)
|
||||||
|
elif event.type == EventTypes.Redaction:
|
||||||
|
self._store_redaction(txn, event)
|
||||||
|
|
||||||
|
event_dict = {
|
||||||
|
k: v
|
||||||
|
for k, v in event.get_dict().items()
|
||||||
|
if k not in [
|
||||||
|
"redacted",
|
||||||
|
"redacted_because",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="event_json",
|
||||||
|
values={
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"internal_metadata": metadata_json.decode("UTF-8"),
|
||||||
|
"json": encode_canonical_json(event_dict).decode("UTF-8"),
|
||||||
|
},
|
||||||
|
or_replace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = encode_canonical_json(
|
||||||
|
event.content
|
||||||
|
).decode("UTF-8")
|
||||||
|
|
||||||
|
vals = {
|
||||||
|
"topological_ordering": event.depth,
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"type": event.type,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"content": content,
|
||||||
|
"processed": True,
|
||||||
|
"outlier": outlier,
|
||||||
|
"depth": event.depth,
|
||||||
|
}
|
||||||
|
|
||||||
|
if stream_ordering is not None:
|
||||||
|
vals["stream_ordering"] = stream_ordering
|
||||||
|
|
||||||
|
unrec = {
|
||||||
|
k: v
|
||||||
|
for k, v in event.get_dict().items()
|
||||||
|
if k not in vals.keys() and k not in [
|
||||||
|
"redacted",
|
||||||
|
"redacted_because",
|
||||||
|
"signatures",
|
||||||
|
"hashes",
|
||||||
|
"prev_events",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
vals["unrecognized_keys"] = encode_canonical_json(
|
||||||
|
unrec
|
||||||
|
).decode("UTF-8")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
"events",
|
||||||
|
vals,
|
||||||
|
or_replace=(not outlier),
|
||||||
|
or_ignore=bool(outlier),
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to persist, probably duplicate: %s",
|
||||||
|
event.event_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise _RollbackButIsFineException("_persist_event")
|
||||||
|
|
||||||
|
if context.rejected:
|
||||||
|
self._store_rejections_txn(txn, event.event_id, context.rejected)
|
||||||
|
|
||||||
|
if event.is_state():
|
||||||
|
vals = {
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"type": event.type,
|
||||||
|
"state_key": event.state_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: How does this work with backfilling?
|
||||||
|
if hasattr(event, "replaces_state"):
|
||||||
|
vals["prev_state"] = event.replaces_state
|
||||||
|
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
"state_events",
|
||||||
|
vals,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_new_state and not context.rejected:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
"current_state_events",
|
||||||
|
{
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"type": event.type,
|
||||||
|
"state_key": event.state_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for e_id, h in event.prev_state:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="event_edges",
|
||||||
|
values={
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"prev_event_id": e_id,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"is_state": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for hash_alg, hash_base64 in event.hashes.items():
|
||||||
|
hash_bytes = decode_base64(hash_base64)
|
||||||
|
self._store_event_content_hash_txn(
|
||||||
|
txn, event.event_id, hash_alg, hash_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
for prev_event_id, prev_hashes in event.prev_events:
|
||||||
|
for alg, hash_base64 in prev_hashes.items():
|
||||||
|
hash_bytes = decode_base64(hash_base64)
|
||||||
|
self._store_prev_event_hash_txn(
|
||||||
|
txn, event.event_id, prev_event_id, alg, hash_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
for auth_id, _ in event.auth_events:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="event_auth",
|
||||||
|
values={
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"auth_id": auth_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
|
||||||
|
self._store_event_reference_hash_txn(
|
||||||
|
txn, event.event_id, ref_alg, ref_hash_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
def _store_redaction(self, txn, event):
|
||||||
|
# invalidate the cache for the redacted event
|
||||||
|
self._get_event_cache.pop(event.redacts)
|
||||||
|
txn.execute(
|
||||||
|
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
|
||||||
|
(event.event_id, event.redacts)
|
||||||
|
)
|
||||||
|
|
||||||
|
def have_events(self, event_ids):
|
||||||
|
"""Given a list of event ids, check if we have already processed them.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Has an entry for each event id we already have seen. Maps to
|
||||||
|
the rejected reason string if we rejected the event, else maps to
|
||||||
|
None.
|
||||||
|
"""
|
||||||
|
if not event_ids:
|
||||||
|
return defer.succeed({})
|
||||||
|
|
||||||
|
def f(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT e.event_id, reason FROM events as e "
|
||||||
|
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
|
||||||
|
"WHERE e.event_id = ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
res = {}
|
||||||
|
for event_id in event_ids:
|
||||||
|
txn.execute(sql, (event_id,))
|
||||||
|
row = txn.fetchone()
|
||||||
|
if row:
|
||||||
|
_, rejected = row
|
||||||
|
res[event_id] = rejected
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"have_events", f,
|
||||||
|
)
|
|
@ -1,47 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Copyright 2014, 2015 OpenMarket Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackStore(SQLBaseStore):
|
|
||||||
|
|
||||||
def _store_feedback_txn(self, txn, event):
|
|
||||||
self._simple_insert_txn(txn, "feedback", {
|
|
||||||
"event_id": event.event_id,
|
|
||||||
"feedback_type": event.content["type"],
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"target_event_id": event.content["target_event_id"],
|
|
||||||
"sender": event.user_id,
|
|
||||||
})
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_feedback_for_event(self, event_id):
|
|
||||||
sql = (
|
|
||||||
"SELECT events.* FROM events INNER JOIN feedback "
|
|
||||||
"ON events.event_id = feedback.event_id "
|
|
||||||
"WHERE feedback.target_event_id = ? "
|
|
||||||
)
|
|
||||||
|
|
||||||
rows = yield self._execute_and_decode("get_feedback_for_event", sql, event_id)
|
|
||||||
|
|
||||||
defer.returnValue(
|
|
||||||
[
|
|
||||||
(yield self._parse_events(r))
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
)
|
|
|
@ -31,6 +31,7 @@ class FilteringStore(SQLBaseStore):
|
||||||
},
|
},
|
||||||
retcol="filter_json",
|
retcol="filter_json",
|
||||||
allow_none=False,
|
allow_none=False,
|
||||||
|
desc="get_user_filter",
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(json.loads(def_json))
|
defer.returnValue(json.loads(def_json))
|
||||||
|
|
|
@ -32,6 +32,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
{"media_id": media_id},
|
{"media_id": media_id},
|
||||||
("media_type", "media_length", "upload_name", "created_ts"),
|
("media_type", "media_length", "upload_name", "created_ts"),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="get_local_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
|
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
|
||||||
|
@ -45,7 +46,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
"upload_name": upload_name,
|
"upload_name": upload_name,
|
||||||
"media_length": media_length,
|
"media_length": media_length,
|
||||||
"user_id": user_id.to_string(),
|
"user_id": user_id.to_string(),
|
||||||
}
|
},
|
||||||
|
desc="store_local_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_local_media_thumbnails(self, media_id):
|
def get_local_media_thumbnails(self, media_id):
|
||||||
|
@ -55,7 +57,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
(
|
(
|
||||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||||
"thumbnail_type", "thumbnail_length",
|
"thumbnail_type", "thumbnail_length",
|
||||||
)
|
),
|
||||||
|
desc="get_local_media_thumbnails",
|
||||||
)
|
)
|
||||||
|
|
||||||
def store_local_thumbnail(self, media_id, thumbnail_width,
|
def store_local_thumbnail(self, media_id, thumbnail_width,
|
||||||
|
@ -70,7 +73,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
"thumbnail_method": thumbnail_method,
|
"thumbnail_method": thumbnail_method,
|
||||||
"thumbnail_type": thumbnail_type,
|
"thumbnail_type": thumbnail_type,
|
||||||
"thumbnail_length": thumbnail_length,
|
"thumbnail_length": thumbnail_length,
|
||||||
}
|
},
|
||||||
|
desc="store_local_thumbnail",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_cached_remote_media(self, origin, media_id):
|
def get_cached_remote_media(self, origin, media_id):
|
||||||
|
@ -82,6 +86,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
"filesystem_id",
|
"filesystem_id",
|
||||||
),
|
),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="get_cached_remote_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
def store_cached_remote_media(self, origin, media_id, media_type,
|
def store_cached_remote_media(self, origin, media_id, media_type,
|
||||||
|
@ -97,7 +102,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
"created_ts": time_now_ms,
|
"created_ts": time_now_ms,
|
||||||
"upload_name": upload_name,
|
"upload_name": upload_name,
|
||||||
"filesystem_id": filesystem_id,
|
"filesystem_id": filesystem_id,
|
||||||
}
|
},
|
||||||
|
desc="store_cached_remote_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_remote_media_thumbnails(self, origin, media_id):
|
def get_remote_media_thumbnails(self, origin, media_id):
|
||||||
|
@ -107,7 +113,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
(
|
(
|
||||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||||
"thumbnail_type", "thumbnail_length", "filesystem_id",
|
"thumbnail_type", "thumbnail_length", "filesystem_id",
|
||||||
)
|
),
|
||||||
|
desc="get_remote_media_thumbnails",
|
||||||
)
|
)
|
||||||
|
|
||||||
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
|
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
|
||||||
|
@ -125,5 +132,6 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||||
"thumbnail_type": thumbnail_type,
|
"thumbnail_type": thumbnail_type,
|
||||||
"thumbnail_length": thumbnail_length,
|
"thumbnail_length": thumbnail_length,
|
||||||
"filesystem_id": filesystem_id,
|
"filesystem_id": filesystem_id,
|
||||||
}
|
},
|
||||||
|
desc="store_remote_media_thumbnail",
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
return self._simple_insert(
|
return self._simple_insert(
|
||||||
table="presence",
|
table="presence",
|
||||||
values={"user_id": user_localpart},
|
values={"user_id": user_localpart},
|
||||||
|
desc="create_presence",
|
||||||
)
|
)
|
||||||
|
|
||||||
def has_presence_state(self, user_localpart):
|
def has_presence_state(self, user_localpart):
|
||||||
|
@ -29,6 +30,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
retcols=["user_id"],
|
retcols=["user_id"],
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="has_presence_state",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_presence_state(self, user_localpart):
|
def get_presence_state(self, user_localpart):
|
||||||
|
@ -36,6 +38,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
table="presence",
|
table="presence",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
retcols=["state", "status_msg", "mtime"],
|
retcols=["state", "status_msg", "mtime"],
|
||||||
|
desc="get_presence_state",
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_presence_state(self, user_localpart, new_state):
|
def set_presence_state(self, user_localpart, new_state):
|
||||||
|
@ -45,7 +48,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
updatevalues={"state": new_state["state"],
|
updatevalues={"state": new_state["state"],
|
||||||
"status_msg": new_state["status_msg"],
|
"status_msg": new_state["status_msg"],
|
||||||
"mtime": self._clock.time_msec()},
|
"mtime": self._clock.time_msec()},
|
||||||
retcols=["state"],
|
desc="set_presence_state",
|
||||||
)
|
)
|
||||||
|
|
||||||
def allow_presence_visible(self, observed_localpart, observer_userid):
|
def allow_presence_visible(self, observed_localpart, observer_userid):
|
||||||
|
@ -53,6 +56,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
table="presence_allow_inbound",
|
table="presence_allow_inbound",
|
||||||
values={"observed_user_id": observed_localpart,
|
values={"observed_user_id": observed_localpart,
|
||||||
"observer_user_id": observer_userid},
|
"observer_user_id": observer_userid},
|
||||||
|
desc="allow_presence_visible",
|
||||||
)
|
)
|
||||||
|
|
||||||
def disallow_presence_visible(self, observed_localpart, observer_userid):
|
def disallow_presence_visible(self, observed_localpart, observer_userid):
|
||||||
|
@ -60,6 +64,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
table="presence_allow_inbound",
|
table="presence_allow_inbound",
|
||||||
keyvalues={"observed_user_id": observed_localpart,
|
keyvalues={"observed_user_id": observed_localpart,
|
||||||
"observer_user_id": observer_userid},
|
"observer_user_id": observer_userid},
|
||||||
|
desc="disallow_presence_visible",
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_presence_visible(self, observed_localpart, observer_userid):
|
def is_presence_visible(self, observed_localpart, observer_userid):
|
||||||
|
@ -69,6 +74,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
"observer_user_id": observer_userid},
|
"observer_user_id": observer_userid},
|
||||||
retcols=["observed_user_id"],
|
retcols=["observed_user_id"],
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="is_presence_visible",
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_presence_list_pending(self, observer_localpart, observed_userid):
|
def add_presence_list_pending(self, observer_localpart, observed_userid):
|
||||||
|
@ -77,6 +83,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
values={"user_id": observer_localpart,
|
values={"user_id": observer_localpart,
|
||||||
"observed_user_id": observed_userid,
|
"observed_user_id": observed_userid,
|
||||||
"accepted": False},
|
"accepted": False},
|
||||||
|
desc="add_presence_list_pending",
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
||||||
|
@ -85,6 +92,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
keyvalues={"user_id": observer_localpart,
|
keyvalues={"user_id": observer_localpart,
|
||||||
"observed_user_id": observed_userid},
|
"observed_user_id": observed_userid},
|
||||||
updatevalues={"accepted": True},
|
updatevalues={"accepted": True},
|
||||||
|
desc="set_presence_list_accepted",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_presence_list(self, observer_localpart, accepted=None):
|
def get_presence_list(self, observer_localpart, accepted=None):
|
||||||
|
@ -96,6 +104,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
table="presence_list",
|
table="presence_list",
|
||||||
keyvalues=keyvalues,
|
keyvalues=keyvalues,
|
||||||
retcols=["observed_user_id", "accepted"],
|
retcols=["observed_user_id", "accepted"],
|
||||||
|
desc="get_presence_list",
|
||||||
)
|
)
|
||||||
|
|
||||||
def del_presence_list(self, observer_localpart, observed_userid):
|
def del_presence_list(self, observer_localpart, observed_userid):
|
||||||
|
@ -103,4 +112,5 @@ class PresenceStore(SQLBaseStore):
|
||||||
table="presence_list",
|
table="presence_list",
|
||||||
keyvalues={"user_id": observer_localpart,
|
keyvalues={"user_id": observer_localpart,
|
||||||
"observed_user_id": observed_userid},
|
"observed_user_id": observed_userid},
|
||||||
|
desc="del_presence_list",
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ class ProfileStore(SQLBaseStore):
|
||||||
return self._simple_insert(
|
return self._simple_insert(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
values={"user_id": user_localpart},
|
values={"user_id": user_localpart},
|
||||||
|
desc="create_profile",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_profile_displayname(self, user_localpart):
|
def get_profile_displayname(self, user_localpart):
|
||||||
|
@ -28,6 +29,7 @@ class ProfileStore(SQLBaseStore):
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
retcol="displayname",
|
retcol="displayname",
|
||||||
|
desc="get_profile_displayname",
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_profile_displayname(self, user_localpart, new_displayname):
|
def set_profile_displayname(self, user_localpart, new_displayname):
|
||||||
|
@ -35,6 +37,7 @@ class ProfileStore(SQLBaseStore):
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
updatevalues={"displayname": new_displayname},
|
updatevalues={"displayname": new_displayname},
|
||||||
|
desc="set_profile_displayname",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_profile_avatar_url(self, user_localpart):
|
def get_profile_avatar_url(self, user_localpart):
|
||||||
|
@ -42,6 +45,7 @@ class ProfileStore(SQLBaseStore):
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
retcol="avatar_url",
|
retcol="avatar_url",
|
||||||
|
desc="get_profile_avatar_url",
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
|
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
|
||||||
|
@ -49,4 +53,5 @@ class ProfileStore(SQLBaseStore):
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
updatevalues={"avatar_url": new_avatar_url},
|
updatevalues={"avatar_url": new_avatar_url},
|
||||||
|
desc="set_profile_avatar_url",
|
||||||
)
|
)
|
||||||
|
|
|
@ -50,7 +50,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
results = yield self._simple_select_list(
|
results = yield self._simple_select_list(
|
||||||
PushRuleEnableTable.table_name,
|
PushRuleEnableTable.table_name,
|
||||||
{'user_name': user_name},
|
{'user_name': user_name},
|
||||||
PushRuleEnableTable.fields
|
PushRuleEnableTable.fields,
|
||||||
|
desc="get_push_rules_enabled_for_user",
|
||||||
)
|
)
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
|
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
|
||||||
|
@ -201,7 +202,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
yield self._simple_delete_one(
|
yield self._simple_delete_one(
|
||||||
PushRuleTable.table_name,
|
PushRuleTable.table_name,
|
||||||
{'user_name': user_name, 'rule_id': rule_id}
|
{'user_name': user_name, 'rule_id': rule_id},
|
||||||
|
desc="delete_push_rule",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -209,7 +211,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
yield self._simple_upsert(
|
yield self._simple_upsert(
|
||||||
PushRuleEnableTable.table_name,
|
PushRuleEnableTable.table_name,
|
||||||
{'user_name': user_name, 'rule_id': rule_id},
|
{'user_name': user_name, 'rule_id': rule_id},
|
||||||
{'enabled': enabled}
|
{'enabled': enabled},
|
||||||
|
desc="set_push_rule_enabled",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -114,7 +114,9 @@ class PusherStore(SQLBaseStore):
|
||||||
ts=pushkey_ts,
|
ts=pushkey_ts,
|
||||||
lang=lang,
|
lang=lang,
|
||||||
data=data
|
data=data
|
||||||
))
|
),
|
||||||
|
desc="add_pusher",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("create_pusher with failed: %s", e)
|
logger.error("create_pusher with failed: %s", e)
|
||||||
raise StoreError(500, "Problem creating pusher.")
|
raise StoreError(500, "Problem creating pusher.")
|
||||||
|
@ -123,7 +125,8 @@ class PusherStore(SQLBaseStore):
|
||||||
def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
|
def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
|
||||||
yield self._simple_delete_one(
|
yield self._simple_delete_one(
|
||||||
PushersTable.table_name,
|
PushersTable.table_name,
|
||||||
dict(app_id=app_id, pushkey=pushkey)
|
{"app_id": app_id, "pushkey": pushkey},
|
||||||
|
desc="delete_pusher_by_app_id_pushkey",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -131,7 +134,8 @@ class PusherStore(SQLBaseStore):
|
||||||
yield self._simple_update_one(
|
yield self._simple_update_one(
|
||||||
PushersTable.table_name,
|
PushersTable.table_name,
|
||||||
{'app_id': app_id, 'pushkey': pushkey},
|
{'app_id': app_id, 'pushkey': pushkey},
|
||||||
{'last_token': last_token}
|
{'last_token': last_token},
|
||||||
|
desc="update_pusher_last_token",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -140,7 +144,8 @@ class PusherStore(SQLBaseStore):
|
||||||
yield self._simple_update_one(
|
yield self._simple_update_one(
|
||||||
PushersTable.table_name,
|
PushersTable.table_name,
|
||||||
{'app_id': app_id, 'pushkey': pushkey},
|
{'app_id': app_id, 'pushkey': pushkey},
|
||||||
{'last_token': last_token, 'last_success': last_success}
|
{'last_token': last_token, 'last_success': last_success},
|
||||||
|
desc="update_pusher_last_token_and_success",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -148,7 +153,8 @@ class PusherStore(SQLBaseStore):
|
||||||
yield self._simple_update_one(
|
yield self._simple_update_one(
|
||||||
PushersTable.table_name,
|
PushersTable.table_name,
|
||||||
{'app_id': app_id, 'pushkey': pushkey},
|
{'app_id': app_id, 'pushkey': pushkey},
|
||||||
{'failing_since': failing_since}
|
{'failing_since': failing_since},
|
||||||
|
desc="update_pusher_failing_since",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ from sqlite3 import IntegrityError
|
||||||
|
|
||||||
from synapse.api.errors import StoreError, Codes
|
from synapse.api.errors import StoreError, Codes
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore, cached
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStore(SQLBaseStore):
|
class RegistrationStore(SQLBaseStore):
|
||||||
|
@ -39,7 +39,10 @@ class RegistrationStore(SQLBaseStore):
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem adding this.
|
StoreError if there was a problem adding this.
|
||||||
"""
|
"""
|
||||||
row = yield self._simple_select_one("users", {"name": user_id}, ["id"])
|
row = yield self._simple_select_one(
|
||||||
|
"users", {"name": user_id}, ["id"],
|
||||||
|
desc="add_access_token_to_user",
|
||||||
|
)
|
||||||
if not row:
|
if not row:
|
||||||
raise StoreError(400, "Bad user ID supplied.")
|
raise StoreError(400, "Bad user ID supplied.")
|
||||||
row_id = row["id"]
|
row_id = row["id"]
|
||||||
|
@ -48,7 +51,8 @@ class RegistrationStore(SQLBaseStore):
|
||||||
{
|
{
|
||||||
"user_id": row_id,
|
"user_id": row_id,
|
||||||
"token": token
|
"token": token
|
||||||
}
|
},
|
||||||
|
desc="add_access_token_to_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -91,6 +95,11 @@ class RegistrationStore(SQLBaseStore):
|
||||||
"get_user_by_id", self.cursor_to_dict, query, user_id
|
"get_user_by_id", self.cursor_to_dict, query, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
# TODO(paul): Currently there's no code to invalidate this cache. That
|
||||||
|
# means if/when we ever add internal ways to invalidate access tokens or
|
||||||
|
# change whether a user is a server admin, those will need to invoke
|
||||||
|
# store.get_user_by_token.invalidate(token)
|
||||||
def get_user_by_token(self, token):
|
def get_user_by_token(self, token):
|
||||||
"""Get a user from the given access token.
|
"""Get a user from the given access token.
|
||||||
|
|
||||||
|
@ -115,6 +124,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
keyvalues={"name": user.to_string()},
|
keyvalues={"name": user.to_string()},
|
||||||
retcol="admin",
|
retcol="admin",
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="is_server_admin",
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(res if res else False)
|
defer.returnValue(res if res else False)
|
||||||
|
|
|
@ -29,7 +29,7 @@ class RejectionsStore(SQLBaseStore):
|
||||||
"event_id": event_id,
|
"event_id": event_id,
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
"last_check": self._clock.time_msec(),
|
"last_check": self._clock.time_msec(),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_rejection_reason(self, event_id):
|
def get_rejection_reason(self, event_id):
|
||||||
|
@ -40,4 +40,5 @@ class RejectionsStore(SQLBaseStore):
|
||||||
"event_id": event_id,
|
"event_id": event_id,
|
||||||
},
|
},
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="get_rejection_reason",
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,11 +15,9 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from sqlite3 import IntegrityError
|
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
|
||||||
from ._base import SQLBaseStore, Table
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
|
@ -27,8 +25,9 @@ import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
OpsLevel = collections.namedtuple("OpsLevel", (
|
OpsLevel = collections.namedtuple(
|
||||||
"ban_level", "kick_level", "redact_level")
|
"OpsLevel",
|
||||||
|
("ban_level", "kick_level", "redact_level",)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,13 +46,15 @@ class RoomStore(SQLBaseStore):
|
||||||
StoreError if the room could not be stored.
|
StoreError if the room could not be stored.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
yield self._simple_insert(RoomsTable.table_name, dict(
|
yield self._simple_insert(
|
||||||
room_id=room_id,
|
RoomsTable.table_name,
|
||||||
creator=room_creator_user_id,
|
{
|
||||||
is_public=is_public
|
"room_id": room_id,
|
||||||
))
|
"creator": room_creator_user_id,
|
||||||
except IntegrityError:
|
"is_public": is_public,
|
||||||
raise StoreError(409, "Room ID in use.")
|
},
|
||||||
|
desc="store_room",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||||
raise StoreError(500, "Problem creating room.")
|
raise StoreError(500, "Problem creating room.")
|
||||||
|
@ -66,9 +67,11 @@ class RoomStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
A namedtuple containing the room information, or an empty list.
|
A namedtuple containing the room information, or an empty list.
|
||||||
"""
|
"""
|
||||||
query = RoomsTable.select_statement("room_id=?")
|
return self._simple_select_one(
|
||||||
return self._execute(
|
table=RoomsTable.table_name,
|
||||||
"get_room", RoomsTable.decode_single_result, query, room_id,
|
keyvalues={"room_id": room_id},
|
||||||
|
retcols=RoomsTable.fields,
|
||||||
|
desc="get_room",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -143,7 +146,7 @@ class RoomStore(SQLBaseStore):
|
||||||
"event_id": event.event_id,
|
"event_id": event.event_id,
|
||||||
"room_id": event.room_id,
|
"room_id": event.room_id,
|
||||||
"topic": event.content["topic"],
|
"topic": event.content["topic"],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _store_room_name_txn(self, txn, event):
|
def _store_room_name_txn(self, txn, event):
|
||||||
|
@ -158,8 +161,45 @@ class RoomStore(SQLBaseStore):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_room_name_and_aliases(self, room_id):
|
||||||
|
del_sql = (
|
||||||
|
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
|
||||||
|
"LIMIT 1"
|
||||||
|
)
|
||||||
|
|
||||||
class RoomsTable(Table):
|
sql = (
|
||||||
|
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
|
||||||
|
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
|
||||||
|
"INNER JOIN state_events as s ON e.event_id = s.event_id "
|
||||||
|
"WHERE c.room_id = ? "
|
||||||
|
) % {
|
||||||
|
"redacted": del_sql,
|
||||||
|
}
|
||||||
|
|
||||||
|
sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
|
||||||
|
sql += " OR s.type = 'm.room.aliases')"
|
||||||
|
args = (room_id,)
|
||||||
|
|
||||||
|
results = yield self._execute_and_decode("get_current_state", sql, *args)
|
||||||
|
|
||||||
|
events = yield self._parse_events(results)
|
||||||
|
|
||||||
|
name = None
|
||||||
|
aliases = []
|
||||||
|
|
||||||
|
for e in events:
|
||||||
|
if e.type == 'm.room.name':
|
||||||
|
if 'name' in e.content:
|
||||||
|
name = e.content['name']
|
||||||
|
elif e.type == 'm.room.aliases':
|
||||||
|
if 'aliases' in e.content:
|
||||||
|
aliases.extend(e.content['aliases'])
|
||||||
|
|
||||||
|
defer.returnValue((name, aliases))
|
||||||
|
|
||||||
|
|
||||||
|
class RoomsTable(object):
|
||||||
table_name = "rooms"
|
table_name = "rooms"
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
|
@ -167,5 +207,3 @@ class RoomsTable(Table):
|
||||||
"is_public",
|
"is_public",
|
||||||
"creator"
|
"creator"
|
||||||
]
|
]
|
||||||
|
|
||||||
EntryType = collections.namedtuple("RoomEntry", fields)
|
|
||||||
|
|
|
@ -212,7 +212,8 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
return self._simple_select_onecol(
|
return self._simple_select_onecol(
|
||||||
"room_hosts",
|
"room_hosts",
|
||||||
{"room_id": room_id},
|
{"room_id": room_id},
|
||||||
"host"
|
"host",
|
||||||
|
desc="get_joined_hosts_for_room",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_members_by_dict(self, where_dict):
|
def _get_members_by_dict(self, where_dict):
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -82,7 +84,7 @@ class StateStore(SQLBaseStore):
|
||||||
if context.current_state is None:
|
if context.current_state is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
state_events = context.current_state
|
state_events = dict(context.current_state)
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
state_events[(event.type, event.state_key)] = event
|
state_events[(event.type, event.state_key)] = event
|
||||||
|
@ -122,3 +124,33 @@ class StateStore(SQLBaseStore):
|
||||||
},
|
},
|
||||||
or_replace=True,
|
or_replace=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_current_state(self, room_id, event_type=None, state_key=""):
|
||||||
|
del_sql = (
|
||||||
|
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
|
||||||
|
"LIMIT 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
|
||||||
|
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
|
||||||
|
"INNER JOIN state_events as s ON e.event_id = s.event_id "
|
||||||
|
"WHERE c.room_id = ? "
|
||||||
|
) % {
|
||||||
|
"redacted": del_sql,
|
||||||
|
}
|
||||||
|
|
||||||
|
if event_type and state_key is not None:
|
||||||
|
sql += " AND s.type = ? AND s.state_key = ? "
|
||||||
|
args = (room_id, event_type, state_key)
|
||||||
|
elif event_type:
|
||||||
|
sql += " AND s.type = ?"
|
||||||
|
args = (room_id, event_type)
|
||||||
|
else:
|
||||||
|
args = (room_id, )
|
||||||
|
|
||||||
|
results = yield self._execute_and_decode("get_current_state", sql, *args)
|
||||||
|
|
||||||
|
events = yield self._parse_events(results)
|
||||||
|
defer.returnValue(events)
|
||||||
|
|
|
@ -35,7 +35,7 @@ what sort order was used:
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore, cached
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
@ -413,12 +413,32 @@ class StreamStore(SQLBaseStore):
|
||||||
"get_recent_events_for_room", get_recent_events_for_room_txn
|
"get_recent_events_for_room", get_recent_events_for_room_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached(num_args=0)
|
||||||
def get_room_events_max_id(self):
|
def get_room_events_max_id(self):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_room_events_max_id",
|
"get_room_events_max_id",
|
||||||
self._get_room_events_max_id_txn
|
self._get_room_events_max_id_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_min_token(self):
|
||||||
|
row = yield self._execute(
|
||||||
|
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
|
||||||
|
self.min_token = min(self.min_token, -1)
|
||||||
|
|
||||||
|
logger.debug("min_token is: %s", self.min_token)
|
||||||
|
|
||||||
|
defer.returnValue(self.min_token)
|
||||||
|
|
||||||
|
def get_next_stream_id(self):
|
||||||
|
with self._next_stream_id_lock:
|
||||||
|
i = self._next_stream_id
|
||||||
|
self._next_stream_id += 1
|
||||||
|
return i
|
||||||
|
|
||||||
def _get_room_events_max_id_txn(self, txn):
|
def _get_room_events_max_id_txn(self, txn):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT MAX(stream_ordering) as m FROM events"
|
"SELECT MAX(stream_ordering) as m FROM events"
|
||||||
|
|
|
@ -46,15 +46,19 @@ class TransactionStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_received_txn_response(self, txn, transaction_id, origin):
|
def _get_received_txn_response(self, txn, transaction_id, origin):
|
||||||
where_clause = "transaction_id = ? AND origin = ?"
|
result = self._simple_select_one_txn(
|
||||||
query = ReceivedTransactionsTable.select_statement(where_clause)
|
txn,
|
||||||
|
table=ReceivedTransactionsTable.table_name,
|
||||||
|
keyvalues={
|
||||||
|
"transaction_id": transaction_id,
|
||||||
|
"origin": origin,
|
||||||
|
},
|
||||||
|
retcols=ReceivedTransactionsTable.fields,
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
txn.execute(query, (transaction_id, origin))
|
if result and result.response_code:
|
||||||
|
return result["response_code"], result["response_json"]
|
||||||
results = ReceivedTransactionsTable.decode_results(txn.fetchall())
|
|
||||||
|
|
||||||
if results and results[0].response_code:
|
|
||||||
return (results[0].response_code, results[0].response_json)
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -90,12 +90,16 @@ class LruCache(object):
|
||||||
def cache_len():
|
def cache_len():
|
||||||
return len(cache)
|
return len(cache)
|
||||||
|
|
||||||
|
def cache_contains(key):
|
||||||
|
return key in cache
|
||||||
|
|
||||||
self.sentinel = object()
|
self.sentinel = object()
|
||||||
self.get = cache_get
|
self.get = cache_get
|
||||||
self.set = cache_set
|
self.set = cache_set
|
||||||
self.setdefault = cache_set_default
|
self.setdefault = cache_set_default
|
||||||
self.pop = cache_pop
|
self.pop = cache_pop
|
||||||
self.len = cache_len
|
self.len = cache_len
|
||||||
|
self.contains = cache_contains
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
result = self.get(key, self.sentinel)
|
result = self.get(key, self.sentinel)
|
||||||
|
@ -114,3 +118,6 @@ class LruCache(object):
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.len()
|
return self.len()
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return self.contains(key)
|
||||||
|
|
|
@ -16,6 +16,10 @@
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
|
||||||
|
_string_with_symbols = (
|
||||||
|
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def origin_from_ucid(ucid):
|
def origin_from_ucid(ucid):
|
||||||
return ucid.split("@", 1)[1]
|
return ucid.split("@", 1)[1]
|
||||||
|
@ -23,3 +27,9 @@ def origin_from_ucid(ucid):
|
||||||
|
|
||||||
def random_string(length):
|
def random_string(length):
|
||||||
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
|
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
|
||||||
|
|
||||||
|
|
||||||
|
def random_string_with_symbols(length):
|
||||||
|
return ''.join(
|
||||||
|
random.choice(_string_with_symbols) for _ in xrange(length)
|
||||||
|
)
|
||||||
|
|
|
@ -17,7 +17,79 @@
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.storage._base import cached
|
from synapse.storage._base import Cache, cached
|
||||||
|
|
||||||
|
|
||||||
|
class CacheTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.cache = Cache("test")
|
||||||
|
|
||||||
|
def test_empty(self):
|
||||||
|
failed = False
|
||||||
|
try:
|
||||||
|
self.cache.get("foo")
|
||||||
|
except KeyError:
|
||||||
|
failed = True
|
||||||
|
|
||||||
|
self.assertTrue(failed)
|
||||||
|
|
||||||
|
def test_hit(self):
|
||||||
|
self.cache.prefill("foo", 123)
|
||||||
|
|
||||||
|
self.assertEquals(self.cache.get("foo"), 123)
|
||||||
|
|
||||||
|
def test_invalidate(self):
|
||||||
|
self.cache.prefill("foo", 123)
|
||||||
|
self.cache.invalidate("foo")
|
||||||
|
|
||||||
|
failed = False
|
||||||
|
try:
|
||||||
|
self.cache.get("foo")
|
||||||
|
except KeyError:
|
||||||
|
failed = True
|
||||||
|
|
||||||
|
self.assertTrue(failed)
|
||||||
|
|
||||||
|
def test_eviction(self):
|
||||||
|
cache = Cache("test", max_entries=2)
|
||||||
|
|
||||||
|
cache.prefill(1, "one")
|
||||||
|
cache.prefill(2, "two")
|
||||||
|
cache.prefill(3, "three") # 1 will be evicted
|
||||||
|
|
||||||
|
failed = False
|
||||||
|
try:
|
||||||
|
cache.get(1)
|
||||||
|
except KeyError:
|
||||||
|
failed = True
|
||||||
|
|
||||||
|
self.assertTrue(failed)
|
||||||
|
|
||||||
|
cache.get(2)
|
||||||
|
cache.get(3)
|
||||||
|
|
||||||
|
def test_eviction_lru(self):
|
||||||
|
cache = Cache("test", max_entries=2, lru=True)
|
||||||
|
|
||||||
|
cache.prefill(1, "one")
|
||||||
|
cache.prefill(2, "two")
|
||||||
|
|
||||||
|
# Now access 1 again, thus causing 2 to be least-recently used
|
||||||
|
cache.get(1)
|
||||||
|
|
||||||
|
cache.prefill(3, "three")
|
||||||
|
|
||||||
|
failed = False
|
||||||
|
try:
|
||||||
|
cache.get(2)
|
||||||
|
except KeyError:
|
||||||
|
failed = True
|
||||||
|
|
||||||
|
self.assertTrue(failed)
|
||||||
|
|
||||||
|
cache.get(1)
|
||||||
|
cache.get(3)
|
||||||
|
|
||||||
|
|
||||||
class CacheDecoratorTestCase(unittest.TestCase):
|
class CacheDecoratorTestCase(unittest.TestCase):
|
||||||
|
|
|
@ -180,7 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
self.mock_txn.fetchone.return_value = ("Old Value",)
|
self.mock_txn.fetchone.return_value = ("Old Value",)
|
||||||
|
|
||||||
ret = yield self.datastore._simple_update_one(
|
ret = yield self.datastore._simple_selectupdate_one(
|
||||||
table="tablename",
|
table="tablename",
|
||||||
keyvalues={"keycol": "TheKey"},
|
keyvalues={"keycol": "TheKey"},
|
||||||
updatevalues={"columname": "New Value"},
|
updatevalues={"columname": "New Value"},
|
||||||
|
|
|
@ -44,7 +44,7 @@ class RoomStoreTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_room(self):
|
def test_get_room(self):
|
||||||
self.assertObjectHasAttributes(
|
self.assertDictContainsSubset(
|
||||||
{"room_id": self.room.to_string(),
|
{"room_id": self.room.to_string(),
|
||||||
"creator": self.u_creator.to_string(),
|
"creator": self.u_creator.to_string(),
|
||||||
"is_public": True},
|
"is_public": True},
|
||||||
|
|
Loading…
Reference in New Issue